Skip to content

Commit 61ee1bf

Browse files
authored
Merge pull request #247 from sarajmunjal/saraj/ctx
Add context support to Exec methods
2 parents 7084132 + fda37a1 commit 61ee1bf

2 files changed

Lines changed: 87 additions & 8 deletions

File tree

migrate.go

Lines changed: 49 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package migrate
22

33
import (
44
"bytes"
5+
"context"
56
"database/sql"
67
"errors"
78
"fmt"
@@ -429,12 +430,24 @@ type SqlExecutor interface {
429430
//
430431
// Returns the number of applied migrations.
431432
func Exec(db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection) (int, error) {
432-
return ExecMax(db, dialect, m, dir, 0)
433+
return ExecMaxContext(context.Background(), db, dialect, m, dir, 0)
433434
}
434435

435436
// Returns the number of applied migrations.
436437
func (ms MigrationSet) Exec(db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection) (int, error) {
437-
return ms.ExecMax(db, dialect, m, dir, 0)
438+
return ms.ExecMaxContext(context.Background(), db, dialect, m, dir, 0)
439+
}
440+
441+
// Execute a set of migrations with an input context.
442+
//
443+
// Returns the number of applied migrations.
444+
func ExecContext(ctx context.Context, db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection) (int, error) {
445+
return ExecMaxContext(ctx, db, dialect, m, dir, 0)
446+
}
447+
448+
// Returns the number of applied migrations.
449+
func (ms MigrationSet) ExecContext(ctx context.Context, db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection) (int, error) {
450+
return ms.ExecMaxContext(ctx, db, dialect, m, dir, 0)
438451
}
439452

440453
// Execute a set of migrations
@@ -446,50 +459,78 @@ func ExecMax(db *sql.DB, dialect string, m MigrationSource, dir MigrationDirecti
446459
return migSet.ExecMax(db, dialect, m, dir, max)
447460
}
448461

462+
// Execute a set of migrations with an input context.
463+
//
464+
// Will apply at most `max` migrations. Pass 0 for no limit (or use Exec).
465+
//
466+
// Returns the number of applied migrations.
467+
func ExecMaxContext(ctx context.Context, db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection, max int) (int, error) {
468+
return migSet.ExecMaxContext(ctx, db, dialect, m, dir, max)
469+
}
470+
449471
// Execute a set of migrations
450472
//
451473
// Will apply at the target `version` of migration. Cannot be a negative value.
452474
//
453475
// Returns the number of applied migrations.
454476
func ExecVersion(db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection, version int64) (int, error) {
477+
return ExecVersionContext(context.Background(), db, dialect, m, dir, version)
478+
}
479+
480+
// Execute a set of migrations with an input context.
481+
//
482+
// Will apply at the target `version` of migration. Cannot be a negative value.
483+
//
484+
// Returns the number of applied migrations.
485+
func ExecVersionContext(ctx context.Context, db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection, version int64) (int, error) {
455486
if version < 0 {
456487
return 0, fmt.Errorf("target version %d should not be negative", version)
457488
}
458-
return migSet.ExecVersion(db, dialect, m, dir, version)
489+
return migSet.ExecVersionContext(ctx, db, dialect, m, dir, version)
459490
}
460491

461492
// Returns the number of applied migrations.
462493
func (ms MigrationSet) ExecMax(db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection, max int) (int, error) {
494+
return ms.ExecMaxContext(context.Background(), db, dialect, m, dir, max)
495+
}
496+
497+
// Returns the number of applied migrations, but applies with an input context.
498+
func (ms MigrationSet) ExecMaxContext(ctx context.Context, db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection, max int) (int, error) {
463499
migrations, dbMap, err := ms.PlanMigration(db, dialect, m, dir, max)
464500
if err != nil {
465501
return 0, err
466502
}
467-
return ms.applyMigrations(dir, migrations, dbMap)
503+
return ms.applyMigrations(ctx, dir, migrations, dbMap)
468504
}
469505

470506
// Returns the number of applied migrations.
471507
func (ms MigrationSet) ExecVersion(db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection, version int64) (int, error) {
508+
return ms.ExecVersionContext(context.Background(), db, dialect, m, dir, version)
509+
}
510+
511+
func (ms MigrationSet) ExecVersionContext(ctx context.Context, db *sql.DB, dialect string, m MigrationSource, dir MigrationDirection, version int64) (int, error) {
472512
migrations, dbMap, err := ms.PlanMigrationToVersion(db, dialect, m, dir, version)
473513
if err != nil {
474514
return 0, err
475515
}
476-
return ms.applyMigrations(dir, migrations, dbMap)
516+
return ms.applyMigrations(ctx, dir, migrations, dbMap)
477517
}
478518

479519
// Applies the planned migrations and returns the number of applied migrations.
480-
func (MigrationSet) applyMigrations(dir MigrationDirection, migrations []*PlannedMigration, dbMap *gorp.DbMap) (int, error) {
520+
func (MigrationSet) applyMigrations(ctx context.Context, dir MigrationDirection, migrations []*PlannedMigration, dbMap *gorp.DbMap) (int, error) {
481521
applied := 0
482522
for _, migration := range migrations {
483523
var executor SqlExecutor
484524
var err error
485525

486526
if migration.DisableTransaction {
487-
executor = dbMap
527+
executor = dbMap.WithContext(ctx)
488528
} else {
489-
executor, err = dbMap.Begin()
529+
e, err := dbMap.Begin()
490530
if err != nil {
491531
return applied, newTxError(migration, err)
492532
}
533+
executor = e.WithContext(ctx)
493534
}
494535

495536
for _, stmt := range migration.Queries {

migrate_test.go

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
package migrate
22

33
import (
4+
"context"
45
"database/sql"
56
"net/http"
7+
"time"
68

79
"github.com/go-gorp/gorp/v3"
810
"github.com/gobuffalo/packr/v2"
@@ -757,3 +759,39 @@ func (s *SqliteMigrateSuite) TestGetMigrationDbMapWithDisableCreateTable(c *C) {
757759
_, err := migSet.getMigrationDbMap(s.Db, "postgres")
758760
c.Assert(err, IsNil)
759761
}
762+
763+
func (s *SqliteMigrateSuite) TestContextTimeout(c *C) {
764+
// This statement will run for a long time: 1,000,000 iterations of the fibonacci sequence
765+
fibonacciLoopStmt := `WITH RECURSIVE
766+
fibo (curr, next)
767+
AS
768+
( SELECT 1,1
769+
UNION ALL
770+
SELECT next, curr+next FROM fibo
771+
LIMIT 1000000 )
772+
SELECT group_concat(curr) FROM fibo;
773+
`
774+
migrations := &MemoryMigrationSource{
775+
Migrations: []*Migration{
776+
sqliteMigrations[0],
777+
sqliteMigrations[1],
778+
{
779+
Id: "125",
780+
Up: []string{fibonacciLoopStmt},
781+
Down: []string{}, // Not important here
782+
},
783+
{
784+
Id: "125",
785+
Up: []string{"INSERT INTO people (id, first_name) VALUES (1, 'Test')", "SELECT fail"},
786+
Down: []string{}, // Not important here
787+
},
788+
},
789+
}
790+
791+
// Should never run the insert
792+
ctx, cancelFunc := context.WithTimeout(context.Background(), 10*time.Millisecond)
793+
defer cancelFunc()
794+
n, err := ExecContext(ctx, s.Db, "sqlite3", migrations, Up)
795+
c.Assert(err, Not(IsNil))
796+
c.Assert(n, Equals, 2)
797+
}

0 commit comments

Comments
 (0)