From 605a850e203b5b7d8c7914e875e8576cdd722ea0 Mon Sep 17 00:00:00 2001 From: Michael Penick Date: Tue, 3 Sep 2024 10:01:33 -0400 Subject: [PATCH] Fix: Remove data type protocol version checks (#60) Checking whether a protocol version supports a particular type is not correct. Cassandra itself supports UDT, tuple, date, time, etc. while using protocol V3. --- datacodec/date.go | 23 ++++++------------- datacodec/date_test.go | 43 ----------------------------------- datacodec/duration.go | 23 ++++++------------- datacodec/duration_test.go | 43 ----------------------------------- datacodec/smallint.go | 23 ++++++------------- datacodec/smallint_test.go | 43 ----------------------------------- datacodec/time.go | 23 ++++++------------- datacodec/time_test.go | 44 ------------------------------------ datacodec/tinyint.go | 23 ++++++------------- datacodec/tinyint_test.go | 43 ----------------------------------- datacodec/tuple.go | 20 +++++------------ datacodec/tuple_test.go | 15 ------------- datacodec/udt.go | 20 +++++------------ datacodec/udt_test.go | 16 ------------- datatype/tuple_test.go | 46 +------------------------------------- datatype/udt_test.go | 42 ---------------------------------- primitive/constants.go | 44 ------------------------------------ primitive/util.go | 2 +- 18 files changed, 49 insertions(+), 487 deletions(-) diff --git a/datacodec/date.go b/datacodec/date.go index 75b9cb4..6ad74a2 100644 --- a/datacodec/date.go +++ b/datacodec/date.go @@ -85,14 +85,10 @@ func (c *dateCodec) DataType() datatype.DataType { // Note that this relies on the fact that some additions will overflow: this is expected. func (c *dateCodec) Encode(source interface{}, version primitive.ProtocolVersion) (dest []byte, err error) { - if !version.SupportsDataType(c.DataType().Code()) { - err = errDataTypeNotSupported(c.DataType(), version) - } else { - var val int32 - var wasNil bool - if val, wasNil, err = convertToInt32Date(source, c.layout); err == nil && !wasNil { - dest = writeInt32(val - math.MinInt32) - } + var val int32 + var wasNil bool + if val, wasNil, err = convertToInt32Date(source, c.layout); err == nil && !wasNil { + dest = writeInt32(val - math.MinInt32) } if err != nil { err = errCannotEncode(source, c.DataType(), version, err) @@ -101,14 +97,9 @@ func (c *dateCodec) Encode(source interface{}, version primitive.ProtocolVersion } func (c *dateCodec) Decode(source []byte, dest interface{}, version primitive.ProtocolVersion) (wasNull bool, err error) { - if !version.SupportsDataType(c.DataType().Code()) { - wasNull = len(source) == 0 - err = errDataTypeNotSupported(c.DataType(), version) - } else { - var val int32 - if val, wasNull, err = readInt32(source); err == nil { - err = convertFromInt32Date(val+math.MinInt32, wasNull, c.layout, dest) - } + var val int32 + if val, wasNull, err = readInt32(source); err == nil { + err = convertFromInt32Date(val+math.MinInt32, wasNull, c.layout, dest) } if err != nil { err = errCannotDecode(dest, c.DataType(), version, err) diff --git a/datacodec/date_test.go b/datacodec/date_test.go index 92a6aeb..b4fe677 100644 --- a/datacodec/date_test.go +++ b/datacodec/date_test.go @@ -102,26 +102,6 @@ func Test_dateCodec_Encode(t *testing.T) { } }) } - for _, version := range primitive.SupportedProtocolVersionsLesserThan(primitive.ProtocolVersion4) { - t.Run(version.String(), func(t *testing.T) { - tests := []struct { - name string - source interface{} - expected []byte - err string - }{ - {"nil", nil, nil, "data type date not supported"}, - {"non nil", datePos, nil, "data type date not supported"}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - actual, err := Date.Encode(tt.source, version) - assert.Equal(t, tt.expected, actual) - assertErrorMessage(t, tt.err, err) - }) - } - }) - } } func Test_dateCodec_Decode(t *testing.T) { @@ -151,29 +131,6 @@ func Test_dateCodec_Decode(t *testing.T) { } }) } - for _, version := range primitive.SupportedProtocolVersionsLesserThan(primitive.ProtocolVersion4) { - t.Run(version.String(), func(t *testing.T) { - tests := []struct { - name string - source []byte - dest interface{} - expected interface{} - wasNull bool - err string - }{ - {"null", nil, new(int32), new(int32), true, "data type date not supported"}, - {"non null", datePosBytes, new(time.Time), new(time.Time), false, "data type date not supported"}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - wasNull, err := Date.Decode(tt.source, tt.dest, version) - assert.Equal(t, tt.expected, tt.dest) - assert.Equal(t, tt.wasNull, wasNull) - assertErrorMessage(t, tt.err, err) - }) - } - }) - } } func Test_convertToInt32Date(t *testing.T) { diff --git a/datacodec/duration.go b/datacodec/duration.go index 2f292a5..b28c577 100644 --- a/datacodec/duration.go +++ b/datacodec/duration.go @@ -46,14 +46,10 @@ func (c *durationCodec) DataType() datatype.DataType { } func (c *durationCodec) Encode(source interface{}, version primitive.ProtocolVersion) (dest []byte, err error) { - if !version.SupportsDataType(c.DataType().Code()) { - err = errDataTypeNotSupported(c.DataType(), version) - } else { - var val CqlDuration - var wasNil bool - if val, wasNil, err = convertToDuration(source); err == nil && !wasNil { - dest = writeDuration(val) - } + var val CqlDuration + var wasNil bool + if val, wasNil, err = convertToDuration(source); err == nil && !wasNil { + dest = writeDuration(val) } if err != nil { err = errCannotEncode(source, c.DataType(), version, err) @@ -62,14 +58,9 @@ func (c *durationCodec) Encode(source interface{}, version primitive.ProtocolVer } func (c *durationCodec) Decode(source []byte, dest interface{}, version primitive.ProtocolVersion) (wasNull bool, err error) { - if !version.SupportsDataType(c.DataType().Code()) { - wasNull = len(source) == 0 - err = errDataTypeNotSupported(c.DataType(), version) - } else { - var val CqlDuration - if val, wasNull, err = readDuration(source); err == nil { - err = convertFromDuration(val, wasNull, dest) - } + var val CqlDuration + if val, wasNull, err = readDuration(source); err == nil { + err = convertFromDuration(val, wasNull, dest) } if err != nil { err = errCannotDecode(dest, c.DataType(), version, err) diff --git a/datacodec/duration_test.go b/datacodec/duration_test.go index 17a3e1b..f8c1993 100644 --- a/datacodec/duration_test.go +++ b/datacodec/duration_test.go @@ -67,26 +67,6 @@ func Test_durationCodec_Encode(t *testing.T) { } }) } - for _, version := range primitive.SupportedProtocolVersionsLesserThan(primitive.ProtocolVersion5) { - t.Run(version.String(), func(t *testing.T) { - tests := []struct { - name string - source interface{} - expected []byte - err string - }{ - {"null", nil, nil, "data type duration not supported"}, - {"non null", cqlDurationPos, nil, "data type duration not supported"}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - actual, err := Duration.Encode(tt.source, version) - assert.Equal(t, tt.expected, actual) - assertErrorMessage(t, tt.err, err) - }) - } - }) - } } func Test_durationCodec_Decode(t *testing.T) { @@ -116,29 +96,6 @@ func Test_durationCodec_Decode(t *testing.T) { } }) } - for _, version := range primitive.SupportedProtocolVersionsLesserThan(primitive.ProtocolVersion5) { - t.Run(version.String(), func(t *testing.T) { - tests := []struct { - name string - source []byte - dest interface{} - expected interface{} - wasNull bool - err string - }{ - {"null", nil, new(CqlDuration), new(CqlDuration), true, "data type duration not supported"}, - {"non null", cqlDurationPosBytes, new(CqlDuration), new(CqlDuration), false, "data type duration not supported"}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - wasNull, err := Duration.Decode(tt.source, tt.dest, version) - assert.Equal(t, tt.expected, tt.dest) - assert.Equal(t, tt.wasNull, wasNull) - assertErrorMessage(t, tt.err, err) - }) - } - }) - } } func Test_convertToDuration(t *testing.T) { diff --git a/datacodec/smallint.go b/datacodec/smallint.go index c3ecfa7..5766ab1 100644 --- a/datacodec/smallint.go +++ b/datacodec/smallint.go @@ -33,14 +33,10 @@ func (c *smallintCodec) DataType() datatype.DataType { } func (c *smallintCodec) Encode(source interface{}, version primitive.ProtocolVersion) (dest []byte, err error) { - if !version.SupportsDataType(c.DataType().Code()) { - err = errDataTypeNotSupported(c.DataType(), version) - } else { - var val int16 - var wasNil bool - if val, wasNil, err = convertToInt16(source); err == nil && !wasNil { - dest = writeInt16(val) - } + var val int16 + var wasNil bool + if val, wasNil, err = convertToInt16(source); err == nil && !wasNil { + dest = writeInt16(val) } if err != nil { err = errCannotEncode(source, c.DataType(), version, err) @@ -49,14 +45,9 @@ func (c *smallintCodec) Encode(source interface{}, version primitive.ProtocolVer } func (c *smallintCodec) Decode(source []byte, dest interface{}, version primitive.ProtocolVersion) (wasNull bool, err error) { - if !version.SupportsDataType(c.DataType().Code()) { - wasNull = len(source) == 0 - err = errDataTypeNotSupported(c.DataType(), version) - } else { - var val int16 - if val, wasNull, err = readInt16(source); err == nil { - err = convertFromInt16(val, wasNull, dest) - } + var val int16 + if val, wasNull, err = readInt16(source); err == nil { + err = convertFromInt16(val, wasNull, dest) } if err != nil { err = errCannotDecode(dest, c.DataType(), version, err) diff --git a/datacodec/smallint_test.go b/datacodec/smallint_test.go index 75d9b9a..2970e5d 100644 --- a/datacodec/smallint_test.go +++ b/datacodec/smallint_test.go @@ -56,26 +56,6 @@ func Test_smallintCodec_Encode(t *testing.T) { } }) } - for _, version := range primitive.SupportedProtocolVersionsLesserThan(primitive.ProtocolVersion4) { - t.Run(version.String(), func(t *testing.T) { - tests := []struct { - name string - source interface{} - expected []byte - err string - }{ - {"nil", int16NilPtr(), nil, "data type smallint not supported"}, - {"non nil", 1, nil, "data type smallint not supported"}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - actual, err := Smallint.Encode(tt.source, version) - assert.Equal(t, tt.expected, actual) - assertErrorMessage(t, tt.err, err) - }) - } - }) - } } func Test_smallintCodec_Decode(t *testing.T) { @@ -105,29 +85,6 @@ func Test_smallintCodec_Decode(t *testing.T) { } }) } - for _, version := range primitive.SupportedProtocolVersionsLesserThan(primitive.ProtocolVersion4) { - t.Run(version.String(), func(t *testing.T) { - tests := []struct { - name string - source []byte - dest interface{} - expected interface{} - wasNull bool - err string - }{ - {"null", nil, new(int16), new(int16), true, "data type smallint not supported"}, - {"non null", smallintOneBytes, new(int16), new(int16), false, "data type smallint not supported"}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - wasNull, err := Smallint.Decode(tt.source, tt.dest, version) - assert.Equal(t, tt.expected, tt.dest) - assert.Equal(t, tt.wasNull, wasNull) - assertErrorMessage(t, tt.err, err) - }) - } - }) - } } func Test_convertToInt16(t *testing.T) { diff --git a/datacodec/time.go b/datacodec/time.go index d97c450..3e0db18 100644 --- a/datacodec/time.go +++ b/datacodec/time.go @@ -99,14 +99,10 @@ func (c *timeCodec) DataType() datatype.DataType { } func (c *timeCodec) Encode(source interface{}, version primitive.ProtocolVersion) (dest []byte, err error) { - if !version.SupportsDataType(c.DataType().Code()) { - err = errDataTypeNotSupported(c.DataType(), version) - } else { - var val int64 - var wasNil bool - if val, wasNil, err = convertToInt64Time(source, c.layout); err == nil && !wasNil { - dest = writeInt64(val) - } + var val int64 + var wasNil bool + if val, wasNil, err = convertToInt64Time(source, c.layout); err == nil && !wasNil { + dest = writeInt64(val) } if err != nil { err = errCannotEncode(source, c.DataType(), version, err) @@ -115,14 +111,9 @@ func (c *timeCodec) Encode(source interface{}, version primitive.ProtocolVersion } func (c *timeCodec) Decode(source []byte, dest interface{}, version primitive.ProtocolVersion) (wasNull bool, err error) { - if !version.SupportsDataType(c.DataType().Code()) { - wasNull = len(source) == 0 - err = errDataTypeNotSupported(c.DataType(), version) - } else { - var val int64 - if val, wasNull, err = readInt64(source); err == nil { - err = convertFromInt64Time(val, wasNull, dest, c.layout) - } + var val int64 + if val, wasNull, err = readInt64(source); err == nil { + err = convertFromInt64Time(val, wasNull, dest, c.layout) } if err != nil { err = errCannotDecode(dest, c.DataType(), version, err) diff --git a/datacodec/time_test.go b/datacodec/time_test.go index 269bb33..4e8807e 100644 --- a/datacodec/time_test.go +++ b/datacodec/time_test.go @@ -147,27 +147,6 @@ func Test_timeCodec_Encode(t *testing.T) { } }) } - for _, version := range primitive.SupportedProtocolVersionsLesserThan(primitive.ProtocolVersion4) { - t.Run(version.String(), func(t *testing.T) { - tests := []struct { - name string - source interface{} - expected []byte - err string - }{ - {"nil", nil, nil, "data type time not supported"}, - {"non nil", timeSimple, nil, "data type time not supported"}, - {"conversion failed", TimeMaxDuration + 1, nil, "data type time not supported"}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - actual, err := Time.Encode(tt.source, version) - assert.Equal(t, tt.expected, actual) - assertErrorMessage(t, tt.err, err) - }) - } - }) - } } func Test_timeCodec_Decode(t *testing.T) { @@ -197,29 +176,6 @@ func Test_timeCodec_Decode(t *testing.T) { } }) } - for _, version := range primitive.SupportedProtocolVersionsLesserThan(primitive.ProtocolVersion4) { - t.Run(version.String(), func(t *testing.T) { - tests := []struct { - name string - source []byte - dest interface{} - expected interface{} - wasNull bool - err string - }{ - {"null", nil, new(int64), new(int64), true, "data type time not supported"}, - {"non null", timeSimpleBytes, new(time.Time), new(time.Time), false, "data type time not supported"}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - wasNull, err := Time.Decode(tt.source, tt.dest, version) - assert.Equal(t, tt.expected, tt.dest) - assert.Equal(t, tt.wasNull, wasNull) - assertErrorMessage(t, tt.err, err) - }) - } - }) - } } func Test_convertToInt64Time(t *testing.T) { diff --git a/datacodec/tinyint.go b/datacodec/tinyint.go index 598a360..399460e 100644 --- a/datacodec/tinyint.go +++ b/datacodec/tinyint.go @@ -33,14 +33,10 @@ func (c *tinyintCodec) DataType() datatype.DataType { } func (c *tinyintCodec) Encode(source interface{}, version primitive.ProtocolVersion) (dest []byte, err error) { - if !version.SupportsDataType(c.DataType().Code()) { - err = errDataTypeNotSupported(c.DataType(), version) - } else { - var val int8 - var wasNil bool - if val, wasNil, err = convertToInt8(source); err == nil && !wasNil { - dest = writeInt8(val) - } + var val int8 + var wasNil bool + if val, wasNil, err = convertToInt8(source); err == nil && !wasNil { + dest = writeInt8(val) } if err != nil { err = errCannotEncode(source, c.DataType(), version, err) @@ -49,14 +45,9 @@ func (c *tinyintCodec) Encode(source interface{}, version primitive.ProtocolVers } func (c *tinyintCodec) Decode(source []byte, dest interface{}, version primitive.ProtocolVersion) (wasNull bool, err error) { - if !version.SupportsDataType(c.DataType().Code()) { - wasNull = len(source) == 0 - err = errDataTypeNotSupported(c.DataType(), version) - } else { - var val int8 - if val, wasNull, err = readInt8(source); err == nil { - err = convertFromInt8(val, wasNull, dest) - } + var val int8 + if val, wasNull, err = readInt8(source); err == nil { + err = convertFromInt8(val, wasNull, dest) } if err != nil { err = errCannotDecode(dest, c.DataType(), version, err) diff --git a/datacodec/tinyint_test.go b/datacodec/tinyint_test.go index 9b54753..33da1a5 100644 --- a/datacodec/tinyint_test.go +++ b/datacodec/tinyint_test.go @@ -56,26 +56,6 @@ func Test_tinyintCodec_Encode(t *testing.T) { } }) } - for _, version := range primitive.SupportedProtocolVersionsLesserThan(primitive.ProtocolVersion4) { - t.Run(version.String(), func(t *testing.T) { - tests := []struct { - name string - source interface{} - expected []byte - err string - }{ - {"nil", int8NilPtr(), nil, "data type tinyint not supported"}, - {"non nil", 1, nil, "data type tinyint not supported"}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - actual, err := Tinyint.Encode(tt.source, version) - assert.Equal(t, tt.expected, actual) - assertErrorMessage(t, tt.err, err) - }) - } - }) - } } func Test_tinyintCodec_Decode(t *testing.T) { @@ -105,29 +85,6 @@ func Test_tinyintCodec_Decode(t *testing.T) { } }) } - for _, version := range primitive.SupportedProtocolVersionsLesserThan(primitive.ProtocolVersion4) { - t.Run(version.String(), func(t *testing.T) { - tests := []struct { - name string - source []byte - dest interface{} - expected interface{} - wasNull bool - err string - }{ - {"null", nil, new(int8), new(int8), true, "data type tinyint not supported"}, - {"non null", tinyintOneBytes, new(int8), new(int8), false, "data type tinyint not supported"}, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - wasNull, err := Tinyint.Decode(tt.source, tt.dest, version) - assert.Equal(t, tt.expected, tt.dest) - assert.Equal(t, tt.wasNull, wasNull) - assertErrorMessage(t, tt.err, err) - }) - } - }) - } } func Test_convertToInt8(t *testing.T) { diff --git a/datacodec/tuple.go b/datacodec/tuple.go index 211dfad..0e04090 100644 --- a/datacodec/tuple.go +++ b/datacodec/tuple.go @@ -48,13 +48,9 @@ func (c *tupleCodec) DataType() datatype.DataType { } func (c *tupleCodec) Encode(source interface{}, version primitive.ProtocolVersion) (dest []byte, err error) { - if !version.SupportsDataType(c.DataType().Code()) { - err = errDataTypeNotSupported(c.DataType(), version) - } else { - var ext extractor - if ext, err = c.createExtractor(source); err == nil && ext != nil { - dest, err = writeTuple(ext, c.elementCodecs, version) - } + var ext extractor + if ext, err = c.createExtractor(source); err == nil && ext != nil { + dest, err = writeTuple(ext, c.elementCodecs, version) } if err != nil { err = errCannotEncode(source, c.DataType(), version, err) @@ -64,13 +60,9 @@ func (c *tupleCodec) Encode(source interface{}, version primitive.ProtocolVersio func (c *tupleCodec) Decode(source []byte, dest interface{}, version primitive.ProtocolVersion) (wasNull bool, err error) { wasNull = len(source) == 0 - if !version.SupportsDataType(c.DataType().Code()) { - err = errDataTypeNotSupported(c.DataType(), version) - } else { - var inj injector - if inj, err = c.createInjector(dest, wasNull); err == nil && inj != nil { - err = readTuple(source, inj, c.elementCodecs, version) - } + var inj injector + if inj, err = c.createInjector(dest, wasNull); err == nil && inj != nil { + err = readTuple(source, inj, c.elementCodecs, version) } if err != nil { err = errCannotDecode(dest, c.DataType(), version, err) diff --git a/datacodec/tuple_test.go b/datacodec/tuple_test.go index adaead9..ff4f6bc 100644 --- a/datacodec/tuple_test.go +++ b/datacodec/tuple_test.go @@ -525,14 +525,6 @@ func Test_tupleCodec_Encode(t *testing.T) { }) }) } - for _, version := range primitive.SupportedProtocolVersionsLesserThan(primitive.ProtocolVersion3) { - t.Run(version.String(), func(t *testing.T) { - codec, _ := NewTuple(datatype.NewTuple(datatype.Int)) - dest, err := codec.Encode(nil, version) - assert.Nil(t, dest) - assertErrorMessage(t, "data type tuple not supported in "+version.String(), err) - }) - } t.Run("invalid types", func(t *testing.T) { dest, err := tupleCodecSimple.Encode(123, primitive.ProtocolVersion5) assert.Nil(t, dest) @@ -853,13 +845,6 @@ func Test_tupleCodec_Decode(t *testing.T) { }) }) } - for _, version := range primitive.SupportedProtocolVersionsLesserThan(primitive.ProtocolVersion3) { - t.Run(version.String(), func(t *testing.T) { - codec, _ := NewTuple(datatype.NewTuple(datatype.Int)) - _, err := codec.Decode(nil, nil, version) - assertErrorMessage(t, "data type tuple not supported in "+version.String(), err) - }) - } t.Run("invalid types", func(t *testing.T) { wasNull, err := tupleCodecSimple.Decode([]byte{1, 2, 3}, new(int), primitive.ProtocolVersion5) assert.False(t, wasNull) diff --git a/datacodec/udt.go b/datacodec/udt.go index 718fcea..7a00b1c 100644 --- a/datacodec/udt.go +++ b/datacodec/udt.go @@ -48,13 +48,9 @@ func (c *udtCodec) DataType() datatype.DataType { } func (c *udtCodec) Encode(source interface{}, version primitive.ProtocolVersion) (dest []byte, err error) { - if !version.SupportsDataType(c.DataType().Code()) { - err = errDataTypeNotSupported(c.DataType(), version) - } else { - var ext extractor - if ext, err = c.createExtractor(source); err == nil && ext != nil { - dest, err = writeUdt(ext, c.dataType.FieldNames, c.fieldCodecs, version) - } + var ext extractor + if ext, err = c.createExtractor(source); err == nil && ext != nil { + dest, err = writeUdt(ext, c.dataType.FieldNames, c.fieldCodecs, version) } if err != nil { err = errCannotEncode(source, c.DataType(), version, err) @@ -64,13 +60,9 @@ func (c *udtCodec) Encode(source interface{}, version primitive.ProtocolVersion) func (c *udtCodec) Decode(source []byte, dest interface{}, version primitive.ProtocolVersion) (wasNull bool, err error) { wasNull = len(source) == 0 - if !version.SupportsDataType(c.DataType().Code()) { - err = errDataTypeNotSupported(c.DataType(), version) - } else { - var inj injector - if inj, err = c.createInjector(dest, wasNull); err == nil && inj != nil { - err = readUdt(source, inj, c.dataType.FieldNames, c.fieldCodecs, version) - } + var inj injector + if inj, err = c.createInjector(dest, wasNull); err == nil && inj != nil { + err = readUdt(source, inj, c.dataType.FieldNames, c.fieldCodecs, version) } if err != nil { err = errCannotDecode(dest, c.DataType(), version, err) diff --git a/datacodec/udt_test.go b/datacodec/udt_test.go index 81a40fc..a5a1569 100644 --- a/datacodec/udt_test.go +++ b/datacodec/udt_test.go @@ -16,7 +16,6 @@ package datacodec import ( "errors" - "fmt" "testing" "github.com/stretchr/testify/assert" @@ -415,14 +414,6 @@ func Test_udtCodec_Encode(t *testing.T) { }) }) } - for _, version := range primitive.SupportedProtocolVersionsLesserThan(primitive.ProtocolVersion3) { - t.Run(version.String(), func(t *testing.T) { - dest, err := udtCodecSimple.Encode(nil, version) - assert.Nil(t, dest) - expectedMessage := fmt.Sprintf("data type %s not supported in %v", udtTypeSimple, version) - assertErrorMessage(t, expectedMessage, err) - }) - } t.Run("invalid types", func(t *testing.T) { dest, err := udtCodecSimple.Encode(123, primitive.ProtocolVersion5) assert.Nil(t, dest) @@ -687,13 +678,6 @@ func Test_udtCodec_Decode(t *testing.T) { }) }) } - for _, version := range primitive.SupportedProtocolVersionsLesserThan(primitive.ProtocolVersion3) { - t.Run(version.String(), func(t *testing.T) { - _, err := udtCodecSimple.Decode(nil, nil, version) - expectedMessage := fmt.Sprintf("data type %s not supported in %v", udtTypeSimple, version) - assertErrorMessage(t, expectedMessage, err) - }) - } t.Run("invalid types", func(t *testing.T) { wasNull, err := udtCodecSimple.Decode([]byte{1, 2, 3}, new(int), primitive.ProtocolVersion5) assert.False(t, wasNull) diff --git a/datatype/tuple_test.go b/datatype/tuple_test.go index 149a70e..52b4234 100644 --- a/datatype/tuple_test.go +++ b/datatype/tuple_test.go @@ -20,10 +20,8 @@ import ( "fmt" "testing" - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - "github.com/datastax/go-cassandra-native-protocol/primitive" + "github.com/stretchr/testify/assert" ) func TestTupleType(t *testing.T) { @@ -135,29 +133,6 @@ func TestWriteTupleType(t *testing.T) { }) } }) - - t.Run("versions_without_tuple_support", func(t *testing.T) { - for _, version := range primitive.SupportedProtocolVersionsLesserThan(primitive.ProtocolVersion3) { - t.Run(version.String(), func(t *testing.T) { - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - var dest = &bytes.Buffer{} - var err error - err = WriteDataType(test.input, dest, version) - actual := dest.Bytes() - require.NotNil(t, err) - if test.err != nil { - assert.Equal(t, test.err, err) - } else { - assert.Contains(t, err.Error(), - fmt.Sprintf("invalid data type code for %s: DataTypeCode Tuple", version)) - } - assert.Equal(t, 0, len(actual)) - }) - } - }) - } - }) } func TestLengthOfTupleType(t *testing.T) { @@ -259,25 +234,6 @@ func TestReadTupleType(t *testing.T) { }) } }) - - t.Run("versions_without_tuple_support", func(t *testing.T) { - for _, version := range primitive.SupportedProtocolVersionsLesserThan(primitive.ProtocolVersion3) { - t.Run(version.String(), func(t *testing.T) { - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - var source = bytes.NewBuffer(test.input) - var actual DataType - var err error - actual, err = ReadDataType(source, version) - require.NotNil(t, err) - assert.Contains(t, err.Error(), - fmt.Sprintf("invalid data type code for %s: DataTypeCode Tuple", version)) - assert.Nil(t, actual) - }) - } - }) - } - }) } func Test_tupleType_String(t1 *testing.T) { diff --git a/datatype/udt_test.go b/datatype/udt_test.go index 5d82e22..71ffa16 100644 --- a/datatype/udt_test.go +++ b/datatype/udt_test.go @@ -175,29 +175,6 @@ func TestWriteUserDefinedType(t *testing.T) { }) } }) - - t.Run("versions_without_udt_support", func(t *testing.T) { - for _, version := range primitive.SupportedProtocolVersionsLesserThan(primitive.ProtocolVersion3) { - t.Run(version.String(), func(t *testing.T) { - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - var dest = &bytes.Buffer{} - var err error - err = WriteDataType(test.input, dest, version) - actual := dest.Bytes() - require.NotNil(t, err) - if test.err != nil { - assert.Equal(t, test.err, err) - } else { - assert.Contains(t, err.Error(), - fmt.Sprintf("invalid data type code for %s: DataTypeCode Udt", version)) - } - assert.Equal(t, 0, len(actual)) - }) - } - }) - } - }) } func TestLengthOfUserDefinedType(t *testing.T) { @@ -322,25 +299,6 @@ func TestReadUserDefinedType(t *testing.T) { }) } }) - - t.Run("versions_without_udt_support", func(t *testing.T) { - for _, version := range primitive.SupportedProtocolVersionsLesserThan(primitive.ProtocolVersion3) { - t.Run(version.String(), func(t *testing.T) { - for _, test := range tests { - t.Run(test.name, func(t *testing.T) { - var source = bytes.NewBuffer(test.input) - var actual DataType - var err error - actual, err = ReadDataType(source, version) - require.NotNil(t, err) - assert.Contains(t, err.Error(), - fmt.Sprintf("invalid data type code for %s: DataTypeCode Udt", version)) - assert.Nil(t, actual) - }) - } - }) - } - }) } func Test_userDefinedType_String(t1 *testing.T) { diff --git a/primitive/constants.go b/primitive/constants.go index f5ba191..0bf3515 100644 --- a/primitive/constants.go +++ b/primitive/constants.go @@ -156,50 +156,6 @@ func (v ProtocolVersion) SupportsWriteTimeoutContentions() bool { return v >= ProtocolVersion5 && v != ProtocolVersionDse1 && v != ProtocolVersionDse2 } -func (v ProtocolVersion) SupportsDataType(code DataTypeCode) bool { - switch code { - case DataTypeCodeCustom: - case DataTypeCodeAscii: - case DataTypeCodeBigint: - case DataTypeCodeBlob: - case DataTypeCodeBoolean: - case DataTypeCodeCounter: - case DataTypeCodeDecimal: - case DataTypeCodeDouble: - case DataTypeCodeFloat: - case DataTypeCodeInt: - case DataTypeCodeTimestamp: - case DataTypeCodeUuid: - case DataTypeCodeVarchar: - case DataTypeCodeVarint: - case DataTypeCodeTimeuuid: - case DataTypeCodeInet: - case DataTypeCodeList: - case DataTypeCodeMap: - case DataTypeCodeSet: - case DataTypeCodeText: - return v <= ProtocolVersion2 // removed in version 3 - case DataTypeCodeUdt: - return v >= ProtocolVersion3 - case DataTypeCodeTuple: - return v >= ProtocolVersion3 - case DataTypeCodeDate: - return v >= ProtocolVersion4 - case DataTypeCodeTime: - return v >= ProtocolVersion4 - case DataTypeCodeSmallint: - return v >= ProtocolVersion4 - case DataTypeCodeTinyint: - return v >= ProtocolVersion4 - case DataTypeCodeDuration: - return v >= ProtocolVersion5 - default: - // Unknown code - return false - } - return true -} - func (v ProtocolVersion) SupportsSchemaChangeTarget(target SchemaChangeTarget) bool { switch target { case SchemaChangeTargetKeyspace: diff --git a/primitive/util.go b/primitive/util.go index a867c19..6e43442 100644 --- a/primitive/util.go +++ b/primitive/util.go @@ -150,7 +150,7 @@ func CheckValidBatchType(batchType BatchType) error { } func CheckValidDataTypeCode(code DataTypeCode, version ProtocolVersion) error { - if !code.IsValid() || !version.SupportsDataType(code) { + if !code.IsValid() { return fmt.Errorf("invalid data type code for %v: %v", version, code) } return nil