Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: trace resolver mode #89

Merged
merged 1 commit into from
Dec 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 12 additions & 6 deletions dbresolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@ type DBResolver struct {
}

type Config struct {
Sources []gorm.Dialector
Replicas []gorm.Dialector
Policy Policy
datas []interface{}
Sources []gorm.Dialector
Replicas []gorm.Dialector
Policy Policy
datas []interface{}
TraceResolverMode bool
}

func Register(config Config, datas ...interface{}) *DBResolver {
Expand Down Expand Up @@ -76,8 +77,9 @@ func (dr *DBResolver) compileConfig(config Config) (err error) {
var (
connPool = dr.DB.Config.ConnPool
r = resolver{
dbResolver: dr,
policy: config.Policy,
dbResolver: dr,
policy: config.Policy,
traceResolverMode: config.TraceResolverMode,
}
)

Expand Down Expand Up @@ -122,6 +124,10 @@ func (dr *DBResolver) compileConfig(config Config) (err error) {
}
}

if config.TraceResolverMode {
dr.Logger = NewResolverModeLogger(dr.Logger)
}

return nil
}

Expand Down
13 changes: 11 additions & 2 deletions dbresolver_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@ package dbresolver_test

import (
"fmt"
"os"
"testing"

"gorm.io/driver/mysql"
"gorm.io/gorm"
"gorm.io/gorm/logger"
"gorm.io/plugin/dbresolver"
)

Expand Down Expand Up @@ -50,16 +52,23 @@ func TestDBResolver(t *testing.T) {
if err != nil {
t.Fatalf("failed to connect db, got error: %v", err)
}
if debug := os.Getenv("DEBUG"); debug == "true" {
DB.Logger = DB.Logger.LogMode(logger.Info)
} else if debug == "false" {
DB.Logger = DB.Logger.LogMode(logger.Silent)
}

if err := DB.Use(dbresolver.Register(dbresolver.Config{
Sources: []gorm.Dialector{mysql.Open("gorm:gorm@tcp(localhost:9911)/gorm?charset=utf8&parseTime=True&loc=Local")},
Replicas: []gorm.Dialector{
mysql.Open("gorm:gorm@tcp(localhost:9912)/gorm?charset=utf8&parseTime=True&loc=Local"),
mysql.Open("gorm:gorm@tcp(localhost:9913)/gorm?charset=utf8&parseTime=True&loc=Local"),
},
TraceResolverMode: true,
}).Register(dbresolver.Config{
Sources: []gorm.Dialector{mysql.Open("gorm:gorm@tcp(localhost:9914)/gorm?charset=utf8&parseTime=True&loc=Local")},
Replicas: []gorm.Dialector{mysql.Open("gorm:gorm@tcp(localhost:9913)/gorm?charset=utf8&parseTime=True&loc=Local")},
Sources: []gorm.Dialector{mysql.Open("gorm:gorm@tcp(localhost:9914)/gorm?charset=utf8&parseTime=True&loc=Local")},
Replicas: []gorm.Dialector{mysql.Open("gorm:gorm@tcp(localhost:9913)/gorm?charset=utf8&parseTime=True&loc=Local")},
TraceResolverMode: true,
}, "users", &Product{}).SetMaxOpenConns(5)); err != nil {
t.Fatalf("failed to use plugin, got error: %v", err)
}
Expand Down
54 changes: 54 additions & 0 deletions logger.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
package dbresolver

import (
"context"
"fmt"
"time"

"gorm.io/gorm"
"gorm.io/gorm/logger"
)

type ResolverModeKey string
type ResolverMode string

const resolverModeKey ResolverModeKey = "dbresolver:resolver_mode_key"
const (
ResolverModeSource ResolverMode = "source"
ResolverModeReplica ResolverMode = "replica"
)

type resolverModeLogger struct {
logger.Interface
}

func (l resolverModeLogger) Trace(ctx context.Context, begin time.Time, fc func() (sql string, rowsAffected int64), err error) {
var splitFn = func() (sql string, rowsAffected int64) {
sql, rowsAffected = fc()
op := ctx.Value(resolverModeKey)
if op != nil {
sql = fmt.Sprintf("[%s] %s", op, sql)
return
}

// the situation that dbresolver does not handle
// such as transactions, or some resolvers do not enable MarkResolverMode.
return
}
l.Interface.Trace(ctx, begin, splitFn, err)
}

func NewResolverModeLogger(l logger.Interface) logger.Interface {
if _, ok := l.(resolverModeLogger); ok {
return l
}
return resolverModeLogger{
Interface: l,
}
}

func markStmtResolverMode(stmt *gorm.Statement, mode ResolverMode) {
if _, ok := stmt.Logger.(resolverModeLogger); ok {
stmt.Context = context.WithValue(stmt.Context, resolverModeKey, mode)
}
}
18 changes: 14 additions & 4 deletions resolver.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,11 @@ import (
)

type resolver struct {
sources []gorm.ConnPool
replicas []gorm.ConnPool
policy Policy
dbResolver *DBResolver
sources []gorm.ConnPool
replicas []gorm.ConnPool
policy Policy
dbResolver *DBResolver
traceResolverMode bool
}

func (r *resolver) resolve(stmt *gorm.Statement, op Operation) (connPool gorm.ConnPool) {
Expand All @@ -18,10 +19,19 @@ func (r *resolver) resolve(stmt *gorm.Statement, op Operation) (connPool gorm.Co
} else {
connPool = r.policy.Resolve(r.replicas)
}
if r.traceResolverMode {
markStmtResolverMode(stmt, ResolverModeReplica)
}
} else if len(r.sources) == 1 {
connPool = r.sources[0]
if r.traceResolverMode {
markStmtResolverMode(stmt, ResolverModeSource)
}
} else {
connPool = r.policy.Resolve(r.sources)
if r.traceResolverMode {
markStmtResolverMode(stmt, ResolverModeSource)
}
}

if stmt.DB.PrepareStmt {
Expand Down