diff --git a/states/etcd_restore.go b/states/etcd_restore.go index ab72a8e..e520ab2 100644 --- a/states/etcd_restore.go +++ b/states/etcd_restore.go @@ -9,6 +9,7 @@ import ( "fmt" "io" "strconv" + "sync" "time" "github.com/golang/protobuf/proto" @@ -108,9 +109,9 @@ func restoreV2File(rd *bufio.Reader, state *embedEtcdMockState) error { state.defaultMetrics[fmt.Sprintf("%s-%d", session.ServerName, session.ServerID)] = defaultMetrics }) case int32(models.Configurations): - testRestoreConfigurations(rd, ph) + //testRestoreConfigurations(rd, ph) case int32(models.AppMetrics): - testRestoreConfigurations(rd, ph) + //testRestoreConfigurations(rd, ph) } } } @@ -139,59 +140,112 @@ func restoreEtcdFromBackV2(cli clientv3.KV, rd io.Reader, ph models.PartHeader) fmt.Fprintf(progressDisplay, progressFmt, 0, 0, cnt) defer progressDisplay.Stop() - for { - bsRead, err := io.ReadFull(rd, lb) //rd.Read(lb) - // all file read - if err == io.EOF { - return meta["instance"], nil - } - if err != nil { - fmt.Println("failed to read file:", err.Error()) - return "", err - } - if bsRead < 8 { - fmt.Printf("fail to read next length %d instead of 8 read\n", bsRead) - return "", errors.New("invalid file format") - } + batchNum := 10 + ch := make(chan []*commonpb.KeyDataPair, 10) + errCh := make(chan error, 1) - nextBytes = binary.LittleEndian.Uint64(lb) - // stopper found - if nextBytes == 0 { - return meta["instance"], nil - } - bs = make([]byte, nextBytes) + go func() { + defer close(ch) + batch := make([]*commonpb.KeyDataPair, 0, batchNum) + defer func() { + if len(batch) > 0 { + ch <- batch + } + }() + var lastPrint time.Time + for { + bsRead, err := io.ReadFull(rd, lb) //rd.Read(lb) + // all file read + if err == io.EOF { + //return meta["instance"], nil + errCh <- nil + return + } + if err != nil { + fmt.Println("failed to read file:", err.Error()) + errCh <- err + return + } + if bsRead < 8 { + fmt.Printf("fail to read next length %d instead of 8 read\n", bsRead) + errCh <- errors.New("invalid file format") + return + } - // cannot use rd.Read(bs), since proto marshal may generate a stopper - bsRead, err = io.ReadFull(rd, bs) - if err != nil { - fmt.Println("failed to read next kv data", err.Error()) - return "", err - } - if uint64(bsRead) != nextBytes { - fmt.Printf("bytesRead(%d)is not equal to nextBytes(%d)\n", bsRead, nextBytes) - return "", errors.New("bad file format") - } + nextBytes = binary.LittleEndian.Uint64(lb) + // stopper found + if nextBytes == 0 { + errCh <- nil + return + } + bs = make([]byte, nextBytes) - entry := &commonpb.KeyDataPair{} - err = proto.Unmarshal(bs, entry) - if err != nil { - //Skip for now - fmt.Printf("fail to parse line: %s, skip for now\n", err.Error()) - continue - } + // cannot use rd.Read(bs), since proto marshal may generate a stopper + bsRead, err = io.ReadFull(rd, bs) + if err != nil { + fmt.Println("failed to read next kv data", err.Error()) + errCh <- err + return + } + if uint64(bsRead) != nextBytes { + fmt.Printf("bytesRead(%d)is not equal to nextBytes(%d)\n", bsRead, nextBytes) + errCh <- errors.New("bad file format") + return + } - ctx, cancel := context.WithTimeout(context.Background(), time.Second*3) - defer cancel() - _, err = cli.Put(ctx, entry.Key, string(entry.Data)) - if err != nil { - fmt.Println("failed save kv into etcd, ", err.Error()) - continue + entry := &commonpb.KeyDataPair{} + err = proto.Unmarshal(bs, entry) + if err != nil { + //Skip for now + fmt.Printf("fail to parse line: %s, skip for now\n", err.Error()) + continue + } + + batch = append(batch, entry) + if len(batch) >= batchNum { + ch <- batch + batch = make([]*commonpb.KeyDataPair, 0, batchNum) + } + i++ + progress := i * 100 / int(cnt) + + if time.Since(lastPrint) > time.Millisecond*10 || progress == 100 { + fmt.Fprintf(progressDisplay, progressFmt, progress, i, cnt) + lastPrint = time.Now() + } } - i++ - progress := i * 100 / int(cnt) + }() + + var wg sync.WaitGroup + workerNum := 3 + wg.Add(workerNum) + for i := 0; i < 3; i++ { + go func() { + defer wg.Done() + for batch := range ch { + ops := make([]clientv3.Op, 0, len(batch)) + for _, entry := range batch { + ops = append(ops, clientv3.OpPut(entry.Key, string(entry.Data))) + } + func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Second*3) + defer cancel() + _, err := cli.Txn(ctx).If().Then(ops...).Commit() + if err != nil { + fmt.Println(err.Error()) + } + }() + } + }() + } - fmt.Fprintf(progressDisplay, progressFmt, progress, i, cnt) + err = <-errCh + wg.Wait() + if err != nil { + return "", err } + + return meta["instance"], nil } func restoreMetrics(rd io.Reader, ph models.PartHeader, handler func(session *models.Session, metrics, defaultMetrics []byte)) error { diff --git a/states/load_backup.go b/states/load_backup.go index 860adc2..8207864 100644 --- a/states/load_backup.go +++ b/states/load_backup.go @@ -8,6 +8,7 @@ import ( "os" "path" "strings" + "time" "github.com/cockroachdb/errors" "github.com/milvus-io/birdwatcher/configs" @@ -71,8 +72,9 @@ func (s *disconnectState) LoadBackupCommand(ctx context.Context, p *LoadBackupPa return err } fmt.Println("using data dir:", server.Config().Dir) - // TODO + nextState := getEmbedEtcdInstanceV2(server, s.config) + start := time.Now() switch header.Version { case 1: fmt.Printf("Found backup version: %d, instance name :%s\n", header.Version, header.Instance) @@ -95,6 +97,7 @@ func (s *disconnectState) LoadBackupCommand(ctx context.Context, p *LoadBackupPa nextState.Close() return err } + fmt.Println("load backup cost", time.Since(start)) err = nextState.setupWorkDir(server.Config().Dir) if err != nil { fmt.Println("failed to setup workspace for backup file", err.Error())