From fe804c6f9c9799c50520ff5828eef49db3fad00e Mon Sep 17 00:00:00 2001 From: Daniil Aksenov Date: Fri, 27 Dec 2024 12:13:27 +0300 Subject: [PATCH] support custom types in query result scanning --- CHANGELOG.md | 2 ++ internal/value/cast.go | 8 ++++++++ internal/value/cast_test.go | 15 +++++++++++++++ tests/integration/query_range_test.go | 19 +++++++++++++++++-- 4 files changed, 42 insertions(+), 2 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6772bb342..91df24ae8 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,3 +1,5 @@ +* Added support of custom types to row.ScanStruc + ## v3.95.5 * Fixed goroutine leak on failed execute call in query client diff --git a/internal/value/cast.go b/internal/value/cast.go index 7106b62fd..ddde56ace 100644 --- a/internal/value/cast.go +++ b/internal/value/cast.go @@ -10,5 +10,13 @@ func CastTo(v Value, dst interface{}) error { return nil } + if scanner, has := dst.(Scanner); has { + return scanner.UnmarshalYDBValue(v) + } + return v.castTo(dst) } + +type Scanner interface { + UnmarshalYDBValue(value Value) error +} diff --git a/internal/value/cast_test.go b/internal/value/cast_test.go index 863587f59..8886bb4b5 100644 --- a/internal/value/cast_test.go +++ b/internal/value/cast_test.go @@ -32,6 +32,14 @@ func loadLocation(t *testing.T, name string) *time.Location { return loc } +type testStringValueScanner struct { + field string +} + +func (s *testStringValueScanner) UnmarshalYDBValue(v Value) error { + return CastTo(v, &s.field) +} + func TestCastTo(t *testing.T) { testsCases := []struct { name string @@ -428,6 +436,13 @@ func TestCastTo(t *testing.T) { exp: DateValueFromTime(time.Date(2024, 1, 2, 0, 0, 0, 0, time.UTC)), err: nil, }, + { + name: xtest.CurrentFileLine(), + value: TextValue("text-string"), + dst: ptr[testStringValueScanner](), + exp: testStringValueScanner{field: "text-string"}, + err: nil, + }, } for _, tt := range testsCases { t.Run(tt.name, func(t *testing.T) { diff --git a/tests/integration/query_range_test.go b/tests/integration/query_range_test.go index 03505e30e..76aba5274 100644 --- a/tests/integration/query_range_test.go +++ b/tests/integration/query_range_test.go @@ -17,8 +17,17 @@ import ( "github.com/ydb-platform/ydb-go-sdk/v3/internal/value" "github.com/ydb-platform/ydb-go-sdk/v3/internal/xtest" "github.com/ydb-platform/ydb-go-sdk/v3/query" + "github.com/ydb-platform/ydb-go-sdk/v3/table/types" ) +type testStringValueScanner struct { + field string +} + +func (v *testStringValueScanner) UnmarshalYDBValue(value types.Value) error { + return types.CastTo(value, &v.field) +} + func TestQueryRange(t *testing.T) { ctx, cancel := context.WithCancel(xtest.Context(t)) defer cancel() @@ -84,19 +93,22 @@ func TestQueryRange(t *testing.T) { p1 string p2 uint64 p3 time.Duration + p4 testStringValueScanner ) err := db.Query().Do(ctx, func(ctx context.Context, s query.Session) error { r, err := s.Query(ctx, ` DECLARE $p1 AS Text; DECLARE $p2 AS Uint64; DECLARE $p3 AS Interval; - SELECT $p1, $p2, $p3; + DECLARE $p4 AS Text; + SELECT $p1, $p2, $p3, $p4; `, query.WithParameters( ydb.ParamsBuilder(). Param("$p1").Text("test"). Param("$p2").Uint64(100500000000). Param("$p3").Interval(time.Duration(100500000000)). + Param("$p4").Text("test2"). Build(), ), query.WithSyntax(query.SyntaxYQL), @@ -112,7 +124,7 @@ func TestQueryRange(t *testing.T) { if err != nil { return err } - err = row.Scan(&p1, &p2, &p3) + err = row.Scan(&p1, &p2, &p3, &p4) if err != nil { return err } @@ -126,6 +138,9 @@ func TestQueryRange(t *testing.T) { if p3 != time.Duration(100500000000) { return fmt.Errorf("unexpected p3 value: %v", p3) } + if p4.field != "test2" { + return fmt.Errorf("unexpected p4 value: %v", p4) + } } } return nil