Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions server/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -1091,6 +1091,18 @@ func (c *clientConn) serve() error {
c.sendError("FATAL", "XX000", fmt.Sprintf("failed to initialize session database metadata: %v", err))
return err
}
if c.server.cfg.FilePersistence && !duckLakeAttached && !icebergAttached {
if _, err := c.executor.ExecContext(initCtx, "USE "+sqlcore.QuoteIdentifier(c.username)); err != nil {
Comment on lines +1094 to +1095

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Badge Guard file catalog selection on DataDir

When FilePersistence is true but DataDir is empty, DuckDBDSN/openBaseDB intentionally fall back to :memory: (covered by the existing empty-DataDir fallback test), so there is no catalog named after c.username. This new branch still runs USE <username> for that configuration and will turn every standalone connection into a FATAL failed to select file-backed catalog; gate this on an actual file-backed database, e.g. cfg.DataDir != "", so the in-memory fallback continues to work.

Useful? React with 👍 / 👎.

initCancel()
c.sendError("FATAL", "XX000", fmt.Sprintf("failed to select file-backed catalog: %v", err))
return err
}
if _, err := c.executor.ExecContext(initCtx, "SET search_path = 'main,memory.main'"); err != nil {
initCancel()
c.sendError("FATAL", "XX000", fmt.Sprintf("failed to initialize file-backed search_path: %v", err))
return err
}
}
initCancel()
// Keep c.database aligned with the real catalog so observability surfaces
// agree with current_database(); record the physical catalog so the
Expand Down
126 changes: 126 additions & 0 deletions server/file_persistence_session_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
package server

import (
"database/sql"
"fmt"
"net"
"path/filepath"
"testing"
"time"

_ "github.com/duckdb/duckdb-go/v2"
_ "github.com/lib/pq"
)

func TestFilePersistenceStandaloneConnectionInitializesSession(t *testing.T) {
tmpDir := t.TempDir()
certFile := filepath.Join(tmpDir, "server.crt")
keyFile := filepath.Join(tmpDir, "server.key")
if err := EnsureCertificates(certFile, keyFile); err != nil {
t.Fatalf("generate certs: %v", err)
}

port := freeTCPPort(t)
srv, err := New(Config{
Host: "127.0.0.1",
Port: port,
DataDir: tmpDir,
FilePersistence: true,
TLSCertFile: certFile,
TLSKeyFile: keyFile,
Users: map[string]string{"testuser": "testpass"},
RateLimit: RateLimitConfig{
MaxConnections: 100,
},
})
if err != nil {
t.Fatalf("create server: %v", err)
}
t.Cleanup(func() { _ = srv.Close() })

errCh := make(chan error, 1)
go func() { errCh <- srv.ListenAndServe() }()

waitForTCP(t, port)

connStr := fmt.Sprintf("host=127.0.0.1 port=%d user=testuser password=testpass dbname=test sslmode=require connect_timeout=5", port)
db, err := sql.Open("postgres", connStr)
if err != nil {
t.Fatalf("open postgres connection: %v", err)
}

if err := db.Ping(); err != nil {
_ = db.Close()
t.Fatalf("ping file-persistence session: %v", err)
}

var currentDB string
if err := db.QueryRow("SELECT current_database()").Scan(&currentDB); err != nil {
_ = db.Close()
t.Fatalf("query current_database(): %v", err)
}
if currentDB != "test" {
_ = db.Close()
t.Fatalf("current_database() = %q, want %q", currentDB, "test")
}

if _, err := db.Exec("CREATE TABLE fp_session_probe (id INTEGER)"); err != nil {
_ = db.Close()
t.Fatalf("create persisted table: %v", err)
}
if _, err := db.Exec("INSERT INTO fp_session_probe VALUES (42)"); err != nil {
_ = db.Close()
t.Fatalf("insert persisted row: %v", err)
}
_ = db.Close()

db2, err := sql.Open("postgres", connStr)
if err != nil {
t.Fatalf("open second postgres connection: %v", err)
}

var id int
if err := db2.QueryRow("SELECT id FROM fp_session_probe").Scan(&id); err != nil {
_ = db2.Close()
t.Fatalf("query persisted row on second connection: %v", err)
}
if id != 42 {
_ = db2.Close()
t.Fatalf("persisted row id = %d, want 42", id)
}
_ = db2.Close()

_ = srv.Close()
if err := <-errCh; err != nil {
t.Fatalf("server returned error: %v", err)
}
}

