Skip to content

Commit

Permalink
feat: Support unmarshal resultset into orm receiver
Browse files Browse the repository at this point in the history
Related to milvus-io#800

Signed-off-by: Congqi Xia <[email protected]>
  • Loading branch information
congqixia committed Sep 27, 2024
1 parent 6eef344 commit 864e395
Show file tree
Hide file tree
Showing 3 changed files with 261 additions and 0 deletions.
1 change: 1 addition & 0 deletions client/data.go
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ func (c *GrpcClient) handleSearchResult(schema *entity.Schema, outputFields []st
for i := 0; i < int(results.GetNumQueries()); i++ {
rc := int(results.GetTopks()[i]) // result entry count for current query
entry := SearchResult{
sch: schema,
ResultCount: rc,
Scores: results.GetScores()[offset : offset+rc],
}
Expand Down
151 changes: 151 additions & 0 deletions client/results.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
package client

import (
"go/ast"
"reflect"

"github.com/cockroachdb/errors"
"github.com/milvus-io/milvus-sdk-go/v2/entity"
)

Expand All @@ -9,6 +13,9 @@ import (
// Fields contains the data of `outputFieleds` specified or all columns if non
// Scores is actually the distance between the vector current record contains and the search target vector
type SearchResult struct {
// internal schema for unmarshaling
sch *entity.Schema

ResultCount int // the returning entry count
GroupByValue entity.Column
IDs entity.Column // auto generated id, can be mapped to the columns from `Insert` API
Expand Down Expand Up @@ -44,6 +51,66 @@ func (sr *SearchResult) Slice(start, end int) *SearchResult {
return result
}

func (sr *SearchResult) Unmarshal(receiver interface{}) (err error) {
err = sr.Fields.Unmarshal(receiver)
if err != nil {
return err
}
return sr.fillPKEntry(receiver)
}

func (sr *SearchResult) fillPKEntry(receiver interface{}) (err error) {
defer func() {
if x := recover(); x != nil {
err = errors.Newf("failed to unmarshal result set: %v", x)
}
}()
rr := reflect.ValueOf(receiver)

if rr.Kind() == reflect.Ptr {
if rr.IsNil() && rr.CanAddr() {
rr.Set(reflect.New(rr.Type().Elem()))
}
rr = rr.Elem()
}

rt := rr.Type()
rv := rr

switch rt.Kind() {
case reflect.Slice:
pkField := sr.sch.PKField()

et := rt.Elem()
for et.Kind() == reflect.Ptr {
et = et.Elem()
}

candidates := parseCandidates(et)
candi, ok := candidates[pkField.Name]
if !ok {
// pk field not found in struct, skip
return nil
}
for i := 0; i < sr.IDs.Len(); i++ {
row := rv.Index(i)
for row.Kind() == reflect.Ptr {
row = row.Elem()
}

val, err := sr.IDs.Get(i)
if err != nil {
return err
}
row.Field(candi).Set(reflect.ValueOf(val))
}
rr.Set(rv)
default:
return errors.Newf("receiver need to be slice or array but get %v", rt.Kind())
}
return nil
}

// ResultSet is an alias type for column slice.
type ResultSet []entity.Column

Expand Down Expand Up @@ -71,3 +138,87 @@ func (rs ResultSet) GetColumn(fieldName string) entity.Column {
}
return nil
}

func (rs ResultSet) Unmarshal(receiver interface{}) (err error) {
defer func() {
if x := recover(); x != nil {
err = errors.Newf("failed to unmarshal result set: %v", x)
}
}()
rr := reflect.ValueOf(receiver)

if rr.Kind() == reflect.Ptr {
if rr.IsNil() && rr.CanAddr() {
rr.Set(reflect.New(rr.Type().Elem()))
}
rr = rr.Elem()
}

rt := rr.Type()
rv := rr

switch rt.Kind() {
// TODO maybe support Array and just fill data
// case reflect.Array:
case reflect.Slice:
et := rt.Elem()
if et.Kind() != reflect.Ptr {
return errors.Newf("receiver must be slice of pointers but get: %v", et.Kind())
}
for et.Kind() == reflect.Ptr {
et = et.Elem()
}
for i := 0; i < rs.Len(); i++ {
data := reflect.New(et)
err := rs.fillData(data.Elem(), et, i)
if err != nil {
return err
}
rv = reflect.Append(rv, data)
}
rr.Set(rv)
default:
return errors.Newf("receiver need to be slice or array but get %v", rt.Kind())
}
return nil
}

func parseCandidates(dataType reflect.Type) map[string]int {
result := make(map[string]int)
for i := 0; i < dataType.NumField(); i++ {
f := dataType.Field(i)
// ignore anonymous field for now
if f.Anonymous || !ast.IsExported(f.Name) {
continue
}

name := f.Name
tag := f.Tag.Get(entity.MilvusTag)
tagSettings := entity.ParseTagSetting(tag, entity.MilvusTagSep)
if tagName, has := tagSettings[entity.MilvusTagName]; has {
name = tagName
}

result[name] = i
}
return result
}

func (rs ResultSet) fillData(data reflect.Value, dataType reflect.Type, idx int) error {
m := parseCandidates(dataType)
for i := 0; i < len(rs); i++ {
name := rs[i].Name()
fidx, ok := m[name]
if !ok {
// maybe return error
continue
}
val, err := rs[i].Get(idx)
if err != nil {
return err
}
// TODO check datatype
data.Field(fidx).Set(reflect.ValueOf(val))
}
return nil
}
109 changes: 109 additions & 0 deletions client/results_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
package client

import (
"testing"

"github.com/milvus-io/milvus-sdk-go/v2/entity"
"github.com/stretchr/testify/suite"
)

type ResultSetSuite struct {
suite.Suite
}

func (s *ResultSetSuite) TestResultsetUnmarshal() {
type MyData struct {
A int64 `milvus:"name:id"`
V []float32 `milvus:"name:vector"`
}
type OtherData struct {
A string `milvus:"name:id"`
V []float32 `milvus:"name:vector"`
}

var (
idData = []int64{1, 2, 3}
vectorData = [][]float32{
{0.1, 0.2},
{0.1, 0.2},
{0.1, 0.2},
}
)

rs := ResultSet([]entity.Column{
entity.NewColumnInt64("id", idData),
entity.NewColumnFloatVector("vector", 2, vectorData),
})
err := rs.Unmarshal([]MyData{})
s.Error(err)

receiver := []MyData{}
err = rs.Unmarshal(&receiver)
s.Error(err)

var ptrReceiver []*MyData
err = rs.Unmarshal(&ptrReceiver)
s.NoError(err)

for idx, row := range ptrReceiver {
s.Equal(row.A, idData[idx])
s.Equal(row.V, vectorData[idx])
}

var otherReceiver []*OtherData
err = rs.Unmarshal(&otherReceiver)
s.Error(err)
}

func (s *ResultSetSuite) TestSearchResultUnmarshal() {
type MyData struct {
A int64 `milvus:"name:id"`
V []float32 `milvus:"name:vector"`
}
type OtherData struct {
A string `milvus:"name:id"`
V []float32 `milvus:"name:vector"`
}

var (
idData = []int64{1, 2, 3}
vectorData = [][]float32{
{0.1, 0.2},
{0.1, 0.2},
{0.1, 0.2},
}
)

sr := SearchResult{
sch: entity.NewSchema().
WithField(entity.NewField().WithName("id").WithIsPrimaryKey(true).WithDataType(entity.FieldTypeInt64)).
WithField(entity.NewField().WithName("vector").WithDim(2).WithDataType(entity.FieldTypeFloatVector)),
IDs: entity.NewColumnInt64("id", idData),
Fields: ResultSet([]entity.Column{
entity.NewColumnFloatVector("vector", 2, vectorData),
}),
}
err := sr.Unmarshal([]MyData{})
s.Error(err)

receiver := []MyData{}
err = sr.Unmarshal(&receiver)
s.Error(err)

var ptrReceiver []*MyData
err = sr.Unmarshal(&ptrReceiver)
s.NoError(err)

for idx, row := range ptrReceiver {
s.Equal(row.A, idData[idx])
s.Equal(row.V, vectorData[idx])
}

var otherReceiver []*OtherData
err = sr.Unmarshal(&otherReceiver)
s.Error(err)
}

func TestResults(t *testing.T) {
suite.Run(t, new(ResultSetSuite))
}

0 comments on commit 864e395

Please sign in to comment.