Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pqarrow/arrowutils: Add EnsureSameSchema for records #806

Merged
merged 1 commit into from
Apr 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions pqarrow/arrowutils/schema.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
package arrowutils

import (
"fmt"
"sort"

"github.com/apache/arrow/go/v16/arrow"
"github.com/apache/arrow/go/v16/arrow/array"
)

// EnsureSameSchema ensures that all the records have the same schema. In cases
// where the schema is not equal, virtual null columns are inserted in the
thorfour marked this conversation as resolved.
Show resolved Hide resolved
// records with the missing column. When we have static schemas in the execution
// engine, steps like these should be unnecessary.
func EnsureSameSchema(records []arrow.Record) ([]arrow.Record, error) {
if len(records) < 2 {
return records, nil
}

lastSchema := records[0].Schema()
needSchemaRecalculation := false
for i := range records {
if !records[i].Schema().Equal(lastSchema) {
needSchemaRecalculation = true
break
}
}
if !needSchemaRecalculation {
return records, nil
}

columns := make(map[string]arrow.Field)
for _, r := range records {
for j := 0; j < r.Schema().NumFields(); j++ {
field := r.Schema().Field(j)
if _, ok := columns[field.Name]; !ok {
columns[field.Name] = field
}
}
}

columnNames := make([]string, 0, len(columns))
for name := range columns {
columnNames = append(columnNames, name)
}
sort.Strings(columnNames)

mergedFields := make([]arrow.Field, 0, len(columnNames))
for _, name := range columnNames {
mergedFields = append(mergedFields, columns[name])
}
mergedSchema := arrow.NewSchema(mergedFields, nil)

mergedRecords := make([]arrow.Record, len(records))
var replacedRecords []arrow.Record

for i := range records {
recordSchema := records[i].Schema()
if mergedSchema.Equal(recordSchema) {
mergedRecords[i] = records[i]
continue
}

mergedColumns := make([]arrow.Array, 0, len(mergedFields))
recordNumRows := records[i].NumRows()
for j := 0; j < mergedSchema.NumFields(); j++ {
field := mergedSchema.Field(j)
if otherFields := recordSchema.FieldIndices(field.Name); otherFields != nil {
if len(otherFields) > 1 {
fieldsFound, _ := recordSchema.FieldsByName(field.Name)
return nil, fmt.Errorf(
"found multiple fields %v for name %s",
fieldsFound,
field.Name,
)
}
mergedColumns = append(mergedColumns, records[i].Column(otherFields[0]))
} else {
// Note that this VirtualNullArray will be read from, but the
// merged output will be a physical null array, so there is no
// virtual->physical conversion necessary before we return data.
mergedColumns = append(mergedColumns, MakeVirtualNullArray(field.Type, int(recordNumRows)))
}
}

replacedRecords = append(replacedRecords, records[i])
mergedRecords[i] = array.NewRecord(mergedSchema, mergedColumns, recordNumRows)
}

for _, r := range replacedRecords {
r.Release()
}

return mergedRecords, nil
}
91 changes: 91 additions & 0 deletions pqarrow/arrowutils/schema_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
package arrowutils_test

import (
"testing"

"github.com/apache/arrow/go/v16/arrow"
"github.com/apache/arrow/go/v16/arrow/memory"
"github.com/stretchr/testify/require"

"github.com/polarsignals/frostdb/internal/records"
"github.com/polarsignals/frostdb/pqarrow/arrowutils"
)

func TestEnsureSameSchema(t *testing.T) {
type struct1 struct {
Field1 int64 `frostdb:",asc(0)"`
Field2 int64 `frostdb:",asc(1)"`
}
type struct2 struct {
Field1 int64 `frostdb:",asc(0)"`
Field3 int64 `frostdb:",asc(1)"`
}
type struct3 struct {
Field1 int64 `frostdb:",asc(0)"`
Field2 int64 `frostdb:",asc(1)"`
Field3 int64 `frostdb:",asc(1)"`
}

mem := memory.NewCheckedAllocator(memory.NewGoAllocator())
defer mem.AssertSize(t, 0)

build1 := records.NewBuild[struct1](mem)
defer build1.Release()
err := build1.Append([]struct1{
{Field1: 1, Field2: 2},
{Field1: 1, Field2: 3},
}...)
require.NoError(t, err)

build2 := records.NewBuild[struct2](mem)
defer build2.Release()
err = build2.Append([]struct2{
{Field1: 1, Field3: 2},
{Field1: 1, Field3: 3},
}...)
require.NoError(t, err)

build3 := records.NewBuild[struct3](mem)
defer build3.Release()
err = build3.Append([]struct3{
{Field1: 1, Field2: 1, Field3: 1},
{Field1: 2, Field2: 2, Field3: 2},
}...)
require.NoError(t, err)

record1 := build1.NewRecord()
record2 := build2.NewRecord()
record3 := build3.NewRecord()

recs := []arrow.Record{record1, record2, record3}
defer func() {
for _, r := range recs {
r.Release()
}
}()

recs, err = arrowutils.EnsureSameSchema(recs)
require.NoError(t, err)

expected := []struct3{
// record1
{Field1: 1, Field2: 2, Field3: 0},
{Field1: 1, Field2: 3, Field3: 0},
// record2
{Field1: 1, Field2: 0, Field3: 2},
{Field1: 1, Field2: 0, Field3: 3},
// record3
{Field1: 1, Field2: 1, Field3: 1},
{Field1: 2, Field2: 2, Field3: 2},
}

reader := records.NewReader[struct3](recs...)
rows := reader.NumRows()
require.Equal(t, int64(len(expected)), rows)

actual := make([]struct3, rows)
for i := 0; i < int(rows); i++ {
actual[i] = reader.Value(i)
}
require.Equal(t, expected, actual)
}
Loading