diff --git a/go/internal/feast/onlinestore/cassandraonlinestore.go b/go/internal/feast/onlinestore/cassandraonlinestore.go index 79c820cf97..ec988ea691 100644 --- a/go/internal/feast/onlinestore/cassandraonlinestore.go +++ b/go/internal/feast/onlinestore/cassandraonlinestore.go @@ -33,6 +33,9 @@ type CassandraOnlineStore struct { session *gocql.Session config *registry.RepoConfig + + // The number of keys to include in a single CQL query for retrieval from the database + keyBatchSize int } type CassandraConfig struct { @@ -44,6 +47,7 @@ type CassandraConfig struct { loadBalancingPolicy gocql.HostSelectionPolicy connectionTimeoutMillis int64 requestTimeoutMillis int64 + keyBatchSize int } func parseStringField(config map[string]any, fieldName string, defaultValue string) (string, error) { @@ -156,6 +160,13 @@ func extractCassandraConfig(onlineStoreConfig map[string]any) (*CassandraConfig, } cassandraConfig.requestTimeoutMillis = int64(requestTimeoutMillis.(float64)) + keyBatchSize, ok := onlineStoreConfig["key_batch_size"] + if !ok { + keyBatchSize = 10.0 + log.Warn().Msg("key_batch_size not specified, defaulting to batches of size 10") + } + cassandraConfig.keyBatchSize = int(keyBatchSize.(float64)) + return &cassandraConfig, nil } @@ -176,8 +187,9 @@ func NewCassandraOnlineStore(project string, config *registry.RepoConfig, online store.clusterConfigs.PoolConfig.HostSelectionPolicy = cassandraConfig.loadBalancingPolicy - if cassandraConfig.username != "" && cassandraConfig.password != "" { - log.Warn().Msg("username/password not defined, will not be using authentication") + if cassandraConfig.username == "" || cassandraConfig.password == "" { + log.Warn().Msg("username and/or password not defined, will not be using authentication") + } else { store.clusterConfigs.Authenticator = gocql.PasswordAuthenticator{ Username: cassandraConfig.username, Password: cassandraConfig.password, @@ -203,6 +215,16 @@ func NewCassandraOnlineStore(project string, config *registry.RepoConfig, online return nil, fmt.Errorf("unable to connect to the ScyllaDB database") } store.session = createdSession + + if cassandraConfig.keyBatchSize <= 0 { + return nil, fmt.Errorf("key_batch_size must be greater than zero") + } else if cassandraConfig.keyBatchSize == 1 { + log.Info().Msg("key batching is disabled") + } else { + log.Info().Msgf("key batching is enabled with a batch size of %d", cassandraConfig.keyBatchSize) + } + store.keyBatchSize = cassandraConfig.keyBatchSize + return &store, nil } @@ -210,7 +232,21 @@ func (c *CassandraOnlineStore) getFqTableName(tableName string) string { return fmt.Sprintf(`"%s"."%s_%s"`, c.clusterConfigs.Keyspace, c.project, tableName) } -func (c *CassandraOnlineStore) getCQLStatement(tableName string, featureNames []string, nkeys int) string { +func (c *CassandraOnlineStore) getSingleKeyCQLStatement(tableName string, featureNames []string) string { + // this prevents fetching unnecessary features + quotedFeatureNames := make([]string, len(featureNames)) + for i, featureName := range featureNames { + quotedFeatureNames[i] = fmt.Sprintf(`'%s'`, featureName) + } + + return fmt.Sprintf( + `SELECT "entity_key", "feature_name", "event_ts", "value" FROM %s WHERE "entity_key" = ? AND "feature_name" IN (%s)`, + tableName, + strings.Join(quotedFeatureNames, ","), + ) +} + +func (c *CassandraOnlineStore) getMultiKeyCQLStatement(tableName string, featureNames []string, nkeys int) string { // this prevents fetching unnecessary features quotedFeatureNames := make([]string, len(featureNames)) for i, featureName := range featureNames { @@ -244,7 +280,143 @@ func (c *CassandraOnlineStore) buildCassandraEntityKeys(entityKeys []*types.Enti } return cassandraKeys, cassandraKeyToEntityIndex, nil } -func (c *CassandraOnlineStore) OnlineRead(ctx context.Context, entityKeys []*types.EntityKey, featureViewNames []string, featureNames []string) ([][]FeatureData, error) { + +func (c *CassandraOnlineStore) UnbatchedKeysOnlineRead(ctx context.Context, entityKeys []*types.EntityKey, featureViewNames []string, featureNames []string) ([][]FeatureData, error) { + uniqueNames := make(map[string]int32) + for _, fvName := range featureViewNames { + uniqueNames[fvName] = 0 + } + if len(uniqueNames) != 1 { + return nil, fmt.Errorf("rejecting OnlineRead as more than 1 feature view was tried to be read at once") + } + + serializedEntityKeys, serializedEntityKeyToIndex, err := c.buildCassandraEntityKeys(entityKeys) + + if err != nil { + return nil, fmt.Errorf("error when serializing entity keys for Cassandra") + } + results := make([][]FeatureData, len(entityKeys)) + for i := range results { + results[i] = make([]FeatureData, len(featureNames)) + } + + featureNamesToIdx := make(map[string]int) + for idx, name := range featureNames { + featureNamesToIdx[name] = idx + } + + featureViewName := featureViewNames[0] + + // Prepare the query + tableName := c.getFqTableName(featureViewName) + cqlStatement := c.getSingleKeyCQLStatement(tableName, featureNames) + + var waitGroup sync.WaitGroup + waitGroup.Add(len(serializedEntityKeys)) + + errorsChannel := make(chan error, len(serializedEntityKeys)) + for _, serializedEntityKey := range serializedEntityKeys { + go func(serEntityKey any) { + defer waitGroup.Done() + + iter := c.session.Query(cqlStatement, serEntityKey).WithContext(ctx).Iter() + + rowIdx := serializedEntityKeyToIndex[serializedEntityKey.(string)] + + // fill the row with nulls if not found + if iter.NumRows() == 0 { + for _, featName := range featureNames { + results[rowIdx][featureNamesToIdx[featName]] = FeatureData{ + Reference: serving.FeatureReferenceV2{ + FeatureViewName: featureViewName, + FeatureName: featName, + }, + Value: types.Value{ + Val: &types.Value_NullVal{ + NullVal: types.Null_NULL, + }, + }, + } + } + return + } + + scanner := iter.Scanner() + var entityKey string + var featureName string + var eventTs time.Time + var valueStr []byte + var deserializedValue types.Value + rowFeatures := make(map[string]FeatureData) + for scanner.Next() { + err := scanner.Scan(&entityKey, &featureName, &eventTs, &valueStr) + if err != nil { + errorsChannel <- errors.New("could not read row in query for (entity key, feature name, value, event ts)") + return + } + if err := proto.Unmarshal(valueStr, &deserializedValue); err != nil { + errorsChannel <- errors.New("error converting parsed Cassandra Value to types.Value") + return + } + + if deserializedValue.Val != nil { + // Convert the value to a FeatureData struct + rowFeatures[featureName] = FeatureData{ + Reference: serving.FeatureReferenceV2{ + FeatureViewName: featureViewName, + FeatureName: featureName, + }, + Timestamp: timestamppb.Timestamp{Seconds: eventTs.Unix(), Nanos: int32(eventTs.Nanosecond())}, + Value: types.Value{ + Val: deserializedValue.Val, + }, + } + } + } + + if err := scanner.Err(); err != nil { + errorsChannel <- errors.New("failed to scan features: " + err.Error()) + return + } + + for _, featName := range featureNames { + featureData, ok := rowFeatures[featName] + if !ok { + featureData = FeatureData{ + Reference: serving.FeatureReferenceV2{ + FeatureViewName: featureViewName, + FeatureName: featName, + }, + Value: types.Value{ + Val: &types.Value_NullVal{ + NullVal: types.Null_NULL, + }, + }, + } + } + results[rowIdx][featureNamesToIdx[featName]] = featureData + } + }(serializedEntityKey) + } + + // wait until all concurrent single-key queries are done + waitGroup.Wait() + close(errorsChannel) + + var collectedErrors []error + for err := range errorsChannel { + if err != nil { + collectedErrors = append(collectedErrors, err) + } + } + if len(collectedErrors) > 0 { + return nil, errors.Join(collectedErrors...) + } + + return results, nil +} + +func (c *CassandraOnlineStore) BatchedKeysOnlineRead(ctx context.Context, entityKeys []*types.EntityKey, featureViewNames []string, featureNames []string) ([][]FeatureData, error) { uniqueNames := make(map[string]int32) for _, fvName := range featureViewNames { uniqueNames[fvName] = 0 @@ -273,9 +445,9 @@ func (c *CassandraOnlineStore) OnlineRead(ctx context.Context, entityKeys []*typ // Prepare the query tableName := c.getFqTableName(featureViewName) - // do batching + // Key batching nKeys := len(serializedEntityKeys) - batchSize := 20 + batchSize := c.keyBatchSize nBatches := int(math.Ceil(float64(nKeys) / float64(batchSize))) batches := make([][]any, nBatches) @@ -293,7 +465,6 @@ func (c *CassandraOnlineStore) OnlineRead(ctx context.Context, entityKeys []*typ waitGroup.Add(nBatches) errorsChannel := make(chan error, nBatches) - var prevBatchLength int var cqlStatement string for _, batch := range batches { @@ -302,7 +473,7 @@ func (c *CassandraOnlineStore) OnlineRead(ctx context.Context, entityKeys []*typ // this caches the previous batch query if it had the same number of keys if len(keyBatch) != prevBatchLength { - cqlStatement = c.getCQLStatement(tableName, featureNames, len(keyBatch)) + cqlStatement = c.getMultiKeyCQLStatement(tableName, featureNames, len(keyBatch)) } iter := c.session.Query(cqlStatement, keyBatch...).WithContext(ctx).Iter() @@ -327,7 +498,6 @@ func (c *CassandraOnlineStore) OnlineRead(ctx context.Context, entityKeys []*typ } if deserializedValue.Val != nil { - // Convert the value to a FeatureData struct if batchFeatures[entityKey] == nil { batchFeatures[entityKey] = make(map[string]FeatureData) } @@ -388,6 +558,14 @@ func (c *CassandraOnlineStore) OnlineRead(ctx context.Context, entityKeys []*typ return results, nil } +func (c *CassandraOnlineStore) OnlineRead(ctx context.Context, entityKeys []*types.EntityKey, featureViewNames []string, featureNames []string) ([][]FeatureData, error) { + if c.keyBatchSize == 1 { + return c.UnbatchedKeysOnlineRead(ctx, entityKeys, featureViewNames, featureNames) + } else { + return c.BatchedKeysOnlineRead(ctx, entityKeys, featureViewNames, featureNames) + } +} + func (c *CassandraOnlineStore) Destruct() { c.session.Close() } diff --git a/go/internal/feast/onlinestore/cassandraonlinestore_test.go b/go/internal/feast/onlinestore/cassandraonlinestore_test.go index 67a9eea548..19c53506b3 100644 --- a/go/internal/feast/onlinestore/cassandraonlinestore_test.go +++ b/go/internal/feast/onlinestore/cassandraonlinestore_test.go @@ -60,17 +60,28 @@ func TestGetFqTableName(t *testing.T) { assert.Equal(t, `"scylladb"."dummy_project_dummy_fv"`, fqTableName) } -func TestGetCQLStatement(t *testing.T) { +func TestGetSingleKeyCQLStatement(t *testing.T) { store := CassandraOnlineStore{} fqTableName := `"scylladb"."dummy_project_dummy_fv"` - cqlStatement := store.getCQLStatement(fqTableName, []string{"feat1", "feat2"}) + cqlStatement := store.getSingleKeyCQLStatement(fqTableName, []string{"feat1", "feat2"}) assert.Equal(t, `SELECT "entity_key", "feature_name", "event_ts", "value" FROM "scylladb"."dummy_project_dummy_fv" WHERE "entity_key" = ? AND "feature_name" IN ('feat1','feat2')`, cqlStatement, ) } +func TestGetMultiKeyCQLStatement(t *testing.T) { + store := CassandraOnlineStore{} + fqTableName := `"scylladb"."dummy_project_dummy_fv"` + + cqlStatement := store.getMultiKeyCQLStatement(fqTableName, []string{"feat1", "feat2"}, 5) + assert.Equal(t, + `SELECT "entity_key", "feature_name", "event_ts", "value" FROM "scylladb"."dummy_project_dummy_fv" WHERE "entity_key" IN (?,?,?,?,?) AND "feature_name" IN ('feat1','feat2')`, + cqlStatement, + ) +} + func TestOnlineRead_RejectsDifferentFeatureViewsInSameRead(t *testing.T) { store := CassandraOnlineStore{} _, err := store.OnlineRead(context.TODO(), nil, []string{"fv1", "fv2"}, []string{"feat1", "feat2"})