From e14dad03285fedd14a7b9682ff64468e297d5b46 Mon Sep 17 00:00:00 2001 From: "mykyta.oleksiienko" Date: Wed, 8 Jan 2025 16:39:15 +0200 Subject: [PATCH] CASSGO-29-refactor-packages-to-separate-fronend-api-to-backend-internal-api --- conn.go | 8 +- control.go | 2 +- frame.go | 70 +--- helpers.go | 39 +- internal/frame.go | 50 +++ internal/helpers.go | 37 ++ internal/marshal.go | 803 ++++++++++++++++++++++++++++++++++++++ internal/session.go | 12 + marshal.go | 929 ++++---------------------------------------- marshal_test.go | 24 +- session.go | 12 +- session_test.go | 2 +- token_test.go | 6 +- 13 files changed, 1013 insertions(+), 981 deletions(-) create mode 100644 internal/frame.go create mode 100644 internal/helpers.go create mode 100644 internal/marshal.go create mode 100644 internal/session.go diff --git a/conn.go b/conn.go index ae02bd71c..e4eab2501 100644 --- a/conn.go +++ b/conn.go @@ -1276,7 +1276,7 @@ func (c *Conn) prepareStatement(ctx context.Context, stmt string, tracer Tracer) flight.preparedStatment = &preparedStatment{ // defensively copy as we will recycle the underlying buffer after we // return. - id: copyBytes(x.preparedID), + id: internal.CopyBytes(x.preparedID), // the type info's should _not_ have a reference to the framers read buffer, // therefore we can just copy them directly. request: x.reqMeta, @@ -1308,7 +1308,7 @@ func marshalQueryValue(typ TypeInfo, value interface{}, dst *queryValues) error value = named.value } - if _, ok := value.(unsetColumn); !ok { + if _, ok := value.(internal.UnsetColumn); !ok { val, err := Marshal(typ, value) if err != nil { return err @@ -1431,7 +1431,7 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter { if params.skipMeta { if info != nil { iter.meta = info.response - iter.meta.pagingState = copyBytes(x.meta.pagingState) + iter.meta.pagingState = internal.CopyBytes(x.meta.pagingState) } else { return &Iter{framer: framer, err: errors.New("gocql: did not receive metadata but prepared info is nil")} } @@ -1442,7 +1442,7 @@ func (c *Conn) executeQuery(ctx context.Context, qry *Query) *Iter { if x.meta.morePages() && !qry.disableAutoPage { newQry := new(Query) *newQry = *qry - newQry.pageState = copyBytes(x.meta.pagingState) + newQry.pageState = internal.CopyBytes(x.meta.pagingState) newQry.metrics = &queryMetrics{m: make(map[string]*hostMetrics)} iter.next = &nextIter{ diff --git a/control.go b/control.go index b30b44ea3..8d4416127 100644 --- a/control.go +++ b/control.go @@ -50,7 +50,7 @@ func init() { panic(fmt.Sprintf("unable to seed random number generator: %v", err)) } - randr = rand.New(rand.NewSource(int64(readInt(b)))) + randr = rand.New(rand.NewSource(int64(internal.ReadInt(b)))) } const ( diff --git a/frame.go b/frame.go index d374ae574..888cefe6d 100644 --- a/frame.go +++ b/frame.go @@ -28,6 +28,7 @@ import ( "context" "errors" "fmt" + "github.com/gocql/gocql/internal" "io" "io/ioutil" "net" @@ -36,8 +37,6 @@ import ( "time" ) -type unsetColumn struct{} - // UnsetValue represents a value used in a query binding that will be ignored by Cassandra. // // By setting a field to the unset value Cassandra will ignore the write completely. @@ -45,7 +44,7 @@ type unsetColumn struct{} // want to update some fields, where before you needed to make another prepared statement. // // UnsetValue is only available when using the version 4 of the protocol. -var UnsetValue = unsetColumn{} +var UnsetValue = internal.UnsetColumn{} type namedValue struct { name string @@ -331,10 +330,6 @@ var ( const maxFrameHeaderSize = 9 -func readInt(p []byte) int32 { - return int32(p[0])<<24 | int32(p[1])<<16 | int32(p[2])<<8 | int32(p[3]) -} - type frameHeader struct { version protoVersion flags byte @@ -474,7 +469,7 @@ func readHeader(r io.Reader, p []byte) (head frameHeader, err error) { head.stream = int(int16(p[2])<<8 | int16(p[3])) head.op = frameOp(p[4]) - head.length = int(readInt(p[5:])) + head.length = int(internal.ReadInt(p[5:])) } else { if len(p) != 8 { return frameHeader{}, fmt.Errorf("not enough bytes to read header require 8 got: %d", len(p)) @@ -482,7 +477,7 @@ func readHeader(r io.Reader, p []byte) (head frameHeader, err error) { head.stream = int(int8(p[2])) head.op = frameOp(p[3]) - head.length = int(readInt(p[4:])) + head.length = int(internal.ReadInt(p[4:])) } return head, nil @@ -647,7 +642,7 @@ func (f *framer) parseErrorFrame() frame { stmtId := f.readShortBytes() return &RequestErrUnprepared{ errorFrame: errD, - StatementId: copyBytes(stmtId), // defensively copy + StatementId: internal.CopyBytes(stmtId), // defensively copy } case ErrCodeReadFailure: res := &RequestErrReadFailure{ @@ -969,7 +964,7 @@ func (f *framer) parsePreparedMetadata() preparedMetadata { } if meta.flags&flagHasMorePages == flagHasMorePages { - meta.pagingState = copyBytes(f.readBytes()) + meta.pagingState = internal.CopyBytes(f.readBytes()) } if meta.flags&flagNoMetaData == flagNoMetaData { @@ -1057,7 +1052,7 @@ func (f *framer) parseResultMetadata() resultMetadata { meta.actualColCount = meta.colCount if meta.flags&flagHasMorePages == flagHasMorePages { - meta.pagingState = copyBytes(f.readBytes()) + meta.pagingState = internal.CopyBytes(f.readBytes()) } if meta.flags&flagNoMetaData == flagNoMetaData { @@ -1940,49 +1935,6 @@ func (f *framer) writeByte(b byte) { f.buf = append(f.buf, b) } -func appendBytes(p []byte, d []byte) []byte { - if d == nil { - return appendInt(p, -1) - } - p = appendInt(p, int32(len(d))) - p = append(p, d...) - return p -} - -func appendShort(p []byte, n uint16) []byte { - return append(p, - byte(n>>8), - byte(n), - ) -} - -func appendInt(p []byte, n int32) []byte { - return append(p, byte(n>>24), - byte(n>>16), - byte(n>>8), - byte(n)) -} - -func appendUint(p []byte, n uint32) []byte { - return append(p, byte(n>>24), - byte(n>>16), - byte(n>>8), - byte(n)) -} - -func appendLong(p []byte, n int64) []byte { - return append(p, - byte(n>>56), - byte(n>>48), - byte(n>>40), - byte(n>>32), - byte(n>>24), - byte(n>>16), - byte(n>>8), - byte(n), - ) -} - func (f *framer) writeCustomPayload(customPayload *map[string][]byte) { if len(*customPayload) > 0 { if f.proto < protoVersion4 { @@ -1994,19 +1946,19 @@ func (f *framer) writeCustomPayload(customPayload *map[string][]byte) { // these are protocol level binary types func (f *framer) writeInt(n int32) { - f.buf = appendInt(f.buf, n) + f.buf = internal.AppendInt(f.buf, n) } func (f *framer) writeUint(n uint32) { - f.buf = appendUint(f.buf, n) + f.buf = internal.AppendUint(f.buf, n) } func (f *framer) writeShort(n uint16) { - f.buf = appendShort(f.buf, n) + f.buf = internal.AppendShort(f.buf, n) } func (f *framer) writeLong(n int64) { - f.buf = appendLong(f.buf, n) + f.buf = internal.AppendLong(f.buf, n) } func (f *framer) writeString(s string) { diff --git a/helpers.go b/helpers.go index f2faee9e0..a8a8bf7c2 100644 --- a/helpers.go +++ b/helpers.go @@ -26,6 +26,7 @@ package gocql import ( "fmt" + "github.com/gocql/gocql/internal" "math/big" "net" "reflect" @@ -176,7 +177,7 @@ func getCassandraType(name string, logger StdLogger) TypeInfo { Elem: getCassandraType(strings.TrimPrefix(name[:len(name)-1], "list<"), logger), } } else if strings.HasPrefix(name, "map<") { - names := splitCompositeTypes(strings.TrimPrefix(name[:len(name)-1], "map<")) + names := internal.SplitCompositeTypes(strings.TrimPrefix(name[:len(name)-1], "map<")) if len(names) != 2 { logger.Printf("Error parsing map type, it has %d subelements, expecting 2\n", len(names)) return NativeType{ @@ -189,7 +190,7 @@ func getCassandraType(name string, logger StdLogger) TypeInfo { Elem: getCassandraType(names[1], logger), } } else if strings.HasPrefix(name, "tuple<") { - names := splitCompositeTypes(strings.TrimPrefix(name[:len(name)-1], "tuple<")) + names := internal.SplitCompositeTypes(strings.TrimPrefix(name[:len(name)-1], "tuple<")) types := make([]TypeInfo, len(names)) for i, name := range names { @@ -207,34 +208,6 @@ func getCassandraType(name string, logger StdLogger) TypeInfo { } } -func splitCompositeTypes(name string) []string { - if !strings.Contains(name, "<") { - return strings.Split(name, ", ") - } - var parts []string - lessCount := 0 - segment := "" - for _, char := range name { - if char == ',' && lessCount == 0 { - if segment != "" { - parts = append(parts, strings.TrimSpace(segment)) - } - segment = "" - continue - } - segment += string(char) - if char == '<' { - lessCount++ - } else if char == '>' { - lessCount-- - } - } - if segment != "" { - parts = append(parts, strings.TrimSpace(segment)) - } - return parts -} - func apacheToCassandraType(t string) string { t = strings.Replace(t, apacheCassandraTypePrefix, "", -1) t = strings.Replace(t, "(", "<", -1) @@ -451,12 +424,6 @@ func (iter *Iter) MapScan(m map[string]interface{}) bool { return false } -func copyBytes(p []byte) []byte { - b := make([]byte, len(p)) - copy(b, p) - return b -} - var failDNS = false func LookupIP(host string) ([]net.IP, error) { diff --git a/internal/frame.go b/internal/frame.go new file mode 100644 index 000000000..458c32af8 --- /dev/null +++ b/internal/frame.go @@ -0,0 +1,50 @@ +package internal + +type UnsetColumn struct{} + +func ReadInt(p []byte) int32 { + return int32(p[0])<<24 | int32(p[1])<<16 | int32(p[2])<<8 | int32(p[3]) +} + +func AppendBytes(p []byte, d []byte) []byte { + if d == nil { + return AppendInt(p, -1) + } + p = AppendInt(p, int32(len(d))) + p = append(p, d...) + return p +} + +func AppendShort(p []byte, n uint16) []byte { + return append(p, + byte(n>>8), + byte(n), + ) +} + +func AppendInt(p []byte, n int32) []byte { + return append(p, byte(n>>24), + byte(n>>16), + byte(n>>8), + byte(n)) +} + +func AppendUint(p []byte, n uint32) []byte { + return append(p, byte(n>>24), + byte(n>>16), + byte(n>>8), + byte(n)) +} + +func AppendLong(p []byte, n int64) []byte { + return append(p, + byte(n>>56), + byte(n>>48), + byte(n>>40), + byte(n>>32), + byte(n>>24), + byte(n>>16), + byte(n>>8), + byte(n), + ) +} diff --git a/internal/helpers.go b/internal/helpers.go new file mode 100644 index 000000000..84498e12d --- /dev/null +++ b/internal/helpers.go @@ -0,0 +1,37 @@ +package internal + +import "strings" + +func SplitCompositeTypes(name string) []string { + if !strings.Contains(name, "<") { + return strings.Split(name, ", ") + } + var parts []string + lessCount := 0 + segment := "" + for _, char := range name { + if char == ',' && lessCount == 0 { + if segment != "" { + parts = append(parts, strings.TrimSpace(segment)) + } + segment = "" + continue + } + segment += string(char) + if char == '<' { + lessCount++ + } else if char == '>' { + lessCount-- + } + } + if segment != "" { + parts = append(parts, strings.TrimSpace(segment)) + } + return parts +} + +func CopyBytes(p []byte) []byte { + b := make([]byte, len(p)) + copy(b, p) + return b +} diff --git a/internal/marshal.go b/internal/marshal.go new file mode 100644 index 000000000..8ab09fb69 --- /dev/null +++ b/internal/marshal.go @@ -0,0 +1,803 @@ +package internal + +import ( + "encoding/binary" + "errors" + "fmt" + "gopkg.in/inf.v0" + "math" + "math/big" + "math/bits" + "net" + "reflect" + "strconv" + "time" +) + +var ( + bigOne = big.NewInt(1) + EmptyValue reflect.Value +) + +const MillisecondsInADay int64 = 24 * 60 * 60 * 1000 + +func EncInt(x int32) []byte { + return []byte{byte(x >> 24), byte(x >> 16), byte(x >> 8), byte(x)} +} + +func DecInt(x []byte) int32 { + if len(x) != 4 { + return 0 + } + return int32(x[0])<<24 | int32(x[1])<<16 | int32(x[2])<<8 | int32(x[3]) +} + +func EncShort(x int16) []byte { + p := make([]byte, 2) + p[0] = byte(x >> 8) + p[1] = byte(x) + return p +} + +func DecShort(p []byte) int16 { + if len(p) != 2 { + return 0 + } + return int16(p[0])<<8 | int16(p[1]) +} + +func DecTiny(p []byte) int8 { + if len(p) != 1 { + return 0 + } + return int8(p[0]) +} + +func EncBigInt(x int64) []byte { + return []byte{byte(x >> 56), byte(x >> 48), byte(x >> 40), byte(x >> 32), + byte(x >> 24), byte(x >> 16), byte(x >> 8), byte(x)} +} + +func BytesToInt64(data []byte) (ret int64) { + for i := range data { + ret |= int64(data[i]) << (8 * uint(len(data)-i-1)) + } + return ret +} + +func BytesToUint64(data []byte) (ret uint64) { + for i := range data { + ret |= uint64(data[i]) << (8 * uint(len(data)-i-1)) + } + return ret +} + +func DecBigInt(data []byte) int64 { + if len(data) != 8 { + return 0 + } + return int64(data[0])<<56 | int64(data[1])<<48 | + int64(data[2])<<40 | int64(data[3])<<32 | + int64(data[4])<<24 | int64(data[5])<<16 | + int64(data[6])<<8 | int64(data[7]) +} + +func EncBool(v bool) []byte { + if v { + return []byte{1} + } + return []byte{0} +} + +func DecBool(v []byte) bool { + if len(v) == 0 { + return false + } + return v[0] != 0 +} + +// decBigInt2C sets the value of n to the big-endian two's complement +// value stored in the given data. If data[0]&80 != 0, the number +// is negative. If data is empty, the result will be 0. +func DecBigInt2C(data []byte, n *big.Int) *big.Int { + if n == nil { + n = new(big.Int) + } + n.SetBytes(data) + if len(data) > 0 && data[0]&0x80 > 0 { + n.Sub(n, new(big.Int).Lsh(bigOne, uint(len(data))*8)) + } + return n +} + +// EncBigInt2C returns the big-endian two's complement +// form of n. +func EncBigInt2C(n *big.Int) []byte { + switch n.Sign() { + case 0: + return []byte{0} + case 1: + b := n.Bytes() + if b[0]&0x80 > 0 { + b = append([]byte{0}, b...) + } + return b + case -1: + length := uint(n.BitLen()/8+1) * 8 + b := new(big.Int).Add(n, new(big.Int).Lsh(bigOne, length)).Bytes() + // When the most significant bit is on a byte + // boundary, we can get some extra significant + // bits, so strip them off when that happens. + if len(b) >= 2 && b[0] == 0xff && b[1]&0x80 != 0 { + b = b[1:] + } + return b + } + return nil +} + +func DecVints(data []byte) (int32, int32, int64, error) { + month, i, err := DecVint(data, 0) + if err != nil { + return 0, 0, 0, fmt.Errorf("failed to extract month: %s", err.Error()) + } + days, i, err := DecVint(data, i) + if err != nil { + return 0, 0, 0, fmt.Errorf("failed to extract days: %s", err.Error()) + } + nanos, _, err := DecVint(data, i) + if err != nil { + return 0, 0, 0, fmt.Errorf("failed to extract nanoseconds: %s", err.Error()) + } + return int32(month), int32(days), nanos, err +} + +func DecVint(data []byte, start int) (int64, int, error) { + if len(data) <= start { + return 0, 0, errors.New("unexpected eof") + } + firstByte := data[start] + if firstByte&0x80 == 0 { + return decIntZigZag(uint64(firstByte)), start + 1, nil + } + numBytes := bits.LeadingZeros32(uint32(^firstByte)) - 24 + ret := uint64(firstByte & (0xff >> uint(numBytes))) + if len(data) < start+numBytes+1 { + return 0, 0, fmt.Errorf("data expect to have %d bytes, but it has only %d", start+numBytes+1, len(data)) + } + for i := start; i < start+numBytes; i++ { + ret <<= 8 + ret |= uint64(data[i+1] & 0xff) + } + return decIntZigZag(ret), start + numBytes + 1, nil +} + +func decIntZigZag(n uint64) int64 { + return int64((n >> 1) ^ -(n & 1)) +} + +func encIntZigZag(n int64) uint64 { + return uint64((n >> 63) ^ (n << 1)) +} + +func EncVints(months int32, seconds int32, nanos int64) []byte { + buf := append(EncVint(int64(months)), EncVint(int64(seconds))...) + return append(buf, EncVint(nanos)...) +} + +func EncVint(v int64) []byte { + vEnc := encIntZigZag(v) + lead0 := bits.LeadingZeros64(vEnc) + numBytes := (639 - lead0*9) >> 6 + + // It can be 1 or 0 is v ==0 + if numBytes <= 1 { + return []byte{byte(vEnc)} + } + extraBytes := numBytes - 1 + var buf = make([]byte, numBytes) + for i := extraBytes; i >= 0; i-- { + buf[i] = byte(vEnc) + vEnc >>= 8 + } + buf[0] |= byte(^(0xff >> uint(extraBytes))) + return buf +} + +// TODO: move to internal +func ReadBytes(p []byte) ([]byte, []byte) { + // TODO: really should use a framer + size := ReadInt(p) + p = p[4:] + if size < 0 { + return nil, p + } + return p[:size], p[size:] +} + +func MarshalVarchar(info, value interface{}) ([]byte, error) { + switch v := value.(type) { + //case Marshaler: + // return v.MarshalCQL(info) + case UnsetColumn: + return nil, nil + case string: + return []byte(v), nil + case []byte: + return v, nil + } + + if value == nil { + return nil, nil + } + + rv := reflect.ValueOf(value) + t := rv.Type() + k := t.Kind() + switch { + case k == reflect.String: + return []byte(rv.String()), nil + case k == reflect.Slice && t.Elem().Kind() == reflect.Uint8: + return rv.Bytes(), nil + } + return nil, fmt.Errorf("can not marshal %T into %s", value, info) +} + +func MarshalBool(info, value interface{}) ([]byte, error) { + switch v := value.(type) { + //case Marshaler: + // return v.MarshalCQL(info) + case UnsetColumn: + return nil, nil + case bool: + return EncBool(v), nil + } + + if value == nil { + return nil, nil + } + + rv := reflect.ValueOf(value) + switch rv.Type().Kind() { + case reflect.Bool: + return EncBool(rv.Bool()), nil + } + return nil, fmt.Errorf("can not marshal %T into %s", value, info) +} + +func MarshalTinyInt(info, value interface{}) ([]byte, error) { + switch v := value.(type) { + //case Marshaler: + // return v.MarshalCQL(info) + case UnsetColumn: + return nil, nil + case int8: + return []byte{byte(v)}, nil + case uint8: + return []byte{byte(v)}, nil + case int16: + if v > math.MaxInt8 || v < math.MinInt8 { + return nil, fmt.Errorf("marshal tinyint: value %d out of range", v) + } + return []byte{byte(v)}, nil + case uint16: + if v > math.MaxUint8 { + return nil, fmt.Errorf("marshal tinyint: value %d out of range", v) + } + return []byte{byte(v)}, nil + case int: + if v > math.MaxInt8 || v < math.MinInt8 { + return nil, fmt.Errorf("marshal tinyint: value %d out of range", v) + } + return []byte{byte(v)}, nil + case int32: + if v > math.MaxInt8 || v < math.MinInt8 { + return nil, fmt.Errorf("marshal tinyint: value %d out of range", v) + } + return []byte{byte(v)}, nil + case int64: + if v > math.MaxInt8 || v < math.MinInt8 { + return nil, fmt.Errorf("marshal tinyint: value %d out of range", v) + } + return []byte{byte(v)}, nil + case uint: + if v > math.MaxUint8 { + return nil, fmt.Errorf("marshal tinyint: value %d out of range", v) + } + return []byte{byte(v)}, nil + case uint32: + if v > math.MaxUint8 { + return nil, fmt.Errorf("marshal tinyint: value %d out of range", v) + } + return []byte{byte(v)}, nil + case uint64: + if v > math.MaxUint8 { + return nil, fmt.Errorf("marshal tinyint: value %d out of range", v) + } + return []byte{byte(v)}, nil + case string: + n, err := strconv.ParseInt(v, 10, 8) + if err != nil { + return nil, fmt.Errorf("can not marshal %T into %s: %v", value, info, err) + } + return []byte{byte(n)}, nil + } + + if value == nil { + return nil, nil + } + + switch rv := reflect.ValueOf(value); rv.Type().Kind() { + case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8: + v := rv.Int() + if v > math.MaxInt8 || v < math.MinInt8 { + return nil, fmt.Errorf("marshal tinyint: value %d out of range", v) + } + return []byte{byte(v)}, nil + case reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8: + v := rv.Uint() + if v > math.MaxUint8 { + return nil, fmt.Errorf("marshal tinyint: value %d out of range", v) + } + return []byte{byte(v)}, nil + case reflect.Ptr: + if rv.IsNil() { + return nil, nil + } + } + + return nil, fmt.Errorf("can not marshal %T into %s", value, info) +} + +func MarshalSmallInt(info, value interface{}) ([]byte, error) { + switch v := value.(type) { + //case Marshaler: + // return v.MarshalCQL(info) + case UnsetColumn: + return nil, nil + case int16: + return EncShort(v), nil + case uint16: + return EncShort(int16(v)), nil + case int8: + return EncShort(int16(v)), nil + case uint8: + return EncShort(int16(v)), nil + case int: + if v > math.MaxInt16 || v < math.MinInt16 { + return nil, fmt.Errorf("marshal smallint: value %d out of range", v) + } + return EncShort(int16(v)), nil + case int32: + if v > math.MaxInt16 || v < math.MinInt16 { + return nil, fmt.Errorf("marshal smallint: value %d out of range", v) + } + return EncShort(int16(v)), nil + case int64: + if v > math.MaxInt16 || v < math.MinInt16 { + return nil, fmt.Errorf("marshal smallint: value %d out of range", v) + } + return EncShort(int16(v)), nil + case uint: + if v > math.MaxUint16 { + return nil, fmt.Errorf("marshal smallint: value %d out of range", v) + } + return EncShort(int16(v)), nil + case uint32: + if v > math.MaxUint16 { + return nil, fmt.Errorf("marshal smallint: value %d out of range", v) + } + return EncShort(int16(v)), nil + case uint64: + if v > math.MaxUint16 { + return nil, fmt.Errorf("marshal smallint: value %d out of range", v) + } + return EncShort(int16(v)), nil + case string: + n, err := strconv.ParseInt(v, 10, 16) + if err != nil { + return nil, fmt.Errorf("can not marshal %T into %s: %v", value, info, err) + } + return EncShort(int16(n)), nil + } + + if value == nil { + return nil, nil + } + + switch rv := reflect.ValueOf(value); rv.Type().Kind() { + case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8: + v := rv.Int() + if v > math.MaxInt16 || v < math.MinInt16 { + return nil, fmt.Errorf("marshal smallint: value %d out of range", v) + } + return EncShort(int16(v)), nil + case reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8: + v := rv.Uint() + if v > math.MaxUint16 { + return nil, fmt.Errorf("marshal smallint: value %d out of range", v) + } + return EncShort(int16(v)), nil + case reflect.Ptr: + if rv.IsNil() { + return nil, nil + } + } + + return nil, fmt.Errorf("can not marshal %T into %s", value, info) +} + +func MarshalInt(info, value interface{}) ([]byte, error) { + switch v := value.(type) { + //case Marshaler: + // return v.MarshalCQL(info) + case UnsetColumn: + return nil, nil + case int: + if v > math.MaxInt32 || v < math.MinInt32 { + return nil, fmt.Errorf("marshal int: value %d out of range", v) + } + return EncInt(int32(v)), nil + case uint: + if v > math.MaxUint32 { + return nil, fmt.Errorf("marshal int: value %d out of range", v) + } + return EncInt(int32(v)), nil + case int64: + if v > math.MaxInt32 || v < math.MinInt32 { + return nil, fmt.Errorf("marshal int: value %d out of range", v) + } + return EncInt(int32(v)), nil + case uint64: + if v > math.MaxUint32 { + return nil, fmt.Errorf("marshal int: value %d out of range", v) + } + return EncInt(int32(v)), nil + case int32: + return EncInt(v), nil + case uint32: + return EncInt(int32(v)), nil + case int16: + return EncInt(int32(v)), nil + case uint16: + return EncInt(int32(v)), nil + case int8: + return EncInt(int32(v)), nil + case uint8: + return EncInt(int32(v)), nil + case string: + i, err := strconv.ParseInt(v, 10, 32) + if err != nil { + return nil, fmt.Errorf("can not marshal string to int: %s", err) + } + return EncInt(int32(i)), nil + } + + if value == nil { + return nil, nil + } + + switch rv := reflect.ValueOf(value); rv.Type().Kind() { + case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8: + v := rv.Int() + if v > math.MaxInt32 || v < math.MinInt32 { + return nil, fmt.Errorf("marshal int: value %d out of range", v) + } + return EncInt(int32(v)), nil + case reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8: + v := rv.Uint() + if v > math.MaxInt32 { + return nil, fmt.Errorf("marshal int: value %d out of range", v) + } + return EncInt(int32(v)), nil + case reflect.Ptr: + if rv.IsNil() { + return nil, nil + } + } + + return nil, fmt.Errorf("can not marshal %T into %s", value, info) +} + +func MarshalBigInt(info, value interface{}) ([]byte, error) { + switch v := value.(type) { + //case Marshaler: + // return v.MarshalCQL(info) + case UnsetColumn: + return nil, nil + case int: + return EncBigInt(int64(v)), nil + case uint: + if uint64(v) > math.MaxInt64 { + return nil, fmt.Errorf("marshal bigint: value %d out of range", v) + } + return EncBigInt(int64(v)), nil + case int64: + return EncBigInt(v), nil + case uint64: + return EncBigInt(int64(v)), nil + case int32: + return EncBigInt(int64(v)), nil + case uint32: + return EncBigInt(int64(v)), nil + case int16: + return EncBigInt(int64(v)), nil + case uint16: + return EncBigInt(int64(v)), nil + case int8: + return EncBigInt(int64(v)), nil + case uint8: + return EncBigInt(int64(v)), nil + case big.Int: + return EncBigInt2C(&v), nil + case string: + i, err := strconv.ParseInt(value.(string), 10, 64) + if err != nil { + return nil, fmt.Errorf("can not marshal string to bigint: %s", err) + } + return EncBigInt(i), nil + } + + if value == nil { + return nil, nil + } + + rv := reflect.ValueOf(value) + switch rv.Type().Kind() { + case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8: + v := rv.Int() + return EncBigInt(v), nil + case reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8: + v := rv.Uint() + if v > math.MaxInt64 { + return nil, fmt.Errorf("marshal bigint: value %d out of range", v) + } + return EncBigInt(int64(v)), nil + } + return nil, fmt.Errorf("can not marshal %T into %s", value, info) +} + +func MarshalFloat(info, value interface{}) ([]byte, error) { + switch v := value.(type) { + //case Marshaler: + // return v.MarshalCQL(info) + case UnsetColumn: + return nil, nil + case float32: + return EncInt(int32(math.Float32bits(v))), nil + } + + if value == nil { + return nil, nil + } + + rv := reflect.ValueOf(value) + switch rv.Type().Kind() { + case reflect.Float32: + return EncInt(int32(math.Float32bits(float32(rv.Float())))), nil + } + return nil, fmt.Errorf("can not marshal %T into %s", value, info) +} + +func MarshalDouble(info, value interface{}) ([]byte, error) { + switch v := value.(type) { + //case Marshaler: + // return v.MarshalCQL(info) + case UnsetColumn: + return nil, nil + case float64: + return EncBigInt(int64(math.Float64bits(v))), nil + } + if value == nil { + return nil, nil + } + rv := reflect.ValueOf(value) + switch rv.Type().Kind() { + case reflect.Float64: + return EncBigInt(int64(math.Float64bits(rv.Float()))), nil + } + return nil, fmt.Errorf("can not marshal %T into %s", value, info) +} + +func MarshalDecimal(info, value interface{}) ([]byte, error) { + if value == nil { + return nil, nil + } + + switch v := value.(type) { + //case Marshaler: + // return v.MarshalCQL(info) + case UnsetColumn: + return nil, nil + case inf.Dec: + unscaled := EncBigInt2C(v.UnscaledBig()) + if unscaled == nil { + return nil, fmt.Errorf("can not marshal %T into %s", value, info) + } + + buf := make([]byte, 4+len(unscaled)) + copy(buf[0:4], EncInt(int32(v.Scale()))) + copy(buf[4:], unscaled) + return buf, nil + } + return nil, fmt.Errorf("can not marshal %T into %s", value, info) +} + +func MarshalTime(info, value interface{}) ([]byte, error) { + switch v := value.(type) { + //case Marshaler: + // return v.MarshalCQL(info) + case UnsetColumn: + return nil, nil + case int64: + return EncBigInt(v), nil + case time.Duration: + return EncBigInt(v.Nanoseconds()), nil + } + + if value == nil { + return nil, nil + } + + rv := reflect.ValueOf(value) + switch rv.Type().Kind() { + case reflect.Int64: + return EncBigInt(rv.Int()), nil + } + return nil, fmt.Errorf("can not marshal %T into %s", value, info) +} + +func MarshalTimestamp(info, value interface{}) ([]byte, error) { + switch v := value.(type) { + //case Marshaler: + // return v.MarshalCQL(info) + case UnsetColumn: + return nil, nil + case int64: + return EncBigInt(v), nil + case time.Time: + if v.IsZero() { + return []byte{}, nil + } + x := int64(v.UTC().Unix()*1e3) + int64(v.UTC().Nanosecond()/1e6) + return EncBigInt(x), nil + } + + if value == nil { + return nil, nil + } + + rv := reflect.ValueOf(value) + switch rv.Type().Kind() { + case reflect.Int64: + return EncBigInt(rv.Int()), nil + } + return nil, fmt.Errorf("can not marshal %T into %s", value, info) +} + +func MarshalVarint(info, value interface{}) ([]byte, error) { + var ( + retBytes []byte + err error + ) + + switch v := value.(type) { + case UnsetColumn: + return nil, nil + case uint64: + if v > uint64(math.MaxInt64) { + retBytes = make([]byte, 9) + binary.BigEndian.PutUint64(retBytes[1:], v) + } else { + retBytes = make([]byte, 8) + binary.BigEndian.PutUint64(retBytes, v) + } + default: + retBytes, err = MarshalBigInt(info, value) + } + + if err == nil { + // trim down to most significant byte + i := 0 + for ; i < len(retBytes)-1; i++ { + b0 := retBytes[i] + if b0 != 0 && b0 != 0xFF { + break + } + + b1 := retBytes[i+1] + if b0 == 0 && b1 != 0 { + if b1&0x80 == 0 { + i++ + } + break + } + + if b0 == 0xFF && b1 != 0xFF { + if b1&0x80 > 0 { + i++ + } + break + } + } + retBytes = retBytes[i:] + } + + return retBytes, err +} + +func MarshalInet(info, value interface{}) ([]byte, error) { + // we return either the 4 or 16 byte representation of an + // ip address here otherwise the db value will be prefixed + // with the remaining byte values e.g. ::ffff:127.0.0.1 and not 127.0.0.1 + switch val := value.(type) { + case UnsetColumn: + return nil, nil + case net.IP: + t := val.To4() + if t == nil { + return val.To16(), nil + } + return t, nil + case string: + b := net.ParseIP(val) + if b != nil { + t := b.To4() + if t == nil { + return b.To16(), nil + } + return t, nil + } + return nil, fmt.Errorf("cannot marshal. invalid ip string %s", val) + } + + if value == nil { + return nil, nil + } + + return nil, fmt.Errorf("cannot marshal %T into %s", value, info) +} + +func MarshalDate(info, value interface{}) ([]byte, error) { + var timestamp int64 + switch v := value.(type) { + //case Marshaler: + // return v.MarshalCQL(info) + case UnsetColumn: + return nil, nil + case int64: + timestamp = v + x := timestamp/MillisecondsInADay + int64(1<<31) + return EncInt(int32(x)), nil + case time.Time: + if v.IsZero() { + return []byte{}, nil + } + timestamp = int64(v.UTC().Unix()*1e3) + int64(v.UTC().Nanosecond()/1e6) + x := timestamp/MillisecondsInADay + int64(1<<31) + return EncInt(int32(x)), nil + case *time.Time: + if v.IsZero() { + return []byte{}, nil + } + timestamp = int64(v.UTC().Unix()*1e3) + int64(v.UTC().Nanosecond()/1e6) + x := timestamp/MillisecondsInADay + int64(1<<31) + return EncInt(int32(x)), nil + case string: + if v == "" { + return []byte{}, nil + } + t, err := time.Parse("2006-01-02", v) + if err != nil { + return nil, fmt.Errorf("can not marshal %T into %s, date layout must be '2006-01-02'", value, info) + } + timestamp = int64(t.UTC().Unix()*1e3) + int64(t.UTC().Nanosecond()/1e6) + x := timestamp/MillisecondsInADay + int64(1<<31) + return EncInt(int32(x)), nil + } + + if value == nil { + return nil, nil + } + return nil, fmt.Errorf("can not marshal %T into %s", value, info) +} diff --git a/internal/session.go b/internal/session.go new file mode 100644 index 000000000..1777fadd3 --- /dev/null +++ b/internal/session.go @@ -0,0 +1,12 @@ +package internal + +import "strings" + +// TODO: move to internal +func IsUseStatement(stmt string) bool { + if len(stmt) < 3 { + return false + } + + return strings.EqualFold(stmt[0:3], "use") +} diff --git a/marshal.go b/marshal.go index 4d0adb923..89ab36e70 100644 --- a/marshal.go +++ b/marshal.go @@ -29,21 +29,15 @@ import ( "encoding/binary" "errors" "fmt" + "github.com/gocql/gocql/internal" + "gopkg.in/inf.v0" "math" "math/big" - "math/bits" "net" "reflect" "strconv" "strings" "time" - - "gopkg.in/inf.v0" -) - -var ( - bigOne = big.NewInt(1) - emptyValue reflect.Value ) var ( @@ -129,29 +123,32 @@ func Marshal(info TypeInfo, value interface{}) ([]byte, error) { return v.MarshalCQL(info) } + // TODO: move to internal + // Notice: a lot of marshal functions could be moved to internal package, + // if the Marshaler case and internal.UnsetColumn cases will be moved to this level switch info.Type() { case TypeVarchar, TypeAscii, TypeBlob, TypeText: - return marshalVarchar(info, value) + return internal.MarshalVarchar(info, value) case TypeBoolean: - return marshalBool(info, value) + return internal.MarshalBool(info, value) case TypeTinyInt: - return marshalTinyInt(info, value) + return internal.MarshalTinyInt(info, value) case TypeSmallInt: - return marshalSmallInt(info, value) + return internal.MarshalSmallInt(info, value) case TypeInt: - return marshalInt(info, value) + return internal.MarshalInt(info, value) case TypeBigInt, TypeCounter: - return marshalBigInt(info, value) + return internal.MarshalBigInt(info, value) case TypeFloat: - return marshalFloat(info, value) + return internal.MarshalFloat(info, value) case TypeDouble: - return marshalDouble(info, value) + return internal.MarshalDouble(info, value) case TypeDecimal: - return marshalDecimal(info, value) + return internal.MarshalDecimal(info, value) case TypeTime: - return marshalTime(info, value) + return internal.MarshalTime(info, value) case TypeTimestamp: - return marshalTimestamp(info, value) + return internal.MarshalTimestamp(info, value) case TypeList, TypeSet: return marshalList(info, value) case TypeMap: @@ -159,15 +156,15 @@ func Marshal(info TypeInfo, value interface{}) ([]byte, error) { case TypeUUID, TypeTimeUUID: return marshalUUID(info, value) case TypeVarint: - return marshalVarint(info, value) + return internal.MarshalVarint(info, value) case TypeInet: - return marshalInet(info, value) + return internal.MarshalInet(info, value) case TypeTuple: return marshalTuple(info, value) case TypeUDT: return marshalUDT(info, value) case TypeDate: - return marshalDate(info, value) + return internal.MarshalDate(info, value) case TypeDuration: return marshalDuration(info, value) } @@ -308,34 +305,6 @@ func unmarshalNullable(info TypeInfo, data []byte, value interface{}) error { return Unmarshal(info, data, newValue.Interface()) } -func marshalVarchar(info TypeInfo, value interface{}) ([]byte, error) { - switch v := value.(type) { - case Marshaler: - return v.MarshalCQL(info) - case unsetColumn: - return nil, nil - case string: - return []byte(v), nil - case []byte: - return v, nil - } - - if value == nil { - return nil, nil - } - - rv := reflect.ValueOf(value) - t := rv.Type() - k := t.Kind() - switch { - case k == reflect.String: - return []byte(rv.String()), nil - case k == reflect.Slice && t.Elem().Kind() == reflect.Uint8: - return rv.Bytes(), nil - } - return nil, marshalErrorf("can not marshal %T into %s", value, info) -} - func unmarshalVarchar(info TypeInfo, data []byte, value interface{}) error { switch v := value.(type) { case Unmarshaler: @@ -375,363 +344,20 @@ func unmarshalVarchar(info TypeInfo, data []byte, value interface{}) error { return unmarshalErrorf("can not unmarshal %s into %T", info, value) } -func marshalSmallInt(info TypeInfo, value interface{}) ([]byte, error) { - switch v := value.(type) { - case Marshaler: - return v.MarshalCQL(info) - case unsetColumn: - return nil, nil - case int16: - return encShort(v), nil - case uint16: - return encShort(int16(v)), nil - case int8: - return encShort(int16(v)), nil - case uint8: - return encShort(int16(v)), nil - case int: - if v > math.MaxInt16 || v < math.MinInt16 { - return nil, marshalErrorf("marshal smallint: value %d out of range", v) - } - return encShort(int16(v)), nil - case int32: - if v > math.MaxInt16 || v < math.MinInt16 { - return nil, marshalErrorf("marshal smallint: value %d out of range", v) - } - return encShort(int16(v)), nil - case int64: - if v > math.MaxInt16 || v < math.MinInt16 { - return nil, marshalErrorf("marshal smallint: value %d out of range", v) - } - return encShort(int16(v)), nil - case uint: - if v > math.MaxUint16 { - return nil, marshalErrorf("marshal smallint: value %d out of range", v) - } - return encShort(int16(v)), nil - case uint32: - if v > math.MaxUint16 { - return nil, marshalErrorf("marshal smallint: value %d out of range", v) - } - return encShort(int16(v)), nil - case uint64: - if v > math.MaxUint16 { - return nil, marshalErrorf("marshal smallint: value %d out of range", v) - } - return encShort(int16(v)), nil - case string: - n, err := strconv.ParseInt(v, 10, 16) - if err != nil { - return nil, marshalErrorf("can not marshal %T into %s: %v", value, info, err) - } - return encShort(int16(n)), nil - } - - if value == nil { - return nil, nil - } - - switch rv := reflect.ValueOf(value); rv.Type().Kind() { - case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8: - v := rv.Int() - if v > math.MaxInt16 || v < math.MinInt16 { - return nil, marshalErrorf("marshal smallint: value %d out of range", v) - } - return encShort(int16(v)), nil - case reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8: - v := rv.Uint() - if v > math.MaxUint16 { - return nil, marshalErrorf("marshal smallint: value %d out of range", v) - } - return encShort(int16(v)), nil - case reflect.Ptr: - if rv.IsNil() { - return nil, nil - } - } - - return nil, marshalErrorf("can not marshal %T into %s", value, info) -} - -func marshalTinyInt(info TypeInfo, value interface{}) ([]byte, error) { - switch v := value.(type) { - case Marshaler: - return v.MarshalCQL(info) - case unsetColumn: - return nil, nil - case int8: - return []byte{byte(v)}, nil - case uint8: - return []byte{byte(v)}, nil - case int16: - if v > math.MaxInt8 || v < math.MinInt8 { - return nil, marshalErrorf("marshal tinyint: value %d out of range", v) - } - return []byte{byte(v)}, nil - case uint16: - if v > math.MaxUint8 { - return nil, marshalErrorf("marshal tinyint: value %d out of range", v) - } - return []byte{byte(v)}, nil - case int: - if v > math.MaxInt8 || v < math.MinInt8 { - return nil, marshalErrorf("marshal tinyint: value %d out of range", v) - } - return []byte{byte(v)}, nil - case int32: - if v > math.MaxInt8 || v < math.MinInt8 { - return nil, marshalErrorf("marshal tinyint: value %d out of range", v) - } - return []byte{byte(v)}, nil - case int64: - if v > math.MaxInt8 || v < math.MinInt8 { - return nil, marshalErrorf("marshal tinyint: value %d out of range", v) - } - return []byte{byte(v)}, nil - case uint: - if v > math.MaxUint8 { - return nil, marshalErrorf("marshal tinyint: value %d out of range", v) - } - return []byte{byte(v)}, nil - case uint32: - if v > math.MaxUint8 { - return nil, marshalErrorf("marshal tinyint: value %d out of range", v) - } - return []byte{byte(v)}, nil - case uint64: - if v > math.MaxUint8 { - return nil, marshalErrorf("marshal tinyint: value %d out of range", v) - } - return []byte{byte(v)}, nil - case string: - n, err := strconv.ParseInt(v, 10, 8) - if err != nil { - return nil, marshalErrorf("can not marshal %T into %s: %v", value, info, err) - } - return []byte{byte(n)}, nil - } - - if value == nil { - return nil, nil - } - - switch rv := reflect.ValueOf(value); rv.Type().Kind() { - case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8: - v := rv.Int() - if v > math.MaxInt8 || v < math.MinInt8 { - return nil, marshalErrorf("marshal tinyint: value %d out of range", v) - } - return []byte{byte(v)}, nil - case reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8: - v := rv.Uint() - if v > math.MaxUint8 { - return nil, marshalErrorf("marshal tinyint: value %d out of range", v) - } - return []byte{byte(v)}, nil - case reflect.Ptr: - if rv.IsNil() { - return nil, nil - } - } - - return nil, marshalErrorf("can not marshal %T into %s", value, info) -} - -func marshalInt(info TypeInfo, value interface{}) ([]byte, error) { - switch v := value.(type) { - case Marshaler: - return v.MarshalCQL(info) - case unsetColumn: - return nil, nil - case int: - if v > math.MaxInt32 || v < math.MinInt32 { - return nil, marshalErrorf("marshal int: value %d out of range", v) - } - return encInt(int32(v)), nil - case uint: - if v > math.MaxUint32 { - return nil, marshalErrorf("marshal int: value %d out of range", v) - } - return encInt(int32(v)), nil - case int64: - if v > math.MaxInt32 || v < math.MinInt32 { - return nil, marshalErrorf("marshal int: value %d out of range", v) - } - return encInt(int32(v)), nil - case uint64: - if v > math.MaxUint32 { - return nil, marshalErrorf("marshal int: value %d out of range", v) - } - return encInt(int32(v)), nil - case int32: - return encInt(v), nil - case uint32: - return encInt(int32(v)), nil - case int16: - return encInt(int32(v)), nil - case uint16: - return encInt(int32(v)), nil - case int8: - return encInt(int32(v)), nil - case uint8: - return encInt(int32(v)), nil - case string: - i, err := strconv.ParseInt(v, 10, 32) - if err != nil { - return nil, marshalErrorf("can not marshal string to int: %s", err) - } - return encInt(int32(i)), nil - } - - if value == nil { - return nil, nil - } - - switch rv := reflect.ValueOf(value); rv.Type().Kind() { - case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8: - v := rv.Int() - if v > math.MaxInt32 || v < math.MinInt32 { - return nil, marshalErrorf("marshal int: value %d out of range", v) - } - return encInt(int32(v)), nil - case reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8: - v := rv.Uint() - if v > math.MaxInt32 { - return nil, marshalErrorf("marshal int: value %d out of range", v) - } - return encInt(int32(v)), nil - case reflect.Ptr: - if rv.IsNil() { - return nil, nil - } - } - - return nil, marshalErrorf("can not marshal %T into %s", value, info) -} - -func encInt(x int32) []byte { - return []byte{byte(x >> 24), byte(x >> 16), byte(x >> 8), byte(x)} -} - -func decInt(x []byte) int32 { - if len(x) != 4 { - return 0 - } - return int32(x[0])<<24 | int32(x[1])<<16 | int32(x[2])<<8 | int32(x[3]) -} - -func encShort(x int16) []byte { - p := make([]byte, 2) - p[0] = byte(x >> 8) - p[1] = byte(x) - return p -} - -func decShort(p []byte) int16 { - if len(p) != 2 { - return 0 - } - return int16(p[0])<<8 | int16(p[1]) -} - -func decTiny(p []byte) int8 { - if len(p) != 1 { - return 0 - } - return int8(p[0]) -} - -func marshalBigInt(info TypeInfo, value interface{}) ([]byte, error) { - switch v := value.(type) { - case Marshaler: - return v.MarshalCQL(info) - case unsetColumn: - return nil, nil - case int: - return encBigInt(int64(v)), nil - case uint: - if uint64(v) > math.MaxInt64 { - return nil, marshalErrorf("marshal bigint: value %d out of range", v) - } - return encBigInt(int64(v)), nil - case int64: - return encBigInt(v), nil - case uint64: - return encBigInt(int64(v)), nil - case int32: - return encBigInt(int64(v)), nil - case uint32: - return encBigInt(int64(v)), nil - case int16: - return encBigInt(int64(v)), nil - case uint16: - return encBigInt(int64(v)), nil - case int8: - return encBigInt(int64(v)), nil - case uint8: - return encBigInt(int64(v)), nil - case big.Int: - return encBigInt2C(&v), nil - case string: - i, err := strconv.ParseInt(value.(string), 10, 64) - if err != nil { - return nil, marshalErrorf("can not marshal string to bigint: %s", err) - } - return encBigInt(i), nil - } - - if value == nil { - return nil, nil - } - - rv := reflect.ValueOf(value) - switch rv.Type().Kind() { - case reflect.Int, reflect.Int64, reflect.Int32, reflect.Int16, reflect.Int8: - v := rv.Int() - return encBigInt(v), nil - case reflect.Uint, reflect.Uint64, reflect.Uint32, reflect.Uint16, reflect.Uint8: - v := rv.Uint() - if v > math.MaxInt64 { - return nil, marshalErrorf("marshal bigint: value %d out of range", v) - } - return encBigInt(int64(v)), nil - } - return nil, marshalErrorf("can not marshal %T into %s", value, info) -} - -func encBigInt(x int64) []byte { - return []byte{byte(x >> 56), byte(x >> 48), byte(x >> 40), byte(x >> 32), - byte(x >> 24), byte(x >> 16), byte(x >> 8), byte(x)} -} - -func bytesToInt64(data []byte) (ret int64) { - for i := range data { - ret |= int64(data[i]) << (8 * uint(len(data)-i-1)) - } - return ret -} - -func bytesToUint64(data []byte) (ret uint64) { - for i := range data { - ret |= uint64(data[i]) << (8 * uint(len(data)-i-1)) - } - return ret -} - func unmarshalBigInt(info TypeInfo, data []byte, value interface{}) error { - return unmarshalIntlike(info, decBigInt(data), data, value) + return unmarshalIntlike(info, internal.DecBigInt(data), data, value) } func unmarshalInt(info TypeInfo, data []byte, value interface{}) error { - return unmarshalIntlike(info, int64(decInt(data)), data, value) + return unmarshalIntlike(info, int64(internal.DecInt(data)), data, value) } func unmarshalSmallInt(info TypeInfo, data []byte, value interface{}) error { - return unmarshalIntlike(info, int64(decShort(data)), data, value) + return unmarshalIntlike(info, int64(internal.DecShort(data)), data, value) } func unmarshalTinyInt(info TypeInfo, data []byte, value interface{}) error { - return unmarshalIntlike(info, int64(decTiny(data)), data, value) + return unmarshalIntlike(info, int64(internal.DecTiny(data)), data, value) } func unmarshalVarint(info TypeInfo, data []byte, value interface{}) error { @@ -740,7 +366,7 @@ func unmarshalVarint(info TypeInfo, data []byte, value interface{}) error { return unmarshalIntlike(info, 0, data, value) case *uint64: if len(data) == 9 && data[0] == 0 { - *v = bytesToUint64(data[1:]) + *v = internal.BytesToUint64(data[1:]) return nil } } @@ -749,64 +375,13 @@ func unmarshalVarint(info TypeInfo, data []byte, value interface{}) error { return unmarshalErrorf("unmarshal int: varint value %v out of range for %T (use big.Int)", data, value) } - int64Val := bytesToInt64(data) + int64Val := internal.BytesToInt64(data) if len(data) > 0 && len(data) < 8 && data[0]&0x80 > 0 { int64Val -= (1 << uint(len(data)*8)) } return unmarshalIntlike(info, int64Val, data, value) } -func marshalVarint(info TypeInfo, value interface{}) ([]byte, error) { - var ( - retBytes []byte - err error - ) - - switch v := value.(type) { - case unsetColumn: - return nil, nil - case uint64: - if v > uint64(math.MaxInt64) { - retBytes = make([]byte, 9) - binary.BigEndian.PutUint64(retBytes[1:], v) - } else { - retBytes = make([]byte, 8) - binary.BigEndian.PutUint64(retBytes, v) - } - default: - retBytes, err = marshalBigInt(info, value) - } - - if err == nil { - // trim down to most significant byte - i := 0 - for ; i < len(retBytes)-1; i++ { - b0 := retBytes[i] - if b0 != 0 && b0 != 0xFF { - break - } - - b1 := retBytes[i+1] - if b0 == 0 && b1 != 0 { - if b1&0x80 == 0 { - i++ - } - break - } - - if b0 == 0xFF && b1 != 0xFF { - if b1&0x80 > 0 { - i++ - } - break - } - } - retBytes = retBytes[i:] - } - - return retBytes, err -} - func unmarshalIntlike(info TypeInfo, int64Val int64, data []byte, value interface{}) error { switch v := value.(type) { case *int: @@ -899,7 +474,7 @@ func unmarshalIntlike(info TypeInfo, int64Val int64, data []byte, value interfac *v = uint8(int64Val) & 0xFF return nil case *big.Int: - decBigInt2C(data, v) + internal.DecBigInt2C(data, v) return nil case *string: *v = strconv.FormatInt(int64Val, 10) @@ -1009,51 +584,12 @@ func unmarshalIntlike(info TypeInfo, int64Val int64, data []byte, value interfac return unmarshalErrorf("can not unmarshal %s into %T", info, value) } -func decBigInt(data []byte) int64 { - if len(data) != 8 { - return 0 - } - return int64(data[0])<<56 | int64(data[1])<<48 | - int64(data[2])<<40 | int64(data[3])<<32 | - int64(data[4])<<24 | int64(data[5])<<16 | - int64(data[6])<<8 | int64(data[7]) -} - -func marshalBool(info TypeInfo, value interface{}) ([]byte, error) { - switch v := value.(type) { - case Marshaler: - return v.MarshalCQL(info) - case unsetColumn: - return nil, nil - case bool: - return encBool(v), nil - } - - if value == nil { - return nil, nil - } - - rv := reflect.ValueOf(value) - switch rv.Type().Kind() { - case reflect.Bool: - return encBool(rv.Bool()), nil - } - return nil, marshalErrorf("can not marshal %T into %s", value, info) -} - -func encBool(v bool) []byte { - if v { - return []byte{1} - } - return []byte{0} -} - func unmarshalBool(info TypeInfo, data []byte, value interface{}) error { switch v := value.(type) { case Unmarshaler: return v.UnmarshalCQL(info, data) case *bool: - *v = decBool(data) + *v = internal.DecBool(data) return nil } rv := reflect.ValueOf(value) @@ -1063,47 +599,18 @@ func unmarshalBool(info TypeInfo, data []byte, value interface{}) error { rv = rv.Elem() switch rv.Type().Kind() { case reflect.Bool: - rv.SetBool(decBool(data)) + rv.SetBool(internal.DecBool(data)) return nil } return unmarshalErrorf("can not unmarshal %s into %T", info, value) } -func decBool(v []byte) bool { - if len(v) == 0 { - return false - } - return v[0] != 0 -} - -func marshalFloat(info TypeInfo, value interface{}) ([]byte, error) { - switch v := value.(type) { - case Marshaler: - return v.MarshalCQL(info) - case unsetColumn: - return nil, nil - case float32: - return encInt(int32(math.Float32bits(v))), nil - } - - if value == nil { - return nil, nil - } - - rv := reflect.ValueOf(value) - switch rv.Type().Kind() { - case reflect.Float32: - return encInt(int32(math.Float32bits(float32(rv.Float())))), nil - } - return nil, marshalErrorf("can not marshal %T into %s", value, info) -} - func unmarshalFloat(info TypeInfo, data []byte, value interface{}) error { switch v := value.(type) { case Unmarshaler: return v.UnmarshalCQL(info, data) case *float32: - *v = math.Float32frombits(uint32(decInt(data))) + *v = math.Float32frombits(uint32(internal.DecInt(data))) return nil } rv := reflect.ValueOf(value) @@ -1113,38 +620,18 @@ func unmarshalFloat(info TypeInfo, data []byte, value interface{}) error { rv = rv.Elem() switch rv.Type().Kind() { case reflect.Float32: - rv.SetFloat(float64(math.Float32frombits(uint32(decInt(data))))) + rv.SetFloat(float64(math.Float32frombits(uint32(internal.DecInt(data))))) return nil } return unmarshalErrorf("can not unmarshal %s into %T", info, value) } -func marshalDouble(info TypeInfo, value interface{}) ([]byte, error) { - switch v := value.(type) { - case Marshaler: - return v.MarshalCQL(info) - case unsetColumn: - return nil, nil - case float64: - return encBigInt(int64(math.Float64bits(v))), nil - } - if value == nil { - return nil, nil - } - rv := reflect.ValueOf(value) - switch rv.Type().Kind() { - case reflect.Float64: - return encBigInt(int64(math.Float64bits(rv.Float()))), nil - } - return nil, marshalErrorf("can not marshal %T into %s", value, info) -} - func unmarshalDouble(info TypeInfo, data []byte, value interface{}) error { switch v := value.(type) { case Unmarshaler: return v.UnmarshalCQL(info, data) case *float64: - *v = math.Float64frombits(uint64(decBigInt(data))) + *v = math.Float64frombits(uint64(internal.DecBigInt(data))) return nil } rv := reflect.ValueOf(value) @@ -1154,36 +641,12 @@ func unmarshalDouble(info TypeInfo, data []byte, value interface{}) error { rv = rv.Elem() switch rv.Type().Kind() { case reflect.Float64: - rv.SetFloat(math.Float64frombits(uint64(decBigInt(data)))) + rv.SetFloat(math.Float64frombits(uint64(internal.DecBigInt(data)))) return nil } return unmarshalErrorf("can not unmarshal %s into %T", info, value) } -func marshalDecimal(info TypeInfo, value interface{}) ([]byte, error) { - if value == nil { - return nil, nil - } - - switch v := value.(type) { - case Marshaler: - return v.MarshalCQL(info) - case unsetColumn: - return nil, nil - case inf.Dec: - unscaled := encBigInt2C(v.UnscaledBig()) - if unscaled == nil { - return nil, marshalErrorf("can not marshal %T into %s", value, info) - } - - buf := make([]byte, 4+len(unscaled)) - copy(buf[0:4], encInt(int32(v.Scale()))) - copy(buf[4:], unscaled) - return buf, nil - } - return nil, marshalErrorf("can not marshal %T into %s", value, info) -} - func unmarshalDecimal(info TypeInfo, data []byte, value interface{}) error { switch v := value.(type) { case Unmarshaler: @@ -1192,115 +655,23 @@ func unmarshalDecimal(info TypeInfo, data []byte, value interface{}) error { if len(data) < 4 { return unmarshalErrorf("inf.Dec needs at least 4 bytes, while value has only %d", len(data)) } - scale := decInt(data[0:4]) - unscaled := decBigInt2C(data[4:], nil) + scale := internal.DecInt(data[0:4]) + unscaled := internal.DecBigInt2C(data[4:], nil) *v = *inf.NewDecBig(unscaled, inf.Scale(scale)) return nil } return unmarshalErrorf("can not unmarshal %s into %T", info, value) } -// decBigInt2C sets the value of n to the big-endian two's complement -// value stored in the given data. If data[0]&80 != 0, the number -// is negative. If data is empty, the result will be 0. -func decBigInt2C(data []byte, n *big.Int) *big.Int { - if n == nil { - n = new(big.Int) - } - n.SetBytes(data) - if len(data) > 0 && data[0]&0x80 > 0 { - n.Sub(n, new(big.Int).Lsh(bigOne, uint(len(data))*8)) - } - return n -} - -// encBigInt2C returns the big-endian two's complement -// form of n. -func encBigInt2C(n *big.Int) []byte { - switch n.Sign() { - case 0: - return []byte{0} - case 1: - b := n.Bytes() - if b[0]&0x80 > 0 { - b = append([]byte{0}, b...) - } - return b - case -1: - length := uint(n.BitLen()/8+1) * 8 - b := new(big.Int).Add(n, new(big.Int).Lsh(bigOne, length)).Bytes() - // When the most significant bit is on a byte - // boundary, we can get some extra significant - // bits, so strip them off when that happens. - if len(b) >= 2 && b[0] == 0xff && b[1]&0x80 != 0 { - b = b[1:] - } - return b - } - return nil -} - -func marshalTime(info TypeInfo, value interface{}) ([]byte, error) { - switch v := value.(type) { - case Marshaler: - return v.MarshalCQL(info) - case unsetColumn: - return nil, nil - case int64: - return encBigInt(v), nil - case time.Duration: - return encBigInt(v.Nanoseconds()), nil - } - - if value == nil { - return nil, nil - } - - rv := reflect.ValueOf(value) - switch rv.Type().Kind() { - case reflect.Int64: - return encBigInt(rv.Int()), nil - } - return nil, marshalErrorf("can not marshal %T into %s", value, info) -} - -func marshalTimestamp(info TypeInfo, value interface{}) ([]byte, error) { - switch v := value.(type) { - case Marshaler: - return v.MarshalCQL(info) - case unsetColumn: - return nil, nil - case int64: - return encBigInt(v), nil - case time.Time: - if v.IsZero() { - return []byte{}, nil - } - x := int64(v.UTC().Unix()*1e3) + int64(v.UTC().Nanosecond()/1e6) - return encBigInt(x), nil - } - - if value == nil { - return nil, nil - } - - rv := reflect.ValueOf(value) - switch rv.Type().Kind() { - case reflect.Int64: - return encBigInt(rv.Int()), nil - } - return nil, marshalErrorf("can not marshal %T into %s", value, info) -} - func unmarshalTime(info TypeInfo, data []byte, value interface{}) error { switch v := value.(type) { case Unmarshaler: return v.UnmarshalCQL(info, data) case *int64: - *v = decBigInt(data) + *v = internal.DecBigInt(data) return nil case *time.Duration: - *v = time.Duration(decBigInt(data)) + *v = time.Duration(internal.DecBigInt(data)) return nil } @@ -1311,7 +682,7 @@ func unmarshalTime(info TypeInfo, data []byte, value interface{}) error { rv = rv.Elem() switch rv.Type().Kind() { case reflect.Int64: - rv.SetInt(decBigInt(data)) + rv.SetInt(internal.DecBigInt(data)) return nil } return unmarshalErrorf("can not unmarshal %s into %T", info, value) @@ -1322,14 +693,14 @@ func unmarshalTimestamp(info TypeInfo, data []byte, value interface{}) error { case Unmarshaler: return v.UnmarshalCQL(info, data) case *int64: - *v = decBigInt(data) + *v = internal.DecBigInt(data) return nil case *time.Time: if len(data) == 0 { *v = time.Time{} return nil } - x := decBigInt(data) + x := internal.DecBigInt(data) sec := x / 1000 nsec := (x - sec*1000) * 1000000 *v = time.Unix(sec, nsec).In(time.UTC) @@ -1343,58 +714,12 @@ func unmarshalTimestamp(info TypeInfo, data []byte, value interface{}) error { rv = rv.Elem() switch rv.Type().Kind() { case reflect.Int64: - rv.SetInt(decBigInt(data)) + rv.SetInt(internal.DecBigInt(data)) return nil } return unmarshalErrorf("can not unmarshal %s into %T", info, value) } -const millisecondsInADay int64 = 24 * 60 * 60 * 1000 - -func marshalDate(info TypeInfo, value interface{}) ([]byte, error) { - var timestamp int64 - switch v := value.(type) { - case Marshaler: - return v.MarshalCQL(info) - case unsetColumn: - return nil, nil - case int64: - timestamp = v - x := timestamp/millisecondsInADay + int64(1<<31) - return encInt(int32(x)), nil - case time.Time: - if v.IsZero() { - return []byte{}, nil - } - timestamp = int64(v.UTC().Unix()*1e3) + int64(v.UTC().Nanosecond()/1e6) - x := timestamp/millisecondsInADay + int64(1<<31) - return encInt(int32(x)), nil - case *time.Time: - if v.IsZero() { - return []byte{}, nil - } - timestamp = int64(v.UTC().Unix()*1e3) + int64(v.UTC().Nanosecond()/1e6) - x := timestamp/millisecondsInADay + int64(1<<31) - return encInt(int32(x)), nil - case string: - if v == "" { - return []byte{}, nil - } - t, err := time.Parse("2006-01-02", v) - if err != nil { - return nil, marshalErrorf("can not marshal %T into %s, date layout must be '2006-01-02'", value, info) - } - timestamp = int64(t.UTC().Unix()*1e3) + int64(t.UTC().Nanosecond()/1e6) - x := timestamp/millisecondsInADay + int64(1<<31) - return encInt(int32(x)), nil - } - - if value == nil { - return nil, nil - } - return nil, marshalErrorf("can not marshal %T into %s", value, info) -} - func unmarshalDate(info TypeInfo, data []byte, value interface{}) error { switch v := value.(type) { case Unmarshaler: @@ -1406,7 +731,7 @@ func unmarshalDate(info TypeInfo, data []byte, value interface{}) error { } var origin uint32 = 1 << 31 var current uint32 = binary.BigEndian.Uint32(data) - timestamp := (int64(current) - int64(origin)) * millisecondsInADay + timestamp := (int64(current) - int64(origin)) * internal.MillisecondsInADay *v = time.UnixMilli(timestamp).In(time.UTC) return nil case *string: @@ -1416,7 +741,7 @@ func unmarshalDate(info TypeInfo, data []byte, value interface{}) error { } var origin uint32 = 1 << 31 var current uint32 = binary.BigEndian.Uint32(data) - timestamp := (int64(current) - int64(origin)) * millisecondsInADay + timestamp := (int64(current) - int64(origin)) * internal.MillisecondsInADay *v = time.UnixMilli(timestamp).In(time.UTC).Format("2006-01-02") return nil } @@ -1427,20 +752,20 @@ func marshalDuration(info TypeInfo, value interface{}) ([]byte, error) { switch v := value.(type) { case Marshaler: return v.MarshalCQL(info) - case unsetColumn: + case internal.UnsetColumn: return nil, nil case int64: - return encVints(0, 0, v), nil + return internal.EncVints(0, 0, v), nil case time.Duration: - return encVints(0, 0, v.Nanoseconds()), nil + return internal.EncVints(0, 0, v.Nanoseconds()), nil case string: d, err := time.ParseDuration(v) if err != nil { return nil, err } - return encVints(0, 0, d.Nanoseconds()), nil + return internal.EncVints(0, 0, d.Nanoseconds()), nil case Duration: - return encVints(v.Months, v.Days, v.Nanoseconds), nil + return internal.EncVints(v.Months, v.Days, v.Nanoseconds), nil } if value == nil { @@ -1450,7 +775,7 @@ func marshalDuration(info TypeInfo, value interface{}) ([]byte, error) { rv := reflect.ValueOf(value) switch rv.Type().Kind() { case reflect.Int64: - return encBigInt(rv.Int()), nil + return internal.EncBigInt(rv.Int()), nil } return nil, marshalErrorf("can not marshal %T into %s", value, info) } @@ -1468,7 +793,7 @@ func unmarshalDuration(info TypeInfo, data []byte, value interface{}) error { } return nil } - months, days, nanos, err := decVints(data) + months, days, nanos, err := internal.DecVints(data) if err != nil { return unmarshalErrorf("failed to unmarshal %s into %T: %s", info, value, err.Error()) } @@ -1482,74 +807,8 @@ func unmarshalDuration(info TypeInfo, data []byte, value interface{}) error { return unmarshalErrorf("can not unmarshal %s into %T", info, value) } -func decVints(data []byte) (int32, int32, int64, error) { - month, i, err := decVint(data, 0) - if err != nil { - return 0, 0, 0, fmt.Errorf("failed to extract month: %s", err.Error()) - } - days, i, err := decVint(data, i) - if err != nil { - return 0, 0, 0, fmt.Errorf("failed to extract days: %s", err.Error()) - } - nanos, _, err := decVint(data, i) - if err != nil { - return 0, 0, 0, fmt.Errorf("failed to extract nanoseconds: %s", err.Error()) - } - return int32(month), int32(days), nanos, err -} - -func decVint(data []byte, start int) (int64, int, error) { - if len(data) <= start { - return 0, 0, errors.New("unexpected eof") - } - firstByte := data[start] - if firstByte&0x80 == 0 { - return decIntZigZag(uint64(firstByte)), start + 1, nil - } - numBytes := bits.LeadingZeros32(uint32(^firstByte)) - 24 - ret := uint64(firstByte & (0xff >> uint(numBytes))) - if len(data) < start+numBytes+1 { - return 0, 0, fmt.Errorf("data expect to have %d bytes, but it has only %d", start+numBytes+1, len(data)) - } - for i := start; i < start+numBytes; i++ { - ret <<= 8 - ret |= uint64(data[i+1] & 0xff) - } - return decIntZigZag(ret), start + numBytes + 1, nil -} - -func decIntZigZag(n uint64) int64 { - return int64((n >> 1) ^ -(n & 1)) -} - -func encIntZigZag(n int64) uint64 { - return uint64((n >> 63) ^ (n << 1)) -} - -func encVints(months int32, seconds int32, nanos int64) []byte { - buf := append(encVint(int64(months)), encVint(int64(seconds))...) - return append(buf, encVint(nanos)...) -} - -func encVint(v int64) []byte { - vEnc := encIntZigZag(v) - lead0 := bits.LeadingZeros64(vEnc) - numBytes := (639 - lead0*9) >> 6 - - // It can be 1 or 0 is v ==0 - if numBytes <= 1 { - return []byte{byte(vEnc)} - } - extraBytes := numBytes - 1 - var buf = make([]byte, numBytes) - for i := extraBytes; i >= 0; i-- { - buf[i] = byte(vEnc) - vEnc >>= 8 - } - buf[0] |= byte(^(0xff >> uint(extraBytes))) - return buf -} - +// TODO: move to internal +// just pass the CollectionType.proto to this method instead of CollectionType func writeCollectionSize(info CollectionType, n int, buf *bytes.Buffer) error { if info.proto > protoVersion2 { if n > math.MaxInt32 { @@ -1580,7 +839,7 @@ func marshalList(info TypeInfo, value interface{}) ([]byte, error) { if value == nil { return nil, nil - } else if _, ok := value.(unsetColumn); ok { + } else if _, ok := value.(internal.UnsetColumn); ok { return nil, nil } @@ -1717,7 +976,7 @@ func marshalMap(info TypeInfo, value interface{}) ([]byte, error) { if value == nil { return nil, nil - } else if _, ok := value.(unsetColumn); ok { + } else if _, ok := value.(internal.UnsetColumn); ok { return nil, nil } @@ -1847,7 +1106,7 @@ func unmarshalMap(info TypeInfo, data []byte, value interface{}) error { func marshalUUID(info TypeInfo, value interface{}) ([]byte, error) { switch val := value.(type) { - case unsetColumn: + case internal.UnsetColumn: return nil, nil case UUID: return val.Bytes(), nil @@ -1936,38 +1195,6 @@ func unmarshalTimeUUID(info TypeInfo, data []byte, value interface{}) error { } } -func marshalInet(info TypeInfo, value interface{}) ([]byte, error) { - // we return either the 4 or 16 byte representation of an - // ip address here otherwise the db value will be prefixed - // with the remaining byte values e.g. ::ffff:127.0.0.1 and not 127.0.0.1 - switch val := value.(type) { - case unsetColumn: - return nil, nil - case net.IP: - t := val.To4() - if t == nil { - return val.To16(), nil - } - return t, nil - case string: - b := net.ParseIP(val) - if b != nil { - t := b.To4() - if t == nil { - return b.To16(), nil - } - return t, nil - } - return nil, marshalErrorf("cannot marshal. invalid ip string %s", val) - } - - if value == nil { - return nil, nil - } - - return nil, marshalErrorf("cannot marshal %T into %s", value, info) -} - func unmarshalInet(info TypeInfo, data []byte, value interface{}) error { switch v := value.(type) { case Unmarshaler: @@ -1976,7 +1203,7 @@ func unmarshalInet(info TypeInfo, data []byte, value interface{}) error { if x := len(data); !(x == 4 || x == 16) { return unmarshalErrorf("cannot unmarshal %s into %T: invalid sized IP: got %d bytes not 4 or 16", info, value, x) } - buf := copyBytes(data) + buf := internal.CopyBytes(data) ip := net.IP(buf) if v4 := ip.To4(); v4 != nil { *v = v4 @@ -2003,7 +1230,7 @@ func unmarshalInet(info TypeInfo, data []byte, value interface{}) error { func marshalTuple(info TypeInfo, value interface{}) ([]byte, error) { tuple := info.(TupleTypeInfo) switch v := value.(type) { - case unsetColumn: + case internal.UnsetColumn: return nil, unmarshalErrorf("Invalid request: UnsetValue is unsupported for tuples") case []interface{}: if len(v) != len(tuple.Elems) { @@ -2013,7 +1240,7 @@ func marshalTuple(info TypeInfo, value interface{}) ([]byte, error) { var buf []byte for i, elem := range v { if elem == nil { - buf = appendInt(buf, int32(-1)) + buf = internal.AppendInt(buf, int32(-1)) continue } @@ -2023,7 +1250,7 @@ func marshalTuple(info TypeInfo, value interface{}) ([]byte, error) { } n := len(data) - buf = appendInt(buf, int32(n)) + buf = internal.AppendInt(buf, int32(n)) buf = append(buf, data...) } @@ -2045,7 +1272,7 @@ func marshalTuple(info TypeInfo, value interface{}) ([]byte, error) { field := rv.Field(i) if field.Kind() == reflect.Ptr && field.IsNil() { - buf = appendInt(buf, int32(-1)) + buf = internal.AppendInt(buf, int32(-1)) continue } @@ -2055,7 +1282,7 @@ func marshalTuple(info TypeInfo, value interface{}) ([]byte, error) { } n := len(data) - buf = appendInt(buf, int32(n)) + buf = internal.AppendInt(buf, int32(n)) buf = append(buf, data...) } @@ -2071,7 +1298,7 @@ func marshalTuple(info TypeInfo, value interface{}) ([]byte, error) { item := rv.Index(i) if item.Kind() == reflect.Ptr && item.IsNil() { - buf = appendInt(buf, int32(-1)) + buf = internal.AppendInt(buf, int32(-1)) continue } @@ -2081,7 +1308,7 @@ func marshalTuple(info TypeInfo, value interface{}) ([]byte, error) { } n := len(data) - buf = appendInt(buf, int32(n)) + buf = internal.AppendInt(buf, int32(n)) buf = append(buf, data...) } @@ -2091,16 +1318,6 @@ func marshalTuple(info TypeInfo, value interface{}) ([]byte, error) { return nil, marshalErrorf("cannot marshal %T into %s", value, tuple) } -func readBytes(p []byte) ([]byte, []byte) { - // TODO: really should use a framer - size := readInt(p) - p = p[4:] - if size < 0 { - return nil, p - } - return p[:size], p[size:] -} - // currently only support unmarshal into a list of values, this makes it possible // to support tuples without changing the query API. In the future this can be extend // to allow unmarshalling into custom tuple types. @@ -2116,7 +1333,7 @@ func unmarshalTuple(info TypeInfo, data []byte, value interface{}) error { // each element inside data is a [bytes] var p []byte if len(data) >= 4 { - p, data = readBytes(data) + p, data = internal.ReadBytes(data) } err := Unmarshal(elem, p, v[i]) if err != nil { @@ -2145,7 +1362,7 @@ func unmarshalTuple(info TypeInfo, data []byte, value interface{}) error { for i, elem := range tuple.Elems { var p []byte if len(data) >= 4 { - p, data = readBytes(data) + p, data = internal.ReadBytes(data) } v, err := elem.NewWithError() @@ -2182,7 +1399,7 @@ func unmarshalTuple(info TypeInfo, data []byte, value interface{}) error { for i, elem := range tuple.Elems { var p []byte if len(data) >= 4 { - p, data = readBytes(data) + p, data = internal.ReadBytes(data) } v, err := elem.NewWithError() @@ -2236,7 +1453,7 @@ func marshalUDT(info TypeInfo, value interface{}) ([]byte, error) { switch v := value.(type) { case Marshaler: return v.MarshalCQL(info) - case unsetColumn: + case internal.UnsetColumn: return nil, unmarshalErrorf("invalid request: UnsetValue is unsupported for user defined types") case UDTMarshaler: var buf []byte @@ -2246,7 +1463,7 @@ func marshalUDT(info TypeInfo, value interface{}) ([]byte, error) { return nil, err } - buf = appendBytes(buf, data) + buf = internal.AppendBytes(buf, data) } return buf, nil @@ -2265,7 +1482,7 @@ func marshalUDT(info TypeInfo, value interface{}) ([]byte, error) { } } - buf = appendBytes(buf, data) + buf = internal.AppendBytes(buf, data) } return buf, nil @@ -2309,7 +1526,7 @@ func marshalUDT(info TypeInfo, value interface{}) ([]byte, error) { } } - buf = appendBytes(buf, data) + buf = internal.AppendBytes(buf, data) } return buf, nil @@ -2331,7 +1548,7 @@ func unmarshalUDT(info TypeInfo, data []byte, value interface{}) error { } var p []byte - p, data = readBytes(data) + p, data = internal.ReadBytes(data) if err := v.UnmarshalUDT(e.Name, e.Type, p); err != nil { return err } @@ -2374,7 +1591,7 @@ func unmarshalUDT(info TypeInfo, data []byte, value interface{}) error { val := reflect.New(valType) var p []byte - p, data = readBytes(data) + p, data = internal.ReadBytes(data) if err := Unmarshal(e.Type, p, val.Interface()); err != nil { return err @@ -2424,12 +1641,12 @@ func unmarshalUDT(info TypeInfo, data []byte, value interface{}) error { } var p []byte - p, data = readBytes(data) + p, data = internal.ReadBytes(data) f, ok := fields[e.Name] if !ok { f = k.FieldByName(e.Name) - if f == emptyValue { + if f == internal.EmptyValue { // skip fields which exist in the UDT but not in // the struct passed in continue diff --git a/marshal_test.go b/marshal_test.go index 6c139e6bc..c952ad41c 100644 --- a/marshal_test.go +++ b/marshal_test.go @@ -1075,7 +1075,7 @@ var marshalTests = []struct { }, { NativeType{proto: 2, typ: TypeTime}, - encBigInt(1000), + internal.EncBigInt(1000), time.Duration(1000), nil, nil, @@ -1726,7 +1726,7 @@ func TestMarshalPointer(t *testing.T) { func TestMarshalTime(t *testing.T) { durationS := "1h10m10s" duration, _ := time.ParseDuration(durationS) - expectedData := encBigInt(duration.Nanoseconds()) + expectedData := internal.EncBigInt(duration.Nanoseconds()) var marshalTimeTests = []struct { Info TypeInfo Data []byte @@ -1758,7 +1758,7 @@ func TestMarshalTime(t *testing.T) { } if !bytes.Equal(data, test.Data) { t.Errorf("marshalTest[%d]: expected %x (%v), got %x (%v) for time %s", i, - test.Data, decInt(test.Data), data, decInt(data), test.Value) + test.Data, internal.DecInt(test.Data), data, internal.DecInt(data), test.Value) } } } @@ -1824,7 +1824,7 @@ func TestMarshalTimestamp(t *testing.T) { } if !bytes.Equal(data, test.Data) { t.Errorf("marshalTest[%d]: expected %x (%v), got %x (%v) for time %s", i, - test.Data, decBigInt(test.Data), data, decBigInt(data), test.Value) + test.Data, internal.DecBigInt(test.Data), data, internal.DecBigInt(data), test.Value) } } } @@ -1961,7 +1961,7 @@ func TestMarshalTuple(t *testing.T) { if !bytes.Equal(data, tc.expected) { t.Errorf("marshalTest: expected %x (%v), got %x (%v)", - tc.expected, decBigInt(tc.expected), data, decBigInt(data)) + tc.expected, internal.DecBigInt(tc.expected), data, internal.DecBigInt(data)) return } @@ -2244,7 +2244,7 @@ func TestUnmarshalDate(t *testing.T) { func TestMarshalDate(t *testing.T) { now := time.Now().UTC() timestamp := now.UnixNano() / int64(time.Millisecond) - expectedData := encInt(int32(timestamp/86400000 + int64(1<<31))) + expectedData := internal.EncInt(int32(timestamp/86400000 + int64(1<<31))) var marshalDateTests = []struct { Info TypeInfo @@ -2282,17 +2282,17 @@ func TestMarshalDate(t *testing.T) { } if !bytes.Equal(data, test.Data) { t.Errorf("marshalTest[%d]: expected %x (%v), got %x (%v) for time %s", i, - test.Data, decInt(test.Data), data, decInt(data), test.Value) + test.Data, internal.DecInt(test.Data), data, internal.DecInt(data), test.Value) } } } func TestLargeDate(t *testing.T) { farFuture := time.Date(999999, time.December, 31, 0, 0, 0, 0, time.UTC) - expectedFutureData := encInt(int32(farFuture.UnixMilli()/86400000 + int64(1<<31))) + expectedFutureData := internal.EncInt(int32(farFuture.UnixMilli()/86400000 + int64(1<<31))) farPast := time.Date(-999999, time.January, 1, 0, 0, 0, 0, time.UTC) - expectedPastData := encInt(int32(farPast.UnixMilli()/86400000 + int64(1<<31))) + expectedPastData := internal.EncInt(int32(farPast.UnixMilli()/86400000 + int64(1<<31))) var marshalDateTests = []struct { Data []byte @@ -2323,7 +2323,7 @@ func TestLargeDate(t *testing.T) { } if !bytes.Equal(data, test.Data) { t.Errorf("largeDateTest[%d]: expected %x (%v), got %x (%v) for time %s", i, - test.Data, decInt(test.Data), data, decInt(data), test.Value) + test.Data, internal.DecInt(test.Data), data, internal.DecInt(data), test.Value) } var date time.Time @@ -2354,7 +2354,7 @@ func BenchmarkUnmarshalVarchar(b *testing.B) { func TestMarshalDuration(t *testing.T) { durationS := "1h10m10s" duration, _ := time.ParseDuration(durationS) - expectedData := append([]byte{0, 0}, encVint(duration.Nanoseconds())...) + expectedData := append([]byte{0, 0}, internal.EncVint(duration.Nanoseconds())...) var marshalDurationTests = []struct { Info TypeInfo Data []byte @@ -2391,7 +2391,7 @@ func TestMarshalDuration(t *testing.T) { } if !bytes.Equal(data, test.Data) { t.Errorf("marshalTest[%d]: expected %x (%v), got %x (%v) for time %s", i, - test.Data, decInt(test.Data), data, decInt(data), test.Value) + test.Data, internal.DecInt(test.Data), data, internal.DecInt(data), test.Value) } } } diff --git a/session.go b/session.go index d04a13672..a5623ca19 100644 --- a/session.go +++ b/session.go @@ -30,6 +30,7 @@ import ( "encoding/binary" "errors" "fmt" + "github.com/gocql/gocql/internal" "io" "net" "strings" @@ -824,6 +825,7 @@ type queryMetrics struct { totalAttempts int } +// TODO: move to internal (Maybe, if it's posible to move the hostMetrics and queryMetrics) // preFilledQueryMetrics initializes new queryMetrics based on per-host supplied data. func preFilledQueryMetrics(m map[string]*hostMetrics) *queryMetrics { qm := &queryMetrics{m: m} @@ -1303,18 +1305,10 @@ func (q *Query) Exec() error { return q.Iter().Close() } -func isUseStatement(stmt string) bool { - if len(stmt) < 3 { - return false - } - - return strings.EqualFold(stmt[0:3], "use") -} - // Iter executes the query and returns an iterator capable of iterating // over all results. func (q *Query) Iter() *Iter { - if isUseStatement(q.stmt) { + if internal.IsUseStatement(q.stmt) { return &Iter{err: ErrUseStmt} } // if the query was specifically run on a connection then re-use that diff --git a/session_test.go b/session_test.go index 8633f9957..c5d4cdd25 100644 --- a/session_test.go +++ b/session_test.go @@ -341,7 +341,7 @@ func TestIsUseStatement(t *testing.T) { } for _, tc := range testCases { - v := isUseStatement(tc.input) + v := internal.IsUseStatement(tc.input) if v != tc.exp { t.Fatalf("expected %v but got %v for statement %q", tc.exp, v, tc.input) } diff --git a/token_test.go b/token_test.go index 90e0d4fd8..da3e4d3ba 100644 --- a/token_test.go +++ b/token_test.go @@ -48,7 +48,7 @@ func TestMurmur3Partitioner(t *testing.T) { // at least verify that the partitioner // doesn't return nil - pk, _ := marshalInt(nil, 1) + pk, _ := internal.MarshalInt(nil, 1) token = murmur3Partitioner{}.Hash(pk) if token == nil { t.Fatal("token was nil") @@ -73,7 +73,7 @@ func TestOrderedPartitioner(t *testing.T) { // at least verify that the partitioner // doesn't return nil p := orderedPartitioner{} - pk, _ := marshalInt(nil, 1) + pk, _ := internal.MarshalInt(nil, 1) token := p.Hash(pk) if token == nil { t.Fatal("token was nil") @@ -109,7 +109,7 @@ func TestRandomPartitioner(t *testing.T) { // at least verify that the partitioner // doesn't return nil p := randomPartitioner{} - pk, _ := marshalInt(nil, 1) + pk, _ := internal.MarshalInt(nil, 1) token := p.Hash(pk) if token == nil { t.Fatal("token was nil")