Skip to content

Commit

Permalink
Fix: Remove data type protocol version checks (#60)
Browse files Browse the repository at this point in the history
Checking whether a protocol version supports a particular type is not
correct. Cassandra itself supports UDT, tuple, date, time, etc. while
using protocol V3.
  • Loading branch information
mpenick authored Sep 3, 2024
1 parent 2abea74 commit 605a850
Show file tree
Hide file tree
Showing 18 changed files with 49 additions and 487 deletions.
23 changes: 7 additions & 16 deletions datacodec/date.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,14 +85,10 @@ func (c *dateCodec) DataType() datatype.DataType {
// Note that this relies on the fact that some additions will overflow: this is expected.

func (c *dateCodec) Encode(source interface{}, version primitive.ProtocolVersion) (dest []byte, err error) {
if !version.SupportsDataType(c.DataType().Code()) {
err = errDataTypeNotSupported(c.DataType(), version)
} else {
var val int32
var wasNil bool
if val, wasNil, err = convertToInt32Date(source, c.layout); err == nil && !wasNil {
dest = writeInt32(val - math.MinInt32)
}
var val int32
var wasNil bool
if val, wasNil, err = convertToInt32Date(source, c.layout); err == nil && !wasNil {
dest = writeInt32(val - math.MinInt32)
}
if err != nil {
err = errCannotEncode(source, c.DataType(), version, err)
Expand All @@ -101,14 +97,9 @@ func (c *dateCodec) Encode(source interface{}, version primitive.ProtocolVersion
}

func (c *dateCodec) Decode(source []byte, dest interface{}, version primitive.ProtocolVersion) (wasNull bool, err error) {
if !version.SupportsDataType(c.DataType().Code()) {
wasNull = len(source) == 0
err = errDataTypeNotSupported(c.DataType(), version)
} else {
var val int32
if val, wasNull, err = readInt32(source); err == nil {
err = convertFromInt32Date(val+math.MinInt32, wasNull, c.layout, dest)
}
var val int32
if val, wasNull, err = readInt32(source); err == nil {
err = convertFromInt32Date(val+math.MinInt32, wasNull, c.layout, dest)
}
if err != nil {
err = errCannotDecode(dest, c.DataType(), version, err)
Expand Down
43 changes: 0 additions & 43 deletions datacodec/date_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,26 +102,6 @@ func Test_dateCodec_Encode(t *testing.T) {
}
})
}
for _, version := range primitive.SupportedProtocolVersionsLesserThan(primitive.ProtocolVersion4) {
t.Run(version.String(), func(t *testing.T) {
tests := []struct {
name string
source interface{}
expected []byte
err string
}{
{"nil", nil, nil, "data type date not supported"},
{"non nil", datePos, nil, "data type date not supported"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
actual, err := Date.Encode(tt.source, version)
assert.Equal(t, tt.expected, actual)
assertErrorMessage(t, tt.err, err)
})
}
})
}
}

func Test_dateCodec_Decode(t *testing.T) {
Expand Down Expand Up @@ -151,29 +131,6 @@ func Test_dateCodec_Decode(t *testing.T) {
}
})
}
for _, version := range primitive.SupportedProtocolVersionsLesserThan(primitive.ProtocolVersion4) {
t.Run(version.String(), func(t *testing.T) {
tests := []struct {
name string
source []byte
dest interface{}
expected interface{}
wasNull bool
err string
}{
{"null", nil, new(int32), new(int32), true, "data type date not supported"},
{"non null", datePosBytes, new(time.Time), new(time.Time), false, "data type date not supported"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
wasNull, err := Date.Decode(tt.source, tt.dest, version)
assert.Equal(t, tt.expected, tt.dest)
assert.Equal(t, tt.wasNull, wasNull)
assertErrorMessage(t, tt.err, err)
})
}
})
}
}

