Skip to content

Commit

Permalink
refine look aside balance logic (milvus-io#25837)
Browse files Browse the repository at this point in the history
Signed-off-by: Wei Liu <[email protected]>
  • Loading branch information
weiliu1031 authored Jul 25, 2023
1 parent a669440 commit 302897f
Show file tree
Hide file tree
Showing 11 changed files with 116 additions and 92 deletions.
30 changes: 17 additions & 13 deletions internal/proxy/look_aside_balancer.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ package proxy
import (
"context"
"math"
"math/rand"
"strconv"
"sync"
"time"
Expand All @@ -35,11 +36,6 @@ import (
"go.uber.org/zap"
)

var (
checkQueryNodeHealthInterval = 500 * time.Millisecond
CostMetricsExpireTime = 1000 * time.Millisecond
)

type LookAsideBalancer struct {
clientMgr shardClientMgr

Expand Down Expand Up @@ -88,6 +84,9 @@ func (b *LookAsideBalancer) SelectNode(ctx context.Context, availableNodes []int
log := log.Ctx(ctx).WithRateGroup("proxy.LookAsideBalancer", 1, 60)
targetNode := int64(-1)
targetScore := float64(math.MaxFloat64)
rand.Shuffle(len(availableNodes), func(i, j int) {
availableNodes[i], availableNodes[j] = availableNodes[j], availableNodes[i]
})
for _, node := range availableNodes {
if b.unreachableQueryNodes.Contain(node) {
log.RatedWarn(5, "query node is unreachable, skip it",
Expand Down Expand Up @@ -117,7 +116,8 @@ func (b *LookAsideBalancer) SelectNode(ctx context.Context, availableNodes []int

// update executing task cost
totalNQ, _ := b.executingTaskTotalNQ.Get(targetNode)
totalNQ.Add(cost)
nq := totalNQ.Add(cost)
metrics.ProxyExecutingTotalNq.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Set(float64(nq))

return targetNode, nil
}
Expand All @@ -126,28 +126,31 @@ func (b *LookAsideBalancer) SelectNode(ctx context.Context, availableNodes []int
func (b *LookAsideBalancer) CancelWorkload(node int64, nq int64) {
totalNQ, ok := b.executingTaskTotalNQ.Get(node)
if ok {
totalNQ.Sub(nq)
nq := totalNQ.Sub(nq)
metrics.ProxyExecutingTotalNq.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10)).Set(float64(nq))
}
}

// UpdateCostMetrics used for cache some metrics of recent search/query cost
func (b *LookAsideBalancer) UpdateCostMetrics(node int64, cost *internalpb.CostAggregation) {
// cache the latest query node cost metrics for updating the score
b.metricsMap.Insert(node, cost)
if cost != nil {
b.metricsMap.Insert(node, cost)
}
b.metricsUpdateTs.Insert(node, time.Now().UnixMilli())
}

// calculateScore compute the query node's workload score
// https://www.usenix.org/conference/nsdi15/technical-sessions/presentation/suresh
func (b *LookAsideBalancer) calculateScore(node int64, cost *internalpb.CostAggregation, executingNQ int64) float64 {
if cost == nil || cost.ResponseTime == 0 || cost.ServiceTime == 0 {
return math.Pow(float64(1+executingNQ), 3.0)
if cost == nil || cost.GetResponseTime() == 0 {
return math.Pow(float64(executingNQ), 3.0)
}

// for multi-replica cases, when there are no task which waiting in queue,
// the response time will effect the score, to prevent the score based on a too old value
// we expire the cost metrics by second if no task in queue.
if executingNQ == 0 && cost.TotalNQ == 0 && b.isNodeCostMetricsTooOld(node) {
if executingNQ == 0 && b.isNodeCostMetricsTooOld(node) {
return 0
}

Expand All @@ -167,13 +170,14 @@ func (b *LookAsideBalancer) isNodeCostMetricsTooOld(node int64) bool {
return false
}

return time.Now().UnixMilli()-lastUpdateTs > CostMetricsExpireTime.Milliseconds()
return time.Now().UnixMilli()-lastUpdateTs > Params.ProxyCfg.CostMetricsExpireTime.GetAsInt64()
}

func (b *LookAsideBalancer) checkQueryNodeHealthLoop(ctx context.Context) {
log := log.Ctx(ctx).WithRateGroup("proxy.LookAsideBalancer", 1, 60)
defer b.wg.Done()

checkQueryNodeHealthInterval := Params.ProxyCfg.CheckQueryNodeHealthInterval.GetAsDuration(time.Millisecond)
ticker := time.NewTicker(checkQueryNodeHealthInterval)
defer ticker.Stop()
log.Info("Start check query node health loop")
Expand All @@ -190,7 +194,7 @@ func (b *LookAsideBalancer) checkQueryNodeHealthLoop(ctx context.Context) {
b.metricsUpdateTs.Range(func(node int64, lastUpdateTs int64) bool {
if now-lastUpdateTs > checkQueryNodeHealthInterval.Milliseconds() {
futures = append(futures, pool.Submit(func() (any, error) {
checkInterval := paramtable.Get().ProxyCfg.HealthCheckTimetout.GetAsDuration(time.Millisecond)
checkInterval := Params.ProxyCfg.HealthCheckTimetout.GetAsDuration(time.Millisecond)
ctx, cancel := context.WithTimeout(context.Background(), checkInterval)
defer cancel()

Expand Down
8 changes: 4 additions & 4 deletions internal/proxy/look_aside_balancer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -272,7 +272,7 @@ func (suite *LookAsideBalancerSuite) TestSelectNode() {
}

for node, result := range c.result {
suite.Equal(result, counter[node])
suite.True(math.Abs(float64(result-counter[node])) <= float64(1))
}
})
}
Expand Down Expand Up @@ -302,7 +302,7 @@ func (suite *LookAsideBalancerSuite) TestCheckHealthLoop() {
suite.balancer.unreachableQueryNodes.Insert(2)
suite.Eventually(func() bool {
return suite.balancer.unreachableQueryNodes.Contain(1)
}, 2*time.Second, 100*time.Millisecond)
}, 3*time.Second, 100*time.Millisecond)
targetNode, err := suite.balancer.SelectNode(context.Background(), []int64{1}, 1)
suite.ErrorIs(err, merr.ErrServiceUnavailable)
suite.Equal(int64(-1), targetNode)
Expand Down Expand Up @@ -331,11 +331,11 @@ func (suite *LookAsideBalancerSuite) TestNodeRecover() {
suite.balancer.metricsUpdateTs.Insert(3, time.Now().UnixMilli())
suite.Eventually(func() bool {
return suite.balancer.unreachableQueryNodes.Contain(3)
}, 2*time.Second, 100*time.Millisecond)
}, 5*time.Second, 100*time.Millisecond)

suite.Eventually(func() bool {
return !suite.balancer.unreachableQueryNodes.Contain(3)
}, 3*time.Second, 100*time.Millisecond)
}, 5*time.Second, 100*time.Millisecond)
}

func TestLookAsideBalancerSuite(t *testing.T) {
Expand Down
22 changes: 0 additions & 22 deletions internal/querynodev2/handlers.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,6 @@ import (
"github.com/milvus-io/milvus/internal/proto/querypb"
"github.com/milvus-io/milvus/internal/querynodev2/delegator"
"github.com/milvus-io/milvus/internal/querynodev2/segments"
"github.com/milvus-io/milvus/internal/querynodev2/tasks"
"github.com/milvus-io/milvus/internal/util"
"github.com/milvus-io/milvus/pkg/common"
"github.com/milvus-io/milvus/pkg/log"
Expand Down Expand Up @@ -227,27 +226,6 @@ func (node *QueryNode) queryChannel(ctx context.Context, req *querypb.QueryReque
return ret, nil
}

func (node *QueryNode) querySegments(ctx context.Context, req *querypb.QueryRequest) (*internalpb.RetrieveResults, error) {
collection := node.manager.Collection.Get(req.Req.GetCollectionID())
if collection == nil {
return nil, merr.WrapErrCollectionNotFound(req.Req.GetCollectionID())
}

// Send task to scheduler and wait until it finished.
task := tasks.NewQueryTask(ctx, collection, node.manager, req)
if err := node.scheduler.Add(task); err != nil {
log.Warn("failed to add query task into scheduler", zap.Error(err))
return nil, err
}
err := task.Wait()
if err != nil {
log.Warn("failed to execute task by node scheduler", zap.Error(err))
return nil, err
}

return task.Result(), nil
}

func (node *QueryNode) optimizeSearchParams(ctx context.Context, req *querypb.SearchRequest, deleg delegator.ShardDelegator) (*querypb.SearchRequest, error) {
// no hook applied, just return
if node.queryHook == nil {
Expand Down
2 changes: 1 addition & 1 deletion internal/querynodev2/local_worker.go
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ func (w *LocalWorker) SearchSegments(ctx context.Context, req *querypb.SearchReq
}

func (w *LocalWorker) QuerySegments(ctx context.Context, req *querypb.QueryRequest) (*internalpb.RetrieveResults, error) {
return w.node.querySegments(ctx, req)
return w.node.QuerySegments(ctx, req)
}

func (w *LocalWorker) GetStatistics(ctx context.Context, req *querypb.GetStatisticsRequest) (*internalpb.GetStatisticsResponse, error) {
Expand Down
7 changes: 2 additions & 5 deletions internal/querynodev2/segments/result.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,11 +75,8 @@ func ReduceSearchResults(ctx context.Context, results []*internalpb.SearchResult
return nil, err
}

requestCosts := lo.FilterMap(results, func(result *internalpb.SearchResults, _ int) (*internalpb.CostAggregation, bool) {
if result.CostAggregation == nil {
return nil, false
}
return result.CostAggregation, true
requestCosts := lo.Map(results, func(result *internalpb.SearchResults, _ int) *internalpb.CostAggregation {
return result.GetCostAggregation()
})
searchResults.CostAggregation = mergeRequestCost(requestCosts)

Expand Down
56 changes: 32 additions & 24 deletions internal/querynodev2/services.go
Original file line number Diff line number Diff line change
Expand Up @@ -742,10 +742,8 @@ func (node *QueryNode) SearchSegments(ctx context.Context, req *querypb.SearchRe
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel, metrics.SuccessLabel, metrics.FromLeader).Inc()

result := task.Result()
if result.CostAggregation != nil {
// update channel's response time
result.CostAggregation.ResponseTime = latency.Milliseconds()
}
result.GetCostAggregation().ResponseTime = tr.ElapseSpan().Milliseconds()
result.GetCostAggregation().TotalNQ = node.scheduler.GetWaitingTaskTotalNQ()
return result, nil
}

Expand All @@ -767,6 +765,8 @@ func (node *QueryNode) Search(ctx context.Context, req *querypb.SearchRequest) (
zap.Uint64("guaranteeTimestamp", req.GetReq().GetGuaranteeTimestamp()),
zap.Uint64("timeTravel", req.GetReq().GetTravelTimestamp()))

tr := timerecord.NewTimeRecorderWithTrace(ctx, "SearchRequest")

if !node.lifetime.Add(commonpbutil.IsHealthy) {
msg := fmt.Sprintf("query node %d is not ready", paramtable.GetNodeID())
err := merr.WrapErrServiceNotReady(msg)
Expand Down Expand Up @@ -844,26 +844,25 @@ func (node *QueryNode) Search(ctx context.Context, req *querypb.SearchRequest) (
return failRet, nil
}

tr := timerecord.NewTimeRecorderWithTrace(ctx, "searchRequestReduce")
tr.RecordSpan()
result, err := segments.ReduceSearchResults(ctx, toReduceResults, req.Req.GetNq(), req.Req.GetTopk(), req.Req.GetMetricType())
if err != nil {
log.Warn("failed to reduce search results", zap.Error(err))
failRet.Status.ErrorCode = commonpb.ErrorCode_UnexpectedError
failRet.Status.Reason = err.Error()
return failRet, nil
}
reduceLatency := tr.RecordSpan()
metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.SearchLabel, metrics.ReduceShards).
Observe(float64(tr.ElapseSpan().Milliseconds()))
Observe(float64(reduceLatency.Milliseconds()))

collector.Rate.Add(metricsinfo.NQPerSecond, float64(req.GetReq().GetNq()))
collector.Rate.Add(metricsinfo.SearchThroughput, float64(proto.Size(req)))
metrics.QueryNodeExecuteCounter.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.SearchLabel).
Add(float64(proto.Size(req)))

if result.CostAggregation != nil {
// update channel's response time
currentTotalNQ := node.scheduler.GetWaitingTaskTotalNQ()
result.CostAggregation.TotalNQ = currentTotalNQ
if result.GetCostAggregation() != nil {
result.GetCostAggregation().ResponseTime = tr.ElapseSpan().Milliseconds()
}
return result, nil
}
Expand Down Expand Up @@ -904,7 +903,18 @@ func (node *QueryNode) QuerySegments(ctx context.Context, req *querypb.QueryRequ
defer cancel()

tr := timerecord.NewTimeRecorder("querySegments")
results, err := node.querySegments(queryCtx, req)
collection := node.manager.Collection.Get(req.Req.GetCollectionID())
if collection == nil {
return nil, merr.WrapErrCollectionNotFound(req.Req.GetCollectionID())
}

// Send task to scheduler and wait until it finished.
task := tasks.NewQueryTask(queryCtx, collection, node.manager, req)
if err := node.scheduler.Add(task); err != nil {
log.Warn("failed to add query task into scheduler", zap.Error(err))
return nil, err
}
err := task.Wait()
if err != nil {
log.Warn("failed to query channel", zap.Error(err))
failRet.Status.Reason = err.Error()
Expand All @@ -923,19 +933,17 @@ func (node *QueryNode) QuerySegments(ctx context.Context, req *querypb.QueryRequ
latency := tr.ElapseSpan()
metrics.QueryNodeSQReqLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.FromLeader).Observe(float64(latency.Milliseconds()))
metrics.QueryNodeSQCount.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.SuccessLabel, metrics.FromLeader).Inc()
results.CostAggregation = &internalpb.CostAggregation{
ServiceTime: latency.Milliseconds(),
ResponseTime: latency.Milliseconds(),
TotalNQ: 0,
}
return results, nil
result := task.Result()
result.GetCostAggregation().ResponseTime = latency.Milliseconds()
result.GetCostAggregation().TotalNQ = node.scheduler.GetWaitingTaskTotalNQ()
return result, nil
}

// Query performs replica query tasks.
func (node *QueryNode) Query(ctx context.Context, req *querypb.QueryRequest) (*internalpb.RetrieveResults, error) {
if req.FromShardLeader {
// for compatible with rolling upgrade from version before v2.2.9
return node.querySegments(ctx, req)
return node.QuerySegments(ctx, req)
}

log := log.Ctx(ctx).With(
Expand All @@ -950,6 +958,7 @@ func (node *QueryNode) Query(ctx context.Context, req *querypb.QueryRequest) (*i
zap.Uint64("travelTimestamp", req.GetReq().GetTravelTimestamp()),
zap.Bool("isCount", req.GetReq().GetIsCount()),
)
tr := timerecord.NewTimeRecorderWithTrace(ctx, "QueryRequest")

if !node.lifetime.Add(commonpbutil.IsHealthy) {
msg := fmt.Sprintf("query node %d is not ready", paramtable.GetNodeID())
Expand Down Expand Up @@ -1000,24 +1009,23 @@ func (node *QueryNode) Query(ctx context.Context, req *querypb.QueryRequest) (*i
return WrapRetrieveResult(commonpb.ErrorCode_UnexpectedError, "failed to query channel", err), nil
}

tr := timerecord.NewTimeRecorderWithTrace(ctx, "queryRequestReduce")
tr.RecordSpan()
reducer := segments.CreateInternalReducer(req, node.manager.Collection.Get(req.GetReq().GetCollectionID()).Schema())
ret, err := reducer.Reduce(ctx, toMergeResults)
if err != nil {
return WrapRetrieveResult(commonpb.ErrorCode_UnexpectedError, "failed to query channel", err), nil
}
reduceLatency := tr.RecordSpan()
metrics.QueryNodeReduceLatency.WithLabelValues(fmt.Sprint(paramtable.GetNodeID()), metrics.QueryLabel, metrics.ReduceShards).
Observe(float64(tr.ElapseSpan().Milliseconds()))
Observe(float64(reduceLatency.Milliseconds()))

if !req.FromShardLeader {
collector.Rate.Add(metricsinfo.NQPerSecond, 1)
metrics.QueryNodeExecuteCounter.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.QueryLabel).Add(float64(proto.Size(req)))
}

if ret.CostAggregation != nil {
// update channel's response time
currentTotalNQ := node.scheduler.GetWaitingTaskTotalNQ()
ret.CostAggregation.TotalNQ = currentTotalNQ
if ret.GetCostAggregation() != nil {
ret.GetCostAggregation().ResponseTime = tr.ElapseSpan().Milliseconds()
}
return ret, nil
}
Expand Down
5 changes: 5 additions & 0 deletions internal/querynodev2/tasks/query_task.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,8 @@ func (t *QueryTask) PreExecute() error {

// Execute the task, only call once.
func (t *QueryTask) Execute() error {
tr := timerecord.NewTimeRecorderWithTrace(t.ctx, "QueryTask")

retrievePlan, err := segments.NewRetrievePlan(
t.collection,
t.req.Req.GetSerializedExprPlan(),
Expand Down Expand Up @@ -124,6 +126,9 @@ func (t *QueryTask) Execute() error {
Status: &commonpb.Status{ErrorCode: commonpb.ErrorCode_Success},
Ids: reducedResult.Ids,
FieldsData: reducedResult.FieldsData,
CostAggregation: &internalpb.CostAggregation{
ServiceTime: tr.ElapseSpan().Milliseconds(),
},
}
return nil
}
Expand Down
Loading

0 comments on commit 302897f

Please sign in to comment.