Skip to content

Commit

Permalink
enhance: fix inconsistenty of alias and db for query iterator(#39045) (
Browse files Browse the repository at this point in the history
…#39301)

related: #39045
pr: #39248

Signed-off-by: MrPresent-Han <[email protected]>
Co-authored-by: MrPresent-Han <[email protected]>
  • Loading branch information
MrPresent-Han and MrPresent-Han authored Jan 16, 2025
1 parent 622af57 commit 477425d
Show file tree
Hide file tree
Showing 5 changed files with 64 additions and 20 deletions.
29 changes: 16 additions & 13 deletions internal/proxy/search_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,21 +52,21 @@ func (r *rankParams) String() string {
}

// parseSearchInfo returns QueryInfo and offset
func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb.CollectionSchema, rankParams *rankParams) (*planpb.QueryInfo, int64, error) {
func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb.CollectionSchema, rankParams *rankParams) (*planpb.QueryInfo, int64, int64, error) {
var topK int64
isAdvanced := rankParams != nil
externalLimit := rankParams.GetLimit() + rankParams.GetOffset()
topKStr, err := funcutil.GetAttrByKeyFromRepeatedKV(TopKKey, searchParamsPair)
if err != nil {
if externalLimit <= 0 {
return nil, 0, fmt.Errorf("%s is required", TopKKey)
return nil, 0, 0, fmt.Errorf("%s is required", TopKKey)
}
topK = externalLimit
} else {
topKInParam, err := strconv.ParseInt(topKStr, 0, 64)
if err != nil {
if externalLimit <= 0 {
return nil, 0, fmt.Errorf("%s [%s] is invalid", TopKKey, topKStr)
return nil, 0, 0, fmt.Errorf("%s [%s] is invalid", TopKKey, topKStr)
}
topK = externalLimit
} else {
Expand All @@ -76,13 +76,16 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb

isIterator, _ := funcutil.GetAttrByKeyFromRepeatedKV(IteratorField, searchParamsPair)

collectionIDStr, _ := funcutil.GetAttrByKeyFromRepeatedKV(CollectionID, searchParamsPair)
collectionId, _ := strconv.ParseInt(collectionIDStr, 0, 64)

if err := validateLimit(topK); err != nil {
if isIterator == "True" {
// 1. if the request is from iterator, we set topK to QuotaLimit as the iterator can resolve too large topK problem
// 2. GetAsInt64 has cached inside, no need to worry about cpu cost for parsing here
topK = Params.QuotaConfig.TopKLimit.GetAsInt64()
} else {
return nil, 0, fmt.Errorf("%s [%d] is invalid, %w", TopKKey, topK, err)
return nil, 0, 0, fmt.Errorf("%s [%d] is invalid, %w", TopKKey, topK, err)
}
}

Expand All @@ -93,20 +96,20 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb
if err == nil {
offset, err = strconv.ParseInt(offsetStr, 0, 64)
if err != nil {
return nil, 0, fmt.Errorf("%s [%s] is invalid", OffsetKey, offsetStr)
return nil, 0, 0, fmt.Errorf("%s [%s] is invalid", OffsetKey, offsetStr)
}

if offset != 0 {
if err := validateLimit(offset); err != nil {
return nil, 0, fmt.Errorf("%s [%d] is invalid, %w", OffsetKey, offset, err)
return nil, 0, 0, fmt.Errorf("%s [%d] is invalid, %w", OffsetKey, offset, err)
}
}
}
}

queryTopK := topK + offset
if err := validateLimit(queryTopK); err != nil {
return nil, 0, fmt.Errorf("%s+%s [%d] is invalid, %w", OffsetKey, TopKKey, queryTopK, err)
return nil, 0, 0, fmt.Errorf("%s+%s [%d] is invalid, %w", OffsetKey, TopKKey, queryTopK, err)
}

// 2. parse metrics type
Expand All @@ -123,11 +126,11 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb

roundDecimal, err := strconv.ParseInt(roundDecimalStr, 0, 64)
if err != nil {
return nil, 0, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)
return nil, 0, 0, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)
}

if roundDecimal != -1 && (roundDecimal > 6 || roundDecimal < 0) {
return nil, 0, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)
return nil, 0, 0, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)
}

