diff --git a/client/data.go b/client/data.go index 0e61a71d..79e5a435 100644 --- a/client/data.go +++ b/client/data.go @@ -134,6 +134,7 @@ func (c *GrpcClient) handleSearchResult(schema *entity.Schema, outputFields []st for i := 0; i < int(results.GetNumQueries()); i++ { rc := int(results.GetTopks()[i]) // result entry count for current query entry := SearchResult{ + sch: schema, ResultCount: rc, Scores: results.GetScores()[offset : offset+rc], } diff --git a/client/results.go b/client/results.go index 418abb57..8aba95ac 100644 --- a/client/results.go +++ b/client/results.go @@ -1,6 +1,10 @@ package client import ( + "go/ast" + "reflect" + + "github.com/cockroachdb/errors" "github.com/milvus-io/milvus-sdk-go/v2/entity" ) @@ -9,6 +13,9 @@ import ( // Fields contains the data of `outputFieleds` specified or all columns if non // Scores is actually the distance between the vector current record contains and the search target vector type SearchResult struct { + // internal schema for unmarshaling + sch *entity.Schema + ResultCount int // the returning entry count GroupByValue entity.Column IDs entity.Column // auto generated id, can be mapped to the columns from `Insert` API @@ -44,6 +51,66 @@ func (sr *SearchResult) Slice(start, end int) *SearchResult { return result } +func (sr *SearchResult) Unmarshal(receiver interface{}) (err error) { + err = sr.Fields.Unmarshal(receiver) + if err != nil { + return err + } + return sr.fillPKEntry(receiver) +} + +func (sr *SearchResult) fillPKEntry(receiver interface{}) (err error) { + defer func() { + if x := recover(); x != nil { + err = errors.Newf("failed to unmarshal result set: %v", x) + } + }() + rr := reflect.ValueOf(receiver) + + if rr.Kind() == reflect.Ptr { + if rr.IsNil() && rr.CanAddr() { + rr.Set(reflect.New(rr.Type().Elem())) + } + rr = rr.Elem() + } + + rt := rr.Type() + rv := rr + + switch rt.Kind() { + case reflect.Slice: + pkField := sr.sch.PKField() + + et := rt.Elem() + for et.Kind() == reflect.Ptr { + et = et.Elem() + } + + candidates := parseCandidates(et) + candi, ok := candidates[pkField.Name] + if !ok { + // pk field not found in struct, skip + return nil + } + for i := 0; i < sr.IDs.Len(); i++ { + row := rv.Index(i) + for row.Kind() == reflect.Ptr { + row = row.Elem() + } + + val, err := sr.IDs.Get(i) + if err != nil { + return err + } + row.Field(candi).Set(reflect.ValueOf(val)) + } + rr.Set(rv) + default: + return errors.Newf("receiver need to be slice or array but get %v", rt.Kind()) + } + return nil +} + // ResultSet is an alias type for column slice. type ResultSet []entity.Column @@ -71,3 +138,87 @@ func (rs ResultSet) GetColumn(fieldName string) entity.Column { } return nil } + +func (rs ResultSet) Unmarshal(receiver interface{}) (err error) { + defer func() { + if x := recover(); x != nil { + err = errors.Newf("failed to unmarshal result set: %v", x) + } + }() + rr := reflect.ValueOf(receiver) + + if rr.Kind() == reflect.Ptr { + if rr.IsNil() && rr.CanAddr() { + rr.Set(reflect.New(rr.Type().Elem())) + } + rr = rr.Elem() + } + + rt := rr.Type() + rv := rr + + switch rt.Kind() { + // TODO maybe support Array and just fill data + // case reflect.Array: + case reflect.Slice: + et := rt.Elem() + if et.Kind() != reflect.Ptr { + return errors.Newf("receiver must be slice of pointers but get: %v", et.Kind()) + } + for et.Kind() == reflect.Ptr { + et = et.Elem() + } + for i := 0; i < rs.Len(); i++ { + data := reflect.New(et) + err := rs.fillData(data.Elem(), et, i) + if err != nil { + return err + } + rv = reflect.Append(rv, data) + } + rr.Set(rv) + default: + return errors.Newf("receiver need to be slice or array but get %v", rt.Kind()) + } + return nil +} + +func parseCandidates(dataType reflect.Type) map[string]int { + result := make(map[string]int) + for i := 0; i < dataType.NumField(); i++ { + f := dataType.Field(i) + // ignore anonymous field for now + if f.Anonymous || !ast.IsExported(f.Name) { + continue + } + + name := f.Name + tag := f.Tag.Get(entity.MilvusTag) + tagSettings := entity.ParseTagSetting(tag, entity.MilvusTagSep) + if tagName, has := tagSettings[entity.MilvusTagName]; has { + name = tagName + } + + result[name] = i + } + return result +} + +func (rs ResultSet) fillData(data reflect.Value, dataType reflect.Type, idx int) error { + m := parseCandidates(dataType) + for i := 0; i < len(rs); i++ { + name := rs[i].Name() + fidx, ok := m[name] + if !ok { + // maybe return error + continue + } + val, err := rs[i].Get(idx) + if err != nil { + return err + } + // TODO check datatype + data.Field(fidx).Set(reflect.ValueOf(val)) + } + return nil +} diff --git a/client/results_test.go b/client/results_test.go new file mode 100644 index 00000000..9181b70d --- /dev/null +++ b/client/results_test.go @@ -0,0 +1,109 @@ +package client + +import ( + "testing" + + "github.com/milvus-io/milvus-sdk-go/v2/entity" + "github.com/stretchr/testify/suite" +) + +type ResultSetSuite struct { + suite.Suite +} + +func (s *ResultSetSuite) TestResultsetUnmarshal() { + type MyData struct { + A int64 `milvus:"name:id"` + V []float32 `milvus:"name:vector"` + } + type OtherData struct { + A string `milvus:"name:id"` + V []float32 `milvus:"name:vector"` + } + + var ( + idData = []int64{1, 2, 3} + vectorData = [][]float32{ + {0.1, 0.2}, + {0.1, 0.2}, + {0.1, 0.2}, + } + ) + + rs := ResultSet([]entity.Column{ + entity.NewColumnInt64("id", idData), + entity.NewColumnFloatVector("vector", 2, vectorData), + }) + err := rs.Unmarshal([]MyData{}) + s.Error(err) + + receiver := []MyData{} + err = rs.Unmarshal(&receiver) + s.Error(err) + + var ptrReceiver []*MyData + err = rs.Unmarshal(&ptrReceiver) + s.NoError(err) + + for idx, row := range ptrReceiver { + s.Equal(row.A, idData[idx]) + s.Equal(row.V, vectorData[idx]) + } + + var otherReceiver []*OtherData + err = rs.Unmarshal(&otherReceiver) + s.Error(err) +} + +func (s *ResultSetSuite) TestSearchResultUnmarshal() { + type MyData struct { + A int64 `milvus:"name:id"` + V []float32 `milvus:"name:vector"` + } + type OtherData struct { + A string `milvus:"name:id"` + V []float32 `milvus:"name:vector"` + } + + var ( + idData = []int64{1, 2, 3} + vectorData = [][]float32{ + {0.1, 0.2}, + {0.1, 0.2}, + {0.1, 0.2}, + } + ) + + sr := SearchResult{ + sch: entity.NewSchema(). + WithField(entity.NewField().WithName("id").WithIsPrimaryKey(true).WithDataType(entity.FieldTypeInt64)). + WithField(entity.NewField().WithName("vector").WithDim(2).WithDataType(entity.FieldTypeFloatVector)), + IDs: entity.NewColumnInt64("id", idData), + Fields: ResultSet([]entity.Column{ + entity.NewColumnFloatVector("vector", 2, vectorData), + }), + } + err := sr.Unmarshal([]MyData{}) + s.Error(err) + + receiver := []MyData{} + err = sr.Unmarshal(&receiver) + s.Error(err) + + var ptrReceiver []*MyData + err = sr.Unmarshal(&ptrReceiver) + s.NoError(err) + + for idx, row := range ptrReceiver { + s.Equal(row.A, idData[idx]) + s.Equal(row.V, vectorData[idx]) + } + + var otherReceiver []*OtherData + err = sr.Unmarshal(&otherReceiver) + s.Error(err) +} + +func TestResults(t *testing.T) { + suite.Run(t, new(ResultSetSuite)) +}