Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reservoir sampler #884

Merged
merged 19 commits into from
Jun 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 44 additions & 0 deletions db_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -3279,3 +3279,47 @@ func (a *AssertBucket) Upload(ctx context.Context, path string, r io.Reader) err
}
return a.Bucket.Upload(ctx, path, r)
}

func Test_DB_Sample(t *testing.T) {
t.Parallel()
config := NewTableConfig(
dynparquet.SampleDefinition(),
)
logger := newTestLogger(t)

c, err := New(WithLogger(logger))
t.Cleanup(func() {
require.NoError(t, c.Close())
})
require.NoError(t, err)
db, err := c.DB(context.Background(), "test")
require.NoError(t, err)
table, err := db.Table("test", config)
require.NoError(t, err)

ctx := context.Background()
for i := 0; i < 500; i++ {
samples := dynparquet.GenerateTestSamples(10)
r, err := samples.ToRecord()
require.NoError(t, err)
_, err = table.InsertRecord(ctx, r)
require.NoError(t, err)
}

pool := memory.NewCheckedAllocator(memory.DefaultAllocator)
defer pool.AssertSize(t, 0)
lock := &sync.Mutex{}
rows := int64(0)
sampleSize := int64(13)
engine := query.NewEngine(pool, db.TableProvider())
err = engine.ScanTable("test").
Sample(sampleSize). // Sample 13 rows
Execute(context.Background(), func(ctx context.Context, r arrow.Record) error {
lock.Lock()
defer lock.Unlock()
rows += r.NumRows()
return nil
})
require.NoError(t, err)
require.Equal(t, sampleSize, rows)
}
12 changes: 12 additions & 0 deletions query/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ type Builder interface {
Limit(expr logicalplan.Expr) Builder
Execute(ctx context.Context, callback func(ctx context.Context, r arrow.Record) error) error
Explain(ctx context.Context) (string, error)
Sample(size int64) Builder
}

type LocalEngine struct {
Expand Down Expand Up @@ -143,6 +144,17 @@ func (b LocalQueryBuilder) Limit(
}
}

func (b LocalQueryBuilder) Sample(
size int64,
) Builder {
return LocalQueryBuilder{
pool: b.pool,
tracer: b.tracer,
planBuilder: b.planBuilder.Sample(logicalplan.Literal(size)),
execOpts: b.execOpts,
}
}

func (b LocalQueryBuilder) Execute(ctx context.Context, callback func(ctx context.Context, r arrow.Record) error) error {
ctx, span := b.tracer.Start(ctx, "LocalQueryBuilder/Execute")
defer span.End()
Expand Down
16 changes: 16 additions & 0 deletions query/logicalplan/builder.go
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,22 @@ func resolveAggregation(plan *LogicalPlan, agg *AggregationFunction) ([]*Aggrega
}
}

func (b Builder) Sample(expr Expr) Builder {
thorfour marked this conversation as resolved.
Show resolved Hide resolved
if expr == nil {
return b
}

return Builder{
err: b.err,
plan: &LogicalPlan{
Input: b.plan,
Sample: &Sample{
Expr: expr,
},
},
}
}