// 4. parse search param str
Expand All @@ -151,17 +154,17 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb
}
}
if groupByFieldId == -1 {
return nil, 0, merr.WrapErrFieldNotFound(groupByFieldName, "groupBy field not found in schema")
return nil, 0, 0, merr.WrapErrFieldNotFound(groupByFieldName, "groupBy field not found in schema")
}
}

// 6. disable groupBy for iterator and range search
if isIterator == "True" && groupByFieldId > 0 {
return nil, 0, merr.WrapErrParameterInvalid("", "",
return nil, 0, 0, merr.WrapErrParameterInvalid("", "",
"Not allowed to do groupBy when doing iteration")
}
if strings.Contains(searchParamStr, radiusKey) && groupByFieldId > 0 {
return nil, 0, merr.WrapErrParameterInvalid("", "",
return nil, 0, 0, merr.WrapErrParameterInvalid("", "",
"Not allowed to do range-search when doing search-group-by")
}

Expand All @@ -171,7 +174,7 @@ func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb
SearchParams: searchParamStr,
RoundDecimal: roundDecimal,
GroupByFieldId: groupByFieldId,
}, offset, nil
}, offset, collectionId, nil
}

func getOutputFieldIDs(schema *schemaInfo, outputFields []string) (outputFieldIDs []UniqueID, err error) {
Expand Down
1 change: 1 addition & 0 deletions internal/proxy/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ const (
IgnoreGrowingKey = "ignore_growing"
ReduceStopForBestKey = "reduce_stop_for_best"
IteratorField = "iterator"
CollectionID = "collection_id"
GroupByFieldKey = "group_by_field"
AnnsFieldKey = "anns_field"
TopKKey = "topk"
Expand Down
16 changes: 16 additions & 0 deletions internal/proxy/task_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ type queryParams struct {
limit int64
offset int64
reduceStopForBest bool
collectionID int64
}

// translateToOutputFieldIDs translates output fields name to output fields id.
Expand Down Expand Up @@ -143,6 +144,7 @@ func parseQueryParams(queryParamsPair []*commonpb.KeyValuePair) (*queryParams, e
offset int64
reduceStopForBest bool
err error
collectionID int64
)
reduceStopForBestStr, err := funcutil.GetAttrByKeyFromRepeatedKV(ReduceStopForBestKey, queryParamsPair)
// if reduce_stop_for_best is provided
Expand All @@ -154,6 +156,15 @@ func parseQueryParams(queryParamsPair []*commonpb.KeyValuePair) (*queryParams, e
}
}

collectionIdStr, err := funcutil.GetAttrByKeyFromRepeatedKV(CollectionID, queryParamsPair)
if err == nil {
collectionID, err = strconv.ParseInt(collectionIdStr, 0, 64)
if err != nil {
return nil, merr.WrapErrParameterInvalid("int value for collection_id", CollectionID,
"value for collection id is invalid")
}
}

limitStr, err := funcutil.GetAttrByKeyFromRepeatedKV(LimitKey, queryParamsPair)
// if limit is not provided
if err != nil {
Expand Down Expand Up @@ -182,6 +193,7 @@ func parseQueryParams(queryParamsPair []*commonpb.KeyValuePair) (*queryParams, e
limit: limit,
offset: offset,
reduceStopForBest: reduceStopForBest,
collectionID: collectionID,
}, nil
}

Expand Down Expand Up @@ -344,6 +356,10 @@ func (t *queryTask) PreExecute(ctx context.Context) error {
return err
}
t.RetrieveRequest.ReduceStopForBest = queryParams.reduceStopForBest
if queryParams.collectionID > 0 && queryParams.collectionID != t.GetCollectionID() {
return merr.WrapErrAsInputError(merr.WrapErrParameterInvalidMsg("Input collection id is not consistent to collectionID in the context," +
"alias or database may have changed"))
}

t.queryParams = queryParams
t.RetrieveRequest.Limit = queryParams.limit + queryParams.offset
Expand Down
7 changes: 6 additions & 1 deletion internal/proxy/task_search.go
Original file line number Diff line number Diff line change
Expand Up @@ -490,10 +490,15 @@ func (t *searchTask) tryGeneratePlan(params []*commonpb.KeyValuePair, dsl string
}
annsFieldName = vecFields[0].Name
}
queryInfo, offset, parseErr := parseSearchInfo(params, t.schema.CollectionSchema, t.rankParams)
queryInfo, offset, collectionID, parseErr := parseSearchInfo(params, t.schema.CollectionSchema, t.rankParams)
if parseErr != nil {
return nil, nil, 0, parseErr
}
if collectionID > 0 && collectionID != t.GetCollectionID() {
return nil, nil, 0, merr.WrapErrParameterInvalidMsg("collection id:%d in the request is not consistent to that in the search context,"+
"alias or database may have been changed: %d", collectionID, t.GetCollectionID())
}

