Skip to content

Commit

Permalink
Database locking
Browse files Browse the repository at this point in the history
* Added driver.Lockable interface, which can be optionally implmented
  in order to make it safe to use concurrently.
  • Loading branch information
josephbuchma committed Dec 27, 2017
1 parent 01e24aa commit 0bb43df
Show file tree
Hide file tree
Showing 4 changed files with 198 additions and 82 deletions.
25 changes: 25 additions & 0 deletions driver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,31 @@ type Driver interface {
Execute(statement string) error
}

// Lockable represents driver that supports database locking.
// Implement if possible to make it safe to run migrations concurrently.
// NOTE: Probably better to move into Driver interface to make sure it's not
// dismissed when locking is possible to implement.
type Lockable interface {
Lock() error
Unlock() error
}

// Lock calls Lock method if driver implements Lockable
func Lock(d Driver) error {
if d, ok := d.(Lockable); ok {
return d.Lock()
}
return nil
}

// Unlock calls Unlock method if driver implements Lockable
func Unlock(d Driver) error {
if d, ok := d.(Lockable); ok {
return d.Unlock()
}
return nil
}

// New returns Driver and calls Initialize on it.
func New(url string) (Driver, error) {
u, err := neturl.Parse(url)
Expand Down
184 changes: 129 additions & 55 deletions migrate.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ import (
type Handle struct {
drv driver.Driver
migrationsPath string
locked bool
fatalErr error
}

// Open Handle instance
Expand All @@ -31,96 +33,119 @@ func Open(url, migrationsPath string) (*Handle, error) {

// Up applies all available migrations.
func (m *Handle) Up(ctx context.Context) error {
files, versions, err := m.readMigrationFilesAndGetVersions()
if err != nil {
return err
}
applyMigrationFiles, err := files.Pending(versions)
if err != nil {
return err
}
for _, f := range applyMigrationFiles {
err = m.drvMigrate(ctx, f)
return m.locking(ctx, func() error {
files, versions, err := m.readMigrationFilesAndGetVersions()
if err != nil {
return err
}
}
return nil
applyMigrationFiles, err := files.Pending(versions)
if err != nil {
return err
}
for _, f := range applyMigrationFiles {
err = m.drvMigrate(ctx, f)
if err != nil {
return err
}
}
return nil
})
}

// Down rolls back all migrations.
func (m *Handle) Down(ctx context.Context) error {
files, versions, err := m.readMigrationFilesAndGetVersions()
if err != nil {
return err
}
applyMigrationFiles, err := files.Applied(versions)
if err != nil {
return err
}

for _, f := range applyMigrationFiles {
err = m.drvMigrate(ctx, f)
return m.locking(ctx, func() error {
files, versions, err := m.readMigrationFilesAndGetVersions()
if err != nil {
return err
}
applyMigrationFiles, err := files.Applied(versions)
if err != nil {
break
return err
}
}
return err

for _, f := range applyMigrationFiles {
err = m.drvMigrate(ctx, f)
if err != nil {
break
}
}
return err
})
}

// Redo rolls back the most recently applied migration, then runs it again.
func (m *Handle) Redo(ctx context.Context) error {
err := m.Migrate(ctx, -1)
if err != nil {
return err
}
return m.Migrate(ctx, +1)
return m.locking(ctx, func() error {
err := m.Migrate(ctx, -1)
if err != nil {
return err
}
return m.Migrate(ctx, +1)
})
}

// Reset runs the Down and Up migration function.
func (m *Handle) Reset(ctx context.Context) error {
err := m.Down(ctx)
if err != nil {
return err
}
return m.Up(ctx)
return m.locking(ctx, func() error {
err := m.Down(ctx)
if err != nil {
return err
}
return m.Up(ctx)
})
}

// Migrate applies relative +n/-n migrations.
func (m *Handle) Migrate(ctx context.Context, relativeN int) error {
// TODO: add Lock/Unlock methods to Driver interface
// for now it's not safe for concurrent execution
files, versions, err := m.readMigrationFilesAndGetVersions()
if err != nil {
return err
}

applyMigrationFiles, err := files.Relative(relativeN, versions)
if err != nil {
return err
}
return m.locking(ctx, func() error {
files, versions, err := m.readMigrationFilesAndGetVersions()
if err != nil {
return err
}

for _, f := range applyMigrationFiles {
err = m.drvMigrate(ctx, f)
applyMigrationFiles, err := files.Relative(relativeN, versions)
if err != nil {
break
return err
}
}
return err

for _, f := range applyMigrationFiles {
err = m.drvMigrate(ctx, f)
if err != nil {
break
}
}
return err
})
}

// Version returns the current migration version.
func (m *Handle) Version() (version file.Version, err error) {
func (m *Handle) Version(ctx context.Context) (version file.Version, err error) {
unlock, err := m.lock(ctx)
if err != nil {
return 0, err
}
defer unlock()
return m.drv.Version()
}

// Versions returns applied versions.
func (m *Handle) Versions() (versions file.Versions, err error) {
func (m *Handle) Versions(ctx context.Context) (versions file.Versions, err error) {
unlock, err := m.lock(ctx)
if err != nil {
return nil, err
}
defer unlock()
return m.drv.Versions()
}

// PendingMigrations returns list of pending migration files
func (m *Handle) PendingMigrations() (file.Files, error) {
func (m *Handle) PendingMigrations(ctx context.Context) (file.Files, error) {
unlock, err := m.lock(ctx)
if err != nil {
return nil, err
}
defer unlock()
files, versions, err := m.readMigrationFilesAndGetVersions()
if err != nil {
return nil, err
Expand Down Expand Up @@ -183,6 +208,55 @@ func (m *Handle) Close() error {
return m.drv.Close()
}

func drvLockChan(drv driver.Driver) <-chan error {
ret := make(chan error)
go func() {
if err := driver.Lock(drv); err != nil {
ret <- err
}
close(ret)
}()
return ret
}

func (m *Handle) lock(ctx context.Context) (unlock func(), err error) {
if m.fatalErr != nil {
return nil, m.fatalErr
}
if m.locked {
return func() {}, nil
}
select {
case err := <-drvLockChan(m.drv):
if err != nil {
return nil, err
}
case <-ctx.Done():
return nil, ctx.Err()
}
m.locked = true
return m.unlock, nil
}

func (m *Handle) unlock() {
err := driver.Unlock(m.drv)
if err == nil {
m.locked = false
return
}
m.Close()
m.fatalErr = fmt.Errorf("connection closed, this handle is no longer usable - failed to unlock database after last session: %s", err)
}

func (m *Handle) locking(ctx context.Context, f func() error) error {
unlock, err := m.lock(ctx)
if err != nil {
return err
}
defer unlock()
return f()
}

func (m *Handle) drvMigrate(ctx context.Context, f file.File) error {
select {
case <-ctx.Done():
Expand Down
42 changes: 29 additions & 13 deletions migrate_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ func TestCreate(t *testing.T) {
if err != nil {
t.Fatal(err)
}
defer os.Remove(tmpdir)

file1, err := Create(driverUrl, tmpdir, "test_migration")
if err != nil {
Expand Down Expand Up @@ -64,16 +65,18 @@ func TestCreate(t *testing.T) {
if file1.Version == file2.Version {
t.Errorf("files can't same version: %d", file1.Version)
}
ensureClean(t, tmpdir, driverUrl)
}
}

func TestReset(t *testing.T) {
for _, driverUrl := range driverUrls {
t.Logf("Test driver: %s", driverUrl)
tmpdir, err := ioutil.TempDir("/", "migrate-test")
tmpdir, err := ioutil.TempDir("/tmp", "migrate-test")
if err != nil {
t.Fatal(err)
}
defer os.Remove(tmpdir)

_, err = Create(driverUrl, tmpdir, "migration1")
if err != nil {
Expand All @@ -98,10 +101,7 @@ func TestReset(t *testing.T) {
t.Fatalf("Expected version %d, got %v", file.Version, version)
}

err = Down(driverUrl, tmpdir)
if err != nil {
t.Fatal(err)
}
ensureClean(t, tmpdir, driverUrl)
}
}

Expand All @@ -112,6 +112,7 @@ func TestDown(t *testing.T) {
if err != nil {
t.Fatal(err)
}
defer os.Remove(tmpdir)

Create(driverUrl, tmpdir, "migration1")
file, _ := Create(driverUrl, tmpdir, "migration2")
Expand Down Expand Up @@ -149,6 +150,7 @@ func TestUp(t *testing.T) {
if err != nil {
t.Fatal(err)
}
defer os.Remove(tmpdir)

Create(driverUrl, tmpdir, "migration1")
file, _ := Create(driverUrl, tmpdir, "migration2")
Expand Down Expand Up @@ -177,10 +179,7 @@ func TestUp(t *testing.T) {
t.Fatalf("Expected version %d, got %v", file.Version, version)
}

err = Down(driverUrl, tmpdir)
if err != nil {
t.Fatal(err)
}
ensureClean(t, tmpdir, driverUrl)
}
}

Expand All @@ -191,6 +190,7 @@ func TestRedo(t *testing.T) {
if err != nil {
t.Fatal(err)
}
defer os.Remove(tmpdir)

Create(driverUrl, tmpdir, "migration1")
file, _ := Create(driverUrl, tmpdir, "migration2")
Expand Down Expand Up @@ -218,9 +218,7 @@ func TestRedo(t *testing.T) {
if version != file.Version {
t.Fatalf("Expected version %d, got %v", file.Version, version)
}
if err := Down(driverUrl, tmpdir); err != nil {
t.Fatal(err)
}
ensureClean(t, tmpdir, driverUrl)
}
}

Expand All @@ -231,6 +229,7 @@ func TestMigrate(t *testing.T) {
if err != nil {
t.Fatal(err)
}
defer os.Remove(tmpdir)

file1, err := Create(driverUrl, tmpdir, "migration1")
if err != nil {
Expand Down Expand Up @@ -304,6 +303,20 @@ func TestMigrate(t *testing.T) {
t.Errorf("Expected versions to be: %v, got: %v", expectedVersions, versions)
}

ensureClean(t, tmpdir, driverUrl)
}
}

func ensureClean(t *testing.T, tmpdir, driverUrl string) {
if err := Down(driverUrl, tmpdir); err != nil {
t.Fatal(err)
}
version, err := Version(driverUrl, tmpdir)
if err != nil {
t.Fatal(err)
}
if version != 0 {
t.Fatalf("Expected version 0, got %v", version)
}
}

Expand Down Expand Up @@ -335,5 +348,8 @@ func createOldMigrationFile(url, migrationsPath string) error {
}

err = ioutil.WriteFile(path.Join(mfile.UpFile.Path, mfile.UpFile.FileName), mfile.UpFile.Content, 0644)
return err
if err != nil {
return err
}
return ioutil.WriteFile(path.Join(mfile.DownFile.Path, mfile.DownFile.FileName), mfile.DownFile.Content, 0644)
}
Loading

0 comments on commit 0bb43df

Please sign in to comment.