Skip to content

Commit

Permalink
Support BM25
Browse files Browse the repository at this point in the history
Signed-off-by: wayblink <[email protected]>
  • Loading branch information
wayblink committed Nov 29, 2024
1 parent c682188 commit d1cbf15
Show file tree
Hide file tree
Showing 14 changed files with 1,140 additions and 296 deletions.
25 changes: 22 additions & 3 deletions core/backup_impl_create_backup.go
Original file line number Diff line number Diff line change
Expand Up @@ -235,6 +235,13 @@ func (b *BackupContext) backupCollectionPrepare(ctx context.Context, backupInfo
log.Error("fail in DescribeCollection", zap.Error(err))
return err
}
// todo temporary solution, migrate to sdk V2
completeCollectionV2, err := b.getMilvusClient().DescribeCollectionV2(b.ctx, collection.db, collection.collectionName)
if err != nil {
log.Error("fail in DescribeCollection v2", zap.Error(err))
return err
}

fields := make([]*backuppb.FieldSchema, 0)
for _, field := range completeCollection.Schema.Fields {
fieldBak := &backuppb.FieldSchema{
Expand All @@ -261,22 +268,33 @@ func (b *BackupContext) backupCollectionPrepare(ctx context.Context, backupInfo
}
fields = append(fields, fieldBak)
}

functions := make([]*backuppb.FunctionSchema, 0)
for _, function := range completeCollectionV2.Schema.Functions {
functionBak := &backuppb.FunctionSchema{
Name: function.Name,
Description: function.Description,
Type: backuppb.FunctionType(function.Type),
InputFieldNames: function.InputFieldNames,
OutputFieldNames: function.OutputFieldNames,
Params: utils.MapToKVPair(function.Params),
}
functions = append(functions, functionBak)
}
schema := &backuppb.CollectionSchema{
Name: completeCollection.Schema.CollectionName,
Description: completeCollection.Schema.Description,
AutoID: completeCollection.Schema.AutoID,
Fields: fields,
EnableDynamicField: completeCollection.Schema.EnableDynamicField,
Functions: functions,
}

indexInfos := make([]*backuppb.IndexInfo, 0)
indexDict := make(map[string]*backuppb.IndexInfo, 0)
log.Info("try to get index",
zap.String("collection_name", completeCollection.Name))
for _, field := range completeCollection.Schema.Fields {
//if field.DataType != entity.FieldTypeBinaryVector && field.DataType != entity.FieldTypeFloatVector {
// continue
//}
fieldIndex, err := b.getMilvusClient().DescribeIndex(b.ctx, collection.db, completeCollection.Name, field.Name)
if err != nil {
if strings.Contains(err.Error(), "index not found") ||
Expand Down Expand Up @@ -989,6 +1007,7 @@ func (b *BackupContext) fillSegmentBackupInfo(ctx context.Context, segmentBackup
}

deltaLogPath := fmt.Sprintf("%s%s/%v/%v/%v/", rootPath, "delta_log", segmentBackupInfo.GetCollectionId(), segmentBackupInfo.GetPartitionId(), segmentBackupInfo.GetSegmentId())
log.Debug("deltaPath", zap.String("bucket", b.milvusBucketName), zap.String("deltaPath", deltaLogPath))
deltaFieldsLogDir, _, _ := b.getMilvusStorageClient().ListWithPrefix(ctx, b.milvusBucketName, deltaLogPath, false)
deltaLogs := make([]*backuppb.FieldBinlog, 0)
for _, deltaFieldLogDir := range deltaFieldsLogDir {
Expand Down
45 changes: 22 additions & 23 deletions core/backup_impl_restore_backup.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package core
import (
"context"
"fmt"

"path"
"strings"
"time"
Expand All @@ -13,6 +14,8 @@ import (
"github.com/milvus-io/milvus-proto/go-api/v2/schemapb"
gomilvus "github.com/milvus-io/milvus-sdk-go/v2/client"
"github.com/milvus-io/milvus-sdk-go/v2/entity"
entityV2 "github.com/milvus-io/milvus/client/v2/entity"
indexV2 "github.com/milvus-io/milvus/client/v2/index"
"github.com/samber/lo"
"go.uber.org/atomic"
"go.uber.org/zap"
Expand Down Expand Up @@ -419,22 +422,22 @@ func (b *BackupContext) executeRestoreCollectionTask(ctx context.Context, backup
zap.String("backupBucketName", backupBucketName),
zap.String("backupPath", backupPath))
// create collection
fields := make([]*entity.Field, 0)
fields := make([]*entityV2.Field, 0)
hasPartitionKey := false
for _, field := range task.GetCollBackup().GetSchema().GetFields() {
fieldRestore := &entity.Field{
fieldRestore := &entityV2.Field{
ID: field.GetFieldID(),
Name: field.GetName(),
PrimaryKey: field.GetIsPrimaryKey(),
AutoID: field.GetAutoID(),
Description: field.GetDescription(),
DataType: entity.FieldType(field.GetDataType()),
DataType: entityV2.FieldType(field.GetDataType()),
TypeParams: utils.KvPairsMap(field.GetTypeParams()),
IndexParams: utils.KvPairsMap(field.GetIndexParams()),
IsDynamic: field.GetIsDynamic(),
IsPartitionKey: field.GetIsPartitionKey(),
Nullable: field.GetNullable(),
ElementType: entity.FieldType(field.GetElementType()),
ElementType: entityV2.FieldType(field.GetElementType()),
}
if field.DefaultValueProto != "" {
defaultValue := &schemapb.ValueField{}
Expand All @@ -453,7 +456,7 @@ func (b *BackupContext) executeRestoreCollectionTask(ctx context.Context, backup

log.Info("collection schema", zap.Any("fields", fields))

collectionSchema := &entity.Schema{
collectionSchema := &entityV2.Schema{
CollectionName: targetCollectionName,
Description: task.GetCollBackup().GetSchema().GetDescription(),
AutoID: task.GetCollBackup().GetSchema().GetAutoID(),
Expand Down Expand Up @@ -491,22 +494,12 @@ func (b *BackupContext) executeRestoreCollectionTask(ctx context.Context, backup
log.Info("overwrite shardNum by request parameter", zap.Int32("oldShardNum", task.GetCollBackup().GetShardsNum()), zap.Int32("newShardNum", shardNum))

}

if hasPartitionKey {
partitionNum := len(task.GetCollBackup().GetPartitionBackups())
return b.getMilvusClient().CreateCollection(
ctx,
targetDBName,
collectionSchema,
shardNum,
gomilvus.WithConsistencyLevel(entity.ConsistencyLevel(task.GetCollBackup().GetConsistencyLevel())),
gomilvus.WithPartitionNum(int64(partitionNum)))
return b.getMilvusClient().CreateCollectionV2(ctx, targetDBName, collectionSchema, shardNum, entityV2.ConsistencyLevel(task.GetCollBackup().GetConsistencyLevel()), int64(partitionNum))
}
return b.getMilvusClient().CreateCollection(
ctx,
targetDBName,
collectionSchema,
shardNum,
gomilvus.WithConsistencyLevel(entity.ConsistencyLevel(task.GetCollBackup().GetConsistencyLevel())))
return b.getMilvusClient().CreateCollectionV2(ctx, targetDBName, collectionSchema, shardNum, entityV2.ConsistencyLevel(task.GetCollBackup().GetConsistencyLevel()), 0)
}, retry.Attempts(10), retry.Sleep(1*time.Second))
if err != nil {
errorMsg := fmt.Sprintf("fail to create collection, targetCollectionName: %s err: %s", targetCollectionName, err)
Expand Down Expand Up @@ -560,7 +553,7 @@ func (b *BackupContext) executeRestoreCollectionTask(ctx context.Context, backup
}
indexes := task.GetCollBackup().GetIndexInfos()
for _, index := range indexes {
var idx entity.Index
var idx indexV2.Index
log.Info("source index",
zap.String("indexName", index.GetIndexName()),
zap.String("indexType", index.GetIndexType()),
Expand All @@ -571,7 +564,10 @@ func (b *BackupContext) executeRestoreCollectionTask(ctx context.Context, backup
// auto index only support index_type and metric_type in params
params["index_type"] = "AUTOINDEX"
params["metric_type"] = index.GetParams()["metric_type"]
idx = entity.NewGenericIndex(index.GetIndexName(), entity.AUTOINDEX, params)
// v1
//idx = entity.NewGenericIndex(index.GetIndexName(), entity.AUTOINDEX, params)
// v2
idx = indexV2.NewGenericIndex(index.GetIndexName(), params)
} else {
log.Info("not auto index")
indexType := index.GetIndexType()
Expand All @@ -582,9 +578,12 @@ func (b *BackupContext) executeRestoreCollectionTask(ctx context.Context, backup
if params["index_type"] == "marisa-trie" {
params["index_type"] = "Trie"
}
idx = entity.NewGenericIndex(index.GetIndexName(), entity.IndexType(indexType), index.GetParams())
// v1
//idx = entityV2.NewGenericIndex(index.GetIndexName(), entity.IndexType(indexType), index.GetParams())
// v2
idx = indexV2.NewGenericIndex(index.GetIndexName(), params)
}
err := b.getMilvusClient().CreateIndex(ctx, targetDBName, targetCollectionName, index.GetFieldName(), idx, true)
err := b.getMilvusClient().CreateIndexV2(ctx, targetDBName, targetCollectionName, index.GetFieldName(), idx, true)
if err != nil {
log.Warn("Fail to restore index", zap.Error(err))
return task, err
Expand Down Expand Up @@ -663,14 +662,14 @@ func (b *BackupContext) executeRestoreCollectionTask(ctx context.Context, backup
return task, err
}
if !exist {
log.Info("create partition", zap.String("partitionName", partitionBackup.GetPartitionName()))
err = retry.Do(ctx, func() error {
return b.getMilvusClient().CreatePartition(ctx, targetDBName, targetCollectionName, partitionBackup.GetPartitionName())
}, retry.Attempts(10), retry.Sleep(1*time.Second))
if err != nil {
log.Error("fail to create partition", zap.Error(err))
return task, err
}
log.Info("create partition", zap.String("partitionName", partitionBackup.GetPartitionName()))
}

type restoreGroup struct {
Expand Down
26 changes: 0 additions & 26 deletions core/milvus_sdk_wrapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@ import (
gomilvus "github.com/milvus-io/milvus-sdk-go/v2/client"
"github.com/milvus-io/milvus-sdk-go/v2/entity"

entityV2 "github.com/milvus-io/milvus/client/v2/entity"
milvusClientV2 "github.com/milvus-io/milvus/client/v2/milvusclient"

"github.com/zilliztech/milvus-backup/core/paramtable"
Expand Down Expand Up @@ -153,31 +152,6 @@ func (m *MilvusClient) ListCollections(ctx context.Context, db string) ([]*entit
return m.client.ListCollections(ctx)
}

func (m *MilvusClient) ListCollectionsV2(ctx context.Context, db string) ([]*entityV2.Collection, error) {
m.mu.Lock()
defer m.mu.Unlock()
err := m.milvusClientV2.UsingDatabase(ctx, milvusClientV2.NewUsingDatabaseOption(db))
if err != nil {
return nil, err
}

collections, err := m.milvusClientV2.ListCollections(ctx, milvusClientV2.NewListCollectionOption())
if err != nil {
return nil, err
}

collectionEntities := make([]*entityV2.Collection, 0)
for _, collection := range collections {
coll, err := m.milvusClientV2.DescribeCollection(ctx, milvusClientV2.NewDescribeCollectionOption(collection))
if err != nil {
return nil, err
}
collectionEntities = append(collectionEntities, coll)
}

return collectionEntities, nil
}

func (m *MilvusClient) HasCollection(ctx context.Context, db, collName string) (bool, error) {
m.mu.Lock()
defer m.mu.Unlock()
Expand Down
84 changes: 84 additions & 0 deletions core/milvus_sdk_wrapper_v2.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
package core

import (
"context"
"time"

entityV2 "github.com/milvus-io/milvus/client/v2/entity"
indexV2 "github.com/milvus-io/milvus/client/v2/index"
milvusClientV2 "github.com/milvus-io/milvus/client/v2/milvusclient"
"github.com/zilliztech/milvus-backup/internal/util/retry"
)

func (m *MilvusClient) ListCollectionsV2(ctx context.Context, db string) ([]*entityV2.Collection, error) {
m.mu.Lock()
defer m.mu.Unlock()
err := m.milvusClientV2.UsingDatabase(ctx, milvusClientV2.NewUsingDatabaseOption(db))
if err != nil {
return nil, err
}

collections, err := m.milvusClientV2.ListCollections(ctx, milvusClientV2.NewListCollectionOption())
if err != nil {
return nil, err
}

collectionEntities := make([]*entityV2.Collection, 0)
for _, collection := range collections {
coll, err := m.milvusClientV2.DescribeCollection(ctx, milvusClientV2.NewDescribeCollectionOption(collection))
if err != nil {
return nil, err
}
collectionEntities = append(collectionEntities, coll)
}

return collectionEntities, nil
}

func (m *MilvusClient) DescribeCollectionV2(ctx context.Context, db, collName string) (*entityV2.Collection, error) {
m.mu.Lock()
defer m.mu.Unlock()
err := m.milvusClientV2.UsingDatabase(ctx, milvusClientV2.NewUsingDatabaseOption(db))
if err != nil {
return nil, err
}
return m.milvusClientV2.DescribeCollection(ctx, milvusClientV2.NewDescribeCollectionOption(collName))
}

func (m *MilvusClient) CreateCollectionV2(ctx context.Context, db string, schema *entityV2.Schema, shardsNum int32, cl entityV2.ConsistencyLevel, partitionNum int64) error {
m.mu.Lock()
defer m.mu.Unlock()
err := m.milvusClientV2.UsingDatabase(ctx, milvusClientV2.NewUsingDatabaseOption(db))
if err != nil {
return err
}
// add retry to make sure won't be block by rate control
return retry.Do(ctx, func() error {
option := milvusClientV2.NewCreateCollectionOption(schema.CollectionName, &entityV2.Schema{
CollectionName: schema.CollectionName,
Description: schema.Description,
AutoID: schema.AutoID,
Fields: schema.Fields,
EnableDynamicField: schema.EnableDynamicField,
Functions: schema.Functions,
}).WithShardNum(shardsNum).WithConsistencyLevel(cl)
if partitionNum != 0 {
option = option.WithPartitionNum(partitionNum)
}
return m.milvusClientV2.CreateCollection(ctx, option)
}, retry.Sleep(2*time.Second), retry.Attempts(10))
}

func (m *MilvusClient) CreateIndexV2(ctx context.Context, db, collName string, fieldName string, idx indexV2.Index, async bool) error {
m.mu.Lock()
defer m.mu.Unlock()
err := m.milvusClientV2.UsingDatabase(ctx, milvusClientV2.NewUsingDatabaseOption(db))
if err != nil {
return err
}
_, err = m.milvusClientV2.CreateIndex(ctx, milvusClientV2.NewCreateIndexOption(collName, fieldName, idx))
if err != nil {
return err
}
return nil
}
20 changes: 20 additions & 0 deletions core/proto/backup.proto
Original file line number Diff line number Diff line change
Expand Up @@ -464,6 +464,26 @@ message CollectionSchema {
bool autoID = 3; // deprecated later, keep compatible with c++ part now
repeated FieldSchema fields = 4;
bool enable_dynamic_field = 5; // mark whether this table has the dynamic field function enabled.
repeated KeyValuePair properties = 6;
repeated FunctionSchema functions = 7;
}

enum FunctionType{
Unknown =0;
BM25 =1;
TextEmbedding =2;
}

message FunctionSchema {
string name = 1;
int64 id =2;
string description = 3;
FunctionType type = 4;
repeated string input_field_names = 5;
repeated int64 input_field_ids = 6;
repeated string output_field_names = 7;
repeated int64 output_field_ids = 8;
repeated KeyValuePair params = 9;
}

message CheckRequest {
Expand Down
Loading

0 comments on commit d1cbf15

Please sign in to comment.