From 952ced4df3777192474bf1ad7a93c145c430c06b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E5=BE=90=E4=BA=91=E9=87=91YunjinXu?= Date: Mon, 16 Oct 2023 10:51:54 +0800 Subject: [PATCH] implement sqlserver storage using gorm.io/driver/sqlserver --- client/dtmcli/consts.go | 2 ++ client/dtmcli/dtmimp/consts.go | 2 ++ client/dtmcli/dtmimp/db_special.go | 21 +++++++++++++++++++++ client/dtmcli/dtmimp/utils.go | 14 ++++++++++++++ dtmsvr/config/config.go | 4 +++- dtmsvr/storage/registry/registry.go | 5 +++-- dtmsvr/storage/sql/sql.go | 26 +++++++++++++++++++------- dtmutil/db.go | 6 ++++++ test/main_test.go | 4 ++++ 9 files changed, 74 insertions(+), 10 deletions(-) diff --git a/client/dtmcli/consts.go b/client/dtmcli/consts.go index 7ae4fc63a..bb87a15eb 100644 --- a/client/dtmcli/consts.go +++ b/client/dtmcli/consts.go @@ -35,6 +35,8 @@ const ( DBTypeMysql = dtmimp.DBTypeMysql // DBTypePostgres const for driver postgres DBTypePostgres = dtmimp.DBTypePostgres + // DBTypeSqlServer const for driver SqlServer + DBTypeSqlServer = dtmimp.DBTypeSqlServer ) // MapSuccess HTTP result of SUCCESS diff --git a/client/dtmcli/dtmimp/consts.go b/client/dtmcli/dtmimp/consts.go index 6f4e6cd3d..036e6dc38 100644 --- a/client/dtmcli/dtmimp/consts.go +++ b/client/dtmcli/dtmimp/consts.go @@ -36,6 +36,8 @@ const ( DBTypeMysql = "mysql" // DBTypePostgres const for driver postgres DBTypePostgres = "postgres" + // DBTypeSqlServer const for driver SqlServer + DBTypeSqlServer = "sqlserver" // DBTypeRedis const for driver redis DBTypeRedis = "redis" // Jrpc const for json-rpc diff --git a/client/dtmcli/dtmimp/db_special.go b/client/dtmcli/dtmimp/db_special.go index d9128b151..a7617526b 100644 --- a/client/dtmcli/dtmimp/db_special.go +++ b/client/dtmcli/dtmimp/db_special.go @@ -78,6 +78,27 @@ func init() { dbSpecials[DBTypePostgres] = &postgresDBSpecial{} } +// SqlServer 版本的实现TODO +type sqlserverDBSpecial struct{} + +func (*sqlserverDBSpecial) GetPlaceHoldSQL(sql string) string { + return sql +} + +func (*sqlserverDBSpecial) GetInsertIgnoreTemplate(tableAndValues string, pgConstraint string) string { + return fmt.Sprintf("insert ignore into %s", tableAndValues) //这个只有作为client的事务屏障才用的吧? Server端不用吧 +} + +func (*sqlserverDBSpecial) GetXaSQL(command string, xid string) string { + if command == "abort" { + command = "rollback" + } + return fmt.Sprintf("xa %s '%s'", command, xid) +} +func init() { + dbSpecials[DBTypeSqlServer] = &sqlserverDBSpecial{} +} + // GetDBSpecial get DBSpecial for currentDBType func GetDBSpecial(dbType string) DBSpecial { if dbType == "" { diff --git a/client/dtmcli/dtmimp/utils.go b/client/dtmcli/dtmimp/utils.go index b0a6a8a5e..fb510751a 100644 --- a/client/dtmcli/dtmimp/utils.go +++ b/client/dtmcli/dtmimp/utils.go @@ -221,13 +221,27 @@ func DBExec(dbType string, db DB, sql string, values ...interface{}) (affected i // GetDsn get dsn from map config func GetDsn(conf DBConf) string { + host := MayReplaceLocalhost(conf.Host) driver := conf.Driver + + query := url.Values{} + query.Add("database", conf.Db) + u := &url.URL{ + Scheme: "sqlserver", + User: url.UserPassword(conf.User, conf.Password), + Host: fmt.Sprintf("%s:%d", host, conf.Port), + // Path: instance, // if connecting to an instance instead of a port + RawQuery: query.Encode(), + } + dsn := map[string]string{ "mysql": fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=true&loc=Local&interpolateParams=true", conf.User, conf.Password, host, conf.Port, conf.Db), "postgres": fmt.Sprintf("host=%s user=%s password=%s dbname='%s' search_path=%s port=%d sslmode=disable", host, conf.User, conf.Password, conf.Db, conf.Schema, conf.Port), + // sqlserver://sa:mypass@localhost:1234?database=master&connection+timeout=30 + "sqlserver": u.String(), }[driver] PanicIf(dsn == "", fmt.Errorf("unknow driver: %s", driver)) return dsn diff --git a/dtmsvr/config/config.go b/dtmsvr/config/config.go index 2021db4fe..e9ba1e6fa 100644 --- a/dtmsvr/config/config.go +++ b/dtmsvr/config/config.go @@ -20,6 +20,8 @@ const ( BoltDb = "boltdb" // Postgres is postgres driver Postgres = "postgres" + // SqlServer is SQL Server driver + SqlServer = "sqlserver" ) // MicroService config type for microservice based grpc @@ -65,7 +67,7 @@ type Store struct { // IsDB checks config driver is mysql or postgres func (s *Store) IsDB() bool { - return s.Driver == dtmcli.DBTypeMysql || s.Driver == dtmcli.DBTypePostgres + return s.Driver == dtmcli.DBTypeMysql || s.Driver == dtmcli.DBTypePostgres || s.Driver == dtmcli.DBTypeSqlServer } // GetDBConf returns db conf info diff --git a/dtmsvr/storage/registry/registry.go b/dtmsvr/storage/registry/registry.go index 297c8751a..469d20d12 100644 --- a/dtmsvr/storage/registry/registry.go +++ b/dtmsvr/storage/registry/registry.go @@ -37,8 +37,9 @@ var storeFactorys = map[string]StorageFactory{ return &redis.Store{} }, }, - "mysql": sqlFac, - "postgres": sqlFac, + "mysql": sqlFac, + "postgres": sqlFac, + "sqlserver": sqlFac, } // GetStore returns storage.Store diff --git a/dtmsvr/storage/sql/sql.go b/dtmsvr/storage/sql/sql.go index cf8583c8f..6eea9405d 100644 --- a/dtmsvr/storage/sql/sql.go +++ b/dtmsvr/storage/sql/sql.go @@ -67,10 +67,10 @@ func (s *Store) ScanTransGlobalStores(position *string, limit int64, condition s query = query.Where("trans_type = ?", condition.TransType) } if !condition.CreateTimeStart.IsZero() { - query = query.Where("create_time >= ?", condition.CreateTimeStart.Format("2006-01-02 15:04:05")) + query = query.Where("create_time >= ?", condition.CreateTimeStart) } if !condition.CreateTimeEnd.IsZero() { - query = query.Where("create_time <= ?", condition.CreateTimeEnd.Format("2006-01-02 15:04:05")) + query = query.Where("create_time <= ?", condition.CreateTimeEnd) } dbr := query.Order("id desc").Limit(int(limit)).Find(&globals) @@ -103,7 +103,13 @@ func (s *Store) UpdateBranches(branches []storage.TransBranchStore, updates []st func (s *Store) LockGlobalSaveBranches(gid string, status string, branches []storage.TransBranchStore, branchStart int) { err := dbGet().Transaction(func(tx *gorm.DB) error { g := &storage.TransGlobalStore{} - dbr := tx.Clauses(clause.Locking{Strength: "UPDATE"}).Model(g).Where("gid=? and status=?", gid, status).First(g) + var dbr *gorm.DB + // sqlserver sql should be: SELECT * FROM "trans_global" with(RowLock,UpdLock) ,but gorm generates "FOR UPDATE" at the back, raw sql instead. + if conf.Store.Driver == config.SqlServer { + dbr = tx.Raw("SELECT * FROM trans_global with(RowLock,UpdLock) WHERE gid=? and status=? ORDER BY id OFFSET 0 ROW FETCH NEXT 1 ROWS ONLY ", gid, status).First(g) + } else { + dbr = tx.Clauses(clause.Locking{Strength: "UPDATE"}).Model(g).Where("gid=? and status=?", gid, status).First(g) + } if dbr.Error == nil { if branchStart == -1 { dbr = tx.Create(branches) @@ -160,11 +166,13 @@ func (s *Store) LockOneGlobalTrans(expireIn time.Duration) *storage.TransGlobalS owner := shortuuid.New() nextCronTime := getTimeStr(int64(expireIn / time.Second)) where := map[string]string{ - dtmimp.DBTypeMysql: fmt.Sprintf(`next_cron_time < '%s' and status in ('prepared', 'aborting', 'submitted') limit 1`, nextCronTime), - dtmimp.DBTypePostgres: fmt.Sprintf(`id in (select id from trans_global where next_cron_time < '%s' and status in ('prepared', 'aborting', 'submitted') limit 1 )`, nextCronTime), + dtmimp.DBTypeMysql: fmt.Sprintf(`next_cron_time < '%s' and status in ('prepared', 'aborting', 'submitted') limit 1`, nextCronTime), + dtmimp.DBTypePostgres: fmt.Sprintf(`id in (select id from trans_global where next_cron_time < '%s' and status in ('prepared', 'aborting', 'submitted') limit 1 )`, nextCronTime), + dtmimp.DBTypeSqlServer: fmt.Sprintf(`id in (select top 1 id from trans_global where next_cron_time < '%s' and status in ('prepared', 'aborting', 'submitted') )`, nextCronTime), }[conf.Store.Driver] ssql := fmt.Sprintf(`select count(1) from trans_global where %s`, where) + var cnt int64 err := db.ToSQLDB().QueryRow(ssql).Scan(&cnt) dtmimp.PanicIf(err != nil, err) @@ -193,8 +201,9 @@ func (s *Store) LockOneGlobalTrans(expireIn time.Duration) *storage.TransGlobalS func (s *Store) ResetCronTime(after time.Duration, limit int64) (succeedCount int64, hasRemaining bool, err error) { nextCronTime := getTimeStr(int64(after / time.Second)) where := map[string]string{ - dtmimp.DBTypeMysql: fmt.Sprintf(`next_cron_time > '%s' and status in ('prepared', 'aborting', 'submitted') limit %d`, nextCronTime, limit), - dtmimp.DBTypePostgres: fmt.Sprintf(`id in (select id from trans_global where next_cron_time > '%s' and status in ('prepared', 'aborting', 'submitted') limit %d )`, nextCronTime, limit), + dtmimp.DBTypeMysql: fmt.Sprintf(`next_cron_time > '%s' and status in ('prepared', 'aborting', 'submitted') limit %d`, nextCronTime, limit), + dtmimp.DBTypePostgres: fmt.Sprintf(`id in (select id from trans_global where next_cron_time > '%s' and status in ('prepared', 'aborting', 'submitted') limit %d )`, nextCronTime, limit), + dtmimp.DBTypeSqlServer: fmt.Sprintf(`id in (select top %d id from trans_global where next_cron_time > '%s' and status in ('prepared', 'aborting', 'submitted') )`, limit, nextCronTime), }[conf.Store.Driver] sql := fmt.Sprintf(`UPDATE trans_global SET update_time='%s',next_cron_time='%s' WHERE %s`, @@ -312,5 +321,8 @@ func wrapError(err error) error { } func getTimeStr(afterSecond int64) string { + if conf.Store.Driver == config.SqlServer { + return dtmutil.GetNextTime(afterSecond).Format(time.RFC3339) + } return dtmutil.GetNextTime(afterSecond).Format("2006-01-02 15:04:05") } diff --git a/dtmutil/db.go b/dtmutil/db.go index 7e8423e47..4237ccffc 100644 --- a/dtmutil/db.go +++ b/dtmutil/db.go @@ -11,8 +11,11 @@ import ( "github.com/dtm-labs/logger" _ "github.com/go-sql-driver/mysql" // register mysql driver _ "github.com/lib/pq" // register postgres driver + + // _ "github.com/microsoft/go-mssqldb" // Microsoft's package conflicts with gorm's package: panic: sql: Register called twice for driver mssql "gorm.io/driver/mysql" "gorm.io/driver/postgres" + "gorm.io/driver/sqlserver" // register sqlserver driver, "gorm.io/gorm" ) @@ -27,6 +30,9 @@ func getGormDialetor(driver string, dsn string) gorm.Dialector { if driver == dtmcli.DBTypePostgres { return postgres.Open(dsn) } + if driver == dtmcli.DBTypeSqlServer { + return sqlserver.Open(dsn) + } dtmimp.PanicIf(driver != dtmcli.DBTypeMysql, fmt.Errorf("unknown driver: %s", driver)) return mysql.Open(dsn) } diff --git a/test/main_test.go b/test/main_test.go index b91e9dd85..5b2c1cecb 100644 --- a/test/main_test.go +++ b/test/main_test.go @@ -53,6 +53,10 @@ func TestMain(m *testing.M) { conf.Store.User = "" conf.Store.Password = "" conf.Store.Port = 6379 + } else if tenv == config.SqlServer { + conf.Store.User = "sa" + conf.Store.Password = "p@ssw0rd" + conf.Store.Port = 1433 } conf.Store.Db = "" registry.WaitStoreUp()