diff --git a/pkg/dbconn/metadatalock.go b/pkg/dbconn/metadatalock.go new file mode 100644 index 0000000..933348f --- /dev/null +++ b/pkg/dbconn/metadatalock.go @@ -0,0 +1,106 @@ +package dbconn + +import ( + "context" + "errors" + "fmt" + "time" + + "github.com/siddontang/loggers" +) + +var ( + // getLockTimeout is the timeout for acquiring the GET_LOCK. We set it to 0 + // because we want to return immediately if the lock is not available + getLockTimeout = 0 * time.Second + refreshInterval = 1 * time.Minute +) + +type MetadataLock struct { + cancel context.CancelFunc + closeCh chan error + refreshInterval time.Duration +} + +func NewMetadataLock(ctx context.Context, dsn string, lockName string, logger loggers.Advanced, optionFns ...func(*MetadataLock)) (*MetadataLock, error) { + if len(lockName) == 0 { + return nil, errors.New("metadata lock name is empty") + } + if len(lockName) > 64 { + return nil, fmt.Errorf("metadata lock name is too long: %d, max length is 64", len(lockName)) + } + + mdl := &MetadataLock{ + refreshInterval: refreshInterval, + } + + // Apply option functions + for _, optionFn := range optionFns { + optionFn(mdl) + } + + // Setup the dedicated connection for this lock + dbConfig := NewDBConfig() + dbConfig.MaxOpenConnections = 1 + dbConn, err := New(dsn, dbConfig) + if err != nil { + return nil, err + } + + // Function to acquire the lock + getLock := func() error { + // https://dev.mysql.com/doc/refman/8.0/en/locking-functions.html#function_get-lock + var answer int + if err := dbConn.QueryRowContext(ctx, "SELECT GET_LOCK(?, ?)", lockName, getLockTimeout.Seconds()).Scan(&answer); err != nil { + return fmt.Errorf("could not acquire metadata lock: %s", err) + } + if answer == 0 { + // 0 means the lock is held by another connection + // TODO: we could lookup the connection that holds the lock and report details about it + return fmt.Errorf("could not acquire metadata lock: %s, lock is held by another connection", lockName) + } else if answer != 1 { + // probably we never get here, but just in case + return fmt.Errorf("could not acquire metadata lock: %s, GET_LOCK returned: %d", lockName, answer) + } + return nil + } + + // Acquire the lock or return an error immediately + logger.Infof("attempting to acquire metadata lock: %s", lockName) + if err = getLock(); err != nil { + return nil, err + } + logger.Infof("acquired metadata lock: %s", lockName) + + // Setup background refresh runner + ctx, mdl.cancel = context.WithCancel(ctx) + mdl.closeCh = make(chan error) + go func() { + ticker := time.NewTicker(mdl.refreshInterval) + defer ticker.Stop() + for { + select { + case <-ctx.Done(): + // Close the dedicated connection to release the lock + logger.Warnf("releasing metadata lock: %s", lockName) + mdl.closeCh <- dbConn.Close() + return + case <-ticker.C: + if err = getLock(); err != nil { + logger.Errorf("could not refresh metadata lock: %s", err) + } + logger.Infof("refreshed metadata lock: %s", lockName) + } + } + }() + + return mdl, nil +} + +func (m *MetadataLock) Close() error { + // Cancel the background refresh runner + m.cancel() + + // Wait for the dedicated connection to be closed and return its error (if any) + return <-m.closeCh +} diff --git a/pkg/dbconn/metadatalock_test.go b/pkg/dbconn/metadatalock_test.go new file mode 100644 index 0000000..f3b6cb6 --- /dev/null +++ b/pkg/dbconn/metadatalock_test.go @@ -0,0 +1,87 @@ +package dbconn + +import ( + "context" + "testing" + "time" + + "github.com/cashapp/spirit/pkg/testutils" + "github.com/sirupsen/logrus" + "github.com/stretchr/testify/assert" +) + +func TestMetadataLock(t *testing.T) { + lockName := "test" + logger := logrus.New() + mdl, err := NewMetadataLock(context.Background(), testutils.DSN(), lockName, logger) + assert.NoError(t, err) + assert.NotNil(t, mdl) + + // Confirm a second lock cannot be acquired + _, err = NewMetadataLock(context.Background(), testutils.DSN(), lockName, logger) + assert.ErrorContains(t, err, "lock is held by another connection") + + // Close the original mdl + assert.NoError(t, mdl.Close()) + + // Confirm a new lock can be acquired + mdl3, err := NewMetadataLock(context.Background(), testutils.DSN(), lockName, logger) + assert.NoError(t, err) + assert.NoError(t, mdl3.Close()) +} + +func TestMetadataLockContextCancel(t *testing.T) { + lockName := "test-cancel" + + logger := logrus.New() + ctx, cancel := context.WithCancel(context.Background()) + mdl, err := NewMetadataLock(ctx, testutils.DSN(), lockName, logger) + assert.NoError(t, err) + assert.NotNil(t, mdl) + + // Cancel the context + cancel() + + // Wait for the lock to be released + <-mdl.closeCh + + // Confirm the lock is released by acquiring a new one + mdl2, err := NewMetadataLock(context.Background(), testutils.DSN(), lockName, logger) + assert.NoError(t, err) + assert.NotNil(t, mdl2) + assert.NoError(t, mdl2.Close()) +} + +func TestMetadataLockRefresh(t *testing.T) { + lockName := "test-refresh" + logger := logrus.New() + mdl, err := NewMetadataLock(context.Background(), testutils.DSN(), lockName, logger, func(mdl *MetadataLock) { + // override the refresh interval for faster testing + mdl.refreshInterval = 2 * time.Second + }) + assert.NoError(t, err) + assert.NotNil(t, mdl) + + // wait for the refresh to happen + time.Sleep(5 * time.Second) + + // Confirm the lock is still held + _, err = NewMetadataLock(context.Background(), testutils.DSN(), lockName, logger) + assert.ErrorContains(t, err, "lock is held by another connection") + + // Close the lock + assert.NoError(t, mdl.Close()) +} + +func TestMetadataLockLength(t *testing.T) { + long := "thisisareallylongtablenamethisisareallylongtablenamethisisareallylongtablename" + empty := "" + + logger := logrus.New() + + _, err := NewMetadataLock(context.Background(), testutils.DSN(), long, logger) + assert.ErrorContains(t, err, "metadata lock name is too long") + + _, err = NewMetadataLock(context.Background(), testutils.DSN(), empty, logger) + assert.ErrorContains(t, err, "metadata lock name is empty") +} diff --git a/pkg/migration/runner.go b/pkg/migration/runner.go index b9d209f..a81cf8c 100644 --- a/pkg/migration/runner.go +++ b/pkg/migration/runner.go @@ -83,6 +83,7 @@ type Runner struct { table *table.TableInfo newTable *table.TableInfo checkpointTable *table.TableInfo + metadataLock *dbconn.MetadataLock currentState migrationState // must use atomic to get/set replClient *repl.Client // feed contains all binlog subscription activity. @@ -190,6 +191,12 @@ func (r *Runner) Run(originalCtx context.Context) error { return err } + // Take a metadata lock to prevent other migrations from running concurrently. + r.metadataLock, err = dbconn.NewMetadataLock(ctx, r.dsn(), fmt.Sprintf("spirit_%s_%s", r.migration.Database, r.migration.Table), r.logger) + if err != nil { + return err + } + // Get Table Info r.table = table.NewTableInfo(r.db, r.migration.Database, r.migration.Table) if err := r.table.SetInfo(ctx); err != nil { @@ -702,6 +709,12 @@ func (r *Runner) Close() error { return err } } + if r.metadataLock != nil { + err := r.metadataLock.Close() + if err != nil { + return err + } + } return nil } diff --git a/pkg/migration/runner_test.go b/pkg/migration/runner_test.go index a43ceb1..75f3c90 100644 --- a/pkg/migration/runner_test.go +++ b/pkg/migration/runner_test.go @@ -2853,6 +2853,7 @@ func TestIndexVisibility(t *testing.T) { assert.NoError(t, err) assert.True(t, m.usedInplaceDDL) // expected to count as safe. + assert.NoError(t, m.Close()) // Test again with visible m, err = NewRunner(&Migration{ @@ -2868,6 +2869,7 @@ func TestIndexVisibility(t *testing.T) { err = m.Run(context.Background()) assert.NoError(t, err) assert.True(t, m.usedInplaceDDL) // expected to count as safe. + assert.NoError(t, m.Close()) // Test again but include an unsafe INPLACE change at the same time. // This won't work by default. @@ -2899,6 +2901,7 @@ func TestIndexVisibility(t *testing.T) { assert.NoError(t, err) err = m.Run(context.Background()) assert.NoError(t, err) + assert.NoError(t, m.Close()) // But even when force inplace is set, we won't be able to do an operation // that requires a full copy. This is important because invisible should @@ -2918,3 +2921,63 @@ func TestIndexVisibility(t *testing.T) { assert.Error(t, err) assert.NoError(t, m.Close()) // it's errored, we don't need to try again. We can close. } + +func TestPreventConcurrentRuns(t *testing.T) { + sentinelWaitLimit = 10 * time.Second + + tableName := `prevent_concurrent_runs` + sentinelTableName := fmt.Sprintf("_%s_sentinel", tableName) + checkpointTableName := fmt.Sprintf("_%s_chkpnt", tableName) + + dropStmt := `DROP TABLE IF EXISTS %s` + testutils.RunSQL(t, fmt.Sprintf(dropStmt, tableName)) + testutils.RunSQL(t, fmt.Sprintf(dropStmt, sentinelTableName)) + testutils.RunSQL(t, fmt.Sprintf(dropStmt, checkpointTableName)) + + table := fmt.Sprintf(`CREATE TABLE %s (id bigint unsigned not null auto_increment, primary key(id))`, tableName) + + testutils.RunSQL(t, table) + testutils.RunSQL(t, fmt.Sprintf("insert into %s () values (),(),(),(),(),(),(),(),(),()", tableName)) + testutils.RunSQL(t, fmt.Sprintf("insert into %s (id) select null from %s a, %s b, %s c limit 1000", tableName, tableName, tableName, tableName)) + + cfg, err := mysql.ParseDSN(testutils.DSN()) + assert.NoError(t, err) + m, err := NewRunner(&Migration{ + Host: cfg.Addr, + Username: cfg.User, + Password: cfg.Passwd, + Database: cfg.DBName, + Threads: 4, + Table: tableName, + Alter: "ENGINE=InnoDB", + SkipDropAfterCutover: false, + DeferCutOver: true, + }) + assert.NoError(t, err) + wg := sync.WaitGroup{} + wg.Add(1) + go func() { + defer wg.Done() + err = m.Run(context.Background()) + assert.Error(t, err) + assert.ErrorContains(t, err, "timed out waiting for sentinel table to be dropped") + }() + + // While it's waiting, start another run and confirm it fails. + time.Sleep(1 * time.Second) + m2, err := NewRunner(&Migration{ + Host: cfg.Addr, + Username: cfg.User, + Password: cfg.Passwd, + Database: cfg.DBName, + Threads: 4, + Table: tableName, + Alter: "ENGINE=InnoDB", + SkipDropAfterCutover: false, + DeferCutOver: false, + }) + assert.NoError(t, err) + err = m2.Run(context.Background()) + assert.Error(t, err) + assert.ErrorContains(t, err, "could not acquire metadata lock") +}