Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

enhance: fix inconsistenty of alias and db for query iterator(#39045) #39301

Merged
merged 1 commit into from
Jan 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 16 additions & 13 deletions internal/proxy/search_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -52,21 +52,21 @@
}

// parseSearchInfo returns QueryInfo and offset
func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb.CollectionSchema, rankParams *rankParams) (*planpb.QueryInfo, int64, error) {
func parseSearchInfo(searchParamsPair []*commonpb.KeyValuePair, schema *schemapb.CollectionSchema, rankParams *rankParams) (*planpb.QueryInfo, int64, int64, error) {
var topK int64
isAdvanced := rankParams != nil
externalLimit := rankParams.GetLimit() + rankParams.GetOffset()
topKStr, err := funcutil.GetAttrByKeyFromRepeatedKV(TopKKey, searchParamsPair)
if err != nil {
if externalLimit <= 0 {
return nil, 0, fmt.Errorf("%s is required", TopKKey)
return nil, 0, 0, fmt.Errorf("%s is required", TopKKey)
}
topK = externalLimit
} else {
topKInParam, err := strconv.ParseInt(topKStr, 0, 64)
if err != nil {
if externalLimit <= 0 {
return nil, 0, fmt.Errorf("%s [%s] is invalid", TopKKey, topKStr)
return nil, 0, 0, fmt.Errorf("%s [%s] is invalid", TopKKey, topKStr)
}
topK = externalLimit
} else {
Expand All @@ -76,13 +76,16 @@

isIterator, _ := funcutil.GetAttrByKeyFromRepeatedKV(IteratorField, searchParamsPair)

collectionIDStr, _ := funcutil.GetAttrByKeyFromRepeatedKV(CollectionID, searchParamsPair)
collectionId, _ := strconv.ParseInt(collectionIDStr, 0, 64)

if err := validateLimit(topK); err != nil {
if isIterator == "True" {
// 1. if the request is from iterator, we set topK to QuotaLimit as the iterator can resolve too large topK problem
// 2. GetAsInt64 has cached inside, no need to worry about cpu cost for parsing here
topK = Params.QuotaConfig.TopKLimit.GetAsInt64()
} else {
return nil, 0, fmt.Errorf("%s [%d] is invalid, %w", TopKKey, topK, err)
return nil, 0, 0, fmt.Errorf("%s [%d] is invalid, %w", TopKKey, topK, err)
}
}

Expand All @@ -93,20 +96,20 @@
if err == nil {
offset, err = strconv.ParseInt(offsetStr, 0, 64)
if err != nil {
return nil, 0, fmt.Errorf("%s [%s] is invalid", OffsetKey, offsetStr)
return nil, 0, 0, fmt.Errorf("%s [%s] is invalid", OffsetKey, offsetStr)
}

if offset != 0 {
if err := validateLimit(offset); err != nil {
return nil, 0, fmt.Errorf("%s [%d] is invalid, %w", OffsetKey, offset, err)
return nil, 0, 0, fmt.Errorf("%s [%d] is invalid, %w", OffsetKey, offset, err)
}
}
}
}

queryTopK := topK + offset
if err := validateLimit(queryTopK); err != nil {
return nil, 0, fmt.Errorf("%s+%s [%d] is invalid, %w", OffsetKey, TopKKey, queryTopK, err)
return nil, 0, 0, fmt.Errorf("%s+%s [%d] is invalid, %w", OffsetKey, TopKKey, queryTopK, err)

Check warning on line 112 in internal/proxy/search_util.go

View check run for this annotation

Codecov / codecov/patch

internal/proxy/search_util.go#L112

Added line #L112 was not covered by tests
}

// 2. parse metrics type
Expand All @@ -123,11 +126,11 @@

roundDecimal, err := strconv.ParseInt(roundDecimalStr, 0, 64)
if err != nil {
return nil, 0, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)
return nil, 0, 0, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)
}

if roundDecimal != -1 && (roundDecimal > 6 || roundDecimal < 0) {
return nil, 0, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)
return nil, 0, 0, fmt.Errorf("%s [%s] is invalid, should be -1 or an integer in range [0, 6]", RoundDecimalKey, roundDecimalStr)
}

