diff --git a/db/db.go b/db/db.go index 961a9ee..a71460d 100644 --- a/db/db.go +++ b/db/db.go @@ -1,6 +1,9 @@ package db -import "context" +import ( + "context" + "reflect" +) // Database is responsible for inserting data into the database type Database interface { @@ -9,6 +12,9 @@ type Database interface { // insertList inserts a list of data into the database InsertList(context.Context, InserListParams) ([]interface{}, error) + + // GenCustomType generates a non-zero value for custom types + GenCustomType(reflect.Type) (interface{}, bool) } // InsertParams is a struct that holds the parameters for the Insert method diff --git a/db/gormf/gormf.go b/db/gormf/gormf.go index 24de0ef..38d42db 100644 --- a/db/gormf/gormf.go +++ b/db/gormf/gormf.go @@ -2,8 +2,11 @@ package gormf import ( "context" + "reflect" + "time" "github.com/eyo-chen/gofacto/db" + "gorm.io/datatypes" "gorm.io/gorm" ) @@ -38,3 +41,29 @@ func (c *config) InsertList(ctx context.Context, params db.InserListParams) ([]i return params.Values, nil } + +func (c *config) GenCustomType(t reflect.Type) (interface{}, bool) { + // Check if the type is a pointer + if t.Kind() == reflect.Ptr { + v, ok := c.GenCustomType(t.Elem()) + if !ok { + return nil, false + } + + ptr := reflect.New(reflect.TypeOf(v)) + ptr.Elem().Set(reflect.ValueOf(v)) + return ptr.Interface(), true + } + + // Handle specific types + switch t.String() { + case jsonType: + return datatypes.JSON([]byte(`{"test": "test"}`)), true + case dateType: + return datatypes.Date(time.Now()), true + case timeType: + return datatypes.NewTime(1, 2, 3, 0), true + default: + return nil, false + } +} diff --git a/db/gormf/gormf_test.go b/db/gormf/gormf_test.go index 0e7e676..4410030 100644 --- a/db/gormf/gormf_test.go +++ b/db/gormf/gormf_test.go @@ -9,6 +9,7 @@ import ( "testing" "time" + "gorm.io/datatypes" "gorm.io/driver/mysql" "gorm.io/gorm" @@ -36,22 +37,27 @@ type Author struct { WebsiteURL *string FanCount *int64 ProfilePicture []byte + BornTime datatypes.Time + BornTime1 *datatypes.Time } type Book struct { - ID int64 - AuthorID int64 `gofacto:"Author,authors"` - Title string - ISBN *string - PublicationDate *time.Time - Genre *string - Price *float64 - PageCount *int32 - Description *string - InStock bool - CoverImage []byte - CreatedAt time.Time - UpdatedAt time.Time + ID int64 + AuthorID int64 `gofacto:"Author,authors"` + Title string + ISBN *string + PublicationDate datatypes.Date + PublicationDate1 *datatypes.Date + Genre *string + Price *float64 + PageCount *int32 + Description *string + InStock bool + CoverImage []byte + Data datatypes.JSON + Data1 *datatypes.JSON + CreatedAt time.Time + UpdatedAt time.Time } type testingSuite struct { diff --git a/db/gormf/schema.sql b/db/gormf/schema.sql index d74140e..e9e2bed 100644 --- a/db/gormf/schema.sql +++ b/db/gormf/schema.sql @@ -7,12 +7,14 @@ CREATE TABLE IF NOT EXISTS authors ( email VARCHAR(100) UNIQUE, biography TEXT, is_active BOOLEAN DEFAULT TRUE, - rating DECIMAL(3,2), + rating DECIMAL(4,2), books_written INT UNSIGNED, last_publication_time TIMESTAMP DEFAULT CURRENT_TIMESTAMP, website_url VARCHAR(255), fan_count BIGINT UNSIGNED, - profile_picture BLOB + profile_picture BLOB, + born_time TIME, + born_time1 TIME ); CREATE TABLE IF NOT EXISTS books ( @@ -21,12 +23,15 @@ CREATE TABLE IF NOT EXISTS books ( title VARCHAR(255) NOT NULL, isbn CHAR(13) UNIQUE, publication_date DATE, + publication_date1 DATE, genre ENUM('Fiction', 'Non-Fiction', 'Science', 'History', 'Biography', 'Other'), price DECIMAL(10,2), page_count SMALLINT UNSIGNED, description TEXT, in_stock BOOLEAN DEFAULT TRUE, cover_image MEDIUMBLOB, + data JSON, + data1 JSON, created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP, updated_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP ON UPDATE CURRENT_TIMESTAMP, FOREIGN KEY (author_id) REFERENCES authors(id) ON DELETE SET NULL diff --git a/db/gormf/type.go b/db/gormf/type.go new file mode 100644 index 0000000..deb4f48 --- /dev/null +++ b/db/gormf/type.go @@ -0,0 +1,12 @@ +package gormf + +var ( + // json is the string representation of datatypes.JSON + jsonType = "datatypes.JSON" + + // date is the string representation of datatypes.Date + dateType = "datatypes.Date" + + // time is the string representation of datatypes.Time + timeType = "datatypes.Time" +) diff --git a/db/mongof/mongof.go b/db/mongof/mongof.go index 5b51cb5..8f584c4 100644 --- a/db/mongof/mongof.go +++ b/db/mongof/mongof.go @@ -48,6 +48,10 @@ func (c *config) InsertList(ctx context.Context, params db.InserListParams) ([]i return params.Values, nil } +func (c *config) GenCustomType(t reflect.Type) (interface{}, bool) { + return nil, false +} + // setIDField sets the ID field of the value to the given ID func setIDField(val interface{}, id primitive.ObjectID) { v := reflect.ValueOf(val).Elem().FieldByName("ID") diff --git a/go.mod b/go.mod index b3850c3..3a43aff 100644 --- a/go.mod +++ b/go.mod @@ -49,6 +49,7 @@ require ( golang.org/x/text v0.16.0 // indirect golang.org/x/tools v0.23.0 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect + gorm.io/datatypes v1.2.1 // indirect gorm.io/driver/mysql v1.5.7 // indirect gorm.io/gorm v1.25.11 // indirect ) diff --git a/go.sum b/go.sum index 06001a0..3e1c4f7 100644 --- a/go.sum +++ b/go.sum @@ -165,6 +165,8 @@ gopkg.in/yaml.v2 v2.4.0 h1:D8xgwECY7CYvx+Y2n4sBz93Jn9JRvxdiyyo8CTfuKaY= gopkg.in/yaml.v2 v2.4.0/go.mod h1:RDklbk79AGWmwhnvt/jBztapEOGDOx6ZbXqjP6csGnQ= gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gorm.io/datatypes v1.2.1 h1:r+g0bk4LPCW2v4+Ls7aeNgGme7JYdNDQ2VtvlNUfBh0= +gorm.io/datatypes v1.2.1/go.mod h1:hYK6OTb/1x+m96PgoZZq10UXJ6RvEBb9kRDQ2yyhzGs= gorm.io/driver/mysql v1.5.7 h1:MndhOPYOfEp2rHKgkZIhJ16eVUIRf2HmzgoPmh7FCWo= gorm.io/driver/mysql v1.5.7/go.mod h1:sEtPWMiqiN1N1cMXoXmBbd8C6/l+TESwriotuRRpkDM= gorm.io/gorm v1.25.7/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= diff --git a/gofacto.go b/gofacto.go index 31fa929..c3a9f53 100644 --- a/gofacto.go +++ b/gofacto.go @@ -152,7 +152,7 @@ func (f *Factory[T]) Build(ctx context.Context) *builder[T] { } if f.isSetZeroValue { - setNonZeroValues(f.index, &v, f.ignoreFields) + f.setNonZeroValues(&v) } f.index++ @@ -181,7 +181,7 @@ func (f *Factory[T]) BuildList(ctx context.Context, n int) *builderList[T] { } if f.isSetZeroValue { - setNonZeroValues(f.index, &v, f.ignoreFields) + f.setNonZeroValues(&v) } list[i] = &v @@ -458,7 +458,7 @@ func (b *builder[T]) WithOne(v interface{}, ignoreFields ...string) *builder[T] b.f.tagToInfo = t } - if err := setAssValue(v, b.f.tagToInfo, b.f.index, "WithOne", ignoreFields); err != nil { + if err := b.f.setAssValue(v); err != nil { b.errors = append(b.errors, err) return b } @@ -485,7 +485,7 @@ func (b *builderList[T]) WithOne(v interface{}, ignoreFields ...string) *builder b.f.tagToInfo = t } - if err := setAssValue(v, b.f.tagToInfo, b.f.index, "WithOne", ignoreFields); err != nil { + if err := b.f.setAssValue(v); err != nil { b.errors = append(b.errors, err) return b } @@ -514,7 +514,7 @@ func (b *builderList[T]) WithMany(values []interface{}, ignoreFields ...string) var curValName string for _, v := range values { - if err := setAssValue(v, b.f.tagToInfo, b.f.index, "WithMany", ignoreFields); err != nil { + if err := b.f.setAssValue(v); err != nil { b.errors = append(b.errors, err) return b } @@ -539,7 +539,7 @@ func (b *builderList[T]) WithMany(values []interface{}, ignoreFields ...string) // setAss sets and inserts the associations func (b *builder[T]) setAss() error { // insert the associations - if err := insertAss(b.ctx, b.f.db, b.f.associations, b.f.tagToInfo); err != nil { + if err := b.f.insertAss(b.ctx); err != nil { return err } @@ -563,7 +563,7 @@ func (b *builder[T]) setAss() error { // setAss sets and inserts the associations func (b *builderList[T]) setAss() error { // insert the associations - if err := insertAss(b.ctx, b.f.db, b.f.associations, b.f.tagToInfo); err != nil { + if err := b.f.insertAss(b.ctx); err != nil { return err } diff --git a/helpers.go b/helpers.go index 160ebc1..e63db93 100644 --- a/helpers.go +++ b/helpers.go @@ -17,52 +17,9 @@ const ( packageName = "gofacto" ) -// copyValues copys non-zero values from src to dest -func copyValues[T any](dest *T, src T) error { - destValue := reflect.ValueOf(dest).Elem() - srcValue := reflect.ValueOf(src) - - if destValue.Kind() != reflect.Struct { - return errors.New("destination value is not a struct") - } - - if srcValue.Kind() != reflect.Struct { - return errors.New("source value is not a struct") - } - - if destValue.Type() != srcValue.Type() { - return errors.New("destination and source type is different") - } - - for i := 0; i < destValue.NumField(); i++ { - destField := destValue.Field(i) - srcField := srcValue.FieldByName(destValue.Type().Field(i).Name) - - if srcField.IsValid() && destField.Type() == srcField.Type() && !srcField.IsZero() { - destField.Set(srcField) - } - } - - return nil -} - -// genFinalError generates a final error message from the given errors -func genFinalError(errs []error) error { - if len(errs) == 0 { - return nil - } - - errorMessages := make([]string, len(errs)) - for i, err := range errs { - errorMessages[i] = err.Error() - } - - return fmt.Errorf(strings.Join(errorMessages, "\n")) -} - -// setNonZeroValues sets non-zero values to the given struct -// v must be a pointer to a struct -func setNonZeroValues(i int, v interface{}, ignoreFields []string) { +// setNonZeroValues sets non-zero values to the given struct. +// Parameter v must be a pointer to a struct +func (f *Factory[T]) setNonZeroValues(v interface{}) { val := reflect.ValueOf(v).Elem() typeOfVal := val.Type() @@ -71,7 +28,7 @@ func setNonZeroValues(i int, v interface{}, ignoreFields []string) { curField := typeOfVal.Field(k) // skip ignored fields - if len(ignoreFields) > 0 && slices.Contains(ignoreFields, curField.Name) { + if slices.Contains(f.ignoreFields, curField.Name) { continue } @@ -80,6 +37,14 @@ func setNonZeroValues(i int, v interface{}, ignoreFields []string) { continue } + // handle custom types + if f.db != nil { + if customValue, ok := f.db.GenCustomType(curField.Type); ok { + curVal.Set(reflect.ValueOf(customValue)) + continue + } + } + // handle time.Time if curField.Type == reflect.TypeOf(time.Time{}) { curVal.Set(reflect.ValueOf(time.Now())) @@ -95,48 +60,48 @@ func setNonZeroValues(i int, v interface{}, ignoreFields []string) { // handle struct if curField.Type.Kind() == reflect.Struct { - setNonZeroValues(i, curVal.Addr().Interface(), ignoreFields) + f.setNonZeroValues(curVal.Addr().Interface()) continue } // handle pointer to struct if curField.Type.Kind() == reflect.Ptr && curField.Type.Elem().Kind() == reflect.Struct { newInstance := reflect.New(curField.Type.Elem()).Elem() - setNonZeroValues(i, newInstance.Addr().Interface(), ignoreFields) + f.setNonZeroValues(newInstance.Addr().Interface()) curVal.Set(newInstance.Addr()) continue } // handle slice if curField.Type.Kind() == reflect.Slice { - setNonZeroValuesForSlice(i, curVal.Addr().Interface(), ignoreFields) + f.setNonZeroSlice(curVal.Addr().Interface()) continue } // handle pointer to slice if curField.Type.Kind() == reflect.Ptr && curField.Type.Elem().Kind() == reflect.Slice { newInstance := reflect.New(curField.Type.Elem()).Elem() - setNonZeroValuesForSlice(i, newInstance.Addr().Interface(), ignoreFields) + f.setNonZeroSlice(newInstance.Addr().Interface()) curVal.Set(newInstance.Addr()) continue } // For other types, set non-zero values if the field is zero - if v := genNonZeroValue(curField.Type, i); v != nil { + if v := genNonZeroValue(curField.Type, f.index); v != nil { curVal.Set(reflect.ValueOf(v)) } } } -// setNonZeroValuesForSlice sets non-zero values to the given slice. +// setNonZeroSlice sets non-zero values to the given slice. // Parameter v must be a pointer to a slice -func setNonZeroValuesForSlice(i int, v interface{}, ignoreFields []string) { +func (f *Factory[T]) setNonZeroSlice(v interface{}) { val := reflect.ValueOf(v).Elem() // handle slice if val.Type().Elem().Kind() == reflect.Slice { e := reflect.New(val.Type().Elem()).Elem() - setNonZeroValuesForSlice(i, e.Addr().Interface(), ignoreFields) + f.setNonZeroSlice(e.Addr().Interface()) val.Set(reflect.Append(val, e)) return } @@ -144,7 +109,7 @@ func setNonZeroValuesForSlice(i int, v interface{}, ignoreFields []string) { // handle slice of pointers if val.Type().Elem().Kind() == reflect.Ptr && val.Type().Elem().Elem().Kind() == reflect.Slice { e := reflect.New(val.Type().Elem().Elem()).Elem() - setNonZeroValuesForSlice(i, e.Addr().Interface(), ignoreFields) + f.setNonZeroSlice(e.Addr().Interface()) val.Set(reflect.Append(val, e.Addr())) return } @@ -152,7 +117,7 @@ func setNonZeroValuesForSlice(i int, v interface{}, ignoreFields []string) { // handle struct if val.Type().Elem().Kind() == reflect.Struct { e := reflect.New(val.Type().Elem()).Elem() - setNonZeroValues(i, e.Addr().Interface(), ignoreFields) + f.setNonZeroValues(e.Addr().Interface()) val.Set(reflect.Append(val, e)) return } @@ -160,18 +125,102 @@ func setNonZeroValuesForSlice(i int, v interface{}, ignoreFields []string) { // handle pointer to struct if val.Type().Elem().Kind() == reflect.Ptr && val.Type().Elem().Elem().Kind() == reflect.Struct { e := reflect.New(val.Type().Elem().Elem()) - setNonZeroValues(i, e.Interface(), ignoreFields) + f.setNonZeroValues(e.Interface()) val.Set(reflect.Append(val, e)) return } // handle other types t := val.Type().Elem() - if tv := genNonZeroValue(t, i); tv != nil { + if tv := genNonZeroValue(t, f.index); tv != nil { val.Set(reflect.Append(val, reflect.ValueOf(tv))) } } +// setAssValue sets the value to the associations value +func (f *Factory[T]) setAssValue(v interface{}) error { + typeOfV := reflect.TypeOf(v) + + // check if it's a pointer + if typeOfV.Kind() != reflect.Ptr { + name := typeOfV.Name() + return fmt.Errorf("type %s, value %v is not a pointer", name, v) + } + + name := typeOfV.Elem().Name() + // check if it's a pointer to a struct + if typeOfV.Elem().Kind() != reflect.Struct { + return fmt.Errorf("type %s, value %v is not a pointer to a struct", name, v) + } + + // check if it's existed in tagToInfo + if _, ok := f.tagToInfo[name]; !ok { + return fmt.Errorf("type %s, value %v is not found at tag", name, v) + } + + f.setNonZeroValues(v) + return nil +} + +// genAndInsertAss inserts the associations value into the database +func (f *Factory[T]) insertAss(ctx context.Context) error { + if len(f.tagToInfo) == 0 { + return errors.New("tagToInfo is not set") + } + + for name, vals := range f.associations { + tableName := f.tagToInfo[name].tableName + if _, err := f.db.InsertList(ctx, db.InserListParams{StorageName: tableName, Values: vals}); err != nil { + return err + } + } + + return nil +} + +// copyValues copys non-zero values from src to dest +func copyValues[T any](dest *T, src T) error { + destValue := reflect.ValueOf(dest).Elem() + srcValue := reflect.ValueOf(src) + + if destValue.Kind() != reflect.Struct { + return errors.New("destination value is not a struct") + } + + if srcValue.Kind() != reflect.Struct { + return errors.New("source value is not a struct") + } + + if destValue.Type() != srcValue.Type() { + return errors.New("destination and source type is different") + } + + for i := 0; i < destValue.NumField(); i++ { + destField := destValue.Field(i) + srcField := srcValue.FieldByName(destValue.Type().Field(i).Name) + + if srcField.IsValid() && destField.Type() == srcField.Type() && !srcField.IsZero() { + destField.Set(srcField) + } + } + + return nil +} + +// genFinalError generates a final error message from the given errors +func genFinalError(errs []error) error { + if len(errs) == 0 { + return nil + } + + errorMessages := make([]string, len(errs)) + for i, err := range errs { + errorMessages[i] = err.Error() + } + + return fmt.Errorf(strings.Join(errorMessages, "\n")) +} + // genNonZeroValue generates a non-zero value for the given type func genNonZeroValue(t reflect.Type, i int) interface{} { switch t.Kind() { @@ -239,51 +288,6 @@ func setField(target interface{}, name string, source interface{}, sourceFn stri return nil } -// setAssValue sets the value to the associations value -func setAssValue(v interface{}, tagToInfo map[string]tagInfo, index int, sourceFn string, ignoreFields []string) error { - typeOfV := reflect.TypeOf(v) - - // check if it's a pointer - if typeOfV.Kind() != reflect.Ptr { - name := typeOfV.Name() - return fmt.Errorf("%s: type %s, value %v is not a pointer", sourceFn, name, v) - } - - name := typeOfV.Elem().Name() - // check if it's a pointer to a struct - if typeOfV.Elem().Kind() != reflect.Struct { - return fmt.Errorf("%s: type %s, value %v is not a pointer to a struct", sourceFn, name, v) - } - - // check if it's existed in tagToInfo - if _, ok := tagToInfo[name]; !ok { - return fmt.Errorf("%s: type %s, value %v is not found at tag", sourceFn, name, v) - } - - setNonZeroValues(index, v, ignoreFields) - return nil -} - -// genAndInsertAss inserts the associations value into the database -func insertAss(ctx context.Context, d db.Database, associations map[string][]interface{}, tagToInfo map[string]tagInfo) error { - if len(tagToInfo) == 0 { - return errors.New("tagToInfo is not set") - } - - if len(associations) == 0 { - return errors.New("inserting associations without any associations") - } - - for name, vals := range associations { - tableName := tagToInfo[name].tableName - if _, err := d.InsertList(ctx, db.InserListParams{StorageName: tableName, Values: vals}); err != nil { - return err - } - } - - return nil -} - // genTagToInfo generates the map from tag to metadata func genTagToInfo(dataType reflect.Type) (map[string]tagInfo, error) { tagToInfo := map[string]tagInfo{} @@ -339,10 +343,12 @@ func setFieldValue(target, source reflect.Value) { target.SetUint(uint64(source.Int())) } +// isIntType checks if the kind is an integer type func isIntType(k reflect.Kind) bool { return k >= reflect.Int && k <= reflect.Int64 } +// isUintType checks if the kind is an unsigned integer type func isUintType(k reflect.Kind) bool { return k >= reflect.Uint && k <= reflect.Uint64 } diff --git a/internal/sqllib/sqllib.go b/internal/sqllib/sqllib.go index a580012..6242411 100644 --- a/internal/sqllib/sqllib.go +++ b/internal/sqllib/sqllib.go @@ -115,6 +115,10 @@ func (c *Config) InsertList(ctx context.Context, params db.InserListParams) ([]i return result, nil } +func (c *Config) GenCustomType(t reflect.Type) (interface{}, bool) { + return nil, false +} + // prepareStmtAndVals prepares the SQL insert statement and the values to be inserted // values are the pointer to the struct func (c *Config) prepareStmtAndVals(tableName string, values ...interface{}) (string, [][]interface{}) {