-
Notifications
You must be signed in to change notification settings - Fork 66
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
8 changed files
with
303 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
package physicalplan | ||
|
||
import ( | ||
"context" | ||
"fmt" | ||
"math" | ||
"math/rand" | ||
|
||
"github.com/apache/arrow/go/v16/arrow" | ||
) | ||
|
||
type ReservoirSampler struct { | ||
next PhysicalPlan | ||
|
||
// size is the max number of rows in the reservoir | ||
size int64 | ||
|
||
// reservoir is the set of records that have been sampled. They may vary in schema due to dynamic columns. | ||
reservoir []arrow.Record | ||
|
||
// currentSize is the number of rows in the reservoir in all records. | ||
currentSize int64 | ||
|
||
w float64 // w is the probability of keeping a record | ||
n int64 // n is the number of rows that have been sampled thus far | ||
} | ||
|
||
// NewReservoirSampler will create a new ReservoirSampler operator that will sample up to size rows of all records seen by Callback. | ||
func NewReservoirSampler(size int64) *ReservoirSampler { | ||
return &ReservoirSampler{ | ||
size: size, | ||
reservoir: []arrow.Record{}, | ||
w: math.Exp(math.Log(rand.Float64()) / float64(size)), | ||
} | ||
} | ||
|
||
func (s *ReservoirSampler) SetNext(p PhysicalPlan) { | ||
s.next = p | ||
} | ||
|
||
func (s *ReservoirSampler) Draw() *Diagram { | ||
var child *Diagram | ||
if s.next != nil { | ||
child = s.next.Draw() | ||
} | ||
details := fmt.Sprintf("Reservoir Sampler (%v)", s.size) | ||
return &Diagram{Details: details, Child: child} | ||
} | ||
|
||
func (s *ReservoirSampler) Close() { | ||
for _, r := range s.reservoir { | ||
r.Release() | ||
} | ||
s.next.Close() | ||
} | ||
|
||
// Callback collects all the records to sample. | ||
func (s *ReservoirSampler) Callback(ctx context.Context, r arrow.Record) error { | ||
r = s.fill(r) | ||
if r == nil { | ||
return nil | ||
} | ||
|
||
// Sample the record | ||
s.sample(r) | ||
return nil | ||
} | ||
|
||
// fill will fill the reservoir with the first size records | ||
func (s *ReservoirSampler) fill(r arrow.Record) arrow.Record { | ||
if s.currentSize == s.size { | ||
return r | ||
} | ||
|
||
if s.currentSize+r.NumRows() <= s.size { // The record fits in the reservoir | ||
for i := int64(0); i < r.NumRows(); i++ { // For simplicity of implementation the reservoir is by row; This is probably not optimal | ||
s.reservoir = append(s.reservoir, r.NewSlice(i, i+1)) | ||
} | ||
s.currentSize += r.NumRows() | ||
s.n += r.NumRows() | ||
return nil | ||
} | ||
|
||
// The record partially fits in the reservoir | ||
for i := int64(0); i < s.size-s.currentSize; i++ { | ||
s.reservoir = append(s.reservoir, r.NewSlice(i, i+1)) | ||
} | ||
r = r.NewSlice(s.size-s.currentSize, r.NumRows()) | ||
s.currentSize = s.size | ||
s.n = s.size | ||
return r | ||
} | ||
|
||
// sample implements the reservoir sampling algorithm found https://en.wikipedia.org/wiki/Reservoir_sampling. | ||
func (s *ReservoirSampler) sample(r arrow.Record) { | ||
n := s.n + r.NumRows() | ||
for i := float64(s.n); i < float64(n); { | ||
i += math.Floor(math.Log(rand.Float64())/math.Log(1-s.w)) + 1 | ||
if i <= float64(n) { | ||
// replace a random item of the reservoir with row i | ||
s.reservoir[rand.Intn(int(s.size))] = r.NewSlice(int64(i)-s.n-1, int64(i)-s.n) | ||
s.w = s.w * math.Exp(math.Log(rand.Float64())/float64(s.size)) | ||
} | ||
} | ||
s.n = n | ||
} | ||
|
||
// Finish sends all the records in the reservoir to the next operator. | ||
func (s *ReservoirSampler) Finish(ctx context.Context) error { | ||
// Send all the records in the reservoir to the next operator | ||
for _, r := range s.reservoir { | ||
if err := s.next.Callback(ctx, r); err != nil { | ||
return err | ||
} | ||
} | ||
|
||
return s.next.Finish(ctx) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,86 @@ | ||
package physicalplan | ||
|
||
import ( | ||
"context" | ||
"testing" | ||
|
||
"github.com/apache/arrow/go/v16/arrow" | ||
"github.com/apache/arrow/go/v16/arrow/array" | ||
"github.com/apache/arrow/go/v16/arrow/memory" | ||
"github.com/stretchr/testify/require" | ||
) | ||
|
||
type TestPlan struct { | ||
finish func() error | ||
callback func(ctx context.Context, r arrow.Record) error | ||
} | ||
|
||
func (t *TestPlan) Callback(ctx context.Context, r arrow.Record) error { | ||
if t.callback != nil { | ||
return t.callback(ctx, r) | ||
} | ||
return nil | ||
} | ||
func (t *TestPlan) Finish(ctx context.Context) error { | ||
if t.finish != nil { | ||
return t.finish() | ||
} | ||
return nil | ||
} | ||
func (t *TestPlan) SetNext(next PhysicalPlan) {} | ||
func (t *TestPlan) Draw() *Diagram { return nil } | ||
func (t *TestPlan) Close() {} | ||
|
||
func Test_Sampler(t *testing.T) { | ||
ctx := context.Background() | ||
tests := map[string]struct { | ||
reservoirSize int64 | ||
numRows int | ||
recordSize int | ||
}{ | ||
"basic single row records": { | ||
reservoirSize: 10, | ||
numRows: 100, | ||
recordSize: 1, | ||
}, | ||
"basic multi row records": { | ||
reservoirSize: 10, | ||
numRows: 100, | ||
recordSize: 10, | ||
}, | ||
} | ||
|
||
for name, test := range tests { | ||
t.Run(name, func(t *testing.T) { | ||
// Create a new sampler | ||
s := NewReservoirSampler(test.reservoirSize) | ||
called := false | ||
total := int64(0) | ||
s.SetNext(&TestPlan{ | ||
callback: func(ctx context.Context, r arrow.Record) error { | ||
called = true | ||
total += r.NumRows() | ||
return nil | ||
}, | ||
}) | ||
|
||
schema := arrow.NewSchema([]arrow.Field{ | ||
{Name: "a", Type: arrow.PrimitiveTypes.Int64}, | ||
}, nil) | ||
bldr := array.NewRecordBuilder(memory.NewGoAllocator(), schema) | ||
|
||
for i := 0; i < test.numRows/test.recordSize; i++ { | ||
for j := 0; j < test.recordSize; j++ { | ||
bldr.Field(0).(*array.Int64Builder).Append(int64((i * test.recordSize) + j)) | ||
} | ||
r := bldr.NewRecord() | ||
t.Cleanup(r.Release) | ||
s.Callback(ctx, r) | ||
} | ||
|
||
s.Finish(ctx) | ||
require.True(t, called) | ||
require.Equal(t, test.reservoirSize, total) | ||
}) | ||
} | ||
} |