From fe492f5109e338dc3ad0412ad5fe2baeb3b65da1 Mon Sep 17 00:00:00 2001 From: Will Date: Thu, 2 Nov 2023 14:18:16 +0000 Subject: [PATCH] fix: fix sql render bug after use ReplaceDB & UseDB lost db context --- do.go | 5 +++-- do_test.go | 54 ++++++++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 57 insertions(+), 2 deletions(-) diff --git a/do.go b/do.go index aca161a5..21b42abb 100644 --- a/do.go +++ b/do.go @@ -50,7 +50,7 @@ var ( // UseDB specify a db connection(*gorm.DB) func (d *DO) UseDB(db *gorm.DB, opts ...DOOption) { - db = db.Session(&gorm.Session{Context: context.Background()}) + db = db.Session(&gorm.Session{Context: db.Statement.Context}) d.db = db config := &DOConfig{} for _, opt := range opts { @@ -65,7 +65,8 @@ func (d *DO) UseDB(db *gorm.DB, opts ...DOOption) { // ReplaceDB replace db connection func (d *DO) ReplaceDB(db *gorm.DB) { - d.db = db.Session(&gorm.Session{}) + d.db = db.Session(&gorm.Session{Context: db.Statement.Context}) + d.UseModel(reflect.New(d.modelType).Interface()) } // ReplaceConnPool replace db connection pool diff --git a/do_test.go b/do_test.go index 75f1e8a0..9b32a8f0 100644 --- a/do_test.go +++ b/do_test.go @@ -4,6 +4,7 @@ import ( "reflect" "strings" "testing" + "context" "gorm.io/datatypes" "gorm.io/gorm" @@ -73,6 +74,59 @@ func build(stmt *gorm.Statement, opts ...stmtOpt) *gorm.Statement { return stmt } +func TestDO_ReplaceDB(t *testing.T) { + globalUDB := u.db + globalStDB := student.db + defer func() { + u.db = globalUDB + student.db = globalStDB + }() + buildSql := func(u *user, student *Student) string { + e := u.Select().Where( + u.Columns(u.ID).Eq( + u.Select(u.ID).Where( + u.Columns(u.Name).Eq( + student.Select(student.Name).Where(student.ID.Eq(1)), + ), + ), + ), + ) + stmt := build(e.underlyingDB().Statement) + return strings.TrimSpace(stmt.SQL.String()) + } + + currentDB := db.Session(&gorm.Session{Context: context.TODO()}) + // old func, Session Context is nil, so u and student will use same db.Statement, one point, same address + u.db = currentDB.Session(&gorm.Session{}) + student.db = currentDB.Session(&gorm.Session{}) + sql := buildSql(u, student) // sql will lost table name + result := "SELECT * WHERE `id` = (SELECT `id` FROM ` WHERE `name` = (SELECT `student`.`name` FROM ` WHERE `student`.`id` = ?))" + if sql != result { + t.Errorf("SQL expects %v got %v", result, sql) + } + + u.UseModel(User{}) + student.UseModel(StudentRaw{}) + sql = buildSql(u, student) // sql will render with wrong table name + result = "SELECT * WHERE `id` = (SELECT `id` FROM `users_info` WHERE `name` = (SELECT `student`.`name` FROM `users_info` WHERE `student`.`id` = ?))" + if sql != result { + t.Errorf("SQL expects %v got %v", result, sql) + } + + u.db = globalUDB + student.db = globalStDB + + // new func + u.ReplaceDB(db) + student.ReplaceDB(db) + + sql = buildSql(u, student) // sql will be right + result = "SELECT * WHERE `id` = (SELECT `id` FROM `users_info` WHERE `name` = (SELECT `student`.`name` FROM `student` WHERE `student`.`id` = ?))" + if sql != result { + t.Errorf("SQL expects %v got %v", result, sql) + } +} + func TestDO_methods(t *testing.T) { testcases := []struct { Expr SubQuery