annField := typeutil.GetFieldByName(t.schema.CollectionSchema, annsFieldName)
if queryInfo.GetGroupByFieldId() != -1 && annField.GetDataType() == schemapb.DataType_BinaryVector {
return nil, nil, 0, errors.New("not support search_group_by operation based on binary vector column")
Expand Down
31 changes: 25 additions & 6 deletions internal/proxy/task_search_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,25 @@ func TestSearchTask_PreExecute(t *testing.T) {
task.request.OutputFields = []string{testFloatVecField}
assert.NoError(t, task.PreExecute(ctx))
})
t.Run("search inconsistent collection_id", func(t *testing.T) {
collName := "search_inconsistent_collection" + funcutil.GenRandomStr()
createColl(t, collName, rc)

st := getSearchTask(t, collName)
st.request.SearchParams = getValidSearchParams()
st.request.SearchParams = append(st.request.SearchParams, &commonpb.KeyValuePair{
Key: CollectionID,
Value: "8080",
})
st.request.DslType = commonpb.DslType_BoolExprV1

_, cancel := context.WithTimeout(ctx, time.Second)
defer cancel()
require.Equal(t, typeutil.ZeroTimestamp, st.TimeoutTimestamp)
enqueueTs := uint64(100000)
st.SetTs(enqueueTs)
assert.Error(t, st.PreExecute(ctx))
})
}

func getQueryCoord() *mocks.MockQueryCoord {
Expand Down Expand Up @@ -1974,7 +1993,7 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {

for _, test := range tests {
t.Run(test.description, func(t *testing.T) {
info, offset, err := parseSearchInfo(test.validParams, nil, nil)
info, offset, _, err := parseSearchInfo(test.validParams, nil, nil)
assert.NoError(t, err)
assert.NotNil(t, info)
if test.description == "offsetParam" {
Expand All @@ -1995,7 +2014,7 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
limit: externalLimit,
}

info, offset, err := parseSearchInfo(offsetParam, nil, rank)
info, offset, _, err := parseSearchInfo(offsetParam, nil, rank)
assert.NoError(t, err)
assert.NotNil(t, info)
assert.Equal(t, int64(10), info.GetTopk())
Expand Down Expand Up @@ -2081,7 +2100,7 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {

for _, test := range tests {
t.Run(test.description, func(t *testing.T) {
info, offset, err := parseSearchInfo(test.invalidParams, nil, nil)
info, offset, _, err := parseSearchInfo(test.invalidParams, nil, nil)
assert.Error(t, err)
assert.Nil(t, info)
assert.Zero(t, offset)
Expand All @@ -2108,7 +2127,7 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
schema := &schemapb.CollectionSchema{
Fields: fields,
}
info, _, err := parseSearchInfo(normalParam, schema, nil)
info, _, _, err := parseSearchInfo(normalParam, schema, nil)
assert.Nil(t, info)
assert.ErrorIs(t, err, merr.ErrParameterInvalid)
})
Expand All @@ -2127,7 +2146,7 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
schema := &schemapb.CollectionSchema{
Fields: fields,
}
info, _, err := parseSearchInfo(normalParam, schema, nil)
info, _, _, err := parseSearchInfo(normalParam, schema, nil)
assert.Nil(t, info)
assert.ErrorIs(t, err, merr.ErrParameterInvalid)
})
Expand All @@ -2146,7 +2165,7 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
schema := &schemapb.CollectionSchema{
Fields: fields,
}
info, _, err := parseSearchInfo(normalParam, schema, nil)
info, _, _, err := parseSearchInfo(normalParam, schema, nil)
assert.NotNil(t, info)
assert.NoError(t, err)
assert.Equal(t, Params.QuotaConfig.TopKLimit.GetAsInt64(), info.Topk)
Expand Down

0 comments on commit 477425d

Please sign in to comment.