func Test_convertToInt32Date(t *testing.T) {
Expand Down
23 changes: 7 additions & 16 deletions datacodec/duration.go
Original file line number Diff line number Diff line change
Expand Up @@ -46,14 +46,10 @@ func (c *durationCodec) DataType() datatype.DataType {
}

func (c *durationCodec) Encode(source interface{}, version primitive.ProtocolVersion) (dest []byte, err error) {
if !version.SupportsDataType(c.DataType().Code()) {
err = errDataTypeNotSupported(c.DataType(), version)
} else {
var val CqlDuration
var wasNil bool
if val, wasNil, err = convertToDuration(source); err == nil && !wasNil {
dest = writeDuration(val)
}
var val CqlDuration
var wasNil bool
if val, wasNil, err = convertToDuration(source); err == nil && !wasNil {
dest = writeDuration(val)
}
if err != nil {
err = errCannotEncode(source, c.DataType(), version, err)
Expand All @@ -62,14 +58,9 @@ func (c *durationCodec) Encode(source interface{}, version primitive.ProtocolVer
}

func (c *durationCodec) Decode(source []byte, dest interface{}, version primitive.ProtocolVersion) (wasNull bool, err error) {
if !version.SupportsDataType(c.DataType().Code()) {
wasNull = len(source) == 0
err = errDataTypeNotSupported(c.DataType(), version)
} else {
var val CqlDuration
if val, wasNull, err = readDuration(source); err == nil {
err = convertFromDuration(val, wasNull, dest)
}
var val CqlDuration
if val, wasNull, err = readDuration(source); err == nil {
err = convertFromDuration(val, wasNull, dest)
}
if err != nil {
err = errCannotDecode(dest, c.DataType(), version, err)
Expand Down
43 changes: 0 additions & 43 deletions datacodec/duration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,26 +67,6 @@ func Test_durationCodec_Encode(t *testing.T) {
}
})
}
for _, version := range primitive.SupportedProtocolVersionsLesserThan(primitive.ProtocolVersion5) {
t.Run(version.String(), func(t *testing.T) {
tests := []struct {
name string
source interface{}
expected []byte
err string
}{
{"null", nil, nil, "data type duration not supported"},
{"non null", cqlDurationPos, nil, "data type duration not supported"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
actual, err := Duration.Encode(tt.source, version)
assert.Equal(t, tt.expected, actual)
assertErrorMessage(t, tt.err, err)
})
}
})
}
}

func Test_durationCodec_Decode(t *testing.T) {
Expand Down Expand Up @@ -116,29 +96,6 @@ func Test_durationCodec_Decode(t *testing.T) {
}
})
}
for _, version := range primitive.SupportedProtocolVersionsLesserThan(primitive.ProtocolVersion5) {
t.Run(version.String(), func(t *testing.T) {
tests := []struct {
name string
source []byte
dest interface{}
expected interface{}
wasNull bool
err string
}{
{"null", nil, new(CqlDuration), new(CqlDuration), true, "data type duration not supported"},
{"non null", cqlDurationPosBytes, new(CqlDuration), new(CqlDuration), false, "data type duration not supported"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
wasNull, err := Duration.Decode(tt.source, tt.dest, version)
assert.Equal(t, tt.expected, tt.dest)
assert.Equal(t, tt.wasNull, wasNull)
assertErrorMessage(t, tt.err, err)
})
}
})
}
}

func Test_convertToDuration(t *testing.T) {
Expand Down
23 changes: 7 additions & 16 deletions datacodec/smallint.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,14 +33,10 @@ func (c *smallintCodec) DataType() datatype.DataType {
}

func (c *smallintCodec) Encode(source interface{}, version primitive.ProtocolVersion) (dest []byte, err error) {
if !version.SupportsDataType(c.DataType().Code()) {
err = errDataTypeNotSupported(c.DataType(), version)
} else {
var val int16
var wasNil bool
if val, wasNil, err = convertToInt16(source); err == nil && !wasNil {
dest = writeInt16(val)
}
var val int16
var wasNil bool
if val, wasNil, err = convertToInt16(source); err == nil && !wasNil {
dest = writeInt16(val)
}
if err != nil {
err = errCannotEncode(source, c.DataType(), version, err)
Expand All @@ -49,14 +45,9 @@ func (c *smallintCodec) Encode(source interface{}, version primitive.ProtocolVer
}

func (c *smallintCodec) Decode(source []byte, dest interface{}, version primitive.ProtocolVersion) (wasNull bool, err error) {
if !version.SupportsDataType(c.DataType().Code()) {
wasNull = len(source) == 0
err = errDataTypeNotSupported(c.DataType(), version)
} else {
var val int16
if val, wasNull, err = readInt16(source); err == nil {
err = convertFromInt16(val, wasNull, dest)
}
var val int16
if val, wasNull, err = readInt16(source); err == nil {
err = convertFromInt16(val, wasNull, dest)
}
if err != nil {
err = errCannotDecode(dest, c.DataType(), version, err)
Expand Down
43 changes: 0 additions & 43 deletions datacodec/smallint_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -56,26 +56,6 @@ func Test_smallintCodec_Encode(t *testing.T) {
}
})
}
for _, version := range primitive.SupportedProtocolVersionsLesserThan(primitive.ProtocolVersion4) {
t.Run(version.String(), func(t *testing.T) {
tests := []struct {
name string
source interface{}
expected []byte
err string
}{
{"nil", int16NilPtr(), nil, "data type smallint not supported"},
{"non nil", 1, nil, "data type smallint not supported"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
actual, err := Smallint.Encode(tt.source, version)
assert.Equal(t, tt.expected, actual)
assertErrorMessage(t, tt.err, err)
})
}
})
}
}

func Test_smallintCodec_Decode(t *testing.T) {
Expand Down Expand Up @@ -105,29 +85,6 @@ func Test_smallintCodec_Decode(t *testing.T) {
}
})
}
for _, version := range primitive.SupportedProtocolVersionsLesserThan(primitive.ProtocolVersion4) {
t.Run(version.String(), func(t *testing.T) {
tests := []struct {
name string
source []byte
dest interface{}
expected interface{}
wasNull bool
err string
}{
{"null", nil, new(int16), new(int16), true, "data type smallint not supported"},
{"non null", smallintOneBytes, new(int16), new(int16), false, "data type smallint not supported"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
wasNull, err := Smallint.Decode(tt.source, tt.dest, version)
assert.Equal(t, tt.expected, tt.dest)
assert.Equal(t, tt.wasNull, wasNull)
assertErrorMessage(t, tt.err, err)
})
}
})
}
}

