Skip to content

Commit

Permalink
implement sqlserver storage using gorm.io/driver/sqlserver
Browse files Browse the repository at this point in the history
  • Loading branch information
wooln committed Oct 19, 2023
1 parent ee7f283 commit 952ced4
Show file tree
Hide file tree
Showing 9 changed files with 74 additions and 10 deletions.
2 changes: 2 additions & 0 deletions client/dtmcli/consts.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ const (
DBTypeMysql = dtmimp.DBTypeMysql
// DBTypePostgres const for driver postgres
DBTypePostgres = dtmimp.DBTypePostgres
// DBTypeSqlServer const for driver SqlServer
DBTypeSqlServer = dtmimp.DBTypeSqlServer

Check failure on line 39 in client/dtmcli/consts.go

View workflow job for this annotation

GitHub Actions / CI

const DBTypeSqlServer should be DBTypeSQLServer
)

// MapSuccess HTTP result of SUCCESS
Expand Down
2 changes: 2 additions & 0 deletions client/dtmcli/dtmimp/consts.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ const (
DBTypeMysql = "mysql"
// DBTypePostgres const for driver postgres
DBTypePostgres = "postgres"
// DBTypeSqlServer const for driver SqlServer
DBTypeSqlServer = "sqlserver"

Check failure on line 40 in client/dtmcli/dtmimp/consts.go

View workflow job for this annotation

GitHub Actions / CI

const DBTypeSqlServer should be DBTypeSQLServer
// DBTypeRedis const for driver redis
DBTypeRedis = "redis"
// Jrpc const for json-rpc
Expand Down
21 changes: 21 additions & 0 deletions client/dtmcli/dtmimp/db_special.go
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,27 @@ func init() {
dbSpecials[DBTypePostgres] = &postgresDBSpecial{}
}

// SqlServer 版本的实现TODO
type sqlserverDBSpecial struct{}

func (*sqlserverDBSpecial) GetPlaceHoldSQL(sql string) string {
return sql
}

func (*sqlserverDBSpecial) GetInsertIgnoreTemplate(tableAndValues string, pgConstraint string) string {
return fmt.Sprintf("insert ignore into %s", tableAndValues) //这个只有作为client的事务屏障才用的吧? Server端不用吧
}

func (*sqlserverDBSpecial) GetXaSQL(command string, xid string) string {
if command == "abort" {
command = "rollback"
}
return fmt.Sprintf("xa %s '%s'", command, xid)
}
func init() {
dbSpecials[DBTypeSqlServer] = &sqlserverDBSpecial{}
}

// GetDBSpecial get DBSpecial for currentDBType
func GetDBSpecial(dbType string) DBSpecial {
if dbType == "" {
Expand Down
14 changes: 14 additions & 0 deletions client/dtmcli/dtmimp/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -221,13 +221,27 @@ func DBExec(dbType string, db DB, sql string, values ...interface{}) (affected i

// GetDsn get dsn from map config
func GetDsn(conf DBConf) string {

host := MayReplaceLocalhost(conf.Host)
driver := conf.Driver

query := url.Values{}
query.Add("database", conf.Db)
u := &url.URL{
Scheme: "sqlserver",
User: url.UserPassword(conf.User, conf.Password),
Host: fmt.Sprintf("%s:%d", host, conf.Port),
// Path: instance, // if connecting to an instance instead of a port
RawQuery: query.Encode(),
}

dsn := map[string]string{
"mysql": fmt.Sprintf("%s:%s@tcp(%s:%d)/%s?charset=utf8mb4&parseTime=true&loc=Local&interpolateParams=true",
conf.User, conf.Password, host, conf.Port, conf.Db),
"postgres": fmt.Sprintf("host=%s user=%s password=%s dbname='%s' search_path=%s port=%d sslmode=disable",
host, conf.User, conf.Password, conf.Db, conf.Schema, conf.Port),
// sqlserver://sa:mypass@localhost:1234?database=master&connection+timeout=30
"sqlserver": u.String(),
}[driver]
PanicIf(dsn == "", fmt.Errorf("unknow driver: %s", driver))
return dsn
Expand Down
4 changes: 3 additions & 1 deletion dtmsvr/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ const (
BoltDb = "boltdb"
// Postgres is postgres driver
Postgres = "postgres"
// SqlServer is SQL Server driver
SqlServer = "sqlserver"

Check failure on line 24 in dtmsvr/config/config.go

View workflow job for this annotation

GitHub Actions / CI

const SqlServer should be SQLServer
)

// MicroService config type for microservice based grpc
Expand Down Expand Up @@ -65,7 +67,7 @@ type Store struct {

// IsDB checks config driver is mysql or postgres
func (s *Store) IsDB() bool {
return s.Driver == dtmcli.DBTypeMysql || s.Driver == dtmcli.DBTypePostgres
return s.Driver == dtmcli.DBTypeMysql || s.Driver == dtmcli.DBTypePostgres || s.Driver == dtmcli.DBTypeSqlServer
}

// GetDBConf returns db conf info
Expand Down
5 changes: 3 additions & 2 deletions dtmsvr/storage/registry/registry.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,9 @@ var storeFactorys = map[string]StorageFactory{
return &redis.Store{}
},
},
"mysql": sqlFac,
"postgres": sqlFac,
"mysql": sqlFac,
"postgres": sqlFac,
"sqlserver": sqlFac,
}

// GetStore returns storage.Store
Expand Down
26 changes: 19 additions & 7 deletions dtmsvr/storage/sql/sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -67,10 +67,10 @@ func (s *Store) ScanTransGlobalStores(position *string, limit int64, condition s
query = query.Where("trans_type = ?", condition.TransType)
}
if !condition.CreateTimeStart.IsZero() {
query = query.Where("create_time >= ?", condition.CreateTimeStart.Format("2006-01-02 15:04:05"))
query = query.Where("create_time >= ?", condition.CreateTimeStart)
}
if !condition.CreateTimeEnd.IsZero() {
query = query.Where("create_time <= ?", condition.CreateTimeEnd.Format("2006-01-02 15:04:05"))
query = query.Where("create_time <= ?", condition.CreateTimeEnd)
}

dbr := query.Order("id desc").Limit(int(limit)).Find(&globals)
Expand Down Expand Up @@ -103,7 +103,13 @@ func (s *Store) UpdateBranches(branches []storage.TransBranchStore, updates []st
func (s *Store) LockGlobalSaveBranches(gid string, status string, branches []storage.TransBranchStore, branchStart int) {
err := dbGet().Transaction(func(tx *gorm.DB) error {
g := &storage.TransGlobalStore{}
dbr := tx.Clauses(clause.Locking{Strength: "UPDATE"}).Model(g).Where("gid=? and status=?", gid, status).First(g)
var dbr *gorm.DB
// sqlserver sql should be: SELECT * FROM "trans_global" with(RowLock,UpdLock) ,but gorm generates "FOR UPDATE" at the back, raw sql instead.
if conf.Store.Driver == config.SqlServer {
dbr = tx.Raw("SELECT * FROM trans_global with(RowLock,UpdLock) WHERE gid=? and status=? ORDER BY id OFFSET 0 ROW FETCH NEXT 1 ROWS ONLY ", gid, status).First(g)
} else {
dbr = tx.Clauses(clause.Locking{Strength: "UPDATE"}).Model(g).Where("gid=? and status=?", gid, status).First(g)
}
if dbr.Error == nil {
if branchStart == -1 {
dbr = tx.Create(branches)
Expand Down Expand Up @@ -160,11 +166,13 @@ func (s *Store) LockOneGlobalTrans(expireIn time.Duration) *storage.TransGlobalS
owner := shortuuid.New()
nextCronTime := getTimeStr(int64(expireIn / time.Second))
where := map[string]string{
dtmimp.DBTypeMysql: fmt.Sprintf(`next_cron_time < '%s' and status in ('prepared', 'aborting', 'submitted') limit 1`, nextCronTime),
dtmimp.DBTypePostgres: fmt.Sprintf(`id in (select id from trans_global where next_cron_time < '%s' and status in ('prepared', 'aborting', 'submitted') limit 1 )`, nextCronTime),
dtmimp.DBTypeMysql: fmt.Sprintf(`next_cron_time < '%s' and status in ('prepared', 'aborting', 'submitted') limit 1`, nextCronTime),
dtmimp.DBTypePostgres: fmt.Sprintf(`id in (select id from trans_global where next_cron_time < '%s' and status in ('prepared', 'aborting', 'submitted') limit 1 )`, nextCronTime),
dtmimp.DBTypeSqlServer: fmt.Sprintf(`id in (select top 1 id from trans_global where next_cron_time < '%s' and status in ('prepared', 'aborting', 'submitted') )`, nextCronTime),
}[conf.Store.Driver]

ssql := fmt.Sprintf(`select count(1) from trans_global where %s`, where)

var cnt int64
err := db.ToSQLDB().QueryRow(ssql).Scan(&cnt)
dtmimp.PanicIf(err != nil, err)
Expand Down Expand Up @@ -193,8 +201,9 @@ func (s *Store) LockOneGlobalTrans(expireIn time.Duration) *storage.TransGlobalS
func (s *Store) ResetCronTime(after time.Duration, limit int64) (succeedCount int64, hasRemaining bool, err error) {
nextCronTime := getTimeStr(int64(after / time.Second))
where := map[string]string{
dtmimp.DBTypeMysql: fmt.Sprintf(`next_cron_time > '%s' and status in ('prepared', 'aborting', 'submitted') limit %d`, nextCronTime, limit),
dtmimp.DBTypePostgres: fmt.Sprintf(`id in (select id from trans_global where next_cron_time > '%s' and status in ('prepared', 'aborting', 'submitted') limit %d )`, nextCronTime, limit),
dtmimp.DBTypeMysql: fmt.Sprintf(`next_cron_time > '%s' and status in ('prepared', 'aborting', 'submitted') limit %d`, nextCronTime, limit),
dtmimp.DBTypePostgres: fmt.Sprintf(`id in (select id from trans_global where next_cron_time > '%s' and status in ('prepared', 'aborting', 'submitted') limit %d )`, nextCronTime, limit),
dtmimp.DBTypeSqlServer: fmt.Sprintf(`id in (select top %d id from trans_global where next_cron_time > '%s' and status in ('prepared', 'aborting', 'submitted') )`, limit, nextCronTime),
}[conf.Store.Driver]

sql := fmt.Sprintf(`UPDATE trans_global SET update_time='%s',next_cron_time='%s' WHERE %s`,
Expand Down Expand Up @@ -312,5 +321,8 @@ func wrapError(err error) error {
}

func getTimeStr(afterSecond int64) string {
if conf.Store.Driver == config.SqlServer {
return dtmutil.GetNextTime(afterSecond).Format(time.RFC3339)
}
return dtmutil.GetNextTime(afterSecond).Format("2006-01-02 15:04:05")
}
6 changes: 6 additions & 0 deletions dtmutil/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,11 @@ import (
"github.com/dtm-labs/logger"
_ "github.com/go-sql-driver/mysql" // register mysql driver
_ "github.com/lib/pq" // register postgres driver

// _ "github.com/microsoft/go-mssqldb" // Microsoft's package conflicts with gorm's package: panic: sql: Register called twice for driver mssql
"gorm.io/driver/mysql"
"gorm.io/driver/postgres"
"gorm.io/driver/sqlserver" // register sqlserver driver,
"gorm.io/gorm"
)

Expand All @@ -27,6 +30,9 @@ func getGormDialetor(driver string, dsn string) gorm.Dialector {
if driver == dtmcli.DBTypePostgres {
return postgres.Open(dsn)
}
if driver == dtmcli.DBTypeSqlServer {
return sqlserver.Open(dsn)
}
dtmimp.PanicIf(driver != dtmcli.DBTypeMysql, fmt.Errorf("unknown driver: %s", driver))
return mysql.Open(dsn)
}
Expand Down
4 changes: 4 additions & 0 deletions test/main_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,10 @@ func TestMain(m *testing.M) {
conf.Store.User = ""
conf.Store.Password = ""
conf.Store.Port = 6379
} else if tenv == config.SqlServer {
conf.Store.User = "sa"
conf.Store.Password = "p@ssw0rd"
conf.Store.Port = 1433
}
conf.Store.Db = ""
registry.WaitStoreUp()
Expand Down

0 comments on commit 952ced4

Please sign in to comment.