From c8a20ccfbeebf9997660c0b5b1f4d596c0ede772 Mon Sep 17 00:00:00 2001 From: Sueun Cho Date: Thu, 4 Jun 2026 10:57:52 -0400 Subject: [PATCH] Fix file persistence session initialization --- server/conn.go | 12 +++ server/file_persistence_session_test.go | 126 ++++++++++++++++++++++++ server/sessionmeta/sessionmeta.go | 24 ++++- server/sessionmeta/sessionmeta_test.go | 93 +++++++++++++++++ 4 files changed, 254 insertions(+), 1 deletion(-) create mode 100644 server/file_persistence_session_test.go diff --git a/server/conn.go b/server/conn.go index c94deb8d..3da10aa1 100644 --- a/server/conn.go +++ b/server/conn.go @@ -1091,6 +1091,18 @@ func (c *clientConn) serve() error { c.sendError("FATAL", "XX000", fmt.Sprintf("failed to initialize session database metadata: %v", err)) return err } + if c.server.cfg.FilePersistence && !duckLakeAttached && !icebergAttached { + if _, err := c.executor.ExecContext(initCtx, "USE "+sqlcore.QuoteIdentifier(c.username)); err != nil { + initCancel() + c.sendError("FATAL", "XX000", fmt.Sprintf("failed to select file-backed catalog: %v", err)) + return err + } + if _, err := c.executor.ExecContext(initCtx, "SET search_path = 'main,memory.main'"); err != nil { + initCancel() + c.sendError("FATAL", "XX000", fmt.Sprintf("failed to initialize file-backed search_path: %v", err)) + return err + } + } initCancel() // Keep c.database aligned with the real catalog so observability surfaces // agree with current_database(); record the physical catalog so the diff --git a/server/file_persistence_session_test.go b/server/file_persistence_session_test.go new file mode 100644 index 00000000..98ce2665 --- /dev/null +++ b/server/file_persistence_session_test.go @@ -0,0 +1,126 @@ +package server + +import ( + "database/sql" + "fmt" + "net" + "path/filepath" + "testing" + "time" + + _ "github.com/duckdb/duckdb-go/v2" + _ "github.com/lib/pq" +) + +func TestFilePersistenceStandaloneConnectionInitializesSession(t *testing.T) { + tmpDir := t.TempDir() + certFile := filepath.Join(tmpDir, "server.crt") + keyFile := filepath.Join(tmpDir, "server.key") + if err := EnsureCertificates(certFile, keyFile); err != nil { + t.Fatalf("generate certs: %v", err) + } + + port := freeTCPPort(t) + srv, err := New(Config{ + Host: "127.0.0.1", + Port: port, + DataDir: tmpDir, + FilePersistence: true, + TLSCertFile: certFile, + TLSKeyFile: keyFile, + Users: map[string]string{"testuser": "testpass"}, + RateLimit: RateLimitConfig{ + MaxConnections: 100, + }, + }) + if err != nil { + t.Fatalf("create server: %v", err) + } + t.Cleanup(func() { _ = srv.Close() }) + + errCh := make(chan error, 1) + go func() { errCh <- srv.ListenAndServe() }() + + waitForTCP(t, port) + + connStr := fmt.Sprintf("host=127.0.0.1 port=%d user=testuser password=testpass dbname=test sslmode=require connect_timeout=5", port) + db, err := sql.Open("postgres", connStr) + if err != nil { + t.Fatalf("open postgres connection: %v", err) + } + + if err := db.Ping(); err != nil { + _ = db.Close() + t.Fatalf("ping file-persistence session: %v", err) + } + + var currentDB string + if err := db.QueryRow("SELECT current_database()").Scan(¤tDB); err != nil { + _ = db.Close() + t.Fatalf("query current_database(): %v", err) + } + if currentDB != "test" { + _ = db.Close() + t.Fatalf("current_database() = %q, want %q", currentDB, "test") + } + + if _, err := db.Exec("CREATE TABLE fp_session_probe (id INTEGER)"); err != nil { + _ = db.Close() + t.Fatalf("create persisted table: %v", err) + } + if _, err := db.Exec("INSERT INTO fp_session_probe VALUES (42)"); err != nil { + _ = db.Close() + t.Fatalf("insert persisted row: %v", err) + } + _ = db.Close() + + db2, err := sql.Open("postgres", connStr) + if err != nil { + t.Fatalf("open second postgres connection: %v", err) + } + + var id int + if err := db2.QueryRow("SELECT id FROM fp_session_probe").Scan(&id); err != nil { + _ = db2.Close() + t.Fatalf("query persisted row on second connection: %v", err) + } + if id != 42 { + _ = db2.Close() + t.Fatalf("persisted row id = %d, want 42", id) + } + _ = db2.Close() + + _ = srv.Close() + if err := <-errCh; err != nil { + t.Fatalf("server returned error: %v", err) + } +} + +func freeTCPPort(t *testing.T) int { + t.Helper() + + ln, err := net.Listen("tcp", "127.0.0.1:0") + if err != nil { + t.Fatalf("listen on free port: %v", err) + } + defer func() { _ = ln.Close() }() + return ln.Addr().(*net.TCPAddr).Port +} + +func waitForTCP(t *testing.T, port int) { + t.Helper() + + deadline := time.Now().Add(2 * time.Second) + var lastErr error + for time.Now().Before(deadline) { + conn, err := net.DialTimeout("tcp", fmt.Sprintf("127.0.0.1:%d", port), 50*time.Millisecond) + if err == nil { + _ = conn.Close() + return + } + lastErr = err + time.Sleep(20 * time.Millisecond) + } + + t.Fatalf("server did not listen on port %d: %v", port, lastErr) +} diff --git a/server/sessionmeta/sessionmeta.go b/server/sessionmeta/sessionmeta.go index 72188eb6..3f4f0953 100644 --- a/server/sessionmeta/sessionmeta.go +++ b/server/sessionmeta/sessionmeta.go @@ -47,8 +47,18 @@ func InitSessionDatabaseMetadata(ctx context.Context, executor sqlcore.QueryExec return fmt.Errorf("detect ducklake attachment: %w", err) } + attachedMemoryCatalog := false if _, err := executor.ExecContext(ctx, "USE memory"); err != nil { - return fmt.Errorf("switch to memory catalog: %w", err) + if !isMissingMemoryCatalogError(err) { + return fmt.Errorf("switch to memory catalog: %w", err) + } + if _, attachErr := executor.ExecContext(ctx, "ATTACH ':memory:' AS memory"); attachErr != nil { + return fmt.Errorf("attach memory catalog: %w", attachErr) + } + attachedMemoryCatalog = true + if _, useErr := executor.ExecContext(ctx, "USE memory"); useErr != nil { + return fmt.Errorf("switch to attached memory catalog: %w", useErr) + } } defer func() { // Leave the session in a real catalog (we entered `memory` to install the @@ -59,6 +69,9 @@ func InitSessionDatabaseMetadata(ctx context.Context, executor sqlcore.QueryExec if duckLakeAttached { _, _ = executor.ExecContext(context.Background(), "USE ducklake") _, _ = executor.ExecContext(context.Background(), "SET search_path = 'main,memory.main'") + } else if attachedMemoryCatalog && !strings.EqualFold(catalog, "memory") { + _, _ = executor.ExecContext(context.Background(), "USE "+sqlcore.QuoteIdentifier(catalog)) + _, _ = executor.ExecContext(context.Background(), "SET search_path = 'main,memory.main'") } }() @@ -69,6 +82,15 @@ func InitSessionDatabaseMetadata(ctx context.Context, executor sqlcore.QueryExec return nil } +func isMissingMemoryCatalogError(err error) bool { + if err == nil { + return false + } + msg := err.Error() + return strings.Contains(msg, `No catalog + schema named "memory" found`) || + strings.Contains(msg, `Catalog with name "memory" does not exist`) +} + func HasAttachedCatalog(ctx context.Context, executor sqlcore.QueryExecutor, catalog string) (bool, error) { query := fmt.Sprintf( "SELECT COUNT(*) FROM duckdb_databases() WHERE database_name = %s", diff --git a/server/sessionmeta/sessionmeta_test.go b/server/sessionmeta/sessionmeta_test.go index 17d74889..2613ca63 100644 --- a/server/sessionmeta/sessionmeta_test.go +++ b/server/sessionmeta/sessionmeta_test.go @@ -4,6 +4,7 @@ import ( "context" "database/sql" "errors" + "path/filepath" "strings" "testing" @@ -79,6 +80,68 @@ func (r *singleIntRow) RowsAffected() (int64, error) { return 0, nil } func (r *singleIntRow) LastInsertId() (int64, error) { return 0, nil } func (r *singleIntRow) LastProfilingOutput() string { return "" } +type duckDBTestExecutor struct { + db *sql.DB +} + +func (e *duckDBTestExecutor) QueryContext(ctx context.Context, query string, args ...any) (sqlcore.RowSet, error) { + rows, err := e.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + return &duckDBTestRows{rows: rows}, nil +} + +func (e *duckDBTestExecutor) ExecContext(ctx context.Context, query string, args ...any) (sqlcore.ExecResult, error) { + return e.db.ExecContext(ctx, query, args...) +} + +func (e *duckDBTestExecutor) Query(query string, args ...any) (sqlcore.RowSet, error) { + rows, err := e.db.Query(query, args...) + if err != nil { + return nil, err + } + return &duckDBTestRows{rows: rows}, nil +} + +func (e *duckDBTestExecutor) Exec(query string, args ...any) (sqlcore.ExecResult, error) { + return e.db.Exec(query, args...) +} + +func (e *duckDBTestExecutor) ConnContext(ctx context.Context) (sqlcore.RawConn, error) { + return e.db.Conn(ctx) +} + +func (e *duckDBTestExecutor) PingContext(ctx context.Context) error { return e.db.PingContext(ctx) } +func (e *duckDBTestExecutor) Close() error { return e.db.Close() } +func (e *duckDBTestExecutor) LastProfilingOutput() string { return "" } + +type duckDBTestRows struct { + rows *sql.Rows +} + +func (r *duckDBTestRows) Columns() ([]string, error) { return r.rows.Columns() } + +func (r *duckDBTestRows) ColumnTypes() ([]sqlcore.ColumnTyper, error) { + colTypes, err := r.rows.ColumnTypes() + if err != nil { + return nil, err + } + result := make([]sqlcore.ColumnTyper, len(colTypes)) + for i, ct := range colTypes { + result[i] = ct + } + return result, nil +} + +func (r *duckDBTestRows) Next() bool { return r.rows.Next() } +func (r *duckDBTestRows) Scan(dest ...any) error { return r.rows.Scan(dest...) } +func (r *duckDBTestRows) Close() error { return r.rows.Close() } +func (r *duckDBTestRows) Err() error { return r.rows.Err() } +func (r *duckDBTestRows) RowsAffected() (int64, error) { return 0, nil } +func (r *duckDBTestRows) LastInsertId() (int64, error) { return 0, nil } +func (r *duckDBTestRows) LastProfilingOutput() string { return "" } + func TestInitSessionDatabaseMetadataBatchesPerSessionStatements(t *testing.T) { exec := &countingExecutor{ queryRows: &singleIntRow{v: 1}, // pretend ducklake is attached for HasAttachedCatalog @@ -102,6 +165,36 @@ func TestInitSessionDatabaseMetadataBatchesPerSessionStatements(t *testing.T) { } } +func TestInitSessionDatabaseMetadataWorksOnFileBackedDatabase(t *testing.T) { + dbPath := filepath.Join(t.TempDir(), "alice.duckdb") + db, err := sql.Open("duckdb", dbPath+"?allow_unsigned_extensions=true") + if err != nil { + t.Fatalf("open file-backed duckdb: %v", err) + } + defer func() { _ = db.Close() }() + + executor := &duckDBTestExecutor{db: db} + if err := InitSessionDatabaseMetadata(context.Background(), executor, "alice"); err != nil { + t.Fatalf("InitSessionDatabaseMetadata on file-backed DB failed: %v", err) + } + + var currentDB string + if err := db.QueryRow("SELECT current_database()").Scan(¤tDB); err != nil { + t.Fatalf("query current_database(): %v", err) + } + if currentDB != "alice" { + t.Fatalf("current_database() = %q, want %q", currentDB, "alice") + } + + var datname string + if err := db.QueryRow("SELECT datname FROM pg_database WHERE datname = current_database()").Scan(&datname); err != nil { + t.Fatalf("query pg_database: %v", err) + } + if datname != "alice" { + t.Fatalf("pg_database datname = %q, want %q", datname, "alice") + } +} + func TestBuildSessionMetadataSQLContainsAllExpectedStatements(t *testing.T) { got := buildSessionMetadataSQL("analytics")