diff --git a/internal/core/src/segcore/collection_c.cpp b/internal/core/src/segcore/collection_c.cpp index fde8119e086fe..39ebf29ba70a1 100644 --- a/internal/core/src/segcore/collection_c.cpp +++ b/internal/core/src/segcore/collection_c.cpp @@ -9,6 +9,7 @@ // is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express // or implied. See the License for the specific language governing permissions and limitations under the License +#include "common/type_c.h" #ifdef __linux__ #include #endif @@ -17,29 +18,41 @@ #include "segcore/collection_c.h" #include "segcore/Collection.h" -CCollection -NewCollection(const void* schema_proto_blob, const int64_t length) { - auto collection = std::make_unique( - schema_proto_blob, length); - return (void*)collection.release(); +CStatus +NewCollection(const void* schema_proto_blob, + const int64_t length, + CCollection* newCollection) { + try { + auto collection = std::make_unique( + schema_proto_blob, length); + *newCollection = collection.release(); + return milvus::SuccessCStatus(); + } catch (std::exception& e) { + return milvus::FailureCStatus(&e); + } } -void +CStatus SetIndexMeta(CCollection collection, const void* proto_blob, const int64_t length) { - auto col = (milvus::segcore::Collection*)collection; - col->parseIndexMeta(proto_blob, length); + try { + auto col = static_cast(collection); + col->parseIndexMeta(proto_blob, length); + return milvus::SuccessCStatus(); + } catch (std::exception& e) { + return milvus::FailureCStatus(&e); + } } void DeleteCollection(CCollection collection) { - auto col = (milvus::segcore::Collection*)collection; + auto col = static_cast(collection); delete col; } const char* GetCollectionName(CCollection collection) { - auto col = (milvus::segcore::Collection*)collection; + auto col = static_cast(collection); return strdup(col->get_collection_name().data()); } diff --git a/internal/core/src/segcore/collection_c.h b/internal/core/src/segcore/collection_c.h index b5c629754e23b..3cd20df10f6bb 100644 --- a/internal/core/src/segcore/collection_c.h +++ b/internal/core/src/segcore/collection_c.h @@ -12,6 +12,7 @@ #pragma once #include +#include "common/type_c.h" #ifdef __cplusplus extern "C" { @@ -19,10 +20,12 @@ extern "C" { typedef void* CCollection; -CCollection -NewCollection(const void* schema_proto_blob, const int64_t length); +CStatus +NewCollection(const void* schema_proto_blob, + const int64_t length, + CCollection* collection); -void +CStatus SetIndexMeta(CCollection collection, const void* proto_blob, const int64_t length); diff --git a/internal/querynodev2/delegator/delegator_data_test.go b/internal/querynodev2/delegator/delegator_data_test.go index 552d6a93a6930..695ac2bcc58b3 100644 --- a/internal/querynodev2/delegator/delegator_data_test.go +++ b/internal/querynodev2/delegator/delegator_data_test.go @@ -1518,9 +1518,10 @@ func (s *DelegatorDataSuite) TestLevel0Deletions() { s.Require().NoError(err) schema := mock_segcore.GenTestCollectionSchema("test_stop", schemapb.DataType_Int64, true) - collection := segments.NewCollection(1, schema, nil, &querypb.LoadMetaInfo{ + collection, err := segments.NewCollection(1, schema, nil, &querypb.LoadMetaInfo{ LoadType: querypb.LoadType_LoadCollection, }) + s.NoError(err) l0, _ := segments.NewL0Segment(collection, segments.SegmentTypeSealed, 1, &querypb.SegmentLoadInfo{ CollectionID: 1, diff --git a/internal/querynodev2/local_worker_test.go b/internal/querynodev2/local_worker_test.go index af791f6ccbe07..47bcfdcb25d89 100644 --- a/internal/querynodev2/local_worker_test.go +++ b/internal/querynodev2/local_worker_test.go @@ -96,9 +96,10 @@ func (suite *LocalWorkerTestSuite) BeforeTest(suiteName, testName string) { suite.schema = mock_segcore.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, true) suite.indexMeta = mock_segcore.GenTestIndexMeta(suite.collectionID, suite.schema) - collection := segments.NewCollection(suite.collectionID, suite.schema, suite.indexMeta, &querypb.LoadMetaInfo{ + collection, err := segments.NewCollection(suite.collectionID, suite.schema, suite.indexMeta, &querypb.LoadMetaInfo{ LoadType: querypb.LoadType_LoadCollection, }) + suite.NoError(err) loadMata := &querypb.LoadMetaInfo{ LoadType: querypb.LoadType_LoadCollection, CollectionID: suite.collectionID, diff --git a/internal/querynodev2/pipeline/insert_node_test.go b/internal/querynodev2/pipeline/insert_node_test.go index dd98ea2fe8a6a..6985610026859 100644 --- a/internal/querynodev2/pipeline/insert_node_test.go +++ b/internal/querynodev2/pipeline/insert_node_test.go @@ -62,9 +62,10 @@ func (suite *InsertNodeSuite) TestBasic() { schema := mock_segcore.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, true) in := suite.buildInsertNodeMsg(schema) - collection := segments.NewCollection(suite.collectionID, schema, mock_segcore.GenTestIndexMeta(suite.collectionID, schema), &querypb.LoadMetaInfo{ + collection, err := segments.NewCollection(suite.collectionID, schema, mock_segcore.GenTestIndexMeta(suite.collectionID, schema), &querypb.LoadMetaInfo{ LoadType: querypb.LoadType_LoadCollection, }) + suite.NoError(err) collection.AddPartition(suite.partitionID) // init mock @@ -98,9 +99,10 @@ func (suite *InsertNodeSuite) TestDataTypeNotSupported() { schema := mock_segcore.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, true) in := suite.buildInsertNodeMsg(schema) - collection := segments.NewCollection(suite.collectionID, schema, mock_segcore.GenTestIndexMeta(suite.collectionID, schema), &querypb.LoadMetaInfo{ + collection, err := segments.NewCollection(suite.collectionID, schema, mock_segcore.GenTestIndexMeta(suite.collectionID, schema), &querypb.LoadMetaInfo{ LoadType: querypb.LoadType_LoadCollection, }) + suite.NoError(err) collection.AddPartition(suite.partitionID) // init mock diff --git a/internal/querynodev2/pipeline/pipeline_test.go b/internal/querynodev2/pipeline/pipeline_test.go index 4ceb237989e55..1c5e2ee039940 100644 --- a/internal/querynodev2/pipeline/pipeline_test.go +++ b/internal/querynodev2/pipeline/pipeline_test.go @@ -102,7 +102,7 @@ func (suite *PipelineTestSuite) TestBasic() { // init mock // mock collection manager schema := mock_segcore.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, true) - collection := segments.NewCollection(suite.collectionID, schema, mock_segcore.GenTestIndexMeta(suite.collectionID, schema), &querypb.LoadMetaInfo{ + collection, err := segments.NewCollection(suite.collectionID, schema, mock_segcore.GenTestIndexMeta(suite.collectionID, schema), &querypb.LoadMetaInfo{ LoadType: querypb.LoadType_LoadCollection, DbProperties: []*commonpb.KeyValuePair{ { @@ -111,6 +111,7 @@ func (suite *PipelineTestSuite) TestBasic() { }, }, }) + suite.Require().NoError(err) suite.collectionManager.EXPECT().Get(suite.collectionID).Return(collection) // mock mq factory diff --git a/internal/querynodev2/segments/collection.go b/internal/querynodev2/segments/collection.go index 207f15e2dc8a3..c10cbde9e542b 100644 --- a/internal/querynodev2/segments/collection.go +++ b/internal/querynodev2/segments/collection.go @@ -41,7 +41,7 @@ type CollectionManager interface { List() []int64 ListWithName() map[int64]string Get(collectionID int64) *Collection - PutOrRef(collectionID int64, schema *schemapb.CollectionSchema, meta *segcorepb.CollectionIndexMeta, loadMeta *querypb.LoadMetaInfo) + PutOrRef(collectionID int64, schema *schemapb.CollectionSchema, meta *segcorepb.CollectionIndexMeta, loadMeta *querypb.LoadMetaInfo) error Ref(collectionID int64, count uint32) bool // unref the collection, // returns true if the collection ref count goes 0, or the collection not exists, @@ -84,7 +84,7 @@ func (m *collectionManager) Get(collectionID int64) *Collection { return m.collections[collectionID] } -func (m *collectionManager) PutOrRef(collectionID int64, schema *schemapb.CollectionSchema, meta *segcorepb.CollectionIndexMeta, loadMeta *querypb.LoadMetaInfo) { +func (m *collectionManager) PutOrRef(collectionID int64, schema *schemapb.CollectionSchema, meta *segcorepb.CollectionIndexMeta, loadMeta *querypb.LoadMetaInfo) error { m.mut.Lock() defer m.mut.Unlock() @@ -92,14 +92,18 @@ func (m *collectionManager) PutOrRef(collectionID int64, schema *schemapb.Collec // the schema may be changed even the collection is loaded collection.schema.Store(schema) collection.Ref(1) - return + return nil } log.Info("put new collection", zap.Int64("collectionID", collectionID), zap.Any("schema", schema)) - collection := NewCollection(collectionID, schema, meta, loadMeta) + collection, err := NewCollection(collectionID, schema, meta, loadMeta) + if err != nil { + return err + } collection.Ref(1) m.collections[collectionID] = collection m.updateMetric() + return nil } func (m *collectionManager) updateMetric() { @@ -245,7 +249,7 @@ func (c *Collection) Unref(count uint32) uint32 { } // newCollection returns a new Collection -func NewCollection(collectionID int64, schema *schemapb.CollectionSchema, indexMeta *segcorepb.CollectionIndexMeta, loadMetaInfo *querypb.LoadMetaInfo) *Collection { +func NewCollection(collectionID int64, schema *schemapb.CollectionSchema, indexMeta *segcorepb.CollectionIndexMeta, loadMetaInfo *querypb.LoadMetaInfo) (*Collection, error) { /* CCollection NewCollection(const char* schema_proto_blob); @@ -281,7 +285,7 @@ func NewCollection(collectionID int64, schema *schemapb.CollectionSchema, indexM ccollection, err := segcore.CreateCCollection(req) if err != nil { log.Warn("create collection failed", zap.Error(err)) - return nil + return nil, err } coll := &Collection{ ccollection: ccollection, @@ -300,7 +304,7 @@ func NewCollection(collectionID int64, schema *schemapb.CollectionSchema, indexM } coll.schema.Store(schema) - return coll + return coll, nil } // Only for test diff --git a/internal/querynodev2/segments/manager_test.go b/internal/querynodev2/segments/manager_test.go index caf6f5a0ea887..5904f10bd8e74 100644 --- a/internal/querynodev2/segments/manager_test.go +++ b/internal/querynodev2/segments/manager_test.go @@ -52,11 +52,13 @@ func (s *ManagerSuite) SetupTest() { for i, id := range s.segmentIDs { schema := mock_segcore.GenTestCollectionSchema("manager-suite", schemapb.DataType_Int64, true) + collection, err := NewCollection(s.collectionIDs[i], schema, mock_segcore.GenTestIndexMeta(s.collectionIDs[i], schema), &querypb.LoadMetaInfo{ + LoadType: querypb.LoadType_LoadCollection, + }) + s.Require().NoError(err) segment, err := NewSegment( context.Background(), - NewCollection(s.collectionIDs[i], schema, mock_segcore.GenTestIndexMeta(s.collectionIDs[i], schema), &querypb.LoadMetaInfo{ - LoadType: querypb.LoadType_LoadCollection, - }), + collection, s.types[i], 0, &querypb.SegmentLoadInfo{ diff --git a/internal/querynodev2/segments/mock_collection_manager.go b/internal/querynodev2/segments/mock_collection_manager.go index 1e512316e3eb9..22e1aeb9df496 100644 --- a/internal/querynodev2/segments/mock_collection_manager.go +++ b/internal/querynodev2/segments/mock_collection_manager.go @@ -166,8 +166,21 @@ func (_c *MockCollectionManager_ListWithName_Call) RunAndReturn(run func() map[i } // PutOrRef provides a mock function with given fields: collectionID, schema, meta, loadMeta -func (_m *MockCollectionManager) PutOrRef(collectionID int64, schema *schemapb.CollectionSchema, meta *segcorepb.CollectionIndexMeta, loadMeta *querypb.LoadMetaInfo) { - _m.Called(collectionID, schema, meta, loadMeta) +func (_m *MockCollectionManager) PutOrRef(collectionID int64, schema *schemapb.CollectionSchema, meta *segcorepb.CollectionIndexMeta, loadMeta *querypb.LoadMetaInfo) error { + ret := _m.Called(collectionID, schema, meta, loadMeta) + + if len(ret) == 0 { + panic("no return value specified for PutOrRef") + } + + var r0 error + if rf, ok := ret.Get(0).(func(int64, *schemapb.CollectionSchema, *segcorepb.CollectionIndexMeta, *querypb.LoadMetaInfo) error); ok { + r0 = rf(collectionID, schema, meta, loadMeta) + } else { + r0 = ret.Error(0) + } + + return r0 } // MockCollectionManager_PutOrRef_Call is a *mock.Call that shadows Run/Return methods with type explicit version for method 'PutOrRef' @@ -191,12 +204,12 @@ func (_c *MockCollectionManager_PutOrRef_Call) Run(run func(collectionID int64, return _c } -func (_c *MockCollectionManager_PutOrRef_Call) Return() *MockCollectionManager_PutOrRef_Call { - _c.Call.Return() +func (_c *MockCollectionManager_PutOrRef_Call) Return(_a0 error) *MockCollectionManager_PutOrRef_Call { + _c.Call.Return(_a0) return _c } -func (_c *MockCollectionManager_PutOrRef_Call) RunAndReturn(run func(int64, *schemapb.CollectionSchema, *segcorepb.CollectionIndexMeta, *querypb.LoadMetaInfo)) *MockCollectionManager_PutOrRef_Call { +func (_c *MockCollectionManager_PutOrRef_Call) RunAndReturn(run func(int64, *schemapb.CollectionSchema, *segcorepb.CollectionIndexMeta, *querypb.LoadMetaInfo) error) *MockCollectionManager_PutOrRef_Call { _c.Call.Return(run) return _c } diff --git a/internal/querynodev2/segments/segment_loader_test.go b/internal/querynodev2/segments/segment_loader_test.go index b3b0db5df0acd..c233c8c74eb20 100644 --- a/internal/querynodev2/segments/segment_loader_test.go +++ b/internal/querynodev2/segments/segment_loader_test.go @@ -811,7 +811,8 @@ func (suite *SegmentLoaderDetailSuite) SetupTest() { PartitionIDs: []int64{suite.partitionID}, } - collection := NewCollection(suite.collectionID, schema, indexMeta, loadMeta) + collection, err := NewCollection(suite.collectionID, schema, indexMeta, loadMeta) + suite.Require().NoError(err) suite.collectionManager.EXPECT().Get(suite.collectionID).Return(collection).Maybe() } diff --git a/internal/querynodev2/server_test.go b/internal/querynodev2/server_test.go index bd857d538bfc1..1448a337cd119 100644 --- a/internal/querynodev2/server_test.go +++ b/internal/querynodev2/server_test.go @@ -221,9 +221,10 @@ func (suite *QueryNodeSuite) TestStop() { suite.node.manager = segments.NewManager() schema := mock_segcore.GenTestCollectionSchema("test_stop", schemapb.DataType_Int64, true) - collection := segments.NewCollection(1, schema, nil, &querypb.LoadMetaInfo{ + collection, err := segments.NewCollection(1, schema, nil, &querypb.LoadMetaInfo{ LoadType: querypb.LoadType_LoadCollection, }) + suite.Require().NoError(err) segment, err := segments.NewSegment( context.Background(), collection, diff --git a/internal/querynodev2/services.go b/internal/querynodev2/services.go index 3f5af3f7d0745..c0b2ee91d05f6 100644 --- a/internal/querynodev2/services.go +++ b/internal/querynodev2/services.go @@ -238,8 +238,12 @@ func (node *QueryNode) WatchDmChannels(ctx context.Context, req *querypb.WatchDm return merr.Success(), nil } - node.manager.Collection.PutOrRef(req.GetCollectionID(), req.GetSchema(), + err := node.manager.Collection.PutOrRef(req.GetCollectionID(), req.GetSchema(), node.composeIndexMeta(ctx, req.GetIndexInfoList(), req.Schema), req.GetLoadMeta()) + if err != nil { + log.Warn("failed to ref collection", zap.Error(err)) + return merr.Status(err), nil + } defer func() { if !merr.Ok(status) { node.manager.Collection.Unref(req.GetCollectionID(), 1) @@ -474,8 +478,12 @@ func (node *QueryNode) LoadSegments(ctx context.Context, req *querypb.LoadSegmen return merr.Success(), nil } - node.manager.Collection.PutOrRef(req.GetCollectionID(), req.GetSchema(), + err := node.manager.Collection.PutOrRef(req.GetCollectionID(), req.GetSchema(), node.composeIndexMeta(ctx, req.GetIndexInfoList(), req.GetSchema()), req.GetLoadMeta()) + if err != nil { + log.Warn("failed to ref collection", zap.Error(err)) + return merr.Status(err), nil + } defer node.manager.Collection.Unref(req.GetCollectionID(), 1) if req.GetLoadScope() == querypb.LoadScope_Delta { diff --git a/internal/querynodev2/services_test.go b/internal/querynodev2/services_test.go index a28367ec54895..40c021612e421 100644 --- a/internal/querynodev2/services_test.go +++ b/internal/querynodev2/services_test.go @@ -380,6 +380,69 @@ func (suite *ServiceSuite) TestWatchDmChannelsVarchar() { suite.Equal(commonpb.ErrorCode_Success, status.ErrorCode) } +func (suite *ServiceSuite) TestWatchDmChannels_BadIndexMeta() { + ctx := context.Background() + + // data + schema := mock_segcore.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_Int64, false) + deltaLogs, err := mock_segcore.SaveDeltaLog(suite.collectionID, + suite.partitionIDs[0], + suite.flushedSegmentIDs[0], + suite.node.chunkManager, + ) + suite.NoError(err) + + req := &querypb.WatchDmChannelsRequest{ + Base: &commonpb.MsgBase{ + MsgType: commonpb.MsgType_WatchDmChannels, + MsgID: rand.Int63(), + TargetID: suite.node.session.ServerID, + }, + NodeID: suite.node.session.ServerID, + CollectionID: suite.collectionID, + PartitionIDs: suite.partitionIDs, + Infos: []*datapb.VchannelInfo{ + { + CollectionID: suite.collectionID, + ChannelName: suite.vchannel, + SeekPosition: suite.position, + FlushedSegmentIds: suite.flushedSegmentIDs, + DroppedSegmentIds: suite.droppedSegmentIDs, + LevelZeroSegmentIds: suite.levelZeroSegmentIDs, + }, + }, + SegmentInfos: map[int64]*datapb.SegmentInfo{ + suite.levelZeroSegmentIDs[0]: { + ID: suite.levelZeroSegmentIDs[0], + CollectionID: suite.collectionID, + PartitionID: suite.partitionIDs[0], + InsertChannel: suite.vchannel, + Deltalogs: deltaLogs, + Level: datapb.SegmentLevel_L0, + }, + }, + Schema: schema, + LoadMeta: &querypb.LoadMetaInfo{ + LoadType: querypb.LoadType_LoadCollection, + CollectionID: suite.collectionID, + PartitionIDs: suite.partitionIDs, + MetricType: defaultMetricType, + }, + IndexInfoList: []*indexpb.IndexInfo{{ + IndexName: "bad_index", + FieldID: 100, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "dup_key", Value: "val"}, + {Key: "dup_key", Value: "val"}, + }, + }}, + } + + // watchDmChannels + status, err := suite.node.WatchDmChannels(ctx, req) + suite.Error(merr.CheckRPCCall(status, err)) +} + func (suite *ServiceSuite) TestWatchDmChannels_Failed() { ctx := context.Background() @@ -658,6 +721,49 @@ func (suite *ServiceSuite) TestLoadSegments_VarChar() { } } +func (suite *ServiceSuite) TestLoadSegments_BadIndexMeta() { + ctx := context.Background() + suite.TestWatchDmChannelsVarchar() + // data + schema := mock_segcore.GenTestCollectionSchema(suite.collectionName, schemapb.DataType_VarChar, false) + loadMeta := &querypb.LoadMetaInfo{ + LoadType: querypb.LoadType_LoadCollection, + CollectionID: suite.collectionID, + PartitionIDs: suite.partitionIDs, + } + suite.node.manager.Collection = segments.NewCollectionManager() + // suite.node.manager.Collection.PutOrRef(suite.collectionID, schema, nil, loadMeta) + + infos := suite.genSegmentLoadInfos(schema, nil) + for _, info := range infos { + req := &querypb.LoadSegmentsRequest{ + Base: &commonpb.MsgBase{ + MsgID: rand.Int63(), + TargetID: suite.node.session.ServerID, + }, + CollectionID: suite.collectionID, + DstNodeID: suite.node.session.ServerID, + Infos: []*querypb.SegmentLoadInfo{info}, + Schema: schema, + DeltaPositions: []*msgpb.MsgPosition{{Timestamp: 20000}}, + NeedTransfer: true, + LoadMeta: loadMeta, + IndexInfoList: []*indexpb.IndexInfo{{ + IndexName: "bad_index", + FieldID: 100, + TypeParams: []*commonpb.KeyValuePair{ + {Key: "dup_key", Value: "val"}, + {Key: "dup_key", Value: "val"}, + }, + }}, + } + + // LoadSegment + status, err := suite.node.LoadSegments(ctx, req) + suite.Error(merr.CheckRPCCall(status, err)) + } +} + func (suite *ServiceSuite) TestLoadDeltaInt64() { ctx := context.Background() suite.TestLoadSegments_Int64() diff --git a/internal/util/segcore/collection.go b/internal/util/segcore/collection.go index 100e28e537f90..04a5a94f1cab6 100644 --- a/internal/util/segcore/collection.go +++ b/internal/util/segcore/collection.go @@ -38,9 +38,17 @@ func CreateCCollection(req *CreateCCollectionRequest) (*CCollection, error) { return nil, errors.New("marshal index meta failed") } } - ptr := C.NewCollection(unsafe.Pointer(&schemaBlob[0]), (C.int64_t)(len(schemaBlob))) + var ptr C.CCollection + status := C.NewCollection(unsafe.Pointer(&schemaBlob[0]), (C.int64_t)(len(schemaBlob)), &ptr) + if err := ConsumeCStatusIntoError(&status); err != nil { + return nil, err + } if indexMetaBlob != nil { - C.SetIndexMeta(ptr, unsafe.Pointer(&indexMetaBlob[0]), (C.int64_t)(len(indexMetaBlob))) + status = C.SetIndexMeta(ptr, unsafe.Pointer(&indexMetaBlob[0]), (C.int64_t)(len(indexMetaBlob))) + if err := ConsumeCStatusIntoError(&status); err != nil { + C.DeleteCollection(ptr) + return nil, err + } } return &CCollection{ collectionID: req.CollectionID,