// 4. parse search param str
Expand All @@ -151,17 +154,17 @@
}
}
if groupByFieldId == -1 {
return nil, 0, merr.WrapErrFieldNotFound(groupByFieldName, "groupBy field not found in schema")
return nil, 0, 0, merr.WrapErrFieldNotFound(groupByFieldName, "groupBy field not found in schema")

Check warning on line 157 in internal/proxy/search_util.go

View check run for this annotation

Codecov / codecov/patch

internal/proxy/search_util.go#L157

Added line #L157 was not covered by tests
}
}

// 6. disable groupBy for iterator and range search
if isIterator == "True" && groupByFieldId > 0 {
return nil, 0, merr.WrapErrParameterInvalid("", "",
return nil, 0, 0, merr.WrapErrParameterInvalid("", "",
"Not allowed to do groupBy when doing iteration")
}
if strings.Contains(searchParamStr, radiusKey) && groupByFieldId > 0 {
return nil, 0, merr.WrapErrParameterInvalid("", "",
return nil, 0, 0, merr.WrapErrParameterInvalid("", "",
"Not allowed to do range-search when doing search-group-by")
}

Expand All @@ -171,7 +174,7 @@
SearchParams: searchParamStr,
RoundDecimal: roundDecimal,
GroupByFieldId: groupByFieldId,
}, offset, nil
}, offset, collectionId, nil
}

func getOutputFieldIDs(schema *schemaInfo, outputFields []string) (outputFieldIDs []UniqueID, err error) {
Expand Down
1 change: 1 addition & 0 deletions internal/proxy/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ const (
IgnoreGrowingKey = "ignore_growing"
ReduceStopForBestKey = "reduce_stop_for_best"
IteratorField = "iterator"
CollectionID = "collection_id"
GroupByFieldKey = "group_by_field"
AnnsFieldKey = "anns_field"
TopKKey = "topk"
Expand Down
16 changes: 16 additions & 0 deletions internal/proxy/task_query.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@
limit int64
offset int64
reduceStopForBest bool
collectionID int64
}

// translateToOutputFieldIDs translates output fields name to output fields id.
Expand Down Expand Up @@ -143,6 +144,7 @@
offset int64
reduceStopForBest bool
err error
collectionID int64
)
reduceStopForBestStr, err := funcutil.GetAttrByKeyFromRepeatedKV(ReduceStopForBestKey, queryParamsPair)
// if reduce_stop_for_best is provided
Expand All @@ -154,6 +156,15 @@
}
}

collectionIdStr, err := funcutil.GetAttrByKeyFromRepeatedKV(CollectionID, queryParamsPair)
if err == nil {
collectionID, err = strconv.ParseInt(collectionIdStr, 0, 64)
if err != nil {
return nil, merr.WrapErrParameterInvalid("int value for collection_id", CollectionID,
"value for collection id is invalid")
}

Check warning on line 165 in internal/proxy/task_query.go

View check run for this annotation

Codecov / codecov/patch

internal/proxy/task_query.go#L161-L165

Added lines #L161 - L165 were not covered by tests
}

limitStr, err := funcutil.GetAttrByKeyFromRepeatedKV(LimitKey, queryParamsPair)
// if limit is not provided
if err != nil {
Expand Down Expand Up @@ -182,6 +193,7 @@
limit: limit,
offset: offset,
reduceStopForBest: reduceStopForBest,
collectionID: collectionID,
}, nil
}

Expand Down Expand Up @@ -344,6 +356,10 @@
return err
}
t.RetrieveRequest.ReduceStopForBest = queryParams.reduceStopForBest
if queryParams.collectionID > 0 && queryParams.collectionID != t.GetCollectionID() {
return merr.WrapErrAsInputError(merr.WrapErrParameterInvalidMsg("Input collection id is not consistent to collectionID in the context," +
"alias or database may have changed"))
}

Check warning on line 362 in internal/proxy/task_query.go

View check run for this annotation

Codecov / codecov/patch

internal/proxy/task_query.go#L360-L362

Added lines #L360 - L362 were not covered by tests

t.queryParams = queryParams
t.RetrieveRequest.Limit = queryParams.limit + queryParams.offset
Expand Down
7 changes: 6 additions & 1 deletion internal/proxy/task_search.go
Original file line number Diff line number Diff line change
Expand Up @@ -490,10 +490,15 @@ func (t *searchTask) tryGeneratePlan(params []*commonpb.KeyValuePair, dsl string
}
annsFieldName = vecFields[0].Name
}
queryInfo, offset, parseErr := parseSearchInfo(params, t.schema.CollectionSchema, t.rankParams)
queryInfo, offset, collectionID, parseErr := parseSearchInfo(params, t.schema.CollectionSchema, t.rankParams)
if parseErr != nil {
return nil, nil, 0, parseErr
}
if collectionID > 0 && collectionID != t.GetCollectionID() {
return nil, nil, 0, merr.WrapErrParameterInvalidMsg("collection id:%d in the request is not consistent to that in the search context,"+
"alias or database may have been changed: %d", collectionID, t.GetCollectionID())
}

