Skip to content

Commit

Permalink
enhance: [GoSDK] Use variadic params for options (#36912)
Browse files Browse the repository at this point in the history
Use variadic parameter function for options make client options easier
to use.

Related to #31293

Signed-off-by: Congqi Xia <[email protected]>
  • Loading branch information
congqixia authored Oct 16, 2024
1 parent 903450f commit e5948bd
Show file tree
Hide file tree
Showing 14 changed files with 94 additions and 93 deletions.
12 changes: 6 additions & 6 deletions client/example/playground/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ func Count() {
log.Println("flush done, elapsed", time.Since(start))

result, err := c.Query(ctx, milvusclient.NewQueryOption(collectionName).
WithOutputFields([]string{"count(*)"}).
WithOutputFields("count(*)").
WithConsistencyLevel(entity.ClStrong))
if err != nil {
log.Fatal("failed to connect to milvus, err: ", err.Error())
Expand All @@ -125,7 +125,7 @@ func Count() {
log.Println(rs)
}
result, err = c.Query(ctx, milvusclient.NewQueryOption(collectionName).
WithOutputFields([]string{"count(*)"}).
WithOutputFields("count(*)").
WithFilter("id > 0").
WithConsistencyLevel(entity.ClStrong))
if err != nil {
Expand All @@ -136,10 +136,10 @@ func Count() {
}
}

// err = c.DropCollection(ctx, milvusclient.NewDropCollectionOption(collectionName))
// if err != nil {
// log.Fatal("=== Failed to drop collection", err.Error())
// }
err = c.DropCollection(ctx, milvusclient.NewDropCollectionOption(collectionName))
if err != nil {
log.Fatal("=== Failed to drop collection", err.Error())
}
}

func HelloMilvus() {
Expand Down
2 changes: 1 addition & 1 deletion client/maintenance_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ func (opt *loadPartitionsOption) WithSkipLoadDynamicField(skipFlag bool) *loadPa
return opt
}

func NewLoadPartitionsOption(collectionName string, partitionsNames []string) *loadPartitionsOption {
func NewLoadPartitionsOption(collectionName string, partitionsNames ...string) *loadPartitionsOption {
return &loadPartitionsOption{
collectionName: collectionName,
partitionNames: partitionsNames,
Expand Down
4 changes: 2 additions & 2 deletions client/maintenance_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ func (s *MaintenanceSuite) TestLoadPartitions() {
})
defer s.mock.EXPECT().GetLoadingProgress(mock.Anything, mock.Anything).Unset()

task, err := s.client.LoadPartitions(ctx, NewLoadPartitionsOption(collectionName, []string{partitionName}).
task, err := s.client.LoadPartitions(ctx, NewLoadPartitionsOption(collectionName, partitionName).
WithReplica(replicaNum).
WithLoadFields(fieldNames...).
WithSkipLoadDynamicField(true))
Expand Down Expand Up @@ -174,7 +174,7 @@ func (s *MaintenanceSuite) TestLoadPartitions() {

s.mock.EXPECT().LoadPartitions(mock.Anything, mock.Anything).Return(nil, merr.WrapErrServiceInternal("mocked")).Once()

_, err := s.client.LoadPartitions(ctx, NewLoadPartitionsOption(collectionName, []string{partitionName}))
_, err := s.client.LoadPartitions(ctx, NewLoadPartitionsOption(collectionName, partitionName))
s.Error(err)
})
}
Expand Down
1 change: 1 addition & 0 deletions client/read.go
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ func (c *Client) handleSearchResult(schema *entity.Schema, outputFields []string
entry := ResultSet{
ResultCount: rc,
Scores: results.GetScores()[offset : offset+rc],
sch: schema,
}

entry.IDs, entry.Err = column.IDColumns(schema, results.GetIds(), offset, offset+rc)
Expand Down
8 changes: 4 additions & 4 deletions client/read_options.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ func (opt *searchOption) WithOffset(offset int) *searchOption {
return opt
}

func (opt *searchOption) WithOutputFields(fieldNames []string) *searchOption {
func (opt *searchOption) WithOutputFields(fieldNames ...string) *searchOption {
opt.outputFields = fieldNames
return opt
}
Expand All @@ -137,7 +137,7 @@ func (opt *searchOption) WithANNSField(annsField string) *searchOption {
return opt
}

func (opt *searchOption) WithPartitions(partitionNames []string) *searchOption {
func (opt *searchOption) WithPartitions(partitionNames ...string) *searchOption {
opt.partitionNames = partitionNames
return opt
}
Expand Down Expand Up @@ -240,7 +240,7 @@ func (opt *queryOption) WithLimit(limit int) *queryOption {
return opt
}

func (opt *queryOption) WithOutputFields(fieldNames []string) *queryOption {
func (opt *queryOption) WithOutputFields(fieldNames ...string) *queryOption {
opt.outputFields = fieldNames
return opt
}
Expand All @@ -251,7 +251,7 @@ func (opt *queryOption) WithConsistencyLevel(consistencyLevel entity.Consistency
return opt
}

func (opt *queryOption) WithPartitions(partitionNames []string) *queryOption {
func (opt *queryOption) WithPartitions(partitionNames ...string) *queryOption {
opt.partitionNames = partitionNames
return opt
}
Expand Down
6 changes: 3 additions & 3 deletions client/read_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func (s *ReadSuite) TestSearch() {
entity.FloatVector(lo.RepeatBy(128, func(_ int) float32 {
return rand.Float32()
})),
}).WithPartitions([]string{partitionName}))
}).WithPartitions(partitionName))
s.NoError(err)
})

Expand Down Expand Up @@ -109,7 +109,7 @@ func (s *ReadSuite) TestSearch() {
entity.FloatVector(lo.RepeatBy(128, func(_ int) float32 {
return rand.Float32()
})),
}).WithPartitions([]string{partitionName}))
}).WithPartitions(partitionName))
s.NoError(err)
})