func freeTCPPort(t *testing.T) int {
t.Helper()

ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
t.Fatalf("listen on free port: %v", err)
}
defer func() { _ = ln.Close() }()
return ln.Addr().(*net.TCPAddr).Port
}

func waitForTCP(t *testing.T, port int) {
t.Helper()

deadline := time.Now().Add(2 * time.Second)
var lastErr error
for time.Now().Before(deadline) {
conn, err := net.DialTimeout("tcp", fmt.Sprintf("127.0.0.1:%d", port), 50*time.Millisecond)
if err == nil {
_ = conn.Close()
return
}
lastErr = err
time.Sleep(20 * time.Millisecond)
}

t.Fatalf("server did not listen on port %d: %v", port, lastErr)
}
24 changes: 23 additions & 1 deletion server/sessionmeta/sessionmeta.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,18 @@ func InitSessionDatabaseMetadata(ctx context.Context, executor sqlcore.QueryExec
return fmt.Errorf("detect ducklake attachment: %w", err)
}

attachedMemoryCatalog := false
if _, err := executor.ExecContext(ctx, "USE memory"); err != nil {
return fmt.Errorf("switch to memory catalog: %w", err)
if !isMissingMemoryCatalogError(err) {
return fmt.Errorf("switch to memory catalog: %w", err)
}
if _, attachErr := executor.ExecContext(ctx, "ATTACH ':memory:' AS memory"); attachErr != nil {
return fmt.Errorf("attach memory catalog: %w", attachErr)
}
attachedMemoryCatalog = true
if _, useErr := executor.ExecContext(ctx, "USE memory"); useErr != nil {
return fmt.Errorf("switch to attached memory catalog: %w", useErr)
}
}
defer func() {
// Leave the session in a real catalog (we entered `memory` to install the
Expand All @@ -59,6 +69,9 @@ func InitSessionDatabaseMetadata(ctx context.Context, executor sqlcore.QueryExec
if duckLakeAttached {
_, _ = executor.ExecContext(context.Background(), "USE ducklake")
_, _ = executor.ExecContext(context.Background(), "SET search_path = 'main,memory.main'")
} else if attachedMemoryCatalog && !strings.EqualFold(catalog, "memory") {
_, _ = executor.ExecContext(context.Background(), "USE "+sqlcore.QuoteIdentifier(catalog))
_, _ = executor.ExecContext(context.Background(), "SET search_path = 'main,memory.main'")
}
}()

Expand All @@ -69,6 +82,15 @@ func InitSessionDatabaseMetadata(ctx context.Context, executor sqlcore.QueryExec
return nil
}

func isMissingMemoryCatalogError(err error) bool {
if err == nil {
return false
}
msg := err.Error()
return strings.Contains(msg, `No catalog + schema named "memory" found`) ||
strings.Contains(msg, `Catalog with name "memory" does not exist`)
}

func HasAttachedCatalog(ctx context.Context, executor sqlcore.QueryExecutor, catalog string) (bool, error) {
query := fmt.Sprintf(
"SELECT COUNT(*) FROM duckdb_databases() WHERE database_name = %s",
Expand Down
93 changes: 93 additions & 0 deletions server/sessionmeta/sessionmeta_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"database/sql"
"errors"
"path/filepath"
"strings"
"testing"

Expand Down Expand Up @@ -79,6 +80,68 @@ func (r *singleIntRow) RowsAffected() (int64, error) { return 0, nil }
func (r *singleIntRow) LastInsertId() (int64, error) { return 0, nil }
func (r *singleIntRow) LastProfilingOutput() string { return "" }

type duckDBTestExecutor struct {
db *sql.DB
}

func (e *duckDBTestExecutor) QueryContext(ctx context.Context, query string, args ...any) (sqlcore.RowSet, error) {
rows, err := e.db.QueryContext(ctx, query, args...)
if err != nil {
return nil, err
}
return &duckDBTestRows{rows: rows}, nil
}

func (e *duckDBTestExecutor) ExecContext(ctx context.Context, query string, args ...any) (sqlcore.ExecResult, error) {
return e.db.ExecContext(ctx, query, args...)
}

func (e *duckDBTestExecutor) Query(query string, args ...any) (sqlcore.RowSet, error) {
rows, err := e.db.Query(query, args...)
if err != nil {
return nil, err
}
return &duckDBTestRows{rows: rows}, nil
}

func (e *duckDBTestExecutor) Exec(query string, args ...any) (sqlcore.ExecResult, error) {
return e.db.Exec(query, args...)
}

func (e *duckDBTestExecutor) ConnContext(ctx context.Context) (sqlcore.RawConn, error) {
return e.db.Conn(ctx)
}

func (e *duckDBTestExecutor) PingContext(ctx context.Context) error { return e.db.PingContext(ctx) }
func (e *duckDBTestExecutor) Close() error { return e.db.Close() }
func (e *duckDBTestExecutor) LastProfilingOutput() string { return "" }

type duckDBTestRows struct {
rows *sql.Rows
}

func (r *duckDBTestRows) Columns() ([]string, error) { return r.rows.Columns() }

func (r *duckDBTestRows) ColumnTypes() ([]sqlcore.ColumnTyper, error) {
colTypes, err := r.rows.ColumnTypes()
if err != nil {
return nil, err
}
result := make([]sqlcore.ColumnTyper, len(colTypes))
for i, ct := range colTypes {
result[i] = ct
}
return result, nil
}

func (r *duckDBTestRows) Next() bool { return r.rows.Next() }
func (r *duckDBTestRows) Scan(dest ...any) error { return r.rows.Scan(dest...) }
func (r *duckDBTestRows) Close() error { return r.rows.Close() }
func (r *duckDBTestRows) Err() error { return r.rows.Err() }
func (r *duckDBTestRows) RowsAffected() (int64, error) { return 0, nil }
func (r *duckDBTestRows) LastInsertId() (int64, error) { return 0, nil }
func (r *duckDBTestRows) LastProfilingOutput() string { return "" }

func TestInitSessionDatabaseMetadataBatchesPerSessionStatements(t *testing.T) {
exec := &countingExecutor{
queryRows: &singleIntRow{v: 1}, // pretend ducklake is attached for HasAttachedCatalog
Expand All @@ -102,6 +165,36 @@ func TestInitSessionDatabaseMetadataBatchesPerSessionStatements(t *testing.T) {
}
}

func TestInitSessionDatabaseMetadataWorksOnFileBackedDatabase(t *testing.T) {
dbPath := filepath.Join(t.TempDir(), "alice.duckdb")
db, err := sql.Open("duckdb", dbPath+"?allow_unsigned_extensions=true")
if err != nil {
t.Fatalf("open file-backed duckdb: %v", err)
}
defer func() { _ = db.Close() }()

executor := &duckDBTestExecutor{db: db}
if err := InitSessionDatabaseMetadata(context.Background(), executor, "alice"); err != nil {
t.Fatalf("InitSessionDatabaseMetadata on file-backed DB failed: %v", err)
}

var currentDB string
if err := db.QueryRow("SELECT current_database()").Scan(&currentDB); err != nil {
t.Fatalf("query current_database(): %v", err)
}
if currentDB != "alice" {
t.Fatalf("current_database() = %q, want %q", currentDB, "alice")
}

var datname string
if err := db.QueryRow("SELECT datname FROM pg_database WHERE datname = current_database()").Scan(&datname); err != nil {
t.Fatalf("query pg_database: %v", err)
}
if datname != "alice" {
t.Fatalf("pg_database datname = %q, want %q", datname, "alice")
}
}

func TestBuildSessionMetadataSQLContainsAllExpectedStatements(t *testing.T) {
got := buildSessionMetadataSQL("analytics")

Expand Down