diff --git a/internal/proxy/search_util.go b/internal/proxy/search_util.go index 33ddd45b5d557..8f4ac6c020b6f 100644 --- a/internal/proxy/search_util.go +++ b/internal/proxy/search_util.go @@ -52,21 +52,21 @@ func (r *rankParams) String() string { } // parseSearchInfo returns QueryInfo and offset -func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb.CollectionSchema, rankParams *rankParams) (*planpb.QueryInfo, int64, error) { +func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb.CollectionSchema, rankParams *rankParams) (*planpb.QueryInfo, int64, int64, error) { var topK int64 isAdvanced := rankParams != nil externalLimit := rankParams.GetLimit() + rankParams.GetOffset() topKStr, err := funcutil.GetAttrByKeyFromRepeatedKV(TopKKey, searchParamsPair) if err != nil { if externalLimit <= 0 { - return nil, 0, fmt.Errorf("%s is required", TopKKey) + return nil, 0, 0, fmt.Errorf("%s is required", TopKKey) } topK = externalLimit } else { topKInParam, err := strconv.ParseInt(topKStr, 0, 64) if err != nil { if externalLimit <= 0 { - return nil, 0, fmt.Errorf("%s [%s] is invalid", TopKKey, topKStr) + return nil, 0, 0, fmt.Errorf("%s [%s] is invalid", TopKKey, topKStr) } topK = externalLimit } else { @@ -76,13 +76,16 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb isIterator, _ := funcutil.GetAttrByKeyFromRepeatedKV(IteratorField, searchParamsPair) + collectionIDStr, _ := funcutil.GetAttrByKeyFromRepeatedKV(CollectionID, searchParamsPair) + collectionId, _ := strconv.ParseInt(collectionIDStr, 0, 64) + if err := validateLimit(topK); err != nil { if isIterator == "True" { // 1. if the request is from iterator, we set topK to QuotaLimit as the iterator can resolve too large topK problem // 2. GetAsInt64 has cached inside, no need to worry about cpu cost for parsing here topK = Params.QuotaConfig.TopKLimit.GetAsInt64() } else { - return nil, 0, fmt.Errorf("%s [%d] is invalid, %w", TopKKey, topK, err) + return nil, 0, 0, fmt.Errorf("%s [%d] is invalid, %w", TopKKey, topK, err) } } @@ -93,12 +96,12 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb if err == nil { offset, err = strconv.ParseInt(offsetStr, 0, 64) if err != nil { - return nil, 0, fmt.Errorf("%s [%s] is invalid", OffsetKey, offsetStr) + return nil, 0, 0, fmt.Errorf("%s [%s] is invalid", OffsetKey, offsetStr) } if offset != 0 { if err := validateLimit(offset); err != nil { - return nil, 0, fmt.Errorf("%s [%d] is invalid, %w", OffsetKey, offset, err) + return nil, 0, 0, fmt.Errorf("%s [%d] is invalid, %w", OffsetKey, offset, err) } } } @@ -106,7 +109,7 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb queryTopK := topK + offset if err := validateLimit(queryTopK); err != nil { - return nil, 0, fmt.Errorf("%s+%s [%d] is invalid, %w", OffsetKey, TopKKey, queryTopK, err) + return nil, 0, 0, fmt.Errorf("%s+%s [%d] is invalid, %w", OffsetKey, TopKKey, queryTopK, err) } // 2. parse metrics type @@ -123,11 +126,11 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb roundDecimal, err := strconv.ParseInt(roundDecimalStr, 0, 64) if err != nil { - return nil, 0, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr) + return nil, 0, 0, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr) } if roundDecimal != -1 && (roundDecimal > 6 || roundDecimal < 0) { - return nil, 0, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr) + return nil, 0, 0, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr) } // 4. parse search param str @@ -151,17 +154,17 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb } } if groupByFieldId == -1 { - return nil, 0, merr.WrapErrFieldNotFound(groupByFieldName, "groupBy field not found in schema") + return nil, 0, 0, merr.WrapErrFieldNotFound(groupByFieldName, "groupBy field not found in schema") } } // 6. disable groupBy for iterator and range search if isIterator == "True" && groupByFieldId > 0 { - return nil, 0, merr.WrapErrParameterInvalid("", "", + return nil, 0, 0, merr.WrapErrParameterInvalid("", "", "Not allowed to do groupBy when doing iteration") } if strings.Contains(searchParamStr, radiusKey) && groupByFieldId > 0 { - return nil, 0, merr.WrapErrParameterInvalid("", "", + return nil, 0, 0, merr.WrapErrParameterInvalid("", "", "Not allowed to do range-search when doing search-group-by") } @@ -171,7 +174,7 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb SearchParams: searchParamStr, RoundDecimal: roundDecimal, GroupByFieldId: groupByFieldId, - }, offset, nil + }, offset, collectionId, nil } func getOutputFieldIDs(schema *schemaInfo, outputFields []string) (outputFieldIDs []UniqueID, err error) { diff --git a/internal/proxy/task.go b/internal/proxy/task.go index 160f2be3dc98d..f18b045fb50bb 100644 --- a/internal/proxy/task.go +++ b/internal/proxy/task.go @@ -49,6 +49,7 @@ const ( IgnoreGrowingKey = "ignore_growing" ReduceStopForBestKey = "reduce_stop_for_best" IteratorField = "iterator" + CollectionID = "collection_id" GroupByFieldKey = "group_by_field" AnnsFieldKey = "anns_field" TopKKey = "topk" diff --git a/internal/proxy/task_query.go b/internal/proxy/task_query.go index 260f77710934f..b7ec7e1b7f318 100644 --- a/internal/proxy/task_query.go +++ b/internal/proxy/task_query.go @@ -78,6 +78,7 @@ type queryParams struct { limit int64 offset int64 reduceStopForBest bool + collectionID int64 } // translateToOutputFieldIDs translates output fields name to output fields id. @@ -143,6 +144,7 @@ func parseQueryParams(queryParamsPair []*commonpb.KeyValuePair) (*queryParams, e offset int64 reduceStopForBest bool err error + collectionID int64 ) reduceStopForBestStr, err := funcutil.GetAttrByKeyFromRepeatedKV(ReduceStopForBestKey, queryParamsPair) // if reduce_stop_for_best is provided @@ -154,6 +156,15 @@ func parseQueryParams(queryParamsPair []*commonpb.KeyValuePair) (*queryParams, e } } + collectionIdStr, err := funcutil.GetAttrByKeyFromRepeatedKV(CollectionID, queryParamsPair) + if err == nil { + collectionID, err = strconv.ParseInt(collectionIdStr, 0, 64) + if err != nil { + return nil, merr.WrapErrParameterInvalid("int value for collection_id", CollectionID, + "value for collection id is invalid") + } + } + limitStr, err := funcutil.GetAttrByKeyFromRepeatedKV(LimitKey, queryParamsPair) // if limit is not provided if err != nil { @@ -182,6 +193,7 @@ func parseQueryParams(queryParamsPair []*commonpb.KeyValuePair) (*queryParams, e limit: limit, offset: offset, reduceStopForBest: reduceStopForBest, + collectionID: collectionID, }, nil } @@ -344,6 +356,10 @@ func (t *queryTask) PreExecute(ctx context.Context) error { return err } t.RetrieveRequest.ReduceStopForBest = queryParams.reduceStopForBest + if queryParams.collectionID > 0 && queryParams.collectionID != t.GetCollectionID() { + return merr.WrapErrAsInputError(merr.WrapErrParameterInvalidMsg("Input collection id is not consistent to collectionID in the context," + + "alias or database may have changed")) + } t.queryParams = queryParams t.RetrieveRequest.Limit = queryParams.limit + queryParams.offset diff --git a/internal/proxy/task_search.go b/internal/proxy/task_search.go index cf79e1b139948..282e527ff91dc 100644 --- a/internal/proxy/task_search.go +++ b/internal/proxy/task_search.go @@ -490,10 +490,15 @@ func (t *searchTask) tryGeneratePlan(params []*commonpb.KeyValuePair, dsl string } annsFieldName = vecFields[0].Name } - queryInfo, offset, parseErr := parseSearchInfo(params, t.schema.CollectionSchema, t.rankParams) + queryInfo, offset, collectionID, parseErr := parseSearchInfo(params, t.schema.CollectionSchema, t.rankParams) if parseErr != nil { return nil, nil, 0, parseErr } + if collectionID > 0 && collectionID != t.GetCollectionID() { + return nil, nil, 0, merr.WrapErrParameterInvalidMsg("collection id:%d in the request is not consistent to that in the search context,"+ + "alias or database may have been changed: %d", collectionID, t.GetCollectionID()) + } + annField := typeutil.GetFieldByName(t.schema.CollectionSchema, annsFieldName) if queryInfo.GetGroupByFieldId() != -1 && annField.GetDataType() == schemapb.DataType_BinaryVector { return nil, nil, 0, errors.New("not support search_group_by operation based on binary vector column") diff --git a/internal/proxy/task_search_test.go b/internal/proxy/task_search_test.go index 802c69eb52d07..65f485a88ff82 100644 --- a/internal/proxy/task_search_test.go +++ b/internal/proxy/task_search_test.go @@ -303,6 +303,25 @@ func TestSearchTask_PreExecute(t *testing.T) { task.request.OutputFields = []string{testFloatVecField} assert.NoError(t, task.PreExecute(ctx)) }) + t.Run("search inconsistent collection_id", func(t *testing.T) { + collName := "search_inconsistent_collection" + funcutil.GenRandomStr() + createColl(t, collName, rc) + + st := getSearchTask(t, collName) + st.request.SearchParams = getValidSearchParams() + st.request.SearchParams = append(st.request.SearchParams, &commonpb.KeyValuePair{ + Key: CollectionID, + Value: "8080", + }) + st.request.DslType = commonpb.DslType_BoolExprV1 + + _, cancel := context.WithTimeout(ctx, time.Second) + defer cancel() + require.Equal(t, typeutil.ZeroTimestamp, st.TimeoutTimestamp) + enqueueTs := uint64(100000) + st.SetTs(enqueueTs) + assert.Error(t, st.PreExecute(ctx)) + }) } func getQueryCoord() *mocks.MockQueryCoord { @@ -1974,7 +1993,7 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) { for _, test := range tests { t.Run(test.description, func(t *testing.T) { - info, offset, err := parseSearchInfo(test.validParams, nil, nil) + info, offset, _, err := parseSearchInfo(test.validParams, nil, nil) assert.NoError(t, err) assert.NotNil(t, info) if test.description == "offsetParam" { @@ -1995,7 +2014,7 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) { limit: externalLimit, } - info, offset, err := parseSearchInfo(offsetParam, nil, rank) + info, offset, _, err := parseSearchInfo(offsetParam, nil, rank) assert.NoError(t, err) assert.NotNil(t, info) assert.Equal(t, int64(10), info.GetTopk()) @@ -2081,7 +2100,7 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) { for _, test := range tests { t.Run(test.description, func(t *testing.T) { - info, offset, err := parseSearchInfo(test.invalidParams, nil, nil) + info, offset, _, err := parseSearchInfo(test.invalidParams, nil, nil) assert.Error(t, err) assert.Nil(t, info) assert.Zero(t, offset) @@ -2108,7 +2127,7 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) { schema := &schemapb.CollectionSchema{ Fields: fields, } - info, _, err := parseSearchInfo(normalParam, schema, nil) + info, _, _, err := parseSearchInfo(normalParam, schema, nil) assert.Nil(t, info) assert.ErrorIs(t, err, merr.ErrParameterInvalid) }) @@ -2127,7 +2146,7 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) { schema := &schemapb.CollectionSchema{ Fields: fields, } - info, _, err := parseSearchInfo(normalParam, schema, nil) + info, _, _, err := parseSearchInfo(normalParam, schema, nil) assert.Nil(t, info) assert.ErrorIs(t, err, merr.ErrParameterInvalid) }) @@ -2146,7 +2165,7 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) { schema := &schemapb.CollectionSchema{ Fields: fields, } - info, _, err := parseSearchInfo(normalParam, schema, nil) + info, _, _, err := parseSearchInfo(normalParam, schema, nil) assert.NotNil(t, info) assert.NoError(t, err) assert.Equal(t, Params.QuotaConfig.TopKLimit.GetAsInt64(), info.Topk)