Skip to content

Commit

Permalink
Reservoir sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
thorfour committed May 29, 2024
1 parent d04bda2 commit ab8b489
Show file tree
Hide file tree
Showing 8 changed files with 303 additions and 1 deletion.
44 changes: 44 additions & 0 deletions db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3279,3 +3279,47 @@ func (a *AssertBucket) Upload(ctx context.Context, path string, r io.Reader) err
}
return a.Bucket.Upload(ctx, path, r)
}

func Test_DB_Sample(t *testing.T) {
t.Parallel()
config := NewTableConfig(
dynparquet.SampleDefinition(),
)
logger := newTestLogger(t)

c, err := New(WithLogger(logger))
t.Cleanup(func() {
require.NoError(t, c.Close())
})
require.NoError(t, err)
db, err := c.DB(context.Background(), "test")
require.NoError(t, err)
table, err := db.Table("test", config)
require.NoError(t, err)

ctx := context.Background()
for i := 0; i < 500; i++ {
samples := dynparquet.GenerateTestSamples(10)
r, err := samples.ToRecord()
require.NoError(t, err)
_, err = table.InsertRecord(ctx, r)
require.NoError(t, err)
}

pool := memory.NewCheckedAllocator(memory.DefaultAllocator)
defer pool.AssertSize(t, 0)
lock := &sync.Mutex{}
rows := int64(0)
sampleSize := int64(13)
engine := query.NewEngine(pool, db.TableProvider())
err = engine.ScanTable("test").
Sample(sampleSize). // Sample 13 rows
Execute(context.Background(), func(ctx context.Context, r arrow.Record) error {
lock.Lock()
defer lock.Unlock()
rows += r.NumRows()
return nil
})
require.NoError(t, err)
require.Equal(t, sampleSize, rows)
}
12 changes: 12 additions & 0 deletions query/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ type Builder interface {
Limit(expr logicalplan.Expr) Builder
Execute(ctx context.Context, callback func(ctx context.Context, r arrow.Record) error) error
Explain(ctx context.Context) (string, error)
Sample(size int64) Builder
}

type LocalEngine struct {
Expand Down Expand Up @@ -143,6 +144,17 @@ func (b LocalQueryBuilder) Limit(
}
}

func (b LocalQueryBuilder) Sample(
size int64,
) Builder {
return LocalQueryBuilder{
pool: b.pool,
tracer: b.tracer,
planBuilder: b.planBuilder.Sample(logicalplan.Literal(size)),
execOpts: b.execOpts,
}
}

func (b LocalQueryBuilder) Execute(ctx context.Context, callback func(ctx context.Context, r arrow.Record) error) error {
ctx, span := b.tracer.Start(ctx, "LocalQueryBuilder/Execute")
defer span.End()
Expand Down
16 changes: 16 additions & 0 deletions query/logicalplan/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,22 @@ func resolveAggregation(plan *LogicalPlan, agg *AggregationFunction) ([]*Aggrega
}
}

func (b Builder) Sample(expr Expr) Builder {
if expr == nil {
return b
}

return Builder{
err: b.err,
plan: &LogicalPlan{
Input: b.plan,
Sample: &Sample{
Expr: expr,
},
},
}
}

func (b Builder) Build() (*LogicalPlan, error) {
if b.err != nil {
return nil, b.err
Expand Down
9 changes: 9 additions & 0 deletions query/logicalplan/logicalplan.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ type LogicalPlan struct {
Projection *Projection
Aggregation *Aggregation
Limit *Limit
Sample *Sample
}

// Callback is a function that is called throughout a chain of operators
Expand Down Expand Up @@ -414,3 +415,11 @@ type Limit struct {
func (l *Limit) String() string {
return "Limit" + " Expr: " + fmt.Sprint(l.Expr)
}

type Sample struct {
Expr Expr
}

func (s *Sample) String() string {
return "Sample" + " Expr: " + fmt.Sprint(s.Expr)
}
5 changes: 4 additions & 1 deletion query/logicalplan/validate.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,13 @@ func ValidateSingleFieldSet(plan *LogicalPlan) *PlanValidationError {
if plan.Limit != nil {
fieldsSet = append(fieldsSet, 6)
}
if plan.Sample != nil {
fieldsSet = append(fieldsSet, 7)
}

if len(fieldsSet) != 1 {
fieldsFound := make([]string, 0)
fields := []string{"SchemaScan", "TableScan", "Filter", "Distinct", "Projection", "Aggregation"}
fields := []string{"SchemaScan", "TableScan", "Filter", "Distinct", "Projection", "Aggregation", "Limit", "Sample"}
for _, i := range fieldsSet {
fieldsFound = append(fieldsFound, fields[i])
}
Expand Down
14 changes: 14 additions & 0 deletions query/physicalplan/physicalplan.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"github.com/apache/arrow/go/v16/arrow"
"github.com/apache/arrow/go/v16/arrow/memory"
"github.com/apache/arrow/go/v16/arrow/scalar"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
"golang.org/x/sync/errgroup"
Expand Down Expand Up @@ -482,6 +483,19 @@ func Build(
if ordered {
oInfo.nodeMaintainsOrdering()
}
case plan.Sample != nil:
v := plan.Sample.Expr.(*logicalplan.LiteralExpr).Value.(*scalar.Int64).Value
perSampler := v / int64(len(prev))
r := v % int64(len(prev))
for i := range prev {
adjust := int64(0)
if i < int(r) {
adjust = 1
}
s := NewReservoirSampler(perSampler + adjust)
prev[i].SetNext(s)
prev[i] = s
}
default:
panic("Unsupported plan")
}
Expand Down
118 changes: 118 additions & 0 deletions query/physicalplan/sampler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
package physicalplan

import (
"context"
"fmt"
"math"
"math/rand"

"github.com/apache/arrow/go/v16/arrow"
)

type ReservoirSampler struct {
next PhysicalPlan

// size is the max number of rows in the reservoir
size int64

// reservoir is the set of records that have been sampled. They may vary in schema due to dynamic columns.
reservoir []arrow.Record

// currentSize is the number of rows in the reservoir in all records.
currentSize int64

w float64 // w is the probability of keeping a record
n int64 // n is the number of rows that have been sampled thus far
}

// NewReservoirSampler will create a new ReservoirSampler operator that will sample up to size rows of all records seen by Callback.
func NewReservoirSampler(size int64) *ReservoirSampler {
return &ReservoirSampler{
size: size,
reservoir: []arrow.Record{},
w: math.Exp(math.Log(rand.Float64()) / float64(size)),
}
}

func (s *ReservoirSampler) SetNext(p PhysicalPlan) {
s.next = p
}

func (s *ReservoirSampler) Draw() *Diagram {
var child *Diagram
if s.next != nil {
child = s.next.Draw()
}
details := fmt.Sprintf("Reservoir Sampler (%v)", s.size)
return &Diagram{Details: details, Child: child}
}

func (s *ReservoirSampler) Close() {
for _, r := range s.reservoir {
r.Release()
}
s.next.Close()
}

// Callback collects all the records to sample.
func (s *ReservoirSampler) Callback(ctx context.Context, r arrow.Record) error {

Check failure on line 58 in query/physicalplan/sampler.go

View workflow job for this annotation

GitHub Actions / lint

unused-parameter: parameter 'ctx' seems to be unused, consider removing or renaming it as _ (revive)
r = s.fill(r)
if r == nil {
return nil
}

// Sample the record
s.sample(r)
return nil
}

// fill will fill the reservoir with the first size records

Check failure on line 69 in query/physicalplan/sampler.go

View workflow job for this annotation

GitHub Actions / lint

Comment should end in a period (godot)
func (s *ReservoirSampler) fill(r arrow.Record) arrow.Record {
if s.currentSize == s.size {
return r
}

if s.currentSize+r.NumRows() <= s.size { // The record fits in the reservoir
for i := int64(0); i < r.NumRows(); i++ { // For simplicity of implementation the reservoir is by row; This is probably not optimal
s.reservoir = append(s.reservoir, r.NewSlice(i, i+1))
}
s.currentSize += r.NumRows()
s.n += r.NumRows()
return nil
}

// The record partially fits in the reservoir
for i := int64(0); i < s.size-s.currentSize; i++ {
s.reservoir = append(s.reservoir, r.NewSlice(i, i+1))
}
r = r.NewSlice(s.size-s.currentSize, r.NumRows())
s.currentSize = s.size
s.n = s.size
return r
}

// sample implements the reservoir sampling algorithm found https://en.wikipedia.org/wiki/Reservoir_sampling.
func (s *ReservoirSampler) sample(r arrow.Record) {
n := s.n + r.NumRows()
for i := float64(s.n); i < float64(n); {
i += math.Floor(math.Log(rand.Float64())/math.Log(1-s.w)) + 1
if i <= float64(n) {
// replace a random item of the reservoir with row i
s.reservoir[rand.Intn(int(s.size))] = r.NewSlice(int64(i)-s.n-1, int64(i)-s.n)
s.w = s.w * math.Exp(math.Log(rand.Float64())/float64(s.size))
}
}
s.n = n
}

// Finish sends all the records in the reservoir to the next operator.
func (s *ReservoirSampler) Finish(ctx context.Context) error {
// Send all the records in the reservoir to the next operator
for _, r := range s.reservoir {
if err := s.next.Callback(ctx, r); err != nil {
return err
}
}

return s.next.Finish(ctx)
}
86 changes: 86 additions & 0 deletions query/physicalplan/sampler_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
package physicalplan

import (
"context"
"testing"

"github.com/apache/arrow/go/v16/arrow"
"github.com/apache/arrow/go/v16/arrow/array"
"github.com/apache/arrow/go/v16/arrow/memory"
"github.com/stretchr/testify/require"
)

type TestPlan struct {
finish func() error
callback func(ctx context.Context, r arrow.Record) error
}

func (t *TestPlan) Callback(ctx context.Context, r arrow.Record) error {
if t.callback != nil {
return t.callback(ctx, r)
}
return nil
}
func (t *TestPlan) Finish(ctx context.Context) error {

Check failure on line 24 in query/physicalplan/sampler_test.go

View workflow job for this annotation

GitHub Actions / lint

unused-parameter: parameter 'ctx' seems to be unused, consider removing or renaming it as _ (revive)
if t.finish != nil {
return t.finish()
}
return nil
}
func (t *TestPlan) SetNext(next PhysicalPlan) {}

Check failure on line 30 in query/physicalplan/sampler_test.go

View workflow job for this annotation

GitHub Actions / lint

unused-parameter: parameter 'next' seems to be unused, consider removing or renaming it as _ (revive)
func (t *TestPlan) Draw() *Diagram { return nil }
func (t *TestPlan) Close() {}

func Test_Sampler(t *testing.T) {
ctx := context.Background()
tests := map[string]struct {
reservoirSize int64
numRows int
recordSize int
}{
"basic single row records": {
reservoirSize: 10,
numRows: 100,
recordSize: 1,
},
"basic multi row records": {
reservoirSize: 10,
numRows: 100,
recordSize: 10,
},
}

for name, test := range tests {
t.Run(name, func(t *testing.T) {
// Create a new sampler
s := NewReservoirSampler(test.reservoirSize)
called := false
total := int64(0)
s.SetNext(&TestPlan{
callback: func(ctx context.Context, r arrow.Record) error {
called = true
total += r.NumRows()
return nil
},
})

schema := arrow.NewSchema([]arrow.Field{
{Name: "a", Type: arrow.PrimitiveTypes.Int64},
}, nil)
bldr := array.NewRecordBuilder(memory.NewGoAllocator(), schema)

for i := 0; i < test.numRows/test.recordSize; i++ {
for j := 0; j < test.recordSize; j++ {
bldr.Field(0).(*array.Int64Builder).Append(int64((i * test.recordSize) + j))
}
r := bldr.NewRecord()
t.Cleanup(r.Release)
s.Callback(ctx, r)

Check failure on line 78 in query/physicalplan/sampler_test.go

View workflow job for this annotation

GitHub Actions / lint

Error return value of `s.Callback` is not checked (errcheck)
}

s.Finish(ctx)

Check failure on line 81 in query/physicalplan/sampler_test.go

View workflow job for this annotation

GitHub Actions / lint

Error return value of `s.Finish` is not checked (errcheck)
require.True(t, called)
require.Equal(t, test.reservoirSize, total)
})
}
}

0 comments on commit ab8b489

Please sign in to comment.