Skip to content

Commit

Permalink
Fix all GetCollectionByID usage (#155)
Browse files Browse the repository at this point in the history
Signed-off-by: Congqi Xia <[email protected]>
  • Loading branch information
congqixia authored Jun 15, 2023
1 parent ea86efa commit f3c6057
Show file tree
Hide file tree
Showing 9 changed files with 56 additions and 69 deletions.
5 changes: 4 additions & 1 deletion states/download_pk.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (
"github.com/manifoldco/promptui"
"github.com/milvus-io/birdwatcher/proto/v2.0/datapb"
"github.com/milvus-io/birdwatcher/states/etcd/common"
etcdversion "github.com/milvus-io/birdwatcher/states/etcd/version"
"github.com/minio/minio-go/v7"
"github.com/minio/minio-go/v7/pkg/credentials"
"github.com/spf13/cobra"
Expand All @@ -28,7 +29,9 @@ func getDownloadPKCmd(cli clientv3.KV, basePath string) *cobra.Command {
return err
}

coll, err := common.GetCollectionByID(cli, basePath, collectionID)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
coll, err := common.GetCollectionByIDVersion(ctx, cli, basePath, etcdversion.GetVersion(), collectionID)
if err != nil {
fmt.Println("Collection not found for id", collectionID)
return nil
Expand Down
2 changes: 0 additions & 2 deletions states/etcd/common/channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,9 +58,7 @@ func ListChannelWatch(ctx context.Context, cli clientv3.KV, basePath string, ver
return nil, err
}
result = lo.Map(infos, func(info datapbv2.ChannelWatchInfo, idx int) *models.ChannelWatch {
fmt.Println(info.String())
return models.GetChannelWatchInfo[*datapbv2.ChannelWatchInfo, datapbv2.ChannelWatchState, *datapbv2.VchannelInfo, *internalpbv2.MsgPosition](&info, paths[idx])

})
default:
return nil, errors.New("version not supported")
Expand Down
33 changes: 0 additions & 33 deletions states/etcd/common/collection.go
Original file line number Diff line number Diff line change
Expand Up @@ -112,39 +112,6 @@ func ListCollectionsVersion(ctx context.Context, cli clientv3.KV, basePath strin
}
}

// GetCollectionByID returns collection info from etcd with provided id.
func GetCollectionByID(cli clientv3.KV, basePath string, collID int64) (*etcdpb.CollectionInfo, error) {
ctx, cancel := context.WithTimeout(context.Background(), time.Second*3)
defer cancel()
resp, err := cli.Get(ctx, path.Join(basePath, CollectionMetaPrefix, strconv.FormatInt(collID, 10)))

if err != nil {
return nil, err
}

if len(resp.Kvs) != 1 {
return nil, errors.New("invalid collection id")
}

if bytes.Equal(resp.Kvs[0].Value, CollectionTombstone) {
return nil, fmt.Errorf("%w, collection id: %d", ErrCollectionDropped, collID)
}

coll := &etcdpb.CollectionInfo{}

err = proto.Unmarshal(resp.Kvs[0].Value, coll)
if err != nil {
return nil, err
}

err = FillFieldSchemaIfEmpty(cli, basePath, coll)
if err != nil {
return nil, err
}

return coll, nil
}

// GetCollectionByIDVersion retruns collection info from etcd with provided version & id.
func GetCollectionByIDVersion(ctx context.Context, cli clientv3.KV, basePath string, version string, collID int64) (*models.Collection, error) {

Expand Down
9 changes: 6 additions & 3 deletions states/etcd/repair/channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ import (
"github.com/milvus-io/birdwatcher/proto/v2.0/commonpb"
"github.com/milvus-io/birdwatcher/proto/v2.0/datapb"
"github.com/milvus-io/birdwatcher/states/etcd/common"
etcdversion "github.com/milvus-io/birdwatcher/states/etcd/version"
"github.com/spf13/cobra"
clientv3 "go.etcd.io/etcd/client/v3"
"google.golang.org/grpc"
Expand Down Expand Up @@ -36,15 +37,17 @@ func ChannelCommand(cli clientv3.KV, basePath string) *cobra.Command {
return
}

coll, err := common.GetCollectionByID(cli, basePath, collID)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
coll, err := common.GetCollectionByIDVersion(ctx, cli, basePath, etcdversion.GetVersion(), collID)
if err != nil {
fmt.Println("collection not found")
return
}

chans := make(map[string]struct{})
for _, vchan := range coll.GetVirtualChannelNames() {
chans[vchan] = struct{}{}
for _, c := range coll.Channels {
chans[c.VirtualName] = struct{}{}
}

infos, _, err := common.ListChannelWatchV1(cli, basePath)
Expand Down
37 changes: 21 additions & 16 deletions states/etcd/repair/checkpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,13 @@ import (
"github.com/spf13/cobra"
clientv3 "go.etcd.io/etcd/client/v3"

"github.com/milvus-io/birdwatcher/models"
"github.com/milvus-io/birdwatcher/mq"
"github.com/milvus-io/birdwatcher/proto/v2.0/commonpb"
"github.com/milvus-io/birdwatcher/proto/v2.0/datapb"
"github.com/milvus-io/birdwatcher/proto/v2.0/etcdpb"
"github.com/milvus-io/birdwatcher/proto/v2.0/internalpb"
"github.com/milvus-io/birdwatcher/states/etcd/common"
etcdversion "github.com/milvus-io/birdwatcher/states/etcd/version"
"github.com/milvus-io/birdwatcher/utils"
)

Expand Down Expand Up @@ -46,7 +47,11 @@ func CheckpointCommand(cli clientv3.KV, basePath string) *cobra.Command {
return
}

coll, err := common.GetCollectionByID(cli, basePath, collID)
//coll, err := common.GetCollectionByID(cli, basePath, collID)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
coll, err := common.GetCollectionByIDVersion(ctx, cli, basePath, etcdversion.GetVersion(), collID)

if err != nil {
fmt.Println("failed to get collection", err.Error())
return
Expand Down Expand Up @@ -85,20 +90,20 @@ func CheckpointCommand(cli clientv3.KV, basePath string) *cobra.Command {
return cmd
}

func setCheckPointWithLatestMsgID(cli clientv3.KV, basePath string, coll *etcdpb.CollectionInfo, mqType, address, vchannel string) {
for _, ch := range coll.GetVirtualChannelNames() {
if ch == vchannel {
pChannel := ToPhysicalChannel(ch)
func setCheckPointWithLatestMsgID(cli clientv3.KV, basePath string, coll *models.Collection, mqType, address, vchannel string) {
for _, ch := range coll.Channels {
if ch.VirtualName == vchannel {
pChannel := ch.PhysicalName
cp, err := getLatestFromPChannel(mqType, address, vchannel)
if err != nil {
fmt.Printf("vchannel:%s -> pchannel:%s, get latest msgID faile, err:%s\n", ch, pChannel, err.Error())
fmt.Printf("vchannel:%s -> pchannel:%s, get latest msgID faile, err:%s\n", ch.VirtualName, pChannel, err.Error())
return
}

err = saveChannelCheckpoint(cli, basePath, ch, cp)
err = saveChannelCheckpoint(cli, basePath, ch.VirtualName, cp)
t, _ := utils.ParseTS(cp.GetTimestamp())
if err != nil {
fmt.Printf("failed to set latest msgID(ts:%v) for vchannel:%s", t, ch)
fmt.Printf("failed to set latest msgID(ts:%v) for vchannel:%s", t, ch.VirtualName)
return
}
fmt.Printf("vchannel:%s set to latest msgID(ts:%v) finshed\n", vchannel, t)
Expand All @@ -108,7 +113,7 @@ func setCheckPointWithLatestMsgID(cli clientv3.KV, basePath string, coll *etcdpb
fmt.Printf("vchannel:%s doesn't exists in collection: %d\n", vchannel, coll.ID)
}

func setCheckPointWithLatestCheckPoint(cli clientv3.KV, basePath string, coll *etcdpb.CollectionInfo, vchannel string) {
func setCheckPointWithLatestCheckPoint(cli clientv3.KV, basePath string, coll *models.Collection, vchannel string) {
pChannelName2LatestCP, err := getLatestCheckpointFromPChannel(cli, basePath)
if err != nil {
fmt.Println("failed to get latest cp of all pchannel", err.Error())
Expand All @@ -121,19 +126,19 @@ func setCheckPointWithLatestCheckPoint(cli clientv3.KV, basePath string, coll *e
fmt.Printf("pchannel: %s, the lastest checkpoint ts: %v\n", k, t)
}

for _, ch := range coll.GetVirtualChannelNames() {
if ch == vchannel {
pChannel := ToPhysicalChannel(ch)
for _, ch := range coll.Channels {
if ch.VirtualName == vchannel {
pChannel := ch.PhysicalName
cp, ok := pChannelName2LatestCP[pChannel]
if !ok {
fmt.Printf("vchannel:%s -> pchannel:%s, the pchannel doesn't exists\n", ch, pChannel)
fmt.Printf("vchannel:%s -> pchannel:%s, the pchannel doesn't exists\n", ch.VirtualName, pChannel)
return
}

err := saveChannelCheckpoint(cli, basePath, ch, cp)
err := saveChannelCheckpoint(cli, basePath, ch.VirtualName, cp)
t, _ := utils.ParseTS(cp.GetTimestamp())
if err != nil {
fmt.Printf("failed to set latest checkpoint(ts:%v) for vchannel:%s", t, ch)
fmt.Printf("failed to set latest checkpoint(ts:%v) for vchannel:%s", t, ch.VirtualName)
return
}
fmt.Printf("vchannel:%s set to latest checkpoint(ts:%v) finshed\n", vchannel, t)
Expand Down
17 changes: 11 additions & 6 deletions states/etcd/repair/segment.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,12 +6,13 @@ import (
"path"

"github.com/golang/protobuf/proto"
"github.com/milvus-io/birdwatcher/models"
"github.com/milvus-io/birdwatcher/proto/v2.0/commonpb"
"github.com/milvus-io/birdwatcher/proto/v2.0/datapb"
"github.com/milvus-io/birdwatcher/proto/v2.0/etcdpb"
"github.com/milvus-io/birdwatcher/proto/v2.0/indexpb"
"github.com/milvus-io/birdwatcher/proto/v2.0/schemapb"
"github.com/milvus-io/birdwatcher/states/etcd/common"
etcdversion "github.com/milvus-io/birdwatcher/states/etcd/version"
"github.com/spf13/cobra"
clientv3 "go.etcd.io/etcd/client/v3"
)
Expand Down Expand Up @@ -86,7 +87,7 @@ func SegmentCommand(cli clientv3.KV, basePath string) *cobra.Command {
buildID2Info[info.IndexBuildID] = info
}

collections := make(map[int64]*etcdpb.CollectionInfo)
collections := make(map[int64]*models.Collection)

targetOld := make(map[int64]*datapb.SegmentInfo)
target := make(map[int64]*datapb.SegmentInfo)
Expand All @@ -105,7 +106,11 @@ func SegmentCommand(cli clientv3.KV, basePath string) *cobra.Command {

coll, ok := collections[segment.CollectionID]
if !ok {
coll, err = common.GetCollectionByID(cli, basePath, segment.CollectionID)
//coll, err = common.GetCollectionByID(cli, basePath, segment.CollectionID)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
coll, err := common.GetCollectionByIDVersion(ctx, cli, basePath, etcdversion.GetVersion(), collID)

if err != nil {
fmt.Printf("failed to query collection(id=%d) info error: %s", segment.CollectionID, err.Error())
continue
Expand All @@ -117,9 +122,9 @@ func SegmentCommand(cli clientv3.KV, basePath string) *cobra.Command {

for _, segIdx := range segIdxs {
var valid bool
for _, field := range coll.GetSchema().GetFields() {
if field.GetFieldID() == segIdx.GetFieldID() {
if field.GetDataType() == schemapb.DataType_FloatVector || field.GetDataType() == schemapb.DataType_BinaryVector {
for _, field := range coll.Schema.Fields {
if field.FieldID == segIdx.GetFieldID() {
if field.DataType == models.DataTypeFloatVector || field.DataType == models.DataTypeBinaryVector {
valid = true
}
break
Expand Down
13 changes: 7 additions & 6 deletions states/etcd/show/checkpoint.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"github.com/milvus-io/birdwatcher/proto/v2.0/datapb"
"github.com/milvus-io/birdwatcher/proto/v2.0/internalpb"
"github.com/milvus-io/birdwatcher/states/etcd/common"
etcdversion "github.com/milvus-io/birdwatcher/states/etcd/version"
"github.com/milvus-io/birdwatcher/utils"
"github.com/spf13/cobra"
clientv3 "go.etcd.io/etcd/client/v3"
Expand All @@ -28,27 +29,27 @@ func CheckpointCommand(cli clientv3.KV, basePath string) *cobra.Command {
return
}

coll, err := common.GetCollectionByID(cli, basePath, collID)
coll, err := common.GetCollectionByIDVersion(context.Background(), cli, basePath, etcdversion.GetVersion(), collID)
if err != nil {
fmt.Println("failed to get collection", err.Error())
return
}

for _, vchannel := range coll.GetVirtualChannelNames() {
for _, channel := range coll.Channels {
var cp *internalpb.MsgPosition
var segmentID int64
var err error
cp, err = getChannelCheckpoint(cli, basePath, vchannel)
cp, err = getChannelCheckpoint(cli, basePath, channel.VirtualName)

if err != nil {
cp, segmentID, err = getCheckpointFromSegments(cli, basePath, collID, vchannel)
cp, segmentID, err = getCheckpointFromSegments(cli, basePath, collID, channel.VirtualName)
}

if cp == nil {
fmt.Printf("vchannel %s position nil\n", vchannel)
fmt.Printf("vchannel %s position nil\n", channel.VirtualName)
} else {
t, _ := utils.ParseTS(cp.GetTimestamp())
fmt.Printf("vchannel %s seek to %v, cp channel: %s", vchannel, t, cp.ChannelName)
fmt.Printf("vchannel %s seek to %v, cp channel: %s", channel.VirtualName, t, cp.ChannelName)
if segmentID > 0 {
fmt.Printf(", for segment ID:%d\n", segmentID)
} else {
Expand Down
2 changes: 1 addition & 1 deletion states/force_release.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ func getReleaseDroppedCollectionCmd(cli clientv3.KV, basePath string) *cobra.Com

var missing []int64
for _, info := range collectionLoadInfos {
_, err := common.GetCollectionByID(cli, basePath, info.CollectionID)
_, err := common.GetCollectionByIDVersion(ctx, cli, basePath, etcdversion.GetVersion(), info.CollectionID)
if err != nil {
missing = append(missing, info.CollectionID)
}
Expand Down
7 changes: 6 additions & 1 deletion states/inspect_primary_key.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package states

import (
"context"
"fmt"
"os"
"path"
Expand All @@ -9,6 +10,7 @@ import (
"github.com/milvus-io/birdwatcher/proto/v2.0/commonpb"
"github.com/milvus-io/birdwatcher/proto/v2.0/datapb"
"github.com/milvus-io/birdwatcher/states/etcd/common"
etcdversion "github.com/milvus-io/birdwatcher/states/etcd/version"
"github.com/milvus-io/birdwatcher/storage"
"github.com/spf13/cobra"
clientv3 "go.etcd.io/etcd/client/v3"
Expand Down Expand Up @@ -50,7 +52,10 @@ func getInspectPKCmd(cli clientv3.KV, basePath string) *cobra.Command {
}
pkID, has := cachedCollection[segment.CollectionID]
if !has {
coll, err := common.GetCollectionByID(cli, basePath, segment.CollectionID)
ctx, cancel := context.WithCancel(context.Background())
defer cancel()
coll, err := common.GetCollectionByIDVersion(ctx, cli, basePath, etcdversion.GetVersion(), segment.GetCollectionID())

if err != nil {
fmt.Println("Collection not found for id", segment.CollectionID)
return
Expand Down

0 comments on commit f3c6057

Please sign in to comment.