func Test_convertToInt16(t *testing.T) {
Expand Down
23 changes: 7 additions & 16 deletions datacodec/time.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,14 +99,10 @@ func (c *timeCodec) DataType() datatype.DataType {
}

func (c *timeCodec) Encode(source interface{}, version primitive.ProtocolVersion) (dest []byte, err error) {
if !version.SupportsDataType(c.DataType().Code()) {
err = errDataTypeNotSupported(c.DataType(), version)
} else {
var val int64
var wasNil bool
if val, wasNil, err = convertToInt64Time(source, c.layout); err == nil && !wasNil {
dest = writeInt64(val)
}
var val int64
var wasNil bool
if val, wasNil, err = convertToInt64Time(source, c.layout); err == nil && !wasNil {
dest = writeInt64(val)
}
if err != nil {
err = errCannotEncode(source, c.DataType(), version, err)
Expand All @@ -115,14 +111,9 @@ func (c *timeCodec) Encode(source interface{}, version primitive.ProtocolVersion
}

func (c *timeCodec) Decode(source []byte, dest interface{}, version primitive.ProtocolVersion) (wasNull bool, err error) {
if !version.SupportsDataType(c.DataType().Code()) {
wasNull = len(source) == 0
err = errDataTypeNotSupported(c.DataType(), version)
} else {
var val int64
if val, wasNull, err = readInt64(source); err == nil {
err = convertFromInt64Time(val, wasNull, dest, c.layout)
}
var val int64
if val, wasNull, err = readInt64(source); err == nil {
err = convertFromInt64Time(val, wasNull, dest, c.layout)
}
if err != nil {
err = errCannotDecode(dest, c.DataType(), version, err)
Expand Down
44 changes: 0 additions & 44 deletions datacodec/time_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -147,27 +147,6 @@ func Test_timeCodec_Encode(t *testing.T) {
}
})
}
for _, version := range primitive.SupportedProtocolVersionsLesserThan(primitive.ProtocolVersion4) {
t.Run(version.String(), func(t *testing.T) {
tests := []struct {
name string
source interface{}
expected []byte
err string
}{
{"nil", nil, nil, "data type time not supported"},
{"non nil", timeSimple, nil, "data type time not supported"},
{"conversion failed", TimeMaxDuration + 1, nil, "data type time not supported"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
actual, err := Time.Encode(tt.source, version)
assert.Equal(t, tt.expected, actual)
assertErrorMessage(t, tt.err, err)
})
}
})
}
}

func Test_timeCodec_Decode(t *testing.T) {
Expand Down Expand Up @@ -197,29 +176,6 @@ func Test_timeCodec_Decode(t *testing.T) {
}
})
}
for _, version := range primitive.SupportedProtocolVersionsLesserThan(primitive.ProtocolVersion4) {
t.Run(version.String(), func(t *testing.T) {
tests := []struct {
name string
source []byte
dest interface{}
expected interface{}
wasNull bool
err string
}{
{"null", nil, new(int64), new(int64), true, "data type time not supported"},
{"non null", timeSimpleBytes, new(time.Time), new(time.Time), false, "data type time not supported"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
wasNull, err := Time.Decode(tt.source, tt.dest, version)
assert.Equal(t, tt.expected, tt.dest)
assert.Equal(t, tt.wasNull, wasNull)
assertErrorMessage(t, tt.err, err)
})
}
})
}
}

func Test_convertToInt64Time(t *testing.T) {
Expand Down
Loading

0 comments on commit 605a850

Please sign in to comment.