From ab8b489a3f1e898a63d925a2c3e416f55b3aea99 Mon Sep 17 00:00:00 2001 From: thorfour Date: Wed, 29 May 2024 09:30:11 -0500 Subject: [PATCH] Reservoir sampler --- db_test.go | 44 +++++++++++ query/engine.go | 12 +++ query/logicalplan/builder.go | 16 ++++ query/logicalplan/logicalplan.go | 9 +++ query/logicalplan/validate.go | 5 +- query/physicalplan/physicalplan.go | 14 ++++ query/physicalplan/sampler.go | 118 +++++++++++++++++++++++++++++ query/physicalplan/sampler_test.go | 86 +++++++++++++++++++++ 8 files changed, 303 insertions(+), 1 deletion(-) create mode 100644 query/physicalplan/sampler.go create mode 100644 query/physicalplan/sampler_test.go diff --git a/db_test.go b/db_test.go index 7b06f9f45e..01fb865d97 100644 --- a/db_test.go +++ b/db_test.go @@ -3279,3 +3279,47 @@ func (a *AssertBucket) Upload(ctx context.Context, path string, r io.Reader) err } return a.Bucket.Upload(ctx, path, r) } + +func Test_DB_Sample(t *testing.T) { + t.Parallel() + config := NewTableConfig( + dynparquet.SampleDefinition(), + ) + logger := newTestLogger(t) + + c, err := New(WithLogger(logger)) + t.Cleanup(func() { + require.NoError(t, c.Close()) + }) + require.NoError(t, err) + db, err := c.DB(context.Background(), "test") + require.NoError(t, err) + table, err := db.Table("test", config) + require.NoError(t, err) + + ctx := context.Background() + for i := 0; i < 500; i++ { + samples := dynparquet.GenerateTestSamples(10) + r, err := samples.ToRecord() + require.NoError(t, err) + _, err = table.InsertRecord(ctx, r) + require.NoError(t, err) + } + + pool := memory.NewCheckedAllocator(memory.DefaultAllocator) + defer pool.AssertSize(t, 0) + lock := &sync.Mutex{} + rows := int64(0) + sampleSize := int64(13) + engine := query.NewEngine(pool, db.TableProvider()) + err = engine.ScanTable("test"). + Sample(sampleSize). // Sample 13 rows + Execute(context.Background(), func(ctx context.Context, r arrow.Record) error { + lock.Lock() + defer lock.Unlock() + rows += r.NumRows() + return nil + }) + require.NoError(t, err) + require.Equal(t, sampleSize, rows) +} diff --git a/query/engine.go b/query/engine.go index 7ca250d99c..4d0e89886c 100644 --- a/query/engine.go +++ b/query/engine.go @@ -21,6 +21,7 @@ type Builder interface { Limit(expr logicalplan.Expr) Builder Execute(ctx context.Context, callback func(ctx context.Context, r arrow.Record) error) error Explain(ctx context.Context) (string, error) + Sample(size int64) Builder } type LocalEngine struct { @@ -143,6 +144,17 @@ func (b LocalQueryBuilder) Limit( } } +func (b LocalQueryBuilder) Sample( + size int64, +) Builder { + return LocalQueryBuilder{ + pool: b.pool, + tracer: b.tracer, + planBuilder: b.planBuilder.Sample(logicalplan.Literal(size)), + execOpts: b.execOpts, + } +} + func (b LocalQueryBuilder) Execute(ctx context.Context, callback func(ctx context.Context, r arrow.Record) error) error { ctx, span := b.tracer.Start(ctx, "LocalQueryBuilder/Execute") defer span.End() diff --git a/query/logicalplan/builder.go b/query/logicalplan/builder.go index d14bf29f00..d0b99f8dfe 100644 --- a/query/logicalplan/builder.go +++ b/query/logicalplan/builder.go @@ -237,6 +237,22 @@ func resolveAggregation(plan *LogicalPlan, agg *AggregationFunction) ([]*Aggrega } } +func (b Builder) Sample(expr Expr) Builder { + if expr == nil { + return b + } + + return Builder{ + err: b.err, + plan: &LogicalPlan{ + Input: b.plan, + Sample: &Sample{ + Expr: expr, + }, + }, + } +} + func (b Builder) Build() (*LogicalPlan, error) { if b.err != nil { return nil, b.err diff --git a/query/logicalplan/logicalplan.go b/query/logicalplan/logicalplan.go index 1c8bec95b2..81aca31bf4 100644 --- a/query/logicalplan/logicalplan.go +++ b/query/logicalplan/logicalplan.go @@ -25,6 +25,7 @@ type LogicalPlan struct { Projection *Projection Aggregation *Aggregation Limit *Limit + Sample *Sample } // Callback is a function that is called throughout a chain of operators @@ -414,3 +415,11 @@ type Limit struct { func (l *Limit) String() string { return "Limit" + " Expr: " + fmt.Sprint(l.Expr) } + +type Sample struct { + Expr Expr +} + +func (s *Sample) String() string { + return "Sample" + " Expr: " + fmt.Sprint(s.Expr) +} diff --git a/query/logicalplan/validate.go b/query/logicalplan/validate.go index 26e0eb7086..8af96e4878 100644 --- a/query/logicalplan/validate.go +++ b/query/logicalplan/validate.go @@ -122,10 +122,13 @@ func ValidateSingleFieldSet(plan *LogicalPlan) *PlanValidationError { if plan.Limit != nil { fieldsSet = append(fieldsSet, 6) } + if plan.Sample != nil { + fieldsSet = append(fieldsSet, 7) + } if len(fieldsSet) != 1 { fieldsFound := make([]string, 0) - fields := []string{"SchemaScan", "TableScan", "Filter", "Distinct", "Projection", "Aggregation"} + fields := []string{"SchemaScan", "TableScan", "Filter", "Distinct", "Projection", "Aggregation", "Limit", "Sample"} for _, i := range fieldsSet { fieldsFound = append(fieldsFound, fields[i]) } diff --git a/query/physicalplan/physicalplan.go b/query/physicalplan/physicalplan.go index 0b718dd376..c30de1ddd5 100644 --- a/query/physicalplan/physicalplan.go +++ b/query/physicalplan/physicalplan.go @@ -8,6 +8,7 @@ import ( "github.com/apache/arrow/go/v16/arrow" "github.com/apache/arrow/go/v16/arrow/memory" + "github.com/apache/arrow/go/v16/arrow/scalar" "go.opentelemetry.io/otel/attribute" "go.opentelemetry.io/otel/trace" "golang.org/x/sync/errgroup" @@ -482,6 +483,19 @@ func Build( if ordered { oInfo.nodeMaintainsOrdering() } + case plan.Sample != nil: + v := plan.Sample.Expr.(*logicalplan.LiteralExpr).Value.(*scalar.Int64).Value + perSampler := v / int64(len(prev)) + r := v % int64(len(prev)) + for i := range prev { + adjust := int64(0) + if i < int(r) { + adjust = 1 + } + s := NewReservoirSampler(perSampler + adjust) + prev[i].SetNext(s) + prev[i] = s + } default: panic("Unsupported plan") } diff --git a/query/physicalplan/sampler.go b/query/physicalplan/sampler.go new file mode 100644 index 0000000000..38c8d19f42 --- /dev/null +++ b/query/physicalplan/sampler.go @@ -0,0 +1,118 @@ +package physicalplan + +import ( + "context" + "fmt" + "math" + "math/rand" + + "github.com/apache/arrow/go/v16/arrow" +) + +type ReservoirSampler struct { + next PhysicalPlan + + // size is the max number of rows in the reservoir + size int64 + + // reservoir is the set of records that have been sampled. They may vary in schema due to dynamic columns. + reservoir []arrow.Record + + // currentSize is the number of rows in the reservoir in all records. + currentSize int64 + + w float64 // w is the probability of keeping a record + n int64 // n is the number of rows that have been sampled thus far +} + +// NewReservoirSampler will create a new ReservoirSampler operator that will sample up to size rows of all records seen by Callback. +func NewReservoirSampler(size int64) *ReservoirSampler { + return &ReservoirSampler{ + size: size, + reservoir: []arrow.Record{}, + w: math.Exp(math.Log(rand.Float64()) / float64(size)), + } +} + +func (s *ReservoirSampler) SetNext(p PhysicalPlan) { + s.next = p +} + +func (s *ReservoirSampler) Draw() *Diagram { + var child *Diagram + if s.next != nil { + child = s.next.Draw() + } + details := fmt.Sprintf("Reservoir Sampler (%v)", s.size) + return &Diagram{Details: details, Child: child} +} + +func (s *ReservoirSampler) Close() { + for _, r := range s.reservoir { + r.Release() + } + s.next.Close() +} + +// Callback collects all the records to sample. +func (s *ReservoirSampler) Callback(ctx context.Context, r arrow.Record) error { + r = s.fill(r) + if r == nil { + return nil + } + + // Sample the record + s.sample(r) + return nil +} + +// fill will fill the reservoir with the first size records +func (s *ReservoirSampler) fill(r arrow.Record) arrow.Record { + if s.currentSize == s.size { + return r + } + + if s.currentSize+r.NumRows() <= s.size { // The record fits in the reservoir + for i := int64(0); i < r.NumRows(); i++ { // For simplicity of implementation the reservoir is by row; This is probably not optimal + s.reservoir = append(s.reservoir, r.NewSlice(i, i+1)) + } + s.currentSize += r.NumRows() + s.n += r.NumRows() + return nil + } + + // The record partially fits in the reservoir + for i := int64(0); i < s.size-s.currentSize; i++ { + s.reservoir = append(s.reservoir, r.NewSlice(i, i+1)) + } + r = r.NewSlice(s.size-s.currentSize, r.NumRows()) + s.currentSize = s.size + s.n = s.size + return r +} + +// sample implements the reservoir sampling algorithm found https://en.wikipedia.org/wiki/Reservoir_sampling. +func (s *ReservoirSampler) sample(r arrow.Record) { + n := s.n + r.NumRows() + for i := float64(s.n); i < float64(n); { + i += math.Floor(math.Log(rand.Float64())/math.Log(1-s.w)) + 1 + if i <= float64(n) { + // replace a random item of the reservoir with row i + s.reservoir[rand.Intn(int(s.size))] = r.NewSlice(int64(i)-s.n-1, int64(i)-s.n) + s.w = s.w * math.Exp(math.Log(rand.Float64())/float64(s.size)) + } + } + s.n = n +} + +// Finish sends all the records in the reservoir to the next operator. +func (s *ReservoirSampler) Finish(ctx context.Context) error { + // Send all the records in the reservoir to the next operator + for _, r := range s.reservoir { + if err := s.next.Callback(ctx, r); err != nil { + return err + } + } + + return s.next.Finish(ctx) +} diff --git a/query/physicalplan/sampler_test.go b/query/physicalplan/sampler_test.go new file mode 100644 index 0000000000..47dc06715d --- /dev/null +++ b/query/physicalplan/sampler_test.go @@ -0,0 +1,86 @@ +package physicalplan + +import ( + "context" + "testing" + + "github.com/apache/arrow/go/v16/arrow" + "github.com/apache/arrow/go/v16/arrow/array" + "github.com/apache/arrow/go/v16/arrow/memory" + "github.com/stretchr/testify/require" +) + +type TestPlan struct { + finish func() error + callback func(ctx context.Context, r arrow.Record) error +} + +func (t *TestPlan) Callback(ctx context.Context, r arrow.Record) error { + if t.callback != nil { + return t.callback(ctx, r) + } + return nil +} +func (t *TestPlan) Finish(ctx context.Context) error { + if t.finish != nil { + return t.finish() + } + return nil +} +func (t *TestPlan) SetNext(next PhysicalPlan) {} +func (t *TestPlan) Draw() *Diagram { return nil } +func (t *TestPlan) Close() {} + +func Test_Sampler(t *testing.T) { + ctx := context.Background() + tests := map[string]struct { + reservoirSize int64 + numRows int + recordSize int + }{ + "basic single row records": { + reservoirSize: 10, + numRows: 100, + recordSize: 1, + }, + "basic multi row records": { + reservoirSize: 10, + numRows: 100, + recordSize: 10, + }, + } + + for name, test := range tests { + t.Run(name, func(t *testing.T) { + // Create a new sampler + s := NewReservoirSampler(test.reservoirSize) + called := false + total := int64(0) + s.SetNext(&TestPlan{ + callback: func(ctx context.Context, r arrow.Record) error { + called = true + total += r.NumRows() + return nil + }, + }) + + schema := arrow.NewSchema([]arrow.Field{ + {Name: "a", Type: arrow.PrimitiveTypes.Int64}, + }, nil) + bldr := array.NewRecordBuilder(memory.NewGoAllocator(), schema) + + for i := 0; i < test.numRows/test.recordSize; i++ { + for j := 0; j < test.recordSize; j++ { + bldr.Field(0).(*array.Int64Builder).Append(int64((i * test.recordSize) + j)) + } + r := bldr.NewRecord() + t.Cleanup(r.Release) + s.Callback(ctx, r) + } + + s.Finish(ctx) + require.True(t, called) + require.Equal(t, test.reservoirSize, total) + }) + } +}