Skip to content

Commit

Permalink
support custom types in query result scanning
Browse files Browse the repository at this point in the history
  • Loading branch information
4el0ve4ek committed Dec 27, 2024
1 parent 4068af9 commit fe804c6
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 2 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
* Added support of custom types to row.ScanStruc

## v3.95.5
* Fixed goroutine leak on failed execute call in query client

Expand Down
8 changes: 8 additions & 0 deletions internal/value/cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,5 +10,13 @@ func CastTo(v Value, dst interface{}) error {
return nil
}

if scanner, has := dst.(Scanner); has {
return scanner.UnmarshalYDBValue(v)
}

return v.castTo(dst)
}

type Scanner interface {
UnmarshalYDBValue(value Value) error
}
15 changes: 15 additions & 0 deletions internal/value/cast_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,14 @@ func loadLocation(t *testing.T, name string) *time.Location {
return loc
}

type testStringValueScanner struct {
field string
}

func (s *testStringValueScanner) UnmarshalYDBValue(v Value) error {
return CastTo(v, &s.field)
}

func TestCastTo(t *testing.T) {
testsCases := []struct {
name string
Expand Down Expand Up @@ -428,6 +436,13 @@ func TestCastTo(t *testing.T) {
exp: DateValueFromTime(time.Date(2024, 1, 2, 0, 0, 0, 0, time.UTC)),
err: nil,
},
{
name: xtest.CurrentFileLine(),
value: TextValue("text-string"),
dst: ptr[testStringValueScanner](),
exp: testStringValueScanner{field: "text-string"},
err: nil,
},
}
for _, tt := range testsCases {
t.Run(tt.name, func(t *testing.T) {
Expand Down
19 changes: 17 additions & 2 deletions tests/integration/query_range_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,17 @@ import (
"github.com/ydb-platform/ydb-go-sdk/v3/internal/value"
"github.com/ydb-platform/ydb-go-sdk/v3/internal/xtest"
"github.com/ydb-platform/ydb-go-sdk/v3/query"
"github.com/ydb-platform/ydb-go-sdk/v3/table/types"
)

type testStringValueScanner struct {
field string
}

func (v *testStringValueScanner) UnmarshalYDBValue(value types.Value) error {
return types.CastTo(value, &v.field)
}

func TestQueryRange(t *testing.T) {
ctx, cancel := context.WithCancel(xtest.Context(t))
defer cancel()
Expand Down Expand Up @@ -84,19 +93,22 @@ func TestQueryRange(t *testing.T) {
p1 string
p2 uint64
p3 time.Duration
p4 testStringValueScanner
)
err := db.Query().Do(ctx, func(ctx context.Context, s query.Session) error {
r, err := s.Query(ctx, `
DECLARE $p1 AS Text;
DECLARE $p2 AS Uint64;
DECLARE $p3 AS Interval;
SELECT $p1, $p2, $p3;
DECLARE $p4 AS Text;
SELECT $p1, $p2, $p3, $p4;
`,
query.WithParameters(
ydb.ParamsBuilder().
Param("$p1").Text("test").
Param("$p2").Uint64(100500000000).
Param("$p3").Interval(time.Duration(100500000000)).
Param("$p4").Text("test2").
Build(),
),
query.WithSyntax(query.SyntaxYQL),
Expand All @@ -112,7 +124,7 @@ func TestQueryRange(t *testing.T) {
if err != nil {
return err
}
err = row.Scan(&p1, &p2, &p3)
err = row.Scan(&p1, &p2, &p3, &p4)
if err != nil {
return err
}
Expand All @@ -126,6 +138,9 @@ func TestQueryRange(t *testing.T) {
if p3 != time.Duration(100500000000) {
return fmt.Errorf("unexpected p3 value: %v", p3)
}
if p4.field != "test2" {
return fmt.Errorf("unexpected p4 value: %v", p4)
}
}
}
return nil
Expand Down

0 comments on commit fe804c6

Please sign in to comment.