Skip to content

Commit

Permalink
fix: fix sql render bug after use ReplaceDB & UseDB lost db context
Browse files Browse the repository at this point in the history
  • Loading branch information
clcy1243 committed Nov 2, 2023
1 parent f534623 commit fe492f5
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 2 deletions.
5 changes: 3 additions & 2 deletions do.go
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ var (

// UseDB specify a db connection(*gorm.DB)
func (d *DO) UseDB(db *gorm.DB, opts ...DOOption) {
db = db.Session(&gorm.Session{Context: context.Background()})
db = db.Session(&gorm.Session{Context: db.Statement.Context})
d.db = db
config := &DOConfig{}
for _, opt := range opts {
Expand All @@ -65,7 +65,8 @@ func (d *DO) UseDB(db *gorm.DB, opts ...DOOption) {

// ReplaceDB replace db connection
func (d *DO) ReplaceDB(db *gorm.DB) {
d.db = db.Session(&gorm.Session{})
d.db = db.Session(&gorm.Session{Context: db.Statement.Context})
d.UseModel(reflect.New(d.modelType).Interface())
}

// ReplaceConnPool replace db connection pool
Expand Down
54 changes: 54 additions & 0 deletions do_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"reflect"
"strings"
"testing"
"context"

"gorm.io/datatypes"
"gorm.io/gorm"
Expand Down Expand Up @@ -73,6 +74,59 @@ func build(stmt *gorm.Statement, opts ...stmtOpt) *gorm.Statement {
return stmt
}

func TestDO_ReplaceDB(t *testing.T) {
globalUDB := u.db
globalStDB := student.db
defer func() {
u.db = globalUDB
student.db = globalStDB
}()
buildSql := func(u *user, student *Student) string {
e := u.Select().Where(
u.Columns(u.ID).Eq(
u.Select(u.ID).Where(
u.Columns(u.Name).Eq(
student.Select(student.Name).Where(student.ID.Eq(1)),
),
),
),
)
stmt := build(e.underlyingDB().Statement)
return strings.TrimSpace(stmt.SQL.String())
}

currentDB := db.Session(&gorm.Session{Context: context.TODO()})
// old func, Session Context is nil, so u and student will use same db.Statement, one point, same address
u.db = currentDB.Session(&gorm.Session{})
student.db = currentDB.Session(&gorm.Session{})
sql := buildSql(u, student) // sql will lost table name
result := "SELECT * WHERE `id` = (SELECT `id` FROM ` WHERE `name` = (SELECT `student`.`name` FROM ` WHERE `student`.`id` = ?))"
if sql != result {
t.Errorf("SQL expects %v got %v", result, sql)
}

u.UseModel(User{})
student.UseModel(StudentRaw{})
sql = buildSql(u, student) // sql will render with wrong table name
result = "SELECT * WHERE `id` = (SELECT `id` FROM `users_info` WHERE `name` = (SELECT `student`.`name` FROM `users_info` WHERE `student`.`id` = ?))"
if sql != result {
t.Errorf("SQL expects %v got %v", result, sql)
}

u.db = globalUDB
student.db = globalStDB

// new func
u.ReplaceDB(db)
student.ReplaceDB(db)

sql = buildSql(u, student) // sql will be right
result = "SELECT * WHERE `id` = (SELECT `id` FROM `users_info` WHERE `name` = (SELECT `student`.`name` FROM `student` WHERE `student`.`id` = ?))"
if sql != result {
t.Errorf("SQL expects %v got %v", result, sql)
}
}

func TestDO_methods(t *testing.T) {
testcases := []struct {
Expr SubQuery
Expand Down

0 comments on commit fe492f5

Please sign in to comment.