Skip to content

Commit

Permalink
Add changes to support INT32 entities and request sources
Browse files Browse the repository at this point in the history
  • Loading branch information
msistla96 committed May 16, 2024
1 parent 4996dca commit 3384006
Show file tree
Hide file tree
Showing 7 changed files with 341 additions and 28 deletions.
46 changes: 44 additions & 2 deletions go/internal/feast/featurestore.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"github.com/rs/zerolog/log"
"os"
"strings"

Expand Down Expand Up @@ -85,7 +86,7 @@ func (fs *FeatureStore) GetOnlineFeatures(
joinKeyToEntityValues map[string]*prototypes.RepeatedValue,
requestData map[string]*prototypes.RepeatedValue,
fullFeatureNames bool) ([]*onlineserving.FeatureVector, error) {
fvs, odFvs, err := fs.listAllViews()
fvs, odFvs, err := fs.ListAllViews()
if err != nil {
return nil, err
}
Expand All @@ -104,6 +105,7 @@ func (fs *FeatureStore) GetOnlineFeatures(
requestedFeatureViews, requestedOnDemandFeatureViews, err =
onlineserving.GetFeatureViewsToUseByFeatureRefs(featureRefs, fvs, odFvs)
}
log.Info().Msgf("requestedOnDemandFeatureViews %v", requestedOnDemandFeatureViews)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -230,7 +232,7 @@ func (fs *FeatureStore) GetFeatureService(name string) (*model.FeatureService, e
return fs.registry.GetFeatureService(fs.config.Project, name)
}

func (fs *FeatureStore) listAllViews() (map[string]*model.FeatureView, map[string]*model.OnDemandFeatureView, error) {
func (fs *FeatureStore) ListAllViews() (map[string]*model.FeatureView, map[string]*model.OnDemandFeatureView, error) {
fvs := make(map[string]*model.FeatureView)
odFvs := make(map[string]*model.OnDemandFeatureView)

Expand Down Expand Up @@ -291,6 +293,38 @@ func (fs *FeatureStore) ListEntities(hideDummyEntity bool) ([]*model.Entity, err
return entities, nil
}

func (fs *FeatureStore) GetEntityKeyTypeMaps() (map[string]prototypes.ValueType_Enum, error) {

entityKeyTypeMap := make(map[string]prototypes.ValueType_Enum, 0)
allEntities, _ := fs.registry.ListEntities(fs.config.Project)
if allEntities == nil || len(allEntities) <= 0 {
return nil, fmt.Errorf("No entities found for project %s", fs.config.Project)
}
for _, entity := range allEntities {
entityKeyTypeMap[entity.JoinKey] = entity.ValueType
}
return entityKeyTypeMap, nil
}
func (fs *FeatureStore) GetRequestSources(fVList []string) (map[string]prototypes.ValueType_Enum, error) {

requestSources := make(map[string]prototypes.ValueType_Enum, 0)
if fVList != nil && len(fVList) > 0 {
for _, fvName := range fVList {
odfv, err := fs.GetOnDemandFeatureView(fvName)
if err == nil {
schema := odfv.GetRequestDataSchema()
for name, dtype := range schema {
requestSources[name] = dtype
}
}
}
}
if len(requestSources) > 0 {
return requestSources, nil
}
return nil, fmt.Errorf("Request sources for feature views %v not found", fVList)
}

func (fs *FeatureStore) ListOnDemandFeatureViews() ([]*model.OnDemandFeatureView, error) {
return fs.registry.ListOnDemandFeatureViews(fs.config.Project)
}
Expand All @@ -311,6 +345,14 @@ func (fs *FeatureStore) GetFeatureView(featureViewName string, hideDummyEntity b
return fv, nil
}

func (fs *FeatureStore) GetOnDemandFeatureView(featureViewName string) (*model.OnDemandFeatureView, error) {
fv, err := fs.registry.GetOnDemandFeatureView(fs.config.Project, featureViewName)
if err != nil {
return nil, err
}
return fv, nil
}

func (fs *FeatureStore) readFromOnlineStore(ctx context.Context, entityRows []*prototypes.EntityKey,
requestedFeatureViewNames []string,
requestedFeatureNames []string,
Expand Down
165 changes: 164 additions & 1 deletion go/internal/feast/featurestore_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package feast

import (
"context"
"github.com/feast-dev/feast/go/protos/feast/core"
"path/filepath"
"runtime"
"testing"
Expand All @@ -10,7 +11,7 @@ import (

"github.com/feast-dev/feast/go/internal/feast/onlinestore"
"github.com/feast-dev/feast/go/internal/feast/registry"
"github.com/feast-dev/feast/go/protos/feast/types"
types "github.com/feast-dev/feast/go/protos/feast/types"
)

// Return absolute path to the test_repo registry regardless of the working directory
Expand Down Expand Up @@ -70,3 +71,165 @@ func TestGetOnlineFeaturesRedis(t *testing.T) {
assert.Nil(t, err)
assert.Len(t, response, 4) // 3 Features + 1 entity = 4 columns (feature vectors) in response
}

func getRepoConfig() (config registry.RepoConfig) {
return registry.RepoConfig{
Project: "feature_repo",
Registry: getRegistryPath(),
Provider: "local",
OnlineStore: map[string]interface{}{
"type": "redis",
"connection_string": "localhost:6379",
},
}
}
func TestGetEntityKeyTypeMapsReturnsExpectedResult(t *testing.T) {

config := getRepoConfig()
fs, _ := NewFeatureStore(&config, nil)
entity1 := &core.Entity{
Spec: &core.EntitySpecV2{
Name: "entity1",
JoinKey: "joinKey1",
ValueType: types.ValueType_INT64,
},
}
entity2 := &core.Entity{
Spec: &core.EntitySpecV2{
Name: "entity2",
JoinKey: "joinKey2",
ValueType: types.ValueType_INT32,
},
}
cachedEntities := make(map[string]map[string]*core.Entity)
cachedEntities["feature_repo"] = make(map[string]*core.Entity)
cachedEntities["feature_repo"]["entity1"] = entity1
cachedEntities["feature_repo"]["entity2"] = entity2

fs.registry.CachedEntities = cachedEntities

entityKeyTypeMap, err := fs.GetEntityKeyTypeMaps()

assert.Nil(t, err)
assert.Equal(t, 2, len(entityKeyTypeMap))
assert.Equal(t, types.ValueType_INT64, entityKeyTypeMap["joinKey1"])
assert.Equal(t, types.ValueType_INT32, entityKeyTypeMap["joinKey2"])
}

func TestGetEntityKeyTypeMapsReturnsErrorWhenNoEntities(t *testing.T) {

config := getRepoConfig()
fs, _ := NewFeatureStore(&config, nil)

cachedEntities := make(map[string]map[string]*core.Entity)
fs.registry.CachedEntities = cachedEntities

entityKeyTypeMap, err := fs.GetEntityKeyTypeMaps()

assert.NotNil(t, err)
assert.Equal(t, 0, len(entityKeyTypeMap))
}
func TestGetRequestSourcesWithValidFeatures(t *testing.T) {
config := getRepoConfig()
fs, _ := NewFeatureStore(&config, nil)
fVList := []string{"odfv1", "fv1"}

odfv := &core.OnDemandFeatureView{
Spec: &core.OnDemandFeatureViewSpec{
Name: "odfv1",
Project: "feature_repo",
Sources: map[string]*core.OnDemandSource{
"odfv1": {
Source: &core.OnDemandSource_RequestDataSource{
RequestDataSource: &core.DataSource{
Name: "request_source_1",
Type: core.DataSource_REQUEST_SOURCE,
Options: &core.DataSource_RequestDataOptions_{
RequestDataOptions: &core.DataSource_RequestDataOptions{
DeprecatedSchema: map[string]types.ValueType_Enum{
"feature1": types.ValueType_INT64,
},
Schema: []*core.FeatureSpecV2{
{
Name: "feat1",
ValueType: types.ValueType_INT64,
},
},
},
},
},
},
},
},
},
}

cachedOnDemandFVs := make(map[string]map[string]*core.OnDemandFeatureView)
cachedOnDemandFVs["feature_repo"] = make(map[string]*core.OnDemandFeatureView)
cachedOnDemandFVs["feature_repo"]["odfv1"] = odfv
fs.registry.CachedOnDemandFeatureViews = cachedOnDemandFVs
requestSources, err := fs.GetRequestSources(fVList)

assert.Nil(t, err)
assert.Equal(t, 1, len(requestSources))
assert.Equal(t, types.ValueType_INT64.Enum(), requestSources["feat1"].Enum())
}

func TestGetRequestSourcesWithInvalidFeatures(t *testing.T) {

config := getRepoConfig()
fs, _ := NewFeatureStore(&config, nil)
fVList := []string{"invalidFV", "fv1"}

odfv := &core.OnDemandFeatureView{
Spec: &core.OnDemandFeatureViewSpec{
Name: "odfv1",
Project: "feature_repo",
Sources: map[string]*core.OnDemandSource{
"odfv1": {
Source: &core.OnDemandSource_RequestDataSource{
RequestDataSource: &core.DataSource{
Name: "request_source_1",
Type: core.DataSource_REQUEST_SOURCE,
Options: &core.DataSource_RequestDataOptions_{
RequestDataOptions: &core.DataSource_RequestDataOptions{
DeprecatedSchema: map[string]types.ValueType_Enum{
"feature1": types.ValueType_INT64,
},
Schema: []*core.FeatureSpecV2{
{
Name: "feature1",
ValueType: types.ValueType_INT64,
},
},
},
},
},
},
},
},
},
}

cachedOnDemandFVs := make(map[string]map[string]*core.OnDemandFeatureView)
cachedOnDemandFVs["feature_repo"] = make(map[string]*core.OnDemandFeatureView)
cachedOnDemandFVs["feature_repo"]["odfv1"] = odfv
fs.registry.CachedOnDemandFeatureViews = cachedOnDemandFVs

requestSources, err := fs.GetRequestSources(fVList)

assert.NotNil(t, err)
assert.Equal(t, 0, len(requestSources))
}

func TestGetRequestSourcesWithNoFeatures(t *testing.T) {

config := getRepoConfig()
fs, _ := NewFeatureStore(&config, nil)
var fvList []string

requestSources, err := fs.GetRequestSources(fvList)

assert.NotNil(t, err)
assert.Equal(t, 0, len(requestSources))
}
11 changes: 7 additions & 4 deletions go/internal/feast/model/entity.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,19 @@ package model

import (
"github.com/feast-dev/feast/go/protos/feast/core"
"github.com/feast-dev/feast/go/protos/feast/types"
)

type Entity struct {
Name string
JoinKey string
Name string
JoinKey string
ValueType types.ValueType_Enum
}

func NewEntityFromProto(proto *core.Entity) *Entity {
return &Entity{
Name: proto.Spec.Name,
JoinKey: proto.Spec.JoinKey,
Name: proto.Spec.Name,
JoinKey: proto.Spec.JoinKey,
ValueType: proto.Spec.ValueType,
}
}
28 changes: 14 additions & 14 deletions go/internal/feast/registry/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,10 @@ type Registry struct {
project string
registryStore RegistryStore
cachedFeatureServices map[string]map[string]*core.FeatureService
cachedEntities map[string]map[string]*core.Entity
CachedEntities map[string]map[string]*core.Entity
cachedFeatureViews map[string]map[string]*core.FeatureView
cachedStreamFeatureViews map[string]map[string]*core.StreamFeatureView
cachedOnDemandFeatureViews map[string]map[string]*core.OnDemandFeatureView
CachedOnDemandFeatureViews map[string]map[string]*core.OnDemandFeatureView
cachedRegistry *core.Registry
cachedRegistryProtoLastUpdated time.Time
cachedRegistryProtoTtl time.Duration
Expand Down Expand Up @@ -114,10 +114,10 @@ func (r *Registry) load(registry *core.Registry) {
defer r.mu.Unlock()
r.cachedRegistry = registry
r.cachedFeatureServices = make(map[string]map[string]*core.FeatureService)
r.cachedEntities = make(map[string]map[string]*core.Entity)
r.CachedEntities = make(map[string]map[string]*core.Entity)
r.cachedFeatureViews = make(map[string]map[string]*core.FeatureView)
r.cachedStreamFeatureViews = make(map[string]map[string]*core.StreamFeatureView)
r.cachedOnDemandFeatureViews = make(map[string]map[string]*core.OnDemandFeatureView)
r.CachedOnDemandFeatureViews = make(map[string]map[string]*core.OnDemandFeatureView)
r.loadEntities(registry)
r.loadFeatureServices(registry)
r.loadFeatureViews(registry)
Expand All @@ -130,10 +130,10 @@ func (r *Registry) loadEntities(registry *core.Registry) {
entities := registry.Entities
for _, entity := range entities {
// fmt.Println("Entity load: ", entity.Spec.Name)
if _, ok := r.cachedEntities[r.project]; !ok {
r.cachedEntities[r.project] = make(map[string]*core.Entity)
if _, ok := r.CachedEntities[r.project]; !ok {
r.CachedEntities[r.project] = make(map[string]*core.Entity)
}
r.cachedEntities[r.project][entity.Spec.Name] = entity
r.CachedEntities[r.project][entity.Spec.Name] = entity
}
}

Expand Down Expand Up @@ -174,10 +174,10 @@ func (r *Registry) loadOnDemandFeatureViews(registry *core.Registry) {
onDemandFeatureViews := registry.OnDemandFeatureViews
for _, onDemandFeatureView := range onDemandFeatureViews {
// fmt.Println("onDemandFeatureView load: ", onDemandFeatureView.Spec.Name)
if _, ok := r.cachedOnDemandFeatureViews[r.project]; !ok {
r.cachedOnDemandFeatureViews[r.project] = make(map[string]*core.OnDemandFeatureView)
if _, ok := r.CachedOnDemandFeatureViews[r.project]; !ok {
r.CachedOnDemandFeatureViews[r.project] = make(map[string]*core.OnDemandFeatureView)
}
r.cachedOnDemandFeatureViews[r.project][onDemandFeatureView.Spec.Name] = onDemandFeatureView
r.CachedOnDemandFeatureViews[r.project][onDemandFeatureView.Spec.Name] = onDemandFeatureView
}
}

Expand All @@ -189,7 +189,7 @@ func (r *Registry) loadOnDemandFeatureViews(registry *core.Registry) {
func (r *Registry) ListEntities(project string) ([]*model.Entity, error) {
r.mu.RLock()
defer r.mu.RUnlock()
if cachedEntities, ok := r.cachedEntities[project]; !ok {
if cachedEntities, ok := r.CachedEntities[project]; !ok {
return []*model.Entity{}, nil
} else {
entities := make([]*model.Entity, len(cachedEntities))
Expand Down Expand Up @@ -273,7 +273,7 @@ func (r *Registry) ListFeatureServices(project string) ([]*model.FeatureService,
func (r *Registry) ListOnDemandFeatureViews(project string) ([]*model.OnDemandFeatureView, error) {
r.mu.RLock()
defer r.mu.RUnlock()
if cachedOnDemandFeatureViews, ok := r.cachedOnDemandFeatureViews[project]; !ok {
if cachedOnDemandFeatureViews, ok := r.CachedOnDemandFeatureViews[project]; !ok {
return []*model.OnDemandFeatureView{}, nil
} else {
onDemandFeatureViews := make([]*model.OnDemandFeatureView, len(cachedOnDemandFeatureViews))
Expand All @@ -289,7 +289,7 @@ func (r *Registry) ListOnDemandFeatureViews(project string) ([]*model.OnDemandFe
func (r *Registry) GetEntity(project, entityName string) (*model.Entity, error) {
r.mu.RLock()
defer r.mu.RUnlock()
if cachedEntities, ok := r.cachedEntities[project]; !ok {
if cachedEntities, ok := r.CachedEntities[project]; !ok {
return nil, fmt.Errorf("no cached entities found for project %s", project)
} else {
if entity, ok := cachedEntities[entityName]; !ok {
Expand Down Expand Up @@ -345,7 +345,7 @@ func (r *Registry) GetFeatureService(project, featureServiceName string) (*model
func (r *Registry) GetOnDemandFeatureView(project, onDemandFeatureViewName string) (*model.OnDemandFeatureView, error) {
r.mu.RLock()
defer r.mu.RUnlock()
if cachedOnDemandFeatureViews, ok := r.cachedOnDemandFeatureViews[project]; !ok {
if cachedOnDemandFeatureViews, ok := r.CachedOnDemandFeatureViews[project]; !ok {
return nil, fmt.Errorf("no cached on demand feature views found for project %s", project)
} else {
if onDemandFeatureViewProto, ok := cachedOnDemandFeatureViews[onDemandFeatureViewName]; !ok {
Expand Down
Loading

0 comments on commit 3384006

Please sign in to comment.