diff --git a/pqarrow/arrowutils/schema.go b/pqarrow/arrowutils/schema.go new file mode 100644 index 000000000..f9e9a4df7 --- /dev/null +++ b/pqarrow/arrowutils/schema.go @@ -0,0 +1,95 @@ +package arrowutils + +import ( + "fmt" + "sort" + + "github.com/apache/arrow/go/v16/arrow" + "github.com/apache/arrow/go/v16/arrow/array" +) + +// EnsureSameSchema ensures that all the records have the same schema. In cases +// where the schema is not equal, virtual null columns are inserted in the +// records with the missing column. When we have static schemas in the execution +// engine, steps like these should be unnecessary. +func EnsureSameSchema(records []arrow.Record) ([]arrow.Record, error) { + if len(records) < 2 { + return records, nil + } + + lastSchema := records[0].Schema() + needSchemaRecalculation := false + for i := range records { + if !records[i].Schema().Equal(lastSchema) { + needSchemaRecalculation = true + break + } + } + if !needSchemaRecalculation { + return records, nil + } + + columns := make(map[string]arrow.Field) + for _, r := range records { + for j := 0; j < r.Schema().NumFields(); j++ { + field := r.Schema().Field(j) + if _, ok := columns[field.Name]; !ok { + columns[field.Name] = field + } + } + } + + columnNames := make([]string, 0, len(columns)) + for name := range columns { + columnNames = append(columnNames, name) + } + sort.Strings(columnNames) + + mergedFields := make([]arrow.Field, 0, len(columnNames)) + for _, name := range columnNames { + mergedFields = append(mergedFields, columns[name]) + } + mergedSchema := arrow.NewSchema(mergedFields, nil) + + mergedRecords := make([]arrow.Record, len(records)) + var replacedRecords []arrow.Record + + for i := range records { + recordSchema := records[i].Schema() + if mergedSchema.Equal(recordSchema) { + mergedRecords[i] = records[i] + continue + } + + mergedColumns := make([]arrow.Array, 0, len(mergedFields)) + recordNumRows := records[i].NumRows() + for j := 0; j < mergedSchema.NumFields(); j++ { + field := mergedSchema.Field(j) + if otherFields := recordSchema.FieldIndices(field.Name); otherFields != nil { + if len(otherFields) > 1 { + fieldsFound, _ := recordSchema.FieldsByName(field.Name) + return nil, fmt.Errorf( + "found multiple fields %v for name %s", + fieldsFound, + field.Name, + ) + } + mergedColumns = append(mergedColumns, records[i].Column(otherFields[0])) + } else { + // Note that this VirtualNullArray will be read from, but the + // merged output will be a physical null array, so there is no + // virtual->physical conversion necessary before we return data. + mergedColumns = append(mergedColumns, MakeVirtualNullArray(field.Type, int(recordNumRows))) + } + } + + replacedRecords = append(replacedRecords, records[i]) + mergedRecords[i] = array.NewRecord(mergedSchema, mergedColumns, recordNumRows) + } + + for _, r := range replacedRecords { + r.Release() + } + + return mergedRecords, nil +} diff --git a/pqarrow/arrowutils/schema_test.go b/pqarrow/arrowutils/schema_test.go new file mode 100644 index 000000000..647a7e58c --- /dev/null +++ b/pqarrow/arrowutils/schema_test.go @@ -0,0 +1,91 @@ +package arrowutils_test + +import ( + "testing" + + "github.com/apache/arrow/go/v16/arrow" + "github.com/apache/arrow/go/v16/arrow/memory" + "github.com/stretchr/testify/require" + + "github.com/polarsignals/frostdb/internal/records" + "github.com/polarsignals/frostdb/pqarrow/arrowutils" +) + +func TestEnsureSameSchema(t *testing.T) { + type struct1 struct { + Field1 int64 `frostdb:",asc(0)"` + Field2 int64 `frostdb:",asc(1)"` + } + type struct2 struct { + Field1 int64 `frostdb:",asc(0)"` + Field3 int64 `frostdb:",asc(1)"` + } + type struct3 struct { + Field1 int64 `frostdb:",asc(0)"` + Field2 int64 `frostdb:",asc(1)"` + Field3 int64 `frostdb:",asc(1)"` + } + + mem := memory.NewCheckedAllocator(memory.NewGoAllocator()) + defer mem.AssertSize(t, 0) + + build1 := records.NewBuild[struct1](mem) + defer build1.Release() + err := build1.Append([]struct1{ + {Field1: 1, Field2: 2}, + {Field1: 1, Field2: 3}, + }...) + require.NoError(t, err) + + build2 := records.NewBuild[struct2](mem) + defer build2.Release() + err = build2.Append([]struct2{ + {Field1: 1, Field3: 2}, + {Field1: 1, Field3: 3}, + }...) + require.NoError(t, err) + + build3 := records.NewBuild[struct3](mem) + defer build3.Release() + err = build3.Append([]struct3{ + {Field1: 1, Field2: 1, Field3: 1}, + {Field1: 2, Field2: 2, Field3: 2}, + }...) + require.NoError(t, err) + + record1 := build1.NewRecord() + record2 := build2.NewRecord() + record3 := build3.NewRecord() + + recs := []arrow.Record{record1, record2, record3} + defer func() { + for _, r := range recs { + r.Release() + } + }() + + recs, err = arrowutils.EnsureSameSchema(recs) + require.NoError(t, err) + + expected := []struct3{ + // record1 + {Field1: 1, Field2: 2, Field3: 0}, + {Field1: 1, Field2: 3, Field3: 0}, + // record2 + {Field1: 1, Field2: 0, Field3: 2}, + {Field1: 1, Field2: 0, Field3: 3}, + // record3 + {Field1: 1, Field2: 1, Field3: 1}, + {Field1: 2, Field2: 2, Field3: 2}, + } + + reader := records.NewReader[struct3](recs...) + rows := reader.NumRows() + require.Equal(t, int64(len(expected)), rows) + + actual := make([]struct3, rows) + for i := 0; i < int(rows); i++ { + actual[i] = reader.Value(i) + } + require.Equal(t, expected, actual) +}