Expand Down Expand Up @@ -145,7 +145,7 @@ func (s *ReadSuite) TestQuery() {
return &milvuspb.QueryResults{}, nil
}).Once()

_, err := s.client.Query(ctx, NewQueryOption(collectionName).WithPartitions([]string{partitionName}))
_, err := s.client.Query(ctx, NewQueryOption(collectionName).WithPartitions(partitionName))
s.NoError(err)
})
}
Expand Down
5 changes: 3 additions & 2 deletions client/results.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package client

import (
"reflect"
"runtime/debug"

"github.com/cockroachdb/errors"

Expand Down Expand Up @@ -64,7 +65,7 @@ func (sr *ResultSet) Unmarshal(receiver any) (err error) {
func (sr *ResultSet) fillPKEntry(receiver any) (err error) {
defer func() {
if x := recover(); x != nil {
err = errors.Newf("failed to unmarshal result set: %v", x)
err = errors.Newf("failed to unmarshal result set: %v, stack: %v", x, string(debug.Stack()))
}
}()
rr := reflect.ValueOf(receiver)
Expand Down Expand Up @@ -132,7 +133,7 @@ func (ds DataSet) Len() int {
func (ds DataSet) Unmarshal(receiver any) (err error) {
defer func() {
if x := recover(); x != nil {
err = errors.Newf("failed to unmarshal result set: %v", x)
err = errors.Newf("failed to unmarshal result set: %v, stack: %v", x, string(debug.Stack()))
}
}()
rr := reflect.ValueOf(receiver)
Expand Down
10 changes: 5 additions & 5 deletions tests/go_client/testcases/delete_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ func TestDeleteComplexExprWithoutLoad(t *testing.T) {
prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName))

res, err := mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithFilter(fmt.Sprintf("%s >= 0 ", common.DefaultInt64FieldName)).
WithOutputFields([]string{common.QueryCountFieldName}).WithConsistencyLevel(entity.ClStrong))
WithOutputFields(common.QueryCountFieldName).WithConsistencyLevel(entity.ClStrong))
common.CheckErr(t, err, true)
count, _ := res.Fields[0].GetAsInt64(0)
require.Equal(t, int64(common.DefaultNb-5), count)
Expand Down Expand Up @@ -324,7 +324,7 @@ func TestDeleteDefaultPartitionName(t *testing.T) {
common.CheckErr(t, errQuery, true)
require.Zero(t, queryRes.ResultCount)

queryRes, errQuery = mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithPartitions([]string{common.DefaultPartition, parName}).
queryRes, errQuery = mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithPartitions(common.DefaultPartition, parName).
WithConsistencyLevel(entity.ClStrong).WithFilter(expr))
common.CheckErr(t, errQuery, true)
require.Zero(t, queryRes.ResultCount)
Expand Down Expand Up @@ -362,7 +362,7 @@ func TestDeleteEmptyPartitionName(t *testing.T) {
common.CheckErr(t, errQuery, true)
require.Zero(t, queryRes.ResultCount)

queryRes, errQuery = mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithPartitions([]string{common.DefaultPartition, parName}).
queryRes, errQuery = mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithPartitions(common.DefaultPartition, parName).
WithConsistencyLevel(entity.ClStrong).WithFilter(expr))
common.CheckErr(t, errQuery, true)
require.Zero(t, queryRes.ResultCount)
Expand Down Expand Up @@ -406,7 +406,7 @@ func TestDeletePartitionName(t *testing.T) {
require.Equal(t, int64(0), del2.DeleteCount)

// query and verify
resQuery, err := mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithFilter(exprQuery).WithOutputFields([]string{common.QueryCountFieldName}).
resQuery, err := mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithFilter(exprQuery).WithOutputFields(common.QueryCountFieldName).
WithConsistencyLevel(entity.ClStrong))
common.CheckErr(t, err, true)
count, _ := resQuery.Fields[0].GetAsInt64(0)
Expand All @@ -427,7 +427,7 @@ func TestDeletePartitionName(t *testing.T) {
require.Equal(t, common.DefaultNb*2-200-1500, queryRes.ResultCount)

queryRes, errQuery = mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithFilter(exprQuery).WithConsistencyLevel(entity.ClStrong).
WithPartitions([]string{common.DefaultPartition, parName}))
WithPartitions(common.DefaultPartition, parName))
common.CheckErr(t, errQuery, true)
require.Equal(t, common.DefaultNb*2-200-1500, queryRes.ResultCount)
}
Expand Down
8 changes: 4 additions & 4 deletions tests/go_client/testcases/index_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -274,7 +274,7 @@ func TestCreateAutoIndexAllFields(t *testing.T) {
// load -> search and output all vector fields
prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName))
queryVec := hp.GenSearchVectors(common.DefaultNq, common.DefaultDim, entity.FieldTypeFloatVector)
searchRes, err := mc.Search(ctx, client.NewSearchOption(schema.CollectionName, common.DefaultLimit, queryVec).WithANNSField(common.DefaultFloatVecFieldName).WithOutputFields([]string{"*"}))
searchRes, err := mc.Search(ctx, client.NewSearchOption(schema.CollectionName, common.DefaultLimit, queryVec).WithANNSField(common.DefaultFloatVecFieldName).WithOutputFields("*"))
common.CheckErr(t, err, true)
common.CheckOutputFields(t, expFields, searchRes[0].Fields)
}
Expand Down Expand Up @@ -483,7 +483,7 @@ func TestCreateSortedScalarIndex(t *testing.T) {

queryVec := hp.GenSearchVectors(common.DefaultNq, common.DefaultDim, entity.FieldTypeFloatVector)
expr := fmt.Sprintf("%s > 10", common.DefaultInt64FieldName)
searchRes, err := mc.Search(ctx, client.NewSearchOption(schema.CollectionName, common.DefaultLimit, queryVec).WithFilter(expr).WithOutputFields([]string{"*"}))
searchRes, err := mc.Search(ctx, client.NewSearchOption(schema.CollectionName, common.DefaultLimit, queryVec).WithFilter(expr).WithOutputFields("*"))
common.CheckErr(t, err, true)
expFields := make([]string, 0, len(schema.Fields))
for _, field := range schema.Fields {
Expand Down Expand Up @@ -526,7 +526,7 @@ func TestCreateInvertedScalarIndex(t *testing.T) {

queryVec := hp.GenSearchVectors(common.DefaultNq, common.DefaultDim, entity.FieldTypeFloatVector)
expr := fmt.Sprintf("%s > 10", common.DefaultInt64FieldName)
searchRes, err := mc.Search(ctx, client.NewSearchOption(schema.CollectionName, common.DefaultLimit, queryVec).WithFilter(expr).WithOutputFields([]string{"*"}))
searchRes, err := mc.Search(ctx, client.NewSearchOption(schema.CollectionName, common.DefaultLimit, queryVec).WithFilter(expr).WithOutputFields("*"))
common.CheckErr(t, err, true)
expFields := make([]string, 0, len(schema.Fields))
for _, field := range schema.Fields {
Expand Down Expand Up @@ -691,7 +691,7 @@ func TestCreateInvertedIndexArrayField(t *testing.T) {
// load -> search and output all fields
prepare.Load(ctx, t, mc, hp.NewLoadParams(schema.CollectionName))
queryVec := hp.GenSearchVectors(common.DefaultNq, common.DefaultDim, entity.FieldTypeFloatVector)
searchRes, errSearch := mc.Search(ctx, client.NewSearchOption(schema.CollectionName, common.DefaultLimit, queryVec).WithConsistencyLevel(entity.ClStrong).WithOutputFields([]string{"*"}))
searchRes, errSearch := mc.Search(ctx, client.NewSearchOption(schema.CollectionName, common.DefaultLimit, queryVec).WithConsistencyLevel(entity.ClStrong).WithOutputFields("*"))
common.CheckErr(t, errSearch, true)
var expFields []string
for _, field := range schema.Fields {
Expand Down
4 changes: 2 additions & 2 deletions tests/go_client/testcases/insert_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ func TestInsertDynamicExtraColumn(t *testing.T) {
common.CheckErr(t, err, true)

// query
res, _ := mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithFilter("int64 == 3000").WithOutputFields([]string{"*"}))
res, _ := mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithFilter("int64 == 3000").WithOutputFields("*"))
common.CheckOutputFields(t, []string{common.DefaultFloatVecFieldName, common.DefaultInt64FieldName, common.DefaultDynamicFieldName}, res.Fields)
for _, c := range res.Fields {
log.Debug("data", zap.Any("data", c.FieldData()))
Expand Down Expand Up @@ -454,7 +454,7 @@ func TestInsertReadSparseEmptyVector(t *testing.T) {
require.EqualValues(t, 1, insertRes.InsertCount)

// query and check vector is empty
resQuery, err := mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithLimit(10).WithOutputFields([]string{common.DefaultSparseVecFieldName}).WithConsistencyLevel(entity.ClStrong))
resQuery, err := mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithLimit(10).WithOutputFields(common.DefaultSparseVecFieldName).WithConsistencyLevel(entity.ClStrong))
common.CheckErr(t, err, true)
require.Equal(t, 1, resQuery.ResultCount)
log.Info("sparseVec", zap.Any("data", resQuery.GetColumn(common.DefaultSparseVecFieldName).(*column.ColumnSparseFloatVector).Data()))
Expand Down
2 changes: 1 addition & 1 deletion tests/go_client/testcases/partition_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ func TestDropPartitionData(t *testing.T) {

// insert data into partition -> query check
prepare.InsertData(ctx, t, mc, hp.NewInsertParams(schema).TWithPartitionName(parName), hp.TNewDataOption())
res, errQ := mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithConsistencyLevel(entity.ClStrong).WithPartitions([]string{parName}).WithOutputFields([]string{common.QueryCountFieldName}))
res, errQ := mc.Query(ctx, client.NewQueryOption(schema.CollectionName).WithConsistencyLevel(entity.ClStrong).WithPartitions(parName).WithOutputFields(common.QueryCountFieldName))
common.CheckErr(t, errQ, true)
count, _ := res.GetColumn(common.QueryCountFieldName).Get(0)
require.EqualValues(t, common.DefaultNb, count)
Expand Down
Loading

0 comments on commit e5948bd

Please sign in to comment.