From 8977454311fe4224ccfe42c2cafd2edfdb3ed0bd Mon Sep 17 00:00:00 2001 From: Gao Date: Wed, 11 Dec 2024 20:40:48 +0800 Subject: [PATCH] enhance: support recall estimation (#38017) issue: #37899 Only `search` api will be supported --------- Signed-off-by: chasingegg --- internal/proto/internal.proto | 2 + internal/proxy/impl.go | 59 ++++--- internal/proxy/task_search.go | 6 + internal/proxy/util.go | 69 ++++++++ internal/proxy/util_test.go | 162 ++++++++++++++++++ internal/querynodev2/segments/result.go | 5 + internal/querynodev2/services.go | 1 + internal/querynodev2/services_test.go | 20 ++- .../util/searchutil/optimizers/query_hook.go | 9 +- .../searchutil/optimizers/query_hook_test.go | 23 +-- pkg/common/common.go | 1 + pkg/metrics/proxy_metrics.go | 15 ++ 12 files changed, 331 insertions(+), 41 deletions(-) diff --git a/internal/proto/internal.proto b/internal/proto/internal.proto index d131d63b2cce7..f61d4f7acbab8 100644 --- a/internal/proto/internal.proto +++ b/internal/proto/internal.proto @@ -128,6 +128,7 @@ message SearchRequest { int64 group_size = 24; int64 field_id = 25; bool is_topk_reduce = 26; + bool is_recall_evaluation = 27; } message SubSearchResults { @@ -164,6 +165,7 @@ message SearchResults { bool is_advanced = 16; int64 all_search_count = 17; bool is_topk_reduce = 18; + bool is_recall_evaluation = 19; } message CostAggregation { diff --git a/internal/proxy/impl.go b/internal/proxy/impl.go index 2a0b41166effe..b51180301bc59 100644 --- a/internal/proxy/impl.go +++ b/internal/proxy/impl.go @@ -3000,12 +3000,13 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest) optimizedSearch := true resultSizeInsufficient := false isTopkReduce := false + isRecallEvaluation := false err2 := retry.Handle(ctx, func() (bool, error) { - rsp, resultSizeInsufficient, isTopkReduce, err = node.search(ctx, request, optimizedSearch) + rsp, resultSizeInsufficient, isTopkReduce, isRecallEvaluation, err = node.search(ctx, request, optimizedSearch, false) if merr.Ok(rsp.GetStatus()) && optimizedSearch && resultSizeInsufficient && isTopkReduce && paramtable.Get().AutoIndexConfig.EnableResultLimitCheck.GetAsBool() { // without optimize search optimizedSearch = false - rsp, resultSizeInsufficient, isTopkReduce, err = node.search(ctx, request, optimizedSearch) + rsp, resultSizeInsufficient, isTopkReduce, isRecallEvaluation, err = node.search(ctx, request, optimizedSearch, false) metrics.ProxyRetrySearchCount.WithLabelValues( strconv.FormatInt(paramtable.GetNodeID(), 10), metrics.SearchLabel, @@ -3023,6 +3024,23 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest) if errors.Is(merr.Error(rsp.GetStatus()), merr.ErrInconsistentRequery) { return true, merr.Error(rsp.GetStatus()) } + // search for ground truth and compute recall + if isRecallEvaluation && merr.Ok(rsp.GetStatus()) { + var rspGT *milvuspb.SearchResults + rspGT, _, _, _, err = node.search(ctx, request, false, true) + metrics.ProxyRecallSearchCount.WithLabelValues( + strconv.FormatInt(paramtable.GetNodeID(), 10), + metrics.SearchLabel, + request.GetCollectionName(), + ).Inc() + if merr.Ok(rspGT.GetStatus()) { + return false, computeRecall(rsp.GetResults(), rspGT.GetResults()) + } + if errors.Is(merr.Error(rspGT.GetStatus()), merr.ErrInconsistentRequery) { + return true, merr.Error(rspGT.GetStatus()) + } + return false, merr.Error(rspGT.GetStatus()) + } return false, nil }) if err2 != nil { @@ -3031,13 +3049,11 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest) return rsp, err } -func (node *Proxy) search(ctx context.Context, request *milvuspb.SearchRequest, optimizedSearch bool) (*milvuspb.SearchResults, bool, bool, error) { - receiveSize := proto.Size(request) - metrics.ProxyReceiveBytes.WithLabelValues( - strconv.FormatInt(paramtable.GetNodeID(), 10), - metrics.SearchLabel, - request.GetCollectionName(), - ).Add(float64(receiveSize)) +func (node *Proxy) search(ctx context.Context, request *milvuspb.SearchRequest, optimizedSearch bool, isRecallEvaluation bool) (*milvuspb.SearchResults, bool, bool, bool, error) { + metrics.GetStats(ctx). + SetNodeID(paramtable.GetNodeID()). + SetInboundLabel(metrics.SearchLabel). + SetCollectionName(request.GetCollectionName()) metrics.ProxyReceivedNQ.WithLabelValues( strconv.FormatInt(paramtable.GetNodeID(), 10), @@ -3048,7 +3064,7 @@ func (node *Proxy) search(ctx context.Context, request *milvuspb.SearchRequest, if err := merr.CheckHealthy(node.GetStateCode()); err != nil { return &milvuspb.SearchResults{ Status: merr.Status(err), - }, false, false, nil + }, false, false, false, nil } method := "Search" @@ -3069,7 +3085,7 @@ func (node *Proxy) search(ctx context.Context, request *milvuspb.SearchRequest, if err != nil { return &milvuspb.SearchResults{ Status: merr.Status(err), - }, false, false, nil + }, false, false, false, nil } request.PlaceholderGroup = placeholderGroupBytes @@ -3083,8 +3099,9 @@ func (node *Proxy) search(ctx context.Context, request *milvuspb.SearchRequest, commonpbutil.WithMsgType(commonpb.MsgType_Search), commonpbutil.WithSourceID(paramtable.GetNodeID()), ), - ReqID: paramtable.GetNodeID(), - IsTopkReduce: optimizedSearch, + ReqID: paramtable.GetNodeID(), + IsTopkReduce: optimizedSearch, + IsRecallEvaluation: isRecallEvaluation, }, request: request, tr: timerecord.NewTimeRecorder("search"), @@ -3146,7 +3163,7 @@ func (node *Proxy) search(ctx context.Context, request *milvuspb.SearchRequest, return &milvuspb.SearchResults{ Status: merr.Status(err), - }, false, false, nil + }, false, false, false, nil } tr.CtxRecord(ctx, "search request enqueue") @@ -3172,7 +3189,7 @@ func (node *Proxy) search(ctx context.Context, request *milvuspb.SearchRequest, return &milvuspb.SearchResults{ Status: merr.Status(err), - }, false, false, nil + }, false, false, false, nil } span := tr.CtxRecord(ctx, "wait search result") @@ -3229,7 +3246,7 @@ func (node *Proxy) search(ctx context.Context, request *milvuspb.SearchRequest, metrics.ProxyReportValue.WithLabelValues(nodeID, hookutil.OpTypeSearch, dbName, username).Add(float64(v)) } } - return qt.result, qt.resultSizeInsufficient, qt.isTopkReduce, nil + return qt.result, qt.resultSizeInsufficient, qt.isTopkReduce, qt.isRecallEvaluation, nil } func (node *Proxy) HybridSearch(ctx context.Context, request *milvuspb.HybridSearchRequest) (*milvuspb.SearchResults, error) { @@ -3272,12 +3289,10 @@ func (node *Proxy) HybridSearch(ctx context.Context, request *milvuspb.HybridSea } func (node *Proxy) hybridSearch(ctx context.Context, request *milvuspb.HybridSearchRequest, optimizedSearch bool) (*milvuspb.SearchResults, bool, bool, error) { - receiveSize := proto.Size(request) - metrics.ProxyReceiveBytes.WithLabelValues( - strconv.FormatInt(paramtable.GetNodeID(), 10), - metrics.HybridSearchLabel, - request.GetCollectionName(), - ).Add(float64(receiveSize)) + metrics.GetStats(ctx). + SetNodeID(paramtable.GetNodeID()). + SetInboundLabel(metrics.HybridSearchLabel). + SetCollectionName(request.GetCollectionName()) if err := merr.CheckHealthy(node.GetStateCode()); err != nil { return &milvuspb.SearchResults{ diff --git a/internal/proxy/task_search.go b/internal/proxy/task_search.go index 3dc48cfe9503c..5c94bdebb11e8 100644 --- a/internal/proxy/task_search.go +++ b/internal/proxy/task_search.go @@ -65,6 +65,7 @@ type searchTask struct { mustUsePartitionKey bool resultSizeInsufficient bool isTopkReduce bool + isRecallEvaluation bool userOutputFields []string userDynamicFields []string @@ -647,10 +648,14 @@ func (t *searchTask) PostExecute(ctx context.Context) error { t.queryChannelsTs = make(map[string]uint64) t.relatedDataSize = 0 isTopkReduce := false + isRecallEvaluation := false for _, r := range toReduceResults { if r.GetIsTopkReduce() { isTopkReduce = true } + if r.GetIsRecallEvaluation() { + isRecallEvaluation = true + } t.relatedDataSize += r.GetCostAggregation().GetTotalRelatedDataSize() for ch, ts := range r.GetChannelsMvcc() { t.queryChannelsTs[ch] = ts @@ -731,6 +736,7 @@ func (t *searchTask) PostExecute(ctx context.Context) error { } t.resultSizeInsufficient = resultSizeInsufficient t.isTopkReduce = isTopkReduce + t.isRecallEvaluation = isRecallEvaluation t.result.CollectionName = t.collectionName t.fillInFieldInfo() diff --git a/internal/proxy/util.go b/internal/proxy/util.go index 15f4f3649096b..e1a318dbf9312 100644 --- a/internal/proxy/util.go +++ b/internal/proxy/util.go @@ -1253,6 +1253,75 @@ func translatePkOutputFields(schema *schemapb.CollectionSchema) ([]string, []int return pkNames, fieldIDs } +func recallCal[T string | int64](results []T, gts []T) float32 { + hit := 0 + total := 0 + for _, r := range results { + total++ + for _, gt := range gts { + if r == gt { + hit++ + break + } + } + } + return float32(hit) / float32(total) +} + +func computeRecall(results *schemapb.SearchResultData, gts *schemapb.SearchResultData) error { + if results.GetNumQueries() != gts.GetNumQueries() { + return fmt.Errorf("num of queries is inconsistent between search results(%d) and ground truth(%d)", results.GetNumQueries(), gts.GetNumQueries()) + } + + switch results.GetIds().GetIdField().(type) { + case *schemapb.IDs_IntId: + switch gts.GetIds().GetIdField().(type) { + case *schemapb.IDs_IntId: + currentResultIndex := int64(0) + currentGTIndex := int64(0) + recalls := make([]float32, 0, results.GetNumQueries()) + for i := 0; i < int(results.GetNumQueries()); i++ { + currentResultTopk := results.GetTopks()[i] + currentGTTopk := gts.GetTopks()[i] + recalls = append(recalls, recallCal(results.GetIds().GetIntId().GetData()[currentResultIndex:currentResultIndex+currentResultTopk], + gts.GetIds().GetIntId().GetData()[currentGTIndex:currentGTIndex+currentGTTopk])) + currentResultIndex += currentResultTopk + currentGTIndex += currentGTTopk + } + results.Recalls = recalls + return nil + case *schemapb.IDs_StrId: + return fmt.Errorf("pk type is inconsistent between search results(int64) and ground truth(string)") + default: + return fmt.Errorf("unsupported pk type") + } + + case *schemapb.IDs_StrId: + switch gts.GetIds().GetIdField().(type) { + case *schemapb.IDs_StrId: + currentResultIndex := int64(0) + currentGTIndex := int64(0) + recalls := make([]float32, 0, results.GetNumQueries()) + for i := 0; i < int(results.GetNumQueries()); i++ { + currentResultTopk := results.GetTopks()[i] + currentGTTopk := gts.GetTopks()[i] + recalls = append(recalls, recallCal(results.GetIds().GetStrId().GetData()[currentResultIndex:currentResultIndex+currentResultTopk], + gts.GetIds().GetStrId().GetData()[currentGTIndex:currentGTIndex+currentGTTopk])) + currentResultIndex += currentResultTopk + currentGTIndex += currentGTTopk + } + results.Recalls = recalls + return nil + case *schemapb.IDs_IntId: + return fmt.Errorf("pk type is inconsistent between search results(string) and ground truth(int64)") + default: + return fmt.Errorf("unsupported pk type") + } + default: + return fmt.Errorf("unsupported pk type") + } +} + // Support wildcard in output fields: // // "*" - all fields diff --git a/internal/proxy/util_test.go b/internal/proxy/util_test.go index 5065f65bd16de..93bef5a1b0af0 100644 --- a/internal/proxy/util_test.go +++ b/internal/proxy/util_test.go @@ -3079,3 +3079,165 @@ func TestValidateFunctionBasicParams(t *testing.T) { assert.Error(t, err) }) } + +func TestComputeRecall(t *testing.T) { + t.Run("normal case1", func(t *testing.T) { + result1 := &schemapb.SearchResultData{ + NumQueries: 3, + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_StrId{ + StrId: &schemapb.StringArray{ + Data: []string{"11", "9", "8", "5", "3", "1"}, + }, + }, + }, + Scores: []float32{1.1, 0.9, 0.8, 0.5, 0.3, 0.1}, + Topks: []int64{2, 2, 2}, + } + + gt := &schemapb.SearchResultData{ + NumQueries: 3, + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_StrId{ + StrId: &schemapb.StringArray{ + Data: []string{"11", "10", "8", "5", "3", "1"}, + }, + }, + }, + Scores: []float32{1.1, 0.98, 0.8, 0.5, 0.3, 0.1}, + Topks: []int64{2, 2, 2}, + } + + err := computeRecall(result1, gt) + assert.NoError(t, err) + assert.Equal(t, result1.Recalls[0], float32(0.5)) + assert.Equal(t, result1.Recalls[1], float32(1.0)) + assert.Equal(t, result1.Recalls[2], float32(1.0)) + }) + + t.Run("normal case2", func(t *testing.T) { + result1 := &schemapb.SearchResultData{ + NumQueries: 2, + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: []int64{11, 9, 8, 5, 3, 1, 34, 23, 22, 21}, + }, + }, + }, + Scores: []float32{1.1, 0.9, 0.8, 0.5, 0.3, 0.8, 0.7, 0.6, 0.5, 0.4}, + Topks: []int64{5, 5}, + } + + gt := &schemapb.SearchResultData{ + NumQueries: 2, + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: []int64{11, 9, 6, 5, 4, 1, 34, 23, 22, 20}, + }, + }, + }, + Scores: []float32{1.1, 0.9, 0.8, 0.5, 0.3, 0.8, 0.7, 0.6, 0.5, 0.4}, + Topks: []int64{5, 5}, + } + + err := computeRecall(result1, gt) + assert.NoError(t, err) + assert.Equal(t, result1.Recalls[0], float32(0.6)) + assert.Equal(t, result1.Recalls[1], float32(0.8)) + }) + + t.Run("not match size", func(t *testing.T) { + result1 := &schemapb.SearchResultData{ + NumQueries: 2, + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: []int64{11, 9, 8, 5, 3, 1, 34, 23, 22, 21}, + }, + }, + }, + Scores: []float32{1.1, 0.9, 0.8, 0.5, 0.3, 0.8, 0.7, 0.6, 0.5, 0.4}, + Topks: []int64{5, 5}, + } + + gt := &schemapb.SearchResultData{ + NumQueries: 1, + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: []int64{11, 9, 6, 5, 4}, + }, + }, + }, + Scores: []float32{1.1, 0.9, 0.8, 0.5, 0.3}, + Topks: []int64{5}, + } + + err := computeRecall(result1, gt) + assert.Error(t, err) + }) + + t.Run("not match type1", func(t *testing.T) { + result1 := &schemapb.SearchResultData{ + NumQueries: 2, + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: []int64{11, 9, 8, 5, 3, 1, 34, 23, 22, 21}, + }, + }, + }, + Scores: []float32{1.1, 0.9, 0.8, 0.5, 0.3, 0.8, 0.7, 0.6, 0.5, 0.4}, + Topks: []int64{5, 5}, + } + + gt := &schemapb.SearchResultData{ + NumQueries: 2, + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_StrId{ + StrId: &schemapb.StringArray{ + Data: []string{"11", "10", "8", "5", "3", "1", "23", "22", "21", "20"}, + }, + }, + }, + Scores: []float32{1.1, 0.9, 0.8, 0.5, 0.3, 0.8, 0.7, 0.6, 0.5, 0.4}, + Topks: []int64{5, 5}, + } + + err := computeRecall(result1, gt) + assert.Error(t, err) + }) + + t.Run("not match type2", func(t *testing.T) { + result1 := &schemapb.SearchResultData{ + NumQueries: 2, + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_StrId{ + StrId: &schemapb.StringArray{ + Data: []string{"11", "10", "8", "5", "3", "1", "23", "22", "21", "20"}, + }, + }, + }, + Scores: []float32{1.1, 0.9, 0.8, 0.5, 0.3, 0.8, 0.7, 0.6, 0.5, 0.4}, + Topks: []int64{5, 5}, + } + + gt := &schemapb.SearchResultData{ + NumQueries: 2, + Ids: &schemapb.IDs{ + IdField: &schemapb.IDs_IntId{ + IntId: &schemapb.LongArray{ + Data: []int64{11, 9, 8, 5, 3, 1, 34, 23, 22, 21}, + }, + }, + }, + Scores: []float32{1.1, 0.9, 0.8, 0.5, 0.3, 0.8, 0.7, 0.6, 0.5, 0.4}, + Topks: []int64{5, 5}, + } + + err := computeRecall(result1, gt) + assert.Error(t, err) + }) +} diff --git a/internal/querynodev2/segments/result.go b/internal/querynodev2/segments/result.go index 6eaaab1ad6715..bbe88aa34fbf3 100644 --- a/internal/querynodev2/segments/result.go +++ b/internal/querynodev2/segments/result.go @@ -65,6 +65,7 @@ func ReduceSearchResults(ctx context.Context, results []*internalpb.SearchResult channelsMvcc := make(map[string]uint64) isTopkReduce := false + isRecallEvaluation := false for _, r := range results { for ch, ts := range r.GetChannelsMvcc() { channelsMvcc[ch] = ts @@ -72,6 +73,9 @@ func ReduceSearchResults(ctx context.Context, results []*internalpb.SearchResult if r.GetIsTopkReduce() { isTopkReduce = true } + if r.GetIsRecallEvaluation() { + isRecallEvaluation = true + } // shouldn't let new SearchResults.MetricType to be empty, though the req.MetricType is empty if info.GetMetricType() == "" { info.SetMetricType(r.MetricType) @@ -126,6 +130,7 @@ func ReduceSearchResults(ctx context.Context, results []*internalpb.SearchResult searchResults.CostAggregation.TotalRelatedDataSize = relatedDataSize searchResults.ChannelsMvcc = channelsMvcc searchResults.IsTopkReduce = isTopkReduce + searchResults.IsRecallEvaluation = isRecallEvaluation return searchResults, nil } diff --git a/internal/querynodev2/services.go b/internal/querynodev2/services.go index 35a1b4fa89878..a998417a33a10 100644 --- a/internal/querynodev2/services.go +++ b/internal/querynodev2/services.go @@ -733,6 +733,7 @@ func (node *QueryNode) SearchSegments(ctx context.Context, req *querypb.SearchRe if req.GetReq().GetIsTopkReduce() { resp.IsTopkReduce = true } + resp.IsRecallEvaluation = req.GetReq().GetIsRecallEvaluation() return resp, nil } diff --git a/internal/querynodev2/services_test.go b/internal/querynodev2/services_test.go index 98ca580086f82..aa59ddda43231 100644 --- a/internal/querynodev2/services_test.go +++ b/internal/querynodev2/services_test.go @@ -1177,7 +1177,7 @@ func (suite *ServiceSuite) syncDistribution(ctx context.Context) { } // Test Search -func (suite *ServiceSuite) genCSearchRequest(nq int64, dataType schemapb.DataType, fieldID int64, metricType string, isTopkReduce bool) (*internalpb.SearchRequest, error) { +func (suite *ServiceSuite) genCSearchRequest(nq int64, dataType schemapb.DataType, fieldID int64, metricType string, isTopkReduce bool, isRecallEvaluation bool) (*internalpb.SearchRequest, error) { placeHolder, err := genPlaceHolderGroup(nq) if err != nil { return nil, err @@ -1202,6 +1202,7 @@ func (suite *ServiceSuite) genCSearchRequest(nq int64, dataType schemapb.DataTyp Nq: nq, MvccTimestamp: typeutil.MaxTimestamp, IsTopkReduce: isTopkReduce, + IsRecallEvaluation: isRecallEvaluation, }, nil } @@ -1212,7 +1213,7 @@ func (suite *ServiceSuite) TestSearch_Normal() { suite.TestLoadSegments_Int64() suite.syncDistribution(ctx) - creq, err := suite.genCSearchRequest(10, schemapb.DataType_FloatVector, 107, defaultMetricType, false) + creq, err := suite.genCSearchRequest(10, schemapb.DataType_FloatVector, 107, defaultMetricType, false, false) req := &querypb.SearchRequest{ Req: creq, @@ -1237,7 +1238,7 @@ func (suite *ServiceSuite) TestSearch_Concurrent() { futures := make([]*conc.Future[*internalpb.SearchResults], 0, concurrency) for i := 0; i < concurrency; i++ { future := conc.Go(func() (*internalpb.SearchResults, error) { - creq, err := suite.genCSearchRequest(30, schemapb.DataType_FloatVector, 107, defaultMetricType, false) + creq, err := suite.genCSearchRequest(30, schemapb.DataType_FloatVector, 107, defaultMetricType, false, false) req := &querypb.SearchRequest{ Req: creq, @@ -1263,7 +1264,7 @@ func (suite *ServiceSuite) TestSearch_Failed() { // data schema := mock_segcore.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false) - creq, err := suite.genCSearchRequest(10, schemapb.DataType_FloatVector, 107, "invalidMetricType", false) + creq, err := suite.genCSearchRequest(10, schemapb.DataType_FloatVector, 107, "invalidMetricType", false, false) req := &querypb.SearchRequest{ Req: creq, @@ -1388,7 +1389,7 @@ func (suite *ServiceSuite) TestSearchSegments_Normal() { suite.TestWatchDmChannelsInt64() suite.TestLoadSegments_Int64() - creq, err := suite.genCSearchRequest(10, schemapb.DataType_FloatVector, 107, defaultMetricType, false) + creq, err := suite.genCSearchRequest(10, schemapb.DataType_FloatVector, 107, defaultMetricType, false, false) req := &querypb.SearchRequest{ Req: creq, @@ -1400,13 +1401,15 @@ func (suite *ServiceSuite) TestSearchSegments_Normal() { rsp, err := suite.node.SearchSegments(ctx, req) suite.NoError(err) suite.Equal(rsp.GetIsTopkReduce(), false) + suite.Equal(rsp.GetIsRecallEvaluation(), false) suite.Equal(commonpb.ErrorCode_Success, rsp.GetStatus().GetErrorCode()) - req.Req, err = suite.genCSearchRequest(10, schemapb.DataType_FloatVector, 107, defaultMetricType, true) + req.Req, err = suite.genCSearchRequest(10, schemapb.DataType_FloatVector, 107, defaultMetricType, true, true) suite.NoError(err) rsp, err = suite.node.SearchSegments(ctx, req) suite.NoError(err) suite.Equal(rsp.GetIsTopkReduce(), true) + suite.Equal(rsp.GetIsRecallEvaluation(), true) suite.Equal(commonpb.ErrorCode_Success, rsp.GetStatus().GetErrorCode()) } @@ -1416,7 +1419,7 @@ func (suite *ServiceSuite) TestStreamingSearch() { suite.TestWatchDmChannelsInt64() suite.TestLoadSegments_Int64() paramtable.Get().Save(paramtable.Get().QueryNodeCfg.UseStreamComputing.Key, "true") - creq, err := suite.genCSearchRequest(10, schemapb.DataType_FloatVector, 107, defaultMetricType, false) + creq, err := suite.genCSearchRequest(10, schemapb.DataType_FloatVector, 107, defaultMetricType, false, true) req := &querypb.SearchRequest{ Req: creq, FromShardLeader: true, @@ -1430,6 +1433,7 @@ func (suite *ServiceSuite) TestStreamingSearch() { rsp, err := suite.node.SearchSegments(ctx, req) suite.NoError(err) suite.Equal(false, rsp.GetIsTopkReduce()) + suite.Equal(true, rsp.GetIsRecallEvaluation()) suite.Equal(commonpb.ErrorCode_Success, rsp.GetStatus().GetErrorCode()) } @@ -1438,7 +1442,7 @@ func (suite *ServiceSuite) TestStreamingSearchGrowing() { // pre suite.TestWatchDmChannelsInt64() paramtable.Get().Save(paramtable.Get().QueryNodeCfg.UseStreamComputing.Key, "true") - creq, err := suite.genCSearchRequest(10, schemapb.DataType_FloatVector, 107, defaultMetricType, false) + creq, err := suite.genCSearchRequest(10, schemapb.DataType_FloatVector, 107, defaultMetricType, false, false) req := &querypb.SearchRequest{ Req: creq, FromShardLeader: true, diff --git a/internal/util/searchutil/optimizers/query_hook.go b/internal/util/searchutil/optimizers/query_hook.go index 2a00d206866d3..7ed96cabb38a0 100644 --- a/internal/util/searchutil/optimizers/query_hook.go +++ b/internal/util/searchutil/optimizers/query_hook.go @@ -28,6 +28,7 @@ func OptimizeSearchParams(ctx context.Context, req *querypb.SearchRequest, query // no hook applied or disabled, just return if queryHook == nil || !paramtable.Get().AutoIndexConfig.Enable.GetAsBool() { req.Req.IsTopkReduce = false + req.Req.IsRecallEvaluation = false return req, nil } @@ -68,8 +69,9 @@ func OptimizeSearchParams(ctx context.Context, req *querypb.SearchRequest, query common.SegmentNumKey: estSegmentNum, common.WithFilterKey: withFilter, common.DataTypeKey: int32(plan.GetVectorAnns().GetVectorType()), - common.WithOptimizeKey: paramtable.Get().AutoIndexConfig.EnableOptimize.GetAsBool() && req.GetReq().GetIsTopkReduce(), + common.WithOptimizeKey: paramtable.Get().AutoIndexConfig.EnableOptimize.GetAsBool() && req.GetReq().GetIsTopkReduce() && queryInfo.GetGroupByFieldId() < 0, common.CollectionKey: req.GetReq().GetCollectionID(), + common.RecallEvalKey: req.GetReq().GetIsRecallEvaluation(), } if withFilter && channelNum > 1 { params[common.ChannelNumKey] = channelNum @@ -90,6 +92,11 @@ func OptimizeSearchParams(ctx context.Context, req *querypb.SearchRequest, query } req.Req.SerializedExprPlan = serializedExprPlan req.Req.IsTopkReduce = isTopkReduce + if isRecallEvaluation, ok := params[common.RecallEvalKey]; ok { + req.Req.IsRecallEvaluation = isRecallEvaluation.(bool) && queryInfo.GetGroupByFieldId() < 0 + } else { + req.Req.IsRecallEvaluation = false + } log.Debug("optimized search params done", zap.Any("queryInfo", queryInfo)) default: log.Warn("not supported node type", zap.String("nodeType", fmt.Sprintf("%T", plan.GetNode()))) diff --git a/internal/util/searchutil/optimizers/query_hook_test.go b/internal/util/searchutil/optimizers/query_hook_test.go index 6dc78bfeac403..a0abd38f7dff0 100644 --- a/internal/util/searchutil/optimizers/query_hook_test.go +++ b/internal/util/searchutil/optimizers/query_hook_test.go @@ -41,6 +41,7 @@ func (suite *QueryHookSuite) TestOptimizeSearchParam() { mockHook.EXPECT().Run(mock.Anything).Run(func(params map[string]any) { params[common.TopKKey] = int64(50) params[common.SearchParamKey] = `{"param": 2}` + params[common.RecallEvalKey] = true }).Return(nil) suite.queryHook = mockHook defer func() { @@ -48,20 +49,21 @@ func (suite *QueryHookSuite) TestOptimizeSearchParam() { suite.queryHook = nil }() - getPlan := func(topk int64) *planpb.PlanNode { + getPlan := func(topk int64, groupByField int64) *planpb.PlanNode { return &planpb.PlanNode{ Node: &planpb.PlanNode_VectorAnns{ VectorAnns: &planpb.VectorANNS{ QueryInfo: &planpb.QueryInfo{ - Topk: topk, - SearchParams: `{"param": 1}`, + Topk: topk, + SearchParams: `{"param": 1}`, + GroupByFieldId: groupByField, }, }, }, } } - bs, err := proto.Marshal(getPlan(100)) + bs, err := proto.Marshal(getPlan(100, 101)) suite.Require().NoError(err) req, err := OptimizeSearchParams(ctx, &querypb.SearchRequest{ @@ -72,9 +74,9 @@ func (suite *QueryHookSuite) TestOptimizeSearchParam() { TotalChannelNum: 2, }, suite.queryHook, 2) suite.NoError(err) - suite.verifyQueryInfo(req, 50, true, `{"param": 2}`) + suite.verifyQueryInfo(req, 50, true, false, `{"param": 2}`) - bs, err = proto.Marshal(getPlan(50)) + bs, err = proto.Marshal(getPlan(50, -1)) suite.Require().NoError(err) req, err = OptimizeSearchParams(ctx, &querypb.SearchRequest{ Req: &internalpb.SearchRequest{ @@ -84,7 +86,7 @@ func (suite *QueryHookSuite) TestOptimizeSearchParam() { TotalChannelNum: 2, }, suite.queryHook, 2) suite.NoError(err) - suite.verifyQueryInfo(req, 50, false, `{"param": 2}`) + suite.verifyQueryInfo(req, 50, false, true, `{"param": 2}`) }) suite.Run("disable optimization", func() { @@ -112,7 +114,7 @@ func (suite *QueryHookSuite) TestOptimizeSearchParam() { TotalChannelNum: 2, }, suite.queryHook, 2) suite.NoError(err) - suite.verifyQueryInfo(req, 100, false, `{"param": 1}`) + suite.verifyQueryInfo(req, 100, false, false, `{"param": 1}`) }) suite.Run("no_hook", func() { @@ -140,7 +142,7 @@ func (suite *QueryHookSuite) TestOptimizeSearchParam() { TotalChannelNum: 2, }, suite.queryHook, 2) suite.NoError(err) - suite.verifyQueryInfo(req, 100, false, `{"param": 1}`) + suite.verifyQueryInfo(req, 100, false, false, `{"param": 1}`) }) suite.Run("other_plannode", func() { @@ -221,7 +223,7 @@ func (suite *QueryHookSuite) TestOptimizeSearchParam() { }) } -func (suite *QueryHookSuite) verifyQueryInfo(req *querypb.SearchRequest, topK int64, isTopkReduce bool, param string) { +func (suite *QueryHookSuite) verifyQueryInfo(req *querypb.SearchRequest, topK int64, isTopkReduce bool, isRecallEvaluation bool, param string) { planBytes := req.GetReq().GetSerializedExprPlan() plan := planpb.PlanNode{} @@ -232,6 +234,7 @@ func (suite *QueryHookSuite) verifyQueryInfo(req *querypb.SearchRequest, topK in suite.Equal(topK, queryInfo.GetTopk()) suite.Equal(param, queryInfo.GetSearchParams()) suite.Equal(isTopkReduce, req.GetReq().GetIsTopkReduce()) + suite.Equal(isRecallEvaluation, req.GetReq().GetIsRecallEvaluation()) } func TestOptimizeSearchParam(t *testing.T) { diff --git a/pkg/common/common.go b/pkg/common/common.go index d32ab48c6fe26..e40c5825db10c 100644 --- a/pkg/common/common.go +++ b/pkg/common/common.go @@ -123,6 +123,7 @@ const ( ChannelNumKey = "channel_num" WithOptimizeKey = "with_optimize" CollectionKey = "collection" + RecallEvalKey = "recall_eval" IndexParamsKey = "params" IndexTypeKey = "index_type" diff --git a/pkg/metrics/proxy_metrics.go b/pkg/metrics/proxy_metrics.go index 45a8a1ec2e260..1c28fbb46607a 100644 --- a/pkg/metrics/proxy_metrics.go +++ b/pkg/metrics/proxy_metrics.go @@ -408,6 +408,15 @@ var ( Name: "retry_search_result_insufficient_cnt", Help: "counter of retry search which does not have enough results", }, []string{nodeIDLabelName, queryTypeLabelName, collectionName}) + + // ProxyRecallSearchCount records the counter that users issue recall evaluation requests, which are cpu-intensive + ProxyRecallSearchCount = prometheus.NewCounterVec( + prometheus.CounterOpts{ + Namespace: milvusNamespace, + Subsystem: typeutil.ProxyRole, + Name: "recall_search_cnt", + Help: "counter of recall search", + }, []string{nodeIDLabelName, queryTypeLabelName, collectionName}) ) // RegisterProxy registers Proxy metrics @@ -468,6 +477,7 @@ func RegisterProxy(registry *prometheus.Registry) { registry.MustRegister(MaxInsertRate) registry.MustRegister(ProxyRetrySearchCount) registry.MustRegister(ProxyRetrySearchResultInsufficientCount) + registry.MustRegister(ProxyRecallSearchCount) RegisterStreamingServiceClient(registry) } @@ -593,4 +603,9 @@ func CleanupProxyCollectionMetrics(nodeID int64, collection string) { queryTypeLabelName: HybridSearchLabel, collectionName: collection, }) + ProxyRecallSearchCount.Delete(prometheus.Labels{ + nodeIDLabelName: strconv.FormatInt(nodeID, 10), + queryTypeLabelName: SearchLabel, + collectionName: collection, + }) }