milvus2.4多向量搜索源码分析
api入口
HybridSearch是多向量搜索的API。
func (node *Proxy) HybridSearch(ctx context.Context, request *milvuspb.HybridSearchRequest) (*milvuspb.SearchResults, error) {var err errorrsp := &milvuspb.SearchResults{Status: merr.Success(),}err2 := retry.Handle(ctx, func() (bool, error) {rsp, err = node.hybridSearch(ctx, request)if errors.Is(merr.Error(rsp.GetStatus()), merr.ErrInconsistentRequery) {return true, merr.Error(rsp.GetStatus())}return false, nil})if err2 != nil {rsp.Status = merr.Status(err2)}return rsp, err
}func (node *Proxy) hybridSearch(ctx context.Context, request *milvuspb.HybridSearchRequest) (*milvuspb.SearchResults, error) {......// 转换为milvuspb.SearchRequestnewSearchReq := convertHybridSearchToSearch(request)qt := &searchTask{ctx: ctx,Condition: NewTaskCondition(ctx),SearchRequest: &internalpb.SearchRequest{Base: commonpbutil.NewMsgBase(commonpbutil.WithMsgType(commonpb.MsgType_Search),commonpbutil.WithSourceID(paramtable.GetNodeID()),),ReqID: paramtable.GetNodeID(),},request: newSearchReq,tr: timerecord.NewTimeRecorder(method),qc: node.queryCoord,node: node,lb: node.lbPolicy,mustUsePartitionKey: Params.ProxyCfg.MustUsePartitionKey.GetAsBool(),}guaranteeTs := request.GuaranteeTimestamplog := log.Ctx(ctx).With(zap.String("role", typeutil.ProxyRole),zap.String("db", request.DbName),zap.String("collection", request.CollectionName),zap.Any("partitions", request.PartitionNames),zap.Any("OutputFields", request.OutputFields),zap.Uint64("guarantee_timestamp", guaranteeTs),)defer func() {span := tr.ElapseSpan()if span >= paramtable.Get().ProxyCfg.SlowQuerySpanInSeconds.GetAsDuration(time.Second) {log.Info(rpcSlow(method), zap.Duration("duration", span))metrics.ProxySlowQueryCount.WithLabelValues(strconv.FormatInt(paramtable.GetNodeID(), 10),metrics.HybridSearchLabel,).Inc()}}()log.Debug(rpcReceived(method))if err := node.sched.dqQueue.Enqueue(qt); err != nil {......}......
}
从代码中可以看出HybridSearch最终调用的和Search() API是同一个task。
milvuspb.HybridSearchRequest有一个[]*SearchRequest变量,这个存储了多个查询结构体,如果是普通的Search(),传参直接就是SearchRequest结构体,如果是HybridSearch(),就是多个查询结构体,下一步做转换。
convertHybridSearchToSearch
进入convertHybridSearchToSearch()看看是如何转换的。
func convertHybridSearchToSearch(req *milvuspb.HybridSearchRequest) *milvuspb.SearchRequest {ret := &milvuspb.SearchRequest{Base: req.GetBase(),DbName: req.GetDbName(),CollectionName: req.GetCollectionName(),PartitionNames: req.GetPartitionNames(),OutputFields: req.GetOutputFields(),SearchParams: req.GetRankParams(),TravelTimestamp: req.GetTravelTimestamp(),GuaranteeTimestamp: req.GetGuaranteeTimestamp(),Nq: 0,NotReturnAllMeta: req.GetNotReturnAllMeta(),ConsistencyLevel: req.GetConsistencyLevel(),UseDefaultConsistency: req.GetUseDefaultConsistency(),SearchByPrimaryKeys: false,SubReqs: nil,}for _, sub := range req.GetRequests() {subReq := &milvuspb.SubSearchRequest{Dsl: sub.GetDsl(),PlaceholderGroup: sub.GetPlaceholderGroup(),DslType: sub.GetDslType(),SearchParams: sub.GetSearchParams(),Nq: sub.GetNq(),}ret.SubReqs = append(ret.SubReqs, subReq)}return ret
}
milvuspb.SearchRequest结构体增加了一个SubReqs变量,类型是[]*SubSearchRequest。
type SubSearchRequest struct {Dsl stringPlaceholderGroup []byteDslType commonpb.DslTypeSearchParams []*commonpb.KeyValuePairNq int64XXX_NoUnkeyedLiteral struct{}XXX_unrecognized []byteXXX_sizecache int32
}
searchTask
PreExecute()
t.SearchRequest.IsAdvanced = len(t.request.GetSubReqs()) > 0
Search()和HybridSearch()最终都是走的searchTask,如果是HybridSearch,IsAdvanced会置为true,如果是Search,IsAdvanced会置为false。
func (t *searchTask) PreExecute(ctx context.Context) error {......t.SearchRequest.IsAdvanced = len(t.request.GetSubReqs()) > 0......if t.SearchRequest.GetIsAdvanced() {if len(t.request.GetSubReqs()) > defaultMaxSearchRequest {return errors.New(fmt.Sprintf("maximum of ann search requests is %d", defaultMaxSearchRequest))}}if t.SearchRequest.GetIsAdvanced() {t.rankParams, err = parseRankParams(t.request.GetSearchParams())if err != nil {return err}}......if t.SearchRequest.GetIsAdvanced() {t.requery = len(t.request.OutputFields) > 0err = t.initAdvancedSearchRequest(ctx)} else {t.requery = len(vectorOutputFields) > 0err = t.initSearchRequest(ctx)}......
}
在initAdvancedSearchRequest填充SubReqs
Search
最终会转换为多次search。
// Search preforms search operation on shard.
func (sd *shardDelegator) Search(ctx context.Context, req *querypb.SearchRequest) ([]*internalpb.SearchResults, error) {......if req.GetReq().GetIsAdvanced() {futures := make([]*conc.Future[*internalpb.SearchResults], len(req.GetReq().GetSubReqs()))// 多次调用searchfor index, subReq := range req.GetReq().GetSubReqs() {newRequest := &internalpb.SearchRequest{Base: req.GetReq().GetBase(),ReqID: req.GetReq().GetReqID(),DbID: req.GetReq().GetDbID(),CollectionID: req.GetReq().GetCollectionID(),PartitionIDs: subReq.GetPartitionIDs(),Dsl: subReq.GetDsl(),PlaceholderGroup: subReq.GetPlaceholderGroup(),DslType: subReq.GetDslType(),SerializedExprPlan: subReq.GetSerializedExprPlan(),OutputFieldsId: req.GetReq().GetOutputFieldsId(),MvccTimestamp: req.GetReq().GetMvccTimestamp(),GuaranteeTimestamp: req.GetReq().GetGuaranteeTimestamp(),TimeoutTimestamp: req.GetReq().GetTimeoutTimestamp(),Nq: subReq.GetNq(),Topk: subReq.GetTopk(),MetricType: subReq.GetMetricType(),IgnoreGrowing: req.GetReq().GetIgnoreGrowing(),Username: req.GetReq().GetUsername(),IsAdvanced: false,}future := conc.Go(func() (*internalpb.SearchResults, error) {searchReq := &querypb.SearchRequest{Req: newRequest,DmlChannels: req.GetDmlChannels(),TotalChannelNum: req.GetTotalChannelNum(),}searchReq.Req.GuaranteeTimestamp = req.GetReq().GetGuaranteeTimestamp()searchReq.Req.TimeoutTimestamp = req.GetReq().GetTimeoutTimestamp()if searchReq.GetReq().GetMvccTimestamp() == 0 {searchReq.GetReq().MvccTimestamp = tSafe}// 执行搜索results, err := sd.search(ctx, searchReq, sealed, growing)if err != nil {return nil, err}return segments.ReduceSearchResults(ctx,results,searchReq.Req.GetNq(),searchReq.Req.GetTopk(),searchReq.Req.GetMetricType())})futures[index] = future}// 等待所有任务执行完成err = conc.AwaitAll(futures...)if err != nil {return nil, err}results := make([]*internalpb.SearchResults, len(futures))for i, future := range futures {result := future.Value()if result.GetStatus().GetErrorCode() != commonpb.ErrorCode_Success {log.Debug("delegator hybrid search failed",zap.String("reason", result.GetStatus().GetReason()))return nil, merr.Error(result.GetStatus())}results[i] = result}var ret *internalpb.SearchResultsret, err = segments.MergeToAdvancedResults(ctx, results)if err != nil {return nil, err}// 走这里return []*internalpb.SearchResults{ret}, nil}return sd.search(ctx, req, sealed, growing)
}
总结
HybridSearch会转换为多个Search搜索。