diff --git a/middleware/mysql/global.go b/middleware/mysql/global.go index 08e2369..be73b76 100644 --- a/middleware/mysql/global.go +++ b/middleware/mysql/global.go @@ -2,8 +2,6 @@ package mysql import ( "errors" - - "github.com/siddontang/go-mysql/mysql" ) var _globalPool *Pool @@ -76,7 +74,12 @@ func Close() error { // Get get gets a connection from pool and validate it, // if there is no valid connection in the pool, it will create a new connection func Get() (*PoolConn, error) { - return _globalPool.Get() + conn, err := _globalPool.Get() + if err != nil { + return nil, err + } + + return conn.(*PoolConn), nil } // Release releases given number of connections of global pool, each connection will disconnect with database @@ -85,7 +88,7 @@ func Release(num int) error { } // Execute execute given sql statement -func Execute(sql string, args ...interface{}) (*mysql.Result, error) { +func Execute(sql string, args ...interface{}) (interface{}, error) { if _globalPool == nil { return nil, errors.New("global pool is nil, please initiate it first") } diff --git a/middleware/mysql/global_test.go b/middleware/mysql/global_test.go index 8b66d5b..afceff9 100644 --- a/middleware/mysql/global_test.go +++ b/middleware/mysql/global_test.go @@ -15,7 +15,7 @@ func TestMySQLGlobalPool(t *testing.T) { err error conn *PoolConn slaveList []string - result *mysql.Result + result interface{} ) asst := assert.New(t) @@ -43,15 +43,15 @@ func TestMySQLGlobalPool(t *testing.T) { err = conn.Close() asst.Nil(err, "close connection failed.") - sql := "select 1 as ok;" - result, err = Execute(sql) + sql := "select ? as ok;" + result, err = Execute(sql, 1) asst.Nil(err, "execute sql with global pool failed.") - actual, err := result.GetIntByName(0, "ok") + actual, err := result.(*mysql.Result).GetIntByName(0, "ok") asst.Nil(err, "execute sql with global pool failed.") asst.Equal(int64(1), actual, "expected and actual values are not equal.") // sleep to test maintain mechanism - time.Sleep(60 * time.Second) + time.Sleep(20 * time.Second) err = Close() asst.Nil(err, "close global pool failed.") diff --git a/middleware/mysql/pool_test.go b/middleware/mysql/pool_test.go index c1d0628..ea3f36d 100644 --- a/middleware/mysql/pool_test.go +++ b/middleware/mysql/pool_test.go @@ -8,13 +8,15 @@ import ( "github.com/siddontang/go-mysql/mysql" "github.com/stretchr/testify/assert" "go.uber.org/zap/zapcore" + + "github.com/romberli/go-util/middleware" ) func TestMySQLPool(t *testing.T) { var ( err error pool *Pool - conn *PoolConn + conn middleware.PoolConn repRole string slaveList []string result *mysql.Result @@ -38,7 +40,7 @@ func TestMySQLPool(t *testing.T) { asst.Nil(err, "get connection from pool failed.") // test connection - slaveList, err = conn.GetReplicationSlaveList() + slaveList, err = conn.(*PoolConn).GetReplicationSlaveList() asst.Nil(err, "get replication slave list failed.") t.Logf("replication slave list: %v", slaveList) @@ -47,19 +49,19 @@ func TestMySQLPool(t *testing.T) { conn, err = pool.Get() asst.Nil(err, "get connection from pool failed.") - result, err = conn.GetReplicationSlavesStatus() + result, err = conn.(*PoolConn).GetReplicationSlavesStatus() asst.Nil(err, "get replication slave status failed.") if result.RowNumber() > 0 { t.Logf("show slave status: %v", result.Values) } else { t.Logf("this is not a slave node.") } - repRole, err = conn.GetReplicationRole() + repRole, err = conn.(*PoolConn).GetReplicationRole() asst.Nil(err, "get replication role failed.") t.Logf("replication role: %s", repRole) // sleep to test maintain mechanism - time.Sleep(60 * time.Second) + time.Sleep(20 * time.Second) err = pool.Close() asst.Nil(err, "close pool failed.")