From 45b81da653a48d971071290f52ec28bd4eb7c701 Mon Sep 17 00:00:00 2001 From: Rueian Date: Sun, 9 Jan 2022 16:59:44 +0800 Subject: [PATCH] feat: support om by redisjson --- README.md | 47 ++++++++---- om/hash_conv.go | 187 +++++++++++------------------------------------- om/hash_repo.go | 170 +++++++++++++++++-------------------------- om/json_repo.go | 137 +++++++++++++++++++++++++++++++++++ om/schema.go | 109 ++++++++++++++++++++++++++++ 5 files changed, 386 insertions(+), 264 deletions(-) create mode 100644 om/json_repo.go create mode 100644 om/schema.go diff --git a/README.md b/README.md index 9c823b26..09d23ebc 100644 --- a/README.md +++ b/README.md @@ -259,7 +259,7 @@ c.DoCache(ctx, c.B().Hmget().Key("myhash").Field("1", "2").Cache(), time.Second* ## Object Mapping -The `NewHashRepository` creates an OM repository backed by redis hash. +The `NewHashRepository` and `NewJSONRepository` creates an OM repository backed by redis hash or RedisJSON. ```golang package main @@ -274,27 +274,26 @@ import ( ) type Example struct { - ID string `redis:"-,pk"` // the pk option indicate that this field is the ULID key - Ver int64 `redis:"_v"` // the _v field is required for optimistic locking to prevent the lost update - MyStr string `redis:"f1"` - MyArr []string `redis:"f2,sep=|"` // the sep= option is required for converting the slice to/from a string + Key string `json:"key" redis:",key"` // the redis:",key" is required to indicate which field is the ULID key + Ver int64 `json:"ver" redis:",ver"` // the redis:",ver" is required to do optimistic locking to prevent lost update + Str string `json:"myStr"` // both NewHashRepository and NewJSONRepository use json tag as field name } func main() { ctx := context.Background() c, _ := rueidis.NewClient(rueidis.ClientOption{InitAddress: []string{"127.0.0.1:6379"}}) - // create the hash repo. + // create the repo with NewHashRepository or NewJSONRepository repo := om.NewHashRepository("my_prefix", Example{}, c) exp := repo.NewEntity().(*Example) - exp.MyArr = []string{"1", "2"} - fmt.Println(exp.ID) // output 01FNH4FCXV9JTB9WTVFAAKGSYB + exp.Str = "mystr" + fmt.Println(exp.Key) // output 01FNH4FCXV9JTB9WTVFAAKGSYB repo.Save(ctx, exp) // success // lookup "my_prefix:01FNH4FCXV9JTB9WTVFAAKGSYB" through client side caching - cache, _ := repo.FetchCache(ctx, exp.ID, time.Second*5) + cache, _ := repo.FetchCache(ctx, exp.Key, time.Second*5) exp2 := cache.(*Example) - fmt.Println(exp2.MyArr) // output [1 2], which equals to exp.MyArray + fmt.Println(exp2.Str) // output "mystr", which equals to exp.Str exp2.Ver = 0 // if someone changes the version during your GET then SET operation, repo.Save(ctx, exp2) // the save will fail with ErrVersionMismatch. @@ -308,12 +307,20 @@ If you have RediSearch, you can create and search the repository against the ind ```golang -repo.CreateIndex(ctx, func(schema om.FtCreateSchema) om.Completed { - return schema.FieldName("f1").Text().Build() // you have full index capability by building the command from om.FtCreateSchema -}) +if _, ok := repo.(*om.HashRepository); ok { + repo.CreateIndex(ctx, func(schema om.FtCreateSchema) om.Completed { + return schema.FieldName("myStr").Text().Build() // Note that the Example.Str field is mapped to myStr on redis by its json tag + }) +} + +if _, ok := repo.(*om.JSONRepository); ok { + repo.CreateIndex(ctx, func(schema om.FtCreateSchema) om.Completed { + return schema.FieldName("$.myStr").Text().Build() // the field name of json index should be a json path syntax + }) +} exp := repo.NewEntity().(*Example) -exp.MyStr = "foo" // Note that MyStr is mapped to "f1" in redis by the `redis:"f1"` tag +exp.Str = "foo" repo.Save(ctx, exp) n, records, _ := repo.Search(ctx, func(search om.FtSearchIndex) om.Completed { @@ -323,10 +330,20 @@ n, records, _ := repo.Search(ctx, func(search om.FtSearchIndex) om.Completed { fmt.Println("total", n) // n is total number of results matched in redis, which is >= len(records) for _, v := range records.([]*Example) { - fmt.Println(v.MyStr) // print "foo" + fmt.Println(v.Str) // print "foo" } ``` +### Object Mapping Limitation + +`NewHashRepository` only accepts these field types: +* string, *string +* int64, *int64 +* bool, *bool +* []byte + +Field projection by RediSearch is not supported. + ## Not Yet Implement The following subjects are not yet implemented. diff --git a/om/hash_conv.go b/om/hash_conv.go index 39de7be8..ce0e10b2 100644 --- a/om/hash_conv.go +++ b/om/hash_conv.go @@ -4,120 +4,42 @@ import ( "fmt" "reflect" "strconv" - "strings" + "unsafe" ) -type HashConverter interface { - ToHash() (id string, fields map[string]string) - FromHash(id string, fields map[string]string) error -} - -const ( - PKOption = "pk" - IgnoreField = "-" - VersionField = "_v" - SliceSepTag = "sep" -) - -func newHashConvFactory(t reflect.Type) *hashConvFactory { - if t.Kind() != reflect.Struct { - panic(fmt.Sprintf("schema %q should be a struct", t)) - } - - v := reflect.New(t) - - fields := make(map[string]field, t.NumField()) - for i := 0; i < t.NumField(); i++ { - f := t.Field(i) - name, options, ok := parseTag(f.Tag) - if !ok { - continue - } - if name == "" { - panic(fmt.Sprintf("schema %q should not contain fields with empty redis tag", t)) - } - if _, ok = fields[name]; ok { - panic(fmt.Sprintf("schema %q should not contain fields with duplicated redis tag", t)) - } - if !v.Elem().Field(i).CanSet() { - panic(fmt.Sprintf("schema %q should not contain private fields with redis tag", t)) - } - if name == IgnoreField { - if _, ok := options[PKOption]; !ok { - panic(fmt.Sprintf("schema %q should non pk fields with redis %q tag", t, "-")) - } - } - if name == VersionField { - if f.Type.Kind() != reflect.Int64 { - panic(fmt.Sprintf("field with tag `redis:%q` in schema %q should be a int64", VersionField, t)) - } - } - if _, ok := options[PKOption]; ok { - if f.Type.Kind() != reflect.String { - panic(fmt.Sprintf("field with tag `redis:\",pk\"` in schema %q should be a string", t)) - } - } +func newHashConvFactory(t reflect.Type, schema schema) *hashConvFactory { + factory := &hashConvFactory{converters: make(map[string]conv, len(schema.fields))} + for name, f := range schema.fields { + var converter converter + var ok bool - var conv converter - switch f.Type.Kind() { + switch f.typ.Kind() { case reflect.Ptr: - conv, ok = converters.ptr[f.Type.Elem().Kind()] + converter, ok = converters.ptr[f.typ.Elem().Kind()] case reflect.Slice: - if builder := converters.slice[f.Type.Elem().Kind()]; builder != nil { - sep := options[SliceSepTag] - if len(sep) == 0 { - panic(fmt.Sprintf("string slice field should have separator in tag `redis:\"%s,sep=\"` in schema %q", name, t)) - } - conv, ok = builder(sep), true - } + converter, ok = converters.slice[f.typ.Elem().Kind()] default: - conv, ok = converters.val[f.Type.Kind()] + converter, ok = converters.val[f.typ.Kind()] } if !ok { - panic(fmt.Sprintf("schema %q should not contain unsupported field type %s.", t, f.Type.Kind())) - } - fields[name] = field{position: i, options: options, converter: conv} - } - - factory := &hashConvFactory{fields: fields, pk: -1} - for _, f := range fields { - if _, ok := f.options[PKOption]; ok { - if factory.pk != -1 { - panic(fmt.Sprintf("schema %q should contain only one field with tag `redis:\",pk\"`", t)) - } - factory.pk = f.position + panic(fmt.Sprintf("schema %q should not contain unsupported field type %s.", t, f.typ.Kind())) } + factory.converters[name] = conv{conv: converter, idx: f.idx} } - if factory.pk == -1 { - panic(fmt.Sprintf("schema %q should contain a string field with tag `redis:\",pk\"` as primary key", t)) - } - if _, ok := fields[VersionField]; !ok { - panic(fmt.Sprintf("schema %q should contain a int64 field with tag `redis:%q` as version tag", VersionField, t)) - } - delete(fields, IgnoreField) - return factory } type hashConvFactory struct { - pk int - fields map[string]field + converters map[string]conv } -type field struct { - position int - converter converter - options map[string]string +type conv struct { + idx int + conv converter } func (f hashConvFactory) NewConverter(entity reflect.Value) hashConv { - if entity.Kind() == reflect.Ptr { - entity = entity.Elem() - } - return hashConv{ - factory: f, - entity: entity, - } + return hashConv{factory: f, entity: entity} } type hashConv struct { @@ -125,50 +47,32 @@ type hashConv struct { entity reflect.Value } -func (r hashConv) ToHash() (id string, fields map[string]string) { - fields = make(map[string]string, len(r.factory.fields)) - for f, field := range r.factory.fields { - ref := r.entity.Field(field.position) - if v, ok := field.converter.ValueToString(ref); ok { +func (r hashConv) ToHash() (fields map[string]string) { + fields = make(map[string]string, len(r.factory.converters)) + for f, converter := range r.factory.converters { + ref := r.entity.Field(converter.idx) + if v, ok := converter.conv.ValueToString(ref); ok { fields[f] = v } } - return r.entity.Field(r.factory.pk).String(), fields + return fields } -func (r hashConv) FromHash(id string, fields map[string]string) error { - r.entity.Field(r.factory.pk).Set(reflect.ValueOf(id)) - for f, field := range r.factory.fields { +func (r hashConv) FromHash(fields map[string]string) error { + for f, field := range r.factory.converters { v, ok := fields[f] if !ok { continue } - val, err := field.converter.StringToValue(v) + val, err := field.conv.StringToValue(v) if err != nil { return err } - r.entity.Field(field.position).Set(val) + r.entity.Field(field.idx).Set(val) } return nil } -func parseTag(tag reflect.StructTag) (name string, options map[string]string, ok bool) { - if name, ok = tag.Lookup("redis"); !ok { - return "", nil, false - } - tokens := strings.Split(name, ",") - options = make(map[string]string, len(tokens)-1) - for _, token := range tokens[1:] { - kv := strings.SplitN(token, "=", 2) - if len(kv) == 2 { - options[kv[0]] = kv[1] - } else { - options[kv[0]] = "" - } - } - return tokens[0], options, true -} - type converter struct { ValueToString func(value reflect.Value) (string, bool) StringToValue func(value string) (reflect.Value, error) @@ -177,7 +81,7 @@ type converter struct { var converters = struct { val map[reflect.Kind]converter ptr map[reflect.Kind]converter - slice map[reflect.Kind]func(sep string) converter + slice map[reflect.Kind]converter }{ ptr: map[reflect.Kind]converter{ reflect.Int64: { @@ -256,28 +160,19 @@ var converters = struct { }, }, }, - slice: map[reflect.Kind]func(sep string) converter{ - reflect.String: func(sep string) converter { - return converter{ - ValueToString: func(value reflect.Value) (string, bool) { - length := value.Len() - if length == 0 { - return "", false - } - sb := strings.Builder{} - for i := 0; i < length; i++ { - sb.WriteString(value.Index(i).String()) - if i != length-1 { - sb.WriteString(sep) - } - } - return sb.String(), true - }, - StringToValue: func(value string) (reflect.Value, error) { - s := strings.Split(value, sep) - return reflect.ValueOf(s), nil - }, - } + slice: map[reflect.Kind]converter{ + reflect.Uint8: { + ValueToString: func(value reflect.Value) (string, bool) { + buf, ok := value.Interface().([]byte) + if !ok { + return "", false + } + return *(*string)(unsafe.Pointer(&buf)), true + }, + StringToValue: func(value string) (reflect.Value, error) { + buf := []byte(value) + return reflect.ValueOf(buf), nil + }, }, }, } diff --git a/om/hash_repo.go b/om/hash_repo.go index 1cc269c1..88c3b4ba 100644 --- a/om/hash_repo.go +++ b/om/hash_repo.go @@ -2,32 +2,23 @@ package om import ( "context" - "errors" "fmt" "reflect" - "strings" + "strconv" "time" "github.com/rueian/rueidis" - "github.com/rueian/rueidis/internal/cmds" ) -type FtCreateSchema = cmds.FtCreateSchema -type FtSearchIndex = cmds.FtSearchIndex -type Completed = cmds.Completed - -var ErrVersionMismatch = errors.New("object version mismatched, please retry") - func NewHashRepository(prefix string, schema interface{}, client rueidis.Client) *HashRepository { repo := &HashRepository{ prefix: prefix, - idx: "idx:" + prefix, + idx: "hashidx:" + prefix, typ: reflect.TypeOf(schema), client: client, } - if _, ok := schema.(HashConverter); !ok { - repo.factory = newHashConvFactory(repo.typ) - } + repo.schema = newSchema(repo.typ) + repo.factory = newHashConvFactory(repo.typ, repo.schema) return repo } @@ -35,69 +26,23 @@ type HashRepository struct { prefix string idx string typ reflect.Type + schema schema factory *hashConvFactory client rueidis.Client } -func (r *HashRepository) key(id string) (key string) { - sb := strings.Builder{} - sb.Grow(len(r.prefix) + len(id) + 1) - sb.WriteString(r.prefix) - sb.WriteString(":") - sb.WriteString(id) - return sb.String() -} - -func (r *HashRepository) converter(v reflect.Value) (conv HashConverter) { - if r.factory != nil { - return r.factory.NewConverter(v) - } - return v.Interface().(HashConverter) -} - func (r *HashRepository) NewEntity() (entity interface{}) { v := reflect.New(r.typ) - _ = r.converter(v).FromHash(id(), nil) + v.Elem().Field(r.schema.keyField.idx).Set(reflect.ValueOf(id())) return v.Interface() } -func (r *HashRepository) fromHash(id string, record map[string]rueidis.RedisMessage) (v reflect.Value, err error) { - fields := make(map[string]string, len(record)) - for k, v := range record { - if s, err := v.ToString(); err == nil { - fields[k] = s - } - } - - v = reflect.New(r.typ) - if err := r.converter(v).FromHash(id, fields); err != nil { - return reflect.Value{}, err - } - return v, nil -} - -func (r *HashRepository) fromArray(id string, record []rueidis.RedisMessage) (v reflect.Value, err error) { - fields := make(map[string]string, len(record)/2) - for i := 0; i < len(record); i += 2 { - k, _ := record[i].ToString() - if s, err := record[i+1].ToString(); err == nil { - fields[k] = s - } - } - - v = reflect.New(r.typ) - if err := r.converter(v).FromHash(id, fields); err != nil { - return reflect.Value{}, err - } - return v, nil -} - func (r *HashRepository) Fetch(ctx context.Context, id string) (v interface{}, err error) { - record, err := r.client.Do(ctx, r.client.B().Hgetall().Key(r.key(id)).Build()).ToMap() + record, err := r.client.Do(ctx, r.client.B().Hgetall().Key(key(r.prefix, id)).Build()).ToMap() if err != nil { return nil, err } - val, err := r.fromHash(id, record) + val, err := r.fromHash(record) if err != nil { return nil, err } @@ -105,11 +50,11 @@ func (r *HashRepository) Fetch(ctx context.Context, id string) (v interface{}, e } func (r *HashRepository) FetchCache(ctx context.Context, id string, ttl time.Duration) (v interface{}, err error) { - record, err := r.client.DoCache(ctx, r.client.B().Hgetall().Key(r.key(id)).Cache(), ttl).ToMap() + record, err := r.client.DoCache(ctx, r.client.B().Hgetall().Key(key(r.prefix, id)).Cache(), ttl).ToMap() if err != nil { return nil, err } - val, err := r.fromHash(id, record) + val, err := r.fromHash(record) if err != nil { return nil, err } @@ -117,42 +62,34 @@ func (r *HashRepository) FetchCache(ctx context.Context, id string, ttl time.Dur } func (r *HashRepository) Save(ctx context.Context, entity interface{}) (err error) { - var conv HashConverter - - if r.factory != nil { - conv = r.factory.NewConverter(reflect.ValueOf(entity)) - } else { - conv = entity.(HashConverter) + val, ok := ptrValueOf(entity, r.typ) + if !ok { + panic(fmt.Sprintf("input entity should be a pointer to %v", r.typ)) } - id, fields := conv.ToHash() - if ver, ok := fields[VersionField]; ok { - args := make([]string, 0, len(fields)*2) - args = append(args, VersionField, ver) - for f, v := range fields { - if f == VersionField { - continue - } - args = append(args, f, v) - } - fields[VersionField], err = saveScript.Exec(ctx, r.client, []string{r.key(id)}, args).ToString() - if rueidis.IsRedisNil(err) { - return ErrVersionMismatch - } - if err != nil { - return err - } - return conv.FromHash(id, fields) + fields := r.factory.NewConverter(val).ToHash() + + keyVal := fields[r.schema.keyField.name] + verVal := fields[r.schema.verField.name] + + args := make([]string, 0, len(fields)*2) + args = append(args, r.schema.verField.name, verVal) // keep the ver field be the first pair for the hashSaveScript + delete(fields, r.schema.verField.name) + for k, v := range fields { + args = append(args, k, v) } - cmd := r.client.B().Hset().Key(r.key(id)).FieldValue() - for f, v := range fields { - cmd = cmd.FieldValue(f, v) + + str, err := hashSaveScript.Exec(ctx, r.client, []string{key(r.prefix, keyVal)}, args).ToString() + if rueidis.IsRedisNil(err) { + return ErrVersionMismatch } - return r.client.Do(ctx, cmd.Build()).Error() + ver, _ := strconv.ParseInt(str, 10, 64) + val.Field(r.schema.verField.idx).SetInt(ver) + return nil } func (r *HashRepository) Remove(ctx context.Context, id string) error { - return r.client.Do(ctx, r.client.B().Del().Key(r.key(id)).Build()).Error() + return r.client.Do(ctx, r.client.B().Del().Key(key(r.prefix, id)).Build()).Error() } func (r *HashRepository) CreateIndex(ctx context.Context, cmdFn func(schema FtCreateSchema) Completed) error { @@ -168,15 +105,12 @@ func (r *HashRepository) Search(ctx context.Context, cmdFn func(search FtSearchI if err != nil { return 0, nil, err } - prefix := r.prefix + ":" n, _ := resp[0].ToInt64() s := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(r.typ)), 0, len(resp[1:])/2) - for i := 1; i < len(resp); i += 2 { - id, _ := resp[i].ToString() - kv, _ := resp[i+1].ToArray() - - v, err := r.fromArray(strings.TrimPrefix(id, prefix), kv) + for i := 2; i < len(resp); i += 2 { + kv, _ := resp[i].ToArray() + v, err := r.fromArray(kv) if err != nil { return 0, nil, err } @@ -185,12 +119,42 @@ func (r *HashRepository) Search(ctx context.Context, cmdFn func(search FtSearchI return n, s.Interface(), nil } -var saveScript = rueidis.NewLuaScript(fmt.Sprintf(` -local v = redis.call('HGET',KEYS[1],'%s') +func (r *HashRepository) fromHash(record map[string]rueidis.RedisMessage) (v reflect.Value, err error) { + fields := make(map[string]string, len(record)) + for k, v := range record { + if s, err := v.ToString(); err == nil { + fields[k] = s + } + } + + v = reflect.New(r.typ) + if err := r.factory.NewConverter(v.Elem()).FromHash(fields); err != nil { + return reflect.Value{}, err + } + return v, nil +} + +func (r *HashRepository) fromArray(record []rueidis.RedisMessage) (v reflect.Value, err error) { + fields := make(map[string]string, len(record)/2) + for i := 0; i < len(record); i += 2 { + k, _ := record[i].ToString() + if s, err := record[i+1].ToString(); err == nil { + fields[k] = s + } + } + v = reflect.New(r.typ) + if err := r.factory.NewConverter(v.Elem()).FromHash(fields); err != nil { + return reflect.Value{}, err + } + return v, nil +} + +var hashSaveScript = rueidis.NewLuaScript(` +local v = redis.call('HGET',KEYS[1],ARGV[1]) if (not v or v == ARGV[2]) then ARGV[2] = tostring(tonumber(ARGV[2])+1) if redis.call('HSET',KEYS[1],unpack(ARGV)) then return ARGV[2] end end return nil -`, VersionField)) +`) diff --git a/om/json_repo.go b/om/json_repo.go new file mode 100644 index 00000000..644a1dff --- /dev/null +++ b/om/json_repo.go @@ -0,0 +1,137 @@ +package om + +import ( + "context" + "encoding/json" + "fmt" + "reflect" + "strconv" + "strings" + "time" + + "github.com/rueian/rueidis" +) + +func NewJSONRepository(prefix string, schema interface{}, client rueidis.Client) *JSONRepository { + repo := &JSONRepository{ + prefix: prefix, + idx: "jsonidx:" + prefix, + typ: reflect.TypeOf(schema), + client: client, + } + repo.schema = newSchema(repo.typ) + return repo +} + +type JSONRepository struct { + prefix string + idx string + typ reflect.Type + schema schema + client rueidis.Client +} + +func (r *JSONRepository) NewEntity() (entity interface{}) { + v := reflect.New(r.typ) + v.Elem().Field(r.schema.keyField.idx).Set(reflect.ValueOf(id())) + return v.Interface() +} + +func (r *JSONRepository) Fetch(ctx context.Context, id string) (interface{}, error) { + record, err := r.client.Do(ctx, r.client.B().JsonGet().Key(key(r.prefix, id)).Build()).ToString() + if err != nil { + return nil, err + } + iface, _, err := r.decode(record) + return iface, err +} + +func (r *JSONRepository) FetchCache(ctx context.Context, id string, ttl time.Duration) (v interface{}, err error) { + record, err := r.client.DoCache(ctx, r.client.B().JsonGet().Key(key(r.prefix, id)).Cache(), ttl).ToString() + if err != nil { + return nil, err + } + iface, _, err := r.decode(record) + return iface, err +} + +func (r *JSONRepository) decode(record string) (interface{}, reflect.Value, error) { + val := reflect.New(r.typ) + iface := val.Interface() + if err := json.NewDecoder(strings.NewReader(record)).Decode(iface); err != nil { + return nil, reflect.Value{}, err + } + return iface, val, nil +} + +func (r *JSONRepository) Save(ctx context.Context, entity interface{}) (err error) { + val, ok := ptrValueOf(entity, r.typ) + if !ok { + panic(fmt.Sprintf("input entity should be a pointer to %v", r.typ)) + } + + keyField := val.Field(r.schema.keyField.idx) + verField := val.Field(r.schema.verField.idx) + + sb := strings.Builder{} + if err = json.NewEncoder(&sb).Encode(entity); err != nil { + return err + } + + str, err := jsonSaveScript.Exec(ctx, r.client, []string{key(r.prefix, keyField.String())}, []string{ + r.schema.verField.name, + strconv.FormatInt(verField.Int(), 10), + sb.String(), + }).ToString() + if rueidis.IsRedisNil(err) { + return ErrVersionMismatch + } + ver, _ := strconv.ParseInt(str, 10, 64) + verField.SetInt(ver) + return nil +} + +func (r *JSONRepository) Remove(ctx context.Context, id string) error { + return r.client.Do(ctx, r.client.B().Del().Key(key(r.prefix, id)).Build()).Error() +} + +func (r *JSONRepository) CreateIndex(ctx context.Context, cmdFn func(schema FtCreateSchema) Completed) error { + return r.client.Do(ctx, cmdFn(r.client.B().FtCreate().Index(r.idx).OnJson().Prefix(1).Prefix(r.prefix+":").Schema())).Error() +} + +func (r *JSONRepository) DropIndex(ctx context.Context) error { + return r.client.Do(ctx, r.client.B().FtDropindex().Index(r.idx).Build()).Error() +} + +func (r *JSONRepository) Search(ctx context.Context, cmdFn func(search FtSearchIndex) Completed) (int64, interface{}, error) { + resp, err := r.client.Do(ctx, cmdFn(r.client.B().FtSearch().Index(r.idx))).ToArray() + if err != nil { + return 0, nil, err + } + + n, _ := resp[0].ToInt64() + s := reflect.MakeSlice(reflect.SliceOf(reflect.PtrTo(r.typ)), 0, len(resp[1:])/2) + for i := 2; i < len(resp); i += 2 { + if kv, _ := resp[i].ToArray(); len(kv) == 2 { + if k, _ := kv[0].ToString(); k == "$" { + record, _ := kv[1].ToString() + _, v, err := r.decode(record) + if err != nil { + return 0, nil, err + } + s = reflect.Append(s, v) + } + } + } + return n, s.Interface(), nil +} + +var jsonSaveScript = rueidis.NewLuaScript(` +local v = redis.call('JSON.GET',KEYS[1],ARGV[1]) +if (not v or v == ARGV[2]) +then + redis.call('JSON.SET',KEYS[1],'$',ARGV[3]) + return redis.call('JSON.NUMINCRBY',KEYS[1],ARGV[1],1) +end +return nil +`) diff --git a/om/schema.go b/om/schema.go new file mode 100644 index 00000000..d741c0cb --- /dev/null +++ b/om/schema.go @@ -0,0 +1,109 @@ +package om + +import ( + "errors" + "fmt" + "reflect" + "strings" + + "github.com/rueian/rueidis/internal/cmds" +) + +const IgnoreField = "-" + +type FtCreateSchema = cmds.FtCreateSchema +type FtSearchIndex = cmds.FtSearchIndex +type Completed = cmds.Completed + +var ErrVersionMismatch = errors.New("object version mismatched, please retry") + +type schema struct { + keyField *field + verField *field + fields map[string]*field +} + +type field struct { + name string + idx int + typ reflect.Type + isKeyField bool + isVerField bool +} + +func newSchema(t reflect.Type) schema { + if t.Kind() != reflect.Struct { + panic(fmt.Sprintf("schema %q should be a struct", t)) + } + + schema := schema{fields: make(map[string]*field, t.NumField())} + + for i := 0; i < t.NumField(); i++ { + f := t.Field(i) + if !f.IsExported() { + continue + } + field := parse(f) + if field.name == IgnoreField { + continue + } + field.idx = i + schema.fields[field.name] = &field + + if field.isKeyField { + if f.Type.Kind() != reflect.String { + panic(fmt.Sprintf("field with tag `redis:\",key\"` in schema %q should be a string", t)) + } + schema.keyField = &field + } + if field.isVerField { + if f.Type.Kind() != reflect.Int64 { + panic(fmt.Sprintf("field with tag `redis:\",ver\"` in schema %q should be a int64", t)) + } + schema.verField = &field + } + } + + if schema.keyField == nil { + panic(fmt.Sprintf("schema %q should have one field with `redis:\",key\"` tag", t)) + } + if schema.verField == nil { + panic(fmt.Sprintf("schema %q should have one field with `redis:\",ver\"` tag", t)) + } + + return schema +} + +func parse(f reflect.StructField) (field field) { + v, _ := f.Tag.Lookup("json") + vs := strings.SplitN(v, ",", 1) + if vs[0] == "" { + field.name = f.Name + } else { + field.name = vs[0] + } + + v, _ = f.Tag.Lookup("redis") + field.isKeyField = strings.Contains(v, ",key") + field.isVerField = strings.Contains(v, ",ver") + field.typ = f.Type + return field +} + +func key(prefix, id string) (key string) { + sb := strings.Builder{} + sb.Grow(len(prefix) + len(id) + 1) + sb.WriteString(prefix) + sb.WriteString(":") + sb.WriteString(id) + return sb.String() +} + +func ptrValueOf(entity interface{}, typ reflect.Type) (reflect.Value, bool) { + val := reflect.ValueOf(entity) + if val.Kind() != reflect.Ptr { + return reflect.Value{}, false + } + val = val.Elem() + return val, val.Type() == typ +}