Skip to content

Commit

Permalink
enhance: support recall estimation (#38017)
Browse files Browse the repository at this point in the history
issue: #37899 
Only `search` api will be supported

---------

Signed-off-by: chasingegg <[email protected]>
  • Loading branch information
chasingegg authored Dec 11, 2024
1 parent dc85d8e commit 8977454
Show file tree
Hide file tree
Showing 12 changed files with 331 additions and 41 deletions.
2 changes: 2 additions & 0 deletions internal/proto/internal.proto
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ message SearchRequest {
int64 group_size = 24;
int64 field_id = 25;
bool is_topk_reduce = 26;
bool is_recall_evaluation = 27;
}

message SubSearchResults {
Expand Down Expand Up @@ -164,6 +165,7 @@ message SearchResults {
bool is_advanced = 16;
int64 all_search_count = 17;
bool is_topk_reduce = 18;
bool is_recall_evaluation = 19;
}

message CostAggregation {
Expand Down
59 changes: 37 additions & 22 deletions internal/proxy/impl.go
Original file line number Diff line number Diff line change
Expand Up @@ -3000,12 +3000,13 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest)
optimizedSearch := true
resultSizeInsufficient := false
isTopkReduce := false
isRecallEvaluation := false
err2 := retry.Handle(ctx, func() (bool, error) {
rsp, resultSizeInsufficient, isTopkReduce, err = node.search(ctx, request, optimizedSearch)
rsp, resultSizeInsufficient, isTopkReduce, isRecallEvaluation, err = node.search(ctx, request, optimizedSearch, false)
if merr.Ok(rsp.GetStatus()) && optimizedSearch && resultSizeInsufficient && isTopkReduce && paramtable.Get().AutoIndexConfig.EnableResultLimitCheck.GetAsBool() {
// without optimize search
optimizedSearch = false
rsp, resultSizeInsufficient, isTopkReduce, err = node.search(ctx, request, optimizedSearch)
rsp, resultSizeInsufficient, isTopkReduce, isRecallEvaluation, err = node.search(ctx, request, optimizedSearch, false)
metrics.ProxyRetrySearchCount.WithLabelValues(
strconv.FormatInt(paramtable.GetNodeID(), 10),
metrics.SearchLabel,
Expand All @@ -3023,6 +3024,23 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest)
if errors.Is(merr.Error(rsp.GetStatus()), merr.ErrInconsistentRequery) {
return true, merr.Error(rsp.GetStatus())
}
// search for ground truth and compute recall
if isRecallEvaluation && merr.Ok(rsp.GetStatus()) {
var rspGT *milvuspb.SearchResults
rspGT, _, _, _, err = node.search(ctx, request, false, true)
metrics.ProxyRecallSearchCount.WithLabelValues(
strconv.FormatInt(paramtable.GetNodeID(), 10),
metrics.SearchLabel,
request.GetCollectionName(),
).Inc()
if merr.Ok(rspGT.GetStatus()) {
return false, computeRecall(rsp.GetResults(), rspGT.GetResults())
}
if errors.Is(merr.Error(rspGT.GetStatus()), merr.ErrInconsistentRequery) {
return true, merr.Error(rspGT.GetStatus())
}
return false, merr.Error(rspGT.GetStatus())
}
return false, nil
})
if err2 != nil {
Expand All @@ -3031,13 +3049,11 @@ func (node *Proxy) Search(ctx context.Context, request *milvuspb.SearchRequest)
return rsp, err
}

func (node *Proxy) search(ctx context.Context, request *milvuspb.SearchRequest, optimizedSearch bool) (*milvuspb.SearchResults, bool, bool, error) {
receiveSize := proto.Size(request)
metrics.ProxyReceiveBytes.WithLabelValues(
strconv.FormatInt(paramtable.GetNodeID(), 10),
metrics.SearchLabel,
request.GetCollectionName(),
).Add(float64(receiveSize))
func (node *Proxy) search(ctx context.Context, request *milvuspb.SearchRequest, optimizedSearch bool, isRecallEvaluation bool) (*milvuspb.SearchResults, bool, bool, bool, error) {
metrics.GetStats(ctx).
SetNodeID(paramtable.GetNodeID()).
SetInboundLabel(metrics.SearchLabel).
SetCollectionName(request.GetCollectionName())

metrics.ProxyReceivedNQ.WithLabelValues(
strconv.FormatInt(paramtable.GetNodeID(), 10),
Expand All @@ -3048,7 +3064,7 @@ func (node *Proxy) search(ctx context.Context, request *milvuspb.SearchRequest,
if err := merr.CheckHealthy(node.GetStateCode()); err != nil {
return &milvuspb.SearchResults{
Status: merr.Status(err),
}, false, false, nil
}, false, false, false, nil
}

method := "Search"
Expand All @@ -3069,7 +3085,7 @@ func (node *Proxy) search(ctx context.Context, request *milvuspb.SearchRequest,
if err != nil {
return &milvuspb.SearchResults{
Status: merr.Status(err),
}, false, false, nil
}, false, false, false, nil
}

request.PlaceholderGroup = placeholderGroupBytes
Expand All @@ -3083,8 +3099,9 @@ func (node *Proxy) search(ctx context.Context, request *milvuspb.SearchRequest,
commonpbutil.WithMsgType(commonpb.MsgType_Search),
commonpbutil.WithSourceID(paramtable.GetNodeID()),
),
ReqID: paramtable.GetNodeID(),
IsTopkReduce: optimizedSearch,
ReqID: paramtable.GetNodeID(),
IsTopkReduce: optimizedSearch,
IsRecallEvaluation: isRecallEvaluation,
},
request: request,
tr: timerecord.NewTimeRecorder("search"),
Expand Down Expand Up @@ -3146,7 +3163,7 @@ func (node *Proxy) search(ctx context.Context, request *milvuspb.SearchRequest,

return &milvuspb.SearchResults{
Status: merr.Status(err),
}, false, false, nil
}, false, false, false, nil
}
tr.CtxRecord(ctx, "search request enqueue")

Expand All @@ -3172,7 +3189,7 @@ func (node *Proxy) search(ctx context.Context, request *milvuspb.SearchRequest,

return &milvuspb.SearchResults{
Status: merr.Status(err),
}, false, false, nil
}, false, false, false, nil
}

span := tr.CtxRecord(ctx, "wait search result")
Expand Down Expand Up @@ -3229,7 +3246,7 @@ func (node *Proxy) search(ctx context.Context, request *milvuspb.SearchRequest,
metrics.ProxyReportValue.WithLabelValues(nodeID, hookutil.OpTypeSearch, dbName, username).Add(float64(v))
}
}
return qt.result, qt.resultSizeInsufficient, qt.isTopkReduce, nil
return qt.result, qt.resultSizeInsufficient, qt.isTopkReduce, qt.isRecallEvaluation, nil
}

func (node *Proxy) HybridSearch(ctx context.Context, request *milvuspb.HybridSearchRequest) (*milvuspb.SearchResults, error) {
Expand Down Expand Up @@ -3272,12 +3289,10 @@ func (node *Proxy) HybridSearch(ctx context.Context, request *milvuspb.HybridSea
}

func (node *Proxy) hybridSearch(ctx context.Context, request *milvuspb.HybridSearchRequest, optimizedSearch bool) (*milvuspb.SearchResults, bool, bool, error) {
receiveSize := proto.Size(request)
metrics.ProxyReceiveBytes.WithLabelValues(
strconv.FormatInt(paramtable.GetNodeID(), 10),
metrics.HybridSearchLabel,
request.GetCollectionName(),
).Add(float64(receiveSize))
metrics.GetStats(ctx).
SetNodeID(paramtable.GetNodeID()).
SetInboundLabel(metrics.HybridSearchLabel).
SetCollectionName(request.GetCollectionName())

if err := merr.CheckHealthy(node.GetStateCode()); err != nil {
return &milvuspb.SearchResults{
Expand Down
6 changes: 6 additions & 0 deletions internal/proxy/task_search.go
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ type searchTask struct {
mustUsePartitionKey bool
resultSizeInsufficient bool
isTopkReduce bool
isRecallEvaluation bool

userOutputFields []string
userDynamicFields []string
Expand Down Expand Up @@ -647,10 +648,14 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
t.queryChannelsTs = make(map[string]uint64)
t.relatedDataSize = 0
isTopkReduce := false
isRecallEvaluation := false
for _, r := range toReduceResults {
if r.GetIsTopkReduce() {
isTopkReduce = true
}
if r.GetIsRecallEvaluation() {
isRecallEvaluation = true
}
t.relatedDataSize += r.GetCostAggregation().GetTotalRelatedDataSize()
for ch, ts := range r.GetChannelsMvcc() {
t.queryChannelsTs[ch] = ts
Expand Down Expand Up @@ -731,6 +736,7 @@ func (t *searchTask) PostExecute(ctx context.Context) error {
}
t.resultSizeInsufficient = resultSizeInsufficient
t.isTopkReduce = isTopkReduce
t.isRecallEvaluation = isRecallEvaluation
t.result.CollectionName = t.collectionName
t.fillInFieldInfo()

Expand Down
69 changes: 69 additions & 0 deletions internal/proxy/util.go
Original file line number Diff line number Diff line change
Expand Up @@ -1253,6 +1253,75 @@ func translatePkOutputFields(schema *schemapb.CollectionSchema) ([]string, []int
return pkNames, fieldIDs
}

func recallCal[T string | int64](results []T, gts []T) float32 {
hit := 0
total := 0
for _, r := range results {
total++
for _, gt := range gts {
if r == gt {
hit++
break
}
}
}
return float32(hit) / float32(total)
}

func computeRecall(results *schemapb.SearchResultData, gts *schemapb.SearchResultData) error {
if results.GetNumQueries() != gts.GetNumQueries() {
return fmt.Errorf("num of queries is inconsistent between search results(%d) and ground truth(%d)", results.GetNumQueries(), gts.GetNumQueries())
}

switch results.GetIds().GetIdField().(type) {
case *schemapb.IDs_IntId:
switch gts.GetIds().GetIdField().(type) {
case *schemapb.IDs_IntId:
currentResultIndex := int64(0)
currentGTIndex := int64(0)
recalls := make([]float32, 0, results.GetNumQueries())
for i := 0; i < int(results.GetNumQueries()); i++ {
currentResultTopk := results.GetTopks()[i]
currentGTTopk := gts.GetTopks()[i]
recalls = append(recalls, recallCal(results.GetIds().GetIntId().GetData()[currentResultIndex:currentResultIndex+currentResultTopk],
gts.GetIds().GetIntId().GetData()[currentGTIndex:currentGTIndex+currentGTTopk]))
currentResultIndex += currentResultTopk
currentGTIndex += currentGTTopk
}
results.Recalls = recalls
return nil
case *schemapb.IDs_StrId:
return fmt.Errorf("pk type is inconsistent between search results(int64) and ground truth(string)")
default:
return fmt.Errorf("unsupported pk type")
}

case *schemapb.IDs_StrId:
switch gts.GetIds().GetIdField().(type) {
case *schemapb.IDs_StrId:
currentResultIndex := int64(0)
currentGTIndex := int64(0)
recalls := make([]float32, 0, results.GetNumQueries())
for i := 0; i < int(results.GetNumQueries()); i++ {
currentResultTopk := results.GetTopks()[i]
currentGTTopk := gts.GetTopks()[i]
recalls = append(recalls, recallCal(results.GetIds().GetStrId().GetData()[currentResultIndex:currentResultIndex+currentResultTopk],
gts.GetIds().GetStrId().GetData()[currentGTIndex:currentGTIndex+currentGTTopk]))
currentResultIndex += currentResultTopk
currentGTIndex += currentGTTopk
}
results.Recalls = recalls
return nil
case *schemapb.IDs_IntId:
return fmt.Errorf("pk type is inconsistent between search results(string) and ground truth(int64)")
default:
return fmt.Errorf("unsupported pk type")
}
default:
return fmt.Errorf("unsupported pk type")
}
}

// Support wildcard in output fields:
//
// "*" - all fields
Expand Down
Loading

0 comments on commit 8977454

Please sign in to comment.