annField := typeutil.GetFieldByName(t.schema.CollectionSchema, annsFieldName)
if queryInfo.GetGroupByFieldId() != -1 && annField.GetDataType() == schemapb.DataType_BinaryVector {
return nil, nil, 0, errors.New("not support search_group_by operation based on binary vector column")
Expand Down
31 changes: 25 additions & 6 deletions internal/proxy/task_search_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -303,6 +303,25 @@ func TestSearchTask_PreExecute(t *testing.T) {
task.request.OutputFields = []string{testFloatVecField}
assert.NoError(t, task.PreExecute(ctx))
})
t.Run("search inconsistent collection_id", func(t *testing.T) {
collName := "search_inconsistent_collection" + funcutil.GenRandomStr()
createColl(t, collName, rc)

st := getSearchTask(t, collName)
st.request.SearchParams = getValidSearchParams()
st.request.SearchParams = append(st.request.SearchParams, &commonpb.KeyValuePair{
Key: CollectionID,
Value: "8080",
})
st.request.DslType = commonpb.DslType_BoolExprV1

_, cancel := context.WithTimeout(ctx, time.Second)
defer cancel()
require.Equal(t, typeutil.ZeroTimestamp, st.TimeoutTimestamp)
enqueueTs := uint64(100000)
st.SetTs(enqueueTs)
assert.Error(t, st.PreExecute(ctx))
})
}

func getQueryCoord() *mocks.MockQueryCoord {
Expand Down Expand Up @@ -1974,7 +1993,7 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {

for _, test := range tests {
t.Run(test.description, func(t *testing.T) {
info, offset, err := parseSearchInfo(test.validParams, nil, nil)
info, offset, _, err := parseSearchInfo(test.validParams, nil, nil)
assert.NoError(t, err)
assert.NotNil(t, info)
if test.description == "offsetParam" {
Expand All @@ -1995,7 +2014,7 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
limit: externalLimit,
}

info, offset, err := parseSearchInfo(offsetParam, nil, rank)
info, offset, _, err := parseSearchInfo(offsetParam, nil, rank)
assert.NoError(t, err)
assert.NotNil(t, info)
assert.Equal(t, int64(10), info.GetTopk())
Expand Down Expand Up @@ -2081,7 +2100,7 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {

for _, test := range tests {
t.Run(test.description, func(t *testing.T) {
info, offset, err := parseSearchInfo(test.invalidParams, nil, nil)
info, offset, _, err := parseSearchInfo(test.invalidParams, nil, nil)
assert.Error(t, err)
assert.Nil(t, info)
assert.Zero(t, offset)
Expand All @@ -2108,7 +2127,7 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
schema := &schemapb.CollectionSchema{
Fields: fields,
}
info, _, err := parseSearchInfo(normalParam, schema, nil)
info, _, _, err := parseSearchInfo(normalParam, schema, nil)
assert.Nil(t, info)
assert.ErrorIs(t, err, merr.ErrParameterInvalid)
})
Expand All @@ -2127,7 +2146,7 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
schema := &schemapb.CollectionSchema{
Fields: fields,
}
info, _, err := parseSearchInfo(normalParam, schema, nil)
info, _, _, err := parseSearchInfo(normalParam, schema, nil)
assert.Nil(t, info)
assert.ErrorIs(t, err, merr.ErrParameterInvalid)
})
Expand All @@ -2146,7 +2165,7 @@ func TestTaskSearch_parseSearchInfo(t *testing.T) {
schema := &schemapb.CollectionSchema{
Fields: fields,
}
info, _, err := parseSearchInfo(normalParam, schema, nil)
info, _, _, err := parseSearchInfo(normalParam, schema, nil)
assert.NotNil(t, info)
assert.NoError(t, err)
assert.Equal(t, Params.QuotaConfig.TopKLimit.GetAsInt64(), info.Topk)
Expand Down
Loading