Skip to content

Commit 263cc2b

Browse files
committed
Add Driver() API to wrap driver.Driver
1 parent 1793407 commit 263cc2b

5 files changed

Lines changed: 20 additions & 15 deletions

File tree

README.md

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ import (
3535
"database/sql"
3636

3737
redis "github.com/go-redis/redis/v7"
38-
"github.com/ngrok/sqlmw"
38+
"github.com/jackc/pgx/v4/stdlib"
3939
"github.com/prashanthpai/sqlcache"
4040
)
4141

@@ -51,8 +51,8 @@ func main() {
5151
})
5252
...
5353

54-
// wrap pgx driver with the interceptor and register it
55-
sql.Register("pgx-with-cache", sqlmw.Driver(stdlib.GetDefaultDriver(), interceptor))
54+
// wrap pgx driver with cache interceptor and register it
55+
sql.Register("pgx-sqlcache", interceptor.Driver(stdlib.GetDefaultDriver()))
5656

5757
// open the database using the wrapped driver
5858
db, err := sql.Open("pgx-with-cache", dsn)

doc.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ Usage:
1111
1212
redis "github.com/go-redis/redis/v7"
1313
"github.com/prashanthpai/sqlcache"
14-
"github.com/ngrok/sqlmw"
14+
"github.com/jackc/pgx/v4/stdlib"
1515
)
1616
1717
func main() {
@@ -27,7 +27,7 @@ Usage:
2727
...
2828
2929
// wrap pgx driver with the interceptor and register it
30-
sql.Register("pgx-with-cache", sqlmw.Driver(stdlib.GetDefaultDriver(), interceptor))
30+
sql.Register("pgx-sqlcache", interceptor.Driver(stdlib.GetDefaultDriver()))
3131
3232
// open the database using the wrapped driver
3333
db, err := sql.Open("pgx-with-cache", dsn)

example/main.go

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ import (
1313
"github.com/dgraph-io/ristretto"
1414
redis "github.com/go-redis/redis/v7"
1515
"github.com/jackc/pgx/v4/stdlib"
16-
"github.com/ngrok/sqlmw"
1716
)
1817

1918
const (
@@ -71,7 +70,7 @@ func main() {
7170
}()
7271

7372
// install the wrapper which wraps pgx driver
74-
sql.Register("pgx-sqlcache", sqlmw.Driver(stdlib.GetDefaultDriver(), interceptor))
73+
sql.Register("pgx-sqlcache", interceptor.Driver(stdlib.GetDefaultDriver()))
7574

7675
if err := run(); err != nil {
7776
log.Fatalf("run() failed: %v", err)

interceptor.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,13 @@ func NewInterceptor(config *Config) (*Interceptor, error) {
6464
}, nil
6565
}
6666

67+
// Driver returns the supplied driver.Driver with a new object that has
68+
// all of its calls intercepted by the sqlcache.Interceptor. Any DB call
69+
// without a context passed will not be intercepted.
70+
func (i *Interceptor) Driver(d driver.Driver) driver.Driver {
71+
return sqlmw.Driver(d, i)
72+
}
73+
6774
// Enable enables the interceptor. Interceptor instance is enabled by default
6875
// on creation.
6976
func (i *Interceptor) Enable() {

interceptor_test.go

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@ import (
1313
"github.com/prashanthpai/sqlcache/mocks"
1414

1515
sqlmock "github.com/DATA-DOG/go-sqlmock"
16-
"github.com/ngrok/sqlmw"
1716
"github.com/stretchr/testify/mock"
1817
"github.com/stretchr/testify/require"
1918
)
@@ -131,7 +130,7 @@ func TestAttrs(t *testing.T) {
131130
})
132131

133132
driverName := fmt.Sprintf("mockdriver:%s", t.Name())
134-
sql.Register(driverName, sqlmw.Driver(mockDB.Driver(), ic))
133+
sql.Register(driverName, ic.Driver(mockDB.Driver()))
135134

136135
db, err := sql.Open(driverName, dsn)
137136
assert.Nil(err)
@@ -159,7 +158,7 @@ func TestCacheMiss(t *testing.T) {
159158
})
160159

161160
driverName := fmt.Sprintf("mockdriver:%s", t.Name())
162-
sql.Register(driverName, sqlmw.Driver(mockDB.Driver(), ic))
161+
sql.Register(driverName, ic.Driver(mockDB.Driver()))
163162

164163
db, err := sql.Open(driverName, dsn)
165164
assert.Nil(err)
@@ -219,7 +218,7 @@ func TestCacheHit(t *testing.T) {
219218
})
220219

221220
driverName := fmt.Sprintf("mockdriver:%s", t.Name())
222-
sql.Register(driverName, sqlmw.Driver(mockDB.Driver(), ic))
221+
sql.Register(driverName, ic.Driver(mockDB.Driver()))
223222

224223
db, err := sql.Open(driverName, dsn)
225224
assert.Nil(err)
@@ -263,7 +262,7 @@ func TestDisabled(t *testing.T) {
263262
})
264263

265264
driverName := fmt.Sprintf("mockdriver:%s", t.Name())
266-
sql.Register(driverName, sqlmw.Driver(mockDB.Driver(), ic))
265+
sql.Register(driverName, ic.Driver(mockDB.Driver()))
267266

268267
db, err := sql.Open(driverName, dsn)
269268
assert.Nil(err)
@@ -314,7 +313,7 @@ func TestMaxRows(t *testing.T) {
314313
})
315314

316315
driverName := fmt.Sprintf("mockdriver:%s", t.Name())
317-
sql.Register(driverName, sqlmw.Driver(mockDB.Driver(), ic))
316+
sql.Register(driverName, ic.Driver(mockDB.Driver()))
318317

319318
db, err := sql.Open(driverName, dsn)
320319
assert.Nil(err)
@@ -363,7 +362,7 @@ func TestHashFuncErr(t *testing.T) {
363362
})
364363

365364
driverName := fmt.Sprintf("mockdriver:%s", t.Name())
366-
sql.Register(driverName, sqlmw.Driver(mockDB.Driver(), ic))
365+
sql.Register(driverName, ic.Driver(mockDB.Driver()))
367366

368367
db, err := sql.Open(driverName, dsn)
369368
assert.Nil(err)
@@ -412,7 +411,7 @@ func TestCacheSetErr(t *testing.T) {
412411
})
413412

414413
driverName := fmt.Sprintf("mockdriver:%s", t.Name())
415-
sql.Register(driverName, sqlmw.Driver(mockDB.Driver(), ic))
414+
sql.Register(driverName, ic.Driver(mockDB.Driver()))
416415

417416
db, err := sql.Open(driverName, dsn)
418417
assert.Nil(err)

0 commit comments

Comments
 (0)