func (b Builder) Build() (*LogicalPlan, error) {
if b.err != nil {
return nil, b.err
Expand Down
16 changes: 16 additions & 0 deletions query/logicalplan/logicalplan.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ type LogicalPlan struct {
Projection *Projection
Aggregation *Aggregation
Limit *Limit
Sample *Sample
}

// Callback is a function that is called throughout a chain of operators
Expand Down Expand Up @@ -159,6 +160,13 @@ func (plan *LogicalPlan) DataTypeForExpr(expr Expr) (arrow.DataType, error) {
return nil, fmt.Errorf("data type for expr %v within Distinct: %w", expr, err)
}

return t, nil
case plan.Sample != nil:
t, err := expr.DataType(plan.Input)
if err != nil {
return nil, fmt.Errorf("data type for expr %v within Sample: %w", expr, err)
}

return t, nil
default:
return nil, fmt.Errorf("unknown logical plan")
Expand Down Expand Up @@ -414,3 +422,11 @@ type Limit struct {
func (l *Limit) String() string {
return "Limit" + " Expr: " + fmt.Sprint(l.Expr)
}

type Sample struct {
Expr Expr
}

func (s *Sample) String() string {
return "Sample" + " Expr: " + fmt.Sprint(s.Expr)
}
5 changes: 4 additions & 1 deletion query/logicalplan/validate.go
Original file line number Diff line number Diff line change
Expand Up @@ -122,10 +122,13 @@ func ValidateSingleFieldSet(plan *LogicalPlan) *PlanValidationError {
if plan.Limit != nil {
fieldsSet = append(fieldsSet, 6)
}
if plan.Sample != nil {
fieldsSet = append(fieldsSet, 7)
}

if len(fieldsSet) != 1 {
fieldsFound := make([]string, 0)
fields := []string{"SchemaScan", "TableScan", "Filter", "Distinct", "Projection", "Aggregation"}
fields := []string{"SchemaScan", "TableScan", "Filter", "Distinct", "Projection", "Aggregation", "Limit", "Sample"}
for _, i := range fieldsSet {
fieldsFound = append(fieldsFound, fields[i])
}
Expand Down
14 changes: 14 additions & 0 deletions query/physicalplan/physicalplan.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (

"github.com/apache/arrow/go/v16/arrow"
"github.com/apache/arrow/go/v16/arrow/memory"
"github.com/apache/arrow/go/v16/arrow/scalar"
"go.opentelemetry.io/otel/attribute"
"go.opentelemetry.io/otel/trace"
"golang.org/x/sync/errgroup"
Expand Down Expand Up @@ -471,6 +472,19 @@ func Build(
if ordered {
oInfo.nodeMaintainsOrdering()
}
case plan.Sample != nil:
v := plan.Sample.Expr.(*logicalplan.LiteralExpr).Value.(*scalar.Int64).Value
perSampler := v / int64(len(prev))
r := v % int64(len(prev))
for i := range prev {
adjust := int64(0)
if i < int(r) {
adjust = 1
}
s := NewReservoirSampler(perSampler + adjust)
prev[i].SetNext(s)
prev[i] = s
}
default:
panic("Unsupported plan")
}
Expand Down
153 changes: 153 additions & 0 deletions query/physicalplan/sampler.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
package physicalplan

import (
"context"
"fmt"
"math"
"math/rand"

"github.com/apache/arrow/go/v16/arrow"
)

type ReservoirSampler struct {
next PhysicalPlan

// size is the max number of rows in the reservoir
size int64

// reservoir is the set of records that have been sampled. They may vary in schema due to dynamic columns.
reservoir []sample

w float64 // w is the probability of keeping a record
n int64 // n is the number of rows that have been sampled thus far
i float64 // i is the current row number being sampled
}

// NewReservoirSampler will create a new ReservoirSampler operator that will sample up to size rows of all records seen by Callback.
func NewReservoirSampler(size int64) *ReservoirSampler {
return &ReservoirSampler{
size: size,
w: math.Exp(math.Log(rand.Float64()) / float64(size)),
}
}

func (s *ReservoirSampler) SetNext(p PhysicalPlan) {
s.next = p
}

func (s *ReservoirSampler) Draw() *Diagram {
var child *Diagram
if s.next != nil {
child = s.next.Draw()
}
details := fmt.Sprintf("Reservoir Sampler (%v)", s.size)
return &Diagram{Details: details, Child: child}
}

func (s *ReservoirSampler) Close() {
for _, r := range s.reservoir {
r.r.Release()
}
s.next.Close()
}

// Callback collects all the records to sample.
func (s *ReservoirSampler) Callback(_ context.Context, r arrow.Record) error {
r = s.fill(r)
if r == nil { // The record fit in the reservoir
return nil
}
if s.n == s.size { // The reservoir just filled up. Slice the reservoir to the correct size so we can easily perform row replacement
s.sliceReservoir()
}

// Sample the record
s.sample(r)
return nil
}

// fill will fill the reservoir with the first size records.
func (s *ReservoirSampler) fill(r arrow.Record) arrow.Record {
if s.n >= s.size {
return r
}

if s.n+r.NumRows() <= s.size { // The record fits in the reservoir
s.reservoir = append(s.reservoir, sample{r: r, i: -1}) // -1 means the record is not sampled; use the entire record
r.Retain()
s.n += r.NumRows()
return nil
}

// The record partially fits in the reservoir
s.reservoir = append(s.reservoir, sample{r: r.NewSlice(0, s.size-s.n), i: -1})
r = r.NewSlice(s.size-s.n, r.NumRows())
s.n = s.size
return r
}

func (s *ReservoirSampler) sliceReservoir() {
newReservoir := make([]sample, 0, s.size)
for _, r := range s.reservoir {
for j := int64(0); j < r.r.NumRows(); j++ {
newReservoir = append(newReservoir, sample{r: r.r, i: j})
r.r.Retain()
}
r.r.Release()
}
s.reservoir = newReservoir
}

// sample implements the reservoir sampling algorithm found https://en.wikipedia.org/wiki/Reservoir_sampling.
func (s *ReservoirSampler) sample(r arrow.Record) {
n := s.n + r.NumRows()
if s.i == 0 {
s.i = float64(s.n) - 1
} else if s.i < float64(n) {
s.replace(rand.Intn(int(s.size)), sample{r: r, i: int64(s.i) - s.n})
s.w = s.w * math.Exp(math.Log(rand.Float64())/float64(s.size))
}

for s.i < float64(n) {
s.i += math.Floor(math.Log(rand.Float64())/math.Log(1-s.w)) + 1
if s.i < float64(n) {
// replace a random item of the reservoir with row i
s.replace(rand.Intn(int(s.size)), sample{r: r, i: int64(s.i) - s.n})
s.w = s.w * math.Exp(math.Log(rand.Float64())/float64(s.size))
}
}
s.n = n
}

// Finish sends all the records in the reservoir to the next operator.
func (s *ReservoirSampler) Finish(ctx context.Context) error {
// Send all the records in the reservoir to the next operator
for _, r := range s.reservoir {
if r.i == -1 {
if err := s.next.Callback(ctx, r.r); err != nil {
return err
}
continue
}

record := r.r.NewSlice(r.i, r.i+1)
defer record.Release()
if err := s.next.Callback(ctx, record); err != nil {
return err
}
}

return s.next.Finish(ctx)
}

// replace will replace the row at index i with the row in the record r at index j.
func (s *ReservoirSampler) replace(i int, newRow sample) {
s.reservoir[i].r.Release()
s.reservoir[i] = newRow
newRow.r.Retain()
}

type sample struct {
r arrow.Record
i int64
}
Loading
Loading