From 895e194e0929d01aad3402f90ca6fe5fcefc5740 Mon Sep 17 00:00:00 2001 From: Bill Guowei Yang Date: Thu, 25 Jun 2026 15:42:22 -0400 Subject: [PATCH] Add scenario SQL smoke runner --- docs/runbooks/scenario-runner.md | 64 +++ justfile | 18 +- scripts/scenario_run.sh | 106 +++++ tests/scenario/runner_test.go | 247 +++++++++++ tests/scenario/scenarios/provision_smoke.yaml | 34 +- tests/scenario/script_test.go | 31 ++ tests/scenario/sql/connection.go | 81 ++++ tests/scenario/sql/connection_test.go | 86 ++++ tests/scenario/sql/driver.go | 82 ++++ tests/scenario/sql/errors.go | 44 ++ tests/scenario/sql/steps.go | 407 ++++++++++++++++++ tests/scenario/sql/steps_test.go | 260 +++++++++++ 12 files changed, 1456 insertions(+), 4 deletions(-) create mode 100644 docs/runbooks/scenario-runner.md create mode 100755 scripts/scenario_run.sh create mode 100644 tests/scenario/runner_test.go create mode 100644 tests/scenario/script_test.go create mode 100644 tests/scenario/sql/connection.go create mode 100644 tests/scenario/sql/connection_test.go create mode 100644 tests/scenario/sql/driver.go create mode 100644 tests/scenario/sql/errors.go create mode 100644 tests/scenario/sql/steps.go create mode 100644 tests/scenario/sql/steps_test.go diff --git a/docs/runbooks/scenario-runner.md b/docs/runbooks/scenario-runner.md new file mode 100644 index 00000000..a04d6f35 --- /dev/null +++ b/docs/runbooks/scenario-runner.md @@ -0,0 +1,64 @@ +# Duckgres Scenario Runner + +## Scope + +The scenario runner executes end-to-end managed-warehouse flows against a configured dev environment. The first smoke scenario provisions a warehouse, waits for readiness, runs `SELECT 1` over PGWire with managed-hostname SNI, then deprovisions and verifies cleanup. + +## Required Environment + +Set these before running a real scenario: + +```bash +export DUCKGRES_SCENARIO_API_BASE="" +export DUCKGRES_SCENARIO_INTERNAL_SECRET="" +export DUCKGRES_SCENARIO_PG_HOST="" +export DUCKGRES_SCENARIO_SNI_SUFFIX="" +``` + +`DUCKGRES_SCENARIO_PG_HOST` is used as libpq `hostaddr`; the runner separately sets `host=` so TLS SNI carries the managed warehouse identity. + +Optional: + +```bash +export DUCKGRES_SCENARIO_OUTPUT_BASE="artifacts/scenario" +export DUCKGRES_SCENARIO_RUN_ID="scenario-smoke-manual" +export DUCKGRES_SCENARIO_PG_PORT="5432" +export DUCKGRES_SCENARIO_PG_CONNECT_TIMEOUT="10" +export DUCKGRES_SCENARIO_MAX_RUNTIME="30m" +export DUCKGRES_SCENARIO_GO_TEST_TIMEOUT="60m" +``` + +Do not commit concrete dev endpoints, secrets, org IDs, or private bucket names. + +## Run + +Validate configuration without running: + +```bash +./scripts/scenario_run.sh --check-env +``` + +Run the dev smoke: + +```bash +just scenario-smoke +``` + +Run a specific scenario file: + +```bash +just scenario scenario=tests/scenario/scenarios/provision_smoke.yaml +``` + +Artifacts are written under `artifacts/scenario//`. + +## Leaked Dev Warehouse Recovery + +The smoke scenario has an `always_run` deprovision step, but an interrupted process can still leave dev resources behind. To clean up: + +1. Identify the scenario org ID from the scenario file and artifact directory. +2. Call the control-plane deprovision endpoint with the internal secret. +3. Poll `/warehouse/status` until the state is `deleted` or the warehouse returns `404`. +4. If deletion does not complete, inspect the dev control-plane logs and the managed warehouse deprovision runbook. + +Use placeholders in shared notes and PRs; keep concrete dev values local. diff --git a/justfile b/justfile index 74c9d5aa..4bd640b4 100644 --- a/justfile +++ b/justfile @@ -252,6 +252,7 @@ format: [group('test')] test: just test-unit + just test-scenario just test-integration just test-controlplane just test-configstore-integration @@ -262,6 +263,11 @@ test: test-unit: go test -v -p 1 . ./configresolve/... ./duckdbservice/... ./server/... ./transpiler/... ./internal/... ./tests/manifests/... +# Run scenario runner unit tests +[group('test')] +test-scenario: + go test -v -count=1 ./tests/scenario/... + # Run cache-proxy tests [group('test')] test-cache-proxy: @@ -358,6 +364,16 @@ perf-smoke: perf-nightly: ./scripts/perf_nightly.sh +# Run a Duckgres scenario file against a configured dev environment +[group('test')] +scenario scenario="tests/scenario/scenarios/provision_smoke.yaml": + ./scripts/scenario_run.sh {{scenario}} + +# Run the dev provision smoke scenario +[group('test')] +scenario-smoke: + ./scripts/scenario_run.sh tests/scenario/scenarios/provision_smoke.yaml + # Lint (matches CI — uses golangci-lint, not go vet) [group('test')] lint: @@ -365,7 +381,7 @@ lint: # Run what CI runs locally (excluding kind-backed K8s integration) [group('test')] -ci: lint test-unit test-cache-proxy test-integration test-controlplane test-configstore-integration test-controlplane-k8s +ci: lint test-unit test-scenario test-cache-proxy test-integration test-controlplane test-configstore-integration test-controlplane-k8s # === Metrics === diff --git a/scripts/scenario_run.sh b/scripts/scenario_run.sh new file mode 100755 index 00000000..a479ccd1 --- /dev/null +++ b/scripts/scenario_run.sh @@ -0,0 +1,106 @@ +#!/usr/bin/env bash +set -euo pipefail + +usage() { + cat <<'USAGE' +Usage: scripts/scenario_run.sh [--check-env] [--scenario-file PATH] [--output-base DIR] [--run-id ID] [PATH] + +Runs a Duckgres scenario through the Go test entry point. + +Required environment: + DUCKGRES_SCENARIO_API_BASE + DUCKGRES_SCENARIO_INTERNAL_SECRET + DUCKGRES_SCENARIO_PG_HOST (libpq hostaddr / direct TCP address) + DUCKGRES_SCENARIO_SNI_SUFFIX + +Optional environment: + DUCKGRES_SCENARIO_OUTPUT_BASE + DUCKGRES_SCENARIO_RUN_ID + DUCKGRES_SCENARIO_PG_PORT + DUCKGRES_SCENARIO_PG_CONNECT_TIMEOUT + DUCKGRES_SCENARIO_MAX_RUNTIME + DUCKGRES_SCENARIO_GO_TEST_TIMEOUT +USAGE +} + +scenario_file="${DUCKGRES_SCENARIO_FILE:-tests/scenario/scenarios/provision_smoke.yaml}" +output_base="${DUCKGRES_SCENARIO_OUTPUT_BASE:-artifacts/scenario}" +run_id="${DUCKGRES_SCENARIO_RUN_ID:-}" +max_runtime="${DUCKGRES_SCENARIO_MAX_RUNTIME:-30m}" +go_test_timeout="${DUCKGRES_SCENARIO_GO_TEST_TIMEOUT:-60m}" +check_env_only=0 + +while [ "$#" -gt 0 ]; do + case "$1" in + --check-env) + check_env_only=1 + shift + ;; + --scenario-file) + scenario_file="${2:?--scenario-file requires a path}" + shift 2 + ;; + --output-base) + output_base="${2:?--output-base requires a directory}" + shift 2 + ;; + --run-id) + run_id="${2:?--run-id requires an id}" + shift 2 + ;; + -h|--help) + usage + exit 0 + ;; + --*) + echo "Unknown option: $1" >&2 + usage >&2 + exit 2 + ;; + *) + scenario_file="$1" + shift + ;; + esac +done + +required=( + DUCKGRES_SCENARIO_API_BASE + DUCKGRES_SCENARIO_INTERNAL_SECRET + DUCKGRES_SCENARIO_PG_HOST + DUCKGRES_SCENARIO_SNI_SUFFIX +) +missing=() +for name in "${required[@]}"; do + if [ -z "${!name:-}" ]; then + missing+=("$name") + fi +done + +if [ "${#missing[@]}" -ne 0 ]; then + echo "Missing required Duckgres scenario environment:" >&2 + for name in "${missing[@]}"; do + echo " - $name" >&2 + done + exit 2 +fi + +if [ "$check_env_only" -eq 1 ]; then + echo "Duckgres scenario environment is configured." + exit 0 +fi + +args=( + go test -count=1 ./tests/scenario + -timeout "$go_test_timeout" + -run TestScenarioRunner + -scenario-run + -scenario-file "$scenario_file" + -scenario-output-base "$output_base" + -scenario-max-runtime "$max_runtime" +) +if [ -n "$run_id" ]; then + args+=(-scenario-run-id "$run_id") +fi + +"${args[@]}" diff --git a/tests/scenario/runner_test.go b/tests/scenario/runner_test.go new file mode 100644 index 00000000..0e946027 --- /dev/null +++ b/tests/scenario/runner_test.go @@ -0,0 +1,247 @@ +package scenario + +import ( + "context" + "flag" + "fmt" + "os" + "path/filepath" + "strconv" + "strings" + "testing" + "time" + + "github.com/posthog/duckgres/tests/scenario/core" + "github.com/posthog/duckgres/tests/scenario/provision" + scenariosql "github.com/posthog/duckgres/tests/scenario/sql" +) + +var ( + scenarioRun = flag.Bool("scenario-run", false, "run a real scenario file") + scenarioFile = flag.String("scenario-file", "", "scenario YAML file to run") + scenarioOutputBase = flag.String("scenario-output-base", "artifacts/scenario", "base directory for scenario artifacts") + scenarioRunID = flag.String("scenario-run-id", "", "scenario run id") + scenarioMaxRuntime = flag.Duration("scenario-max-runtime", 30*time.Minute, "maximum scenario runtime") +) + +func TestScenarioRunner(t *testing.T) { + if !*scenarioRun { + t.Skip("set -scenario-run to execute a real scenario") + } + if *scenarioFile == "" { + t.Fatal("-scenario-file is required") + } + + loaded, err := core.LoadScenario(*scenarioFile) + if err != nil { + t.Fatalf("load scenario: %v", err) + } + runID := *scenarioRunID + if runID == "" { + runID = defaultRunID(loaded) + } + loaded = resolveRunTemplates(loaded, runID) + + provisionClient, err := provision.NewClient(provision.Config{ + BaseURL: mustEnv(t, "DUCKGRES_SCENARIO_API_BASE"), + InternalSecret: mustEnv(t, "DUCKGRES_SCENARIO_INTERNAL_SECRET"), + }) + if err != nil { + t.Fatalf("create provision client: %v", err) + } + provisionState := provision.NewState() + provisionExecutor := provision.NewExecutor(provision.ExecutorConfig{ + Client: provisionClient, + State: provisionState, + WaitOptions: provision.WaitOptions{ + PollInterval: 10 * time.Second, + Timeout: 15 * time.Minute, + }, + }) + + sqlExecutor := scenariosql.NewExecutor(scenariosql.ExecutorConfig{ + ProvisionState: provisionState, + Connection: scenariosql.ConnectionConfig{ + HostAddr: mustEnv(t, "DUCKGRES_SCENARIO_PG_HOST"), + SNISuffix: mustEnv(t, "DUCKGRES_SCENARIO_SNI_SUFFIX"), + Port: intEnv(t, "DUCKGRES_SCENARIO_PG_PORT", 5432), + SSLMode: "require", + ConnectTimeout: intEnv(t, "DUCKGRES_SCENARIO_PG_CONNECT_TIMEOUT", 10), + ApplicationName: "duckgres-scenario-runner", + }, + }) + + ctx, cancel := context.WithTimeout(context.Background(), *scenarioMaxRuntime) + defer cancel() + runner := core.NewRunner(core.RunnerConfig{ + RunID: runID, + Scenario: loaded, + Executor: dispatchExecutor{provision: provisionExecutor, sql: sqlExecutor}, + OutputDir: filepath.Join(*scenarioOutputBase, runID), + WriteFiles: true, + CleanupTimeout: 15 * time.Minute, + }) + if summary, err := runner.Run(ctx); err != nil { + t.Fatalf("scenario failed: %+v: %v", summary, err) + } +} + +func TestProvisionSmokeScenarioUsesRunUniqueSupportedSteps(t *testing.T) { + scenario, err := core.LoadScenario(filepath.Join("scenarios", "provision_smoke.yaml")) + if err != nil { + t.Fatalf("load provision smoke: %v", err) + } + resolved := resolveRunTemplates(scenario, "scenario-smoke-20260102t030405z") + for _, step := range resolved.Steps { + if !dispatchSupports(step.Type) { + t.Fatalf("step %s has unsupported type %q", step.ID, step.Type) + } + if containsTemplate(step.With) { + t.Fatalf("step %s still contains unresolved template values: %#v", step.ID, step.With) + } + } + provisionStep := resolved.Steps[0] + orgID, _ := provisionStep.With["org_id"].(string) + if orgID == "scenario-smoke" || !strings.Contains(orgID, "20260102t030405z") { + t.Fatalf("org_id = %q, want run-unique templated org", orgID) + } + request, ok := provisionStep.With["request"].(map[string]any) + if !ok { + t.Fatalf("provision request = %#v, want map", provisionStep.With["request"]) + } + databaseName, _ := request["database_name"].(string) + if databaseName == "scenario_smoke" || !strings.Contains(databaseName, "20260102t030405z") { + t.Fatalf("database_name = %q, want run-unique templated database", databaseName) + } +} + +type dispatchExecutor struct { + provision *provision.Executor + sql *scenariosql.Executor +} + +func (e dispatchExecutor) ExecuteStep(ctx context.Context, step core.Step) error { + switch step.Type { + case provision.StepTypeProvisionWarehouse, provision.StepTypeWaitWarehouseReady, provision.StepTypeDeprovisionWarehouse: + return e.provision.ExecuteStep(ctx, step) + case scenariosql.StepTypeSQL, scenariosql.StepTypeSQLCatalog: + return e.sql.ExecuteStep(ctx, step) + default: + return fmt.Errorf("unsupported scenario step type %q", step.Type) + } +} + +func dispatchSupports(stepType string) bool { + switch stepType { + case provision.StepTypeProvisionWarehouse, provision.StepTypeWaitWarehouseReady, provision.StepTypeDeprovisionWarehouse: + return true + case scenariosql.StepTypeSQL, scenariosql.StepTypeSQLCatalog: + return true + default: + return false + } +} + +func mustEnv(t *testing.T, key string) string { + t.Helper() + value := os.Getenv(key) + if value == "" { + t.Fatalf("%s is required", key) + } + return value +} + +func intEnv(t *testing.T, key string, fallback int) int { + t.Helper() + value := os.Getenv(key) + if value == "" { + return fallback + } + parsed, err := strconv.Atoi(value) + if err != nil { + t.Fatalf("%s must be an integer: %v", key, err) + } + return parsed +} + +func defaultRunID(s core.Scenario) string { + prefix := s.RunIDPrefix + if prefix == "" { + prefix = "scenario" + } + return fmt.Sprintf("%s-%s", prefix, time.Now().UTC().Format("20060102t150405z")) +} + +func resolveRunTemplates(s core.Scenario, runID string) core.Scenario { + vars := map[string]string{ + "run_id": runID, + "run_id_compact": compactRunID(runID), + } + out := s + out.Steps = make([]core.Step, len(s.Steps)) + for i, step := range s.Steps { + if step.With != nil { + step.With = resolveTemplateValue(step.With, vars).(map[string]any) + } + out.Steps[i] = step + } + return out +} + +func compactRunID(runID string) string { + var b strings.Builder + for _, r := range strings.ToLower(runID) { + if r >= 'a' && r <= 'z' || r >= '0' && r <= '9' { + b.WriteRune(r) + } + } + if b.Len() == 0 { + return "scenario" + } + return b.String() +} + +func resolveTemplateValue(value any, vars map[string]string) any { + switch typed := value.(type) { + case map[string]any: + out := make(map[string]any, len(typed)) + for k, v := range typed { + out[k] = resolveTemplateValue(v, vars) + } + return out + case []any: + out := make([]any, len(typed)) + for i, v := range typed { + out[i] = resolveTemplateValue(v, vars) + } + return out + case string: + out := typed + for k, v := range vars { + out = strings.ReplaceAll(out, "${"+k+"}", v) + } + return out + default: + return typed + } +} + +func containsTemplate(value any) bool { + switch typed := value.(type) { + case map[string]any: + for _, v := range typed { + if containsTemplate(v) { + return true + } + } + case []any: + for _, v := range typed { + if containsTemplate(v) { + return true + } + } + case string: + return strings.Contains(typed, "${") + } + return false +} diff --git a/tests/scenario/scenarios/provision_smoke.yaml b/tests/scenario/scenarios/provision_smoke.yaml index fa9a6490..d18d8197 100644 --- a/tests/scenario/scenarios/provision_smoke.yaml +++ b/tests/scenario/scenarios/provision_smoke.yaml @@ -2,9 +2,37 @@ name: provision-smoke run_id_prefix: scenario-smoke steps: - id: provision - type: fake + type: provision_warehouse + with: + org_id: scenario-smoke-${run_id_compact} + request: + database_name: scenario_smoke_${run_id_compact} + metadata_store: + type: cnpg-shard + ducklake: + enabled: true + data_store: + type: s3bucket + - id: wait_ready + type: wait_warehouse_ready + with: + org_id: scenario-smoke-${run_id_compact} + timeout: 15m + poll_interval: 10s - id: select_one - type: fake + type: sql + with: + org_id: scenario-smoke-${run_id_compact} + catalog: ducklake + sql: SELECT 1 + max_attempts: 12 + retry_interval: 10s - id: deprovision - type: fake + type: deprovision_warehouse + depends_on: [select_one] always_run: true + with: + org_id: scenario-smoke-${run_id_compact} + verify_deleted: true + cleanup_timeout: 15m + poll_interval: 10s diff --git a/tests/scenario/script_test.go b/tests/scenario/script_test.go new file mode 100644 index 00000000..d07e796c --- /dev/null +++ b/tests/scenario/script_test.go @@ -0,0 +1,31 @@ +package scenario + +import ( + "os" + "os/exec" + "path/filepath" + "strings" + "testing" +) + +func TestScenarioRunScriptValidatesRequiredEnvVars(t *testing.T) { + script := filepath.Join("..", "..", "scripts", "scenario_run.sh") + cmd := exec.Command("bash", script, "--check-env") + cmd.Env = []string{"PATH=" + os.Getenv("PATH")} + + out, err := cmd.CombinedOutput() + if err == nil { + t.Fatal("expected script to fail without required env vars") + } + text := string(out) + for _, name := range []string{ + "DUCKGRES_SCENARIO_API_BASE", + "DUCKGRES_SCENARIO_INTERNAL_SECRET", + "DUCKGRES_SCENARIO_PG_HOST", + "DUCKGRES_SCENARIO_SNI_SUFFIX", + } { + if !strings.Contains(text, name) { + t.Fatalf("script output %q missing %s", text, name) + } + } +} diff --git a/tests/scenario/sql/connection.go b/tests/scenario/sql/connection.go new file mode 100644 index 00000000..6c3ef189 --- /dev/null +++ b/tests/scenario/sql/connection.go @@ -0,0 +1,81 @@ +package sql + +import ( + "fmt" + "strconv" + "strings" +) + +type ConnectionConfig struct { + OrgID string + SNISuffix string + HostAddr string + Port int + Database string + Username string + Password string + SSLMode string + ConnectTimeout int + ApplicationName string +} + +func (c ConnectionConfig) DSN() (string, error) { + if c.OrgID == "" { + return "", classified(ErrorClassConfig, fmt.Errorf("org id is required for SQL connection")) + } + if c.SNISuffix == "" { + return "", classified(ErrorClassConfig, fmt.Errorf("SNI suffix is required for SQL connection")) + } + if !strings.HasPrefix(c.SNISuffix, ".") { + return "", classified(ErrorClassConfig, fmt.Errorf("SNI suffix must start with a dot")) + } + if c.HostAddr == "" { + return "", classified(ErrorClassConfig, fmt.Errorf("PG host address is required for SQL connection")) + } + if c.Port == 0 { + c.Port = 5432 + } + if c.Database == "" { + c.Database = "ducklake" + } + if c.Username == "" { + c.Username = "root" + } + if c.SSLMode == "" { + c.SSLMode = "require" + } + + values := [][2]string{ + {"host", c.OrgID + c.SNISuffix}, + {"hostaddr", c.HostAddr}, + {"port", strconv.Itoa(c.Port)}, + {"user", c.Username}, + {"password", c.Password}, + {"dbname", c.Database}, + {"sslmode", c.SSLMode}, + } + if c.ConnectTimeout > 0 { + values = append(values, [2]string{"connect_timeout", strconv.Itoa(c.ConnectTimeout)}) + } + if c.ApplicationName != "" { + values = append(values, [2]string{"application_name", c.ApplicationName}) + } + + parts := make([]string, 0, len(values)) + for _, kv := range values { + parts = append(parts, kv[0]+"="+quoteConninfoValue(kv[1])) + } + return strings.Join(parts, " "), nil +} + +func quoteConninfoValue(value string) string { + if value == "" { + return "''" + } + if !strings.ContainsAny(value, " \t\n\r'\\") { + return value + } + escaped := strings.ReplaceAll(value, `\`, `\\`) + escaped = strings.ReplaceAll(escaped, `'`, `\'`) + return "'" + escaped + "'" +} diff --git a/tests/scenario/sql/connection_test.go b/tests/scenario/sql/connection_test.go new file mode 100644 index 00000000..7947d5c1 --- /dev/null +++ b/tests/scenario/sql/connection_test.go @@ -0,0 +1,86 @@ +package sql + +import ( + "strings" + "testing" +) + +func TestConnectionConfigDSNUsesSNIHostAndTCPHostAddr(t *testing.T) { + dsn, err := ConnectionConfig{ + OrgID: "scenario-org", + SNISuffix: ".dev.example", + HostAddr: "10.0.0.10", + Port: 5432, + Database: "ducklake", + Username: "root", + Password: "root password", + SSLMode: "require", + ConnectTimeout: 10, + ApplicationName: "duckgres-scenario", + }.DSN() + if err != nil { + t.Fatalf("DSN returned error: %v", err) + } + + for _, want := range []string{ + "host=scenario-org.dev.example", + "hostaddr=10.0.0.10", + "port=5432", + "user=root", + "password='root password'", + "dbname=ducklake", + "sslmode=require", + "connect_timeout=10", + "application_name=duckgres-scenario", + } { + if !strings.Contains(dsn, want) { + t.Fatalf("DSN %q missing %q", dsn, want) + } + } +} + +func TestConnectionConfigDSNRequiresSNIFields(t *testing.T) { + _, err := ConnectionConfig{ + OrgID: "scenario-org", + HostAddr: "10.0.0.10", + Port: 5432, + Database: "ducklake", + Username: "root", + Password: "root-password", + }.DSN() + if err == nil { + t.Fatal("expected missing SNI suffix to fail") + } +} + +func TestConnectionConfigDSNDefaultsDatabaseAndEscapesPassword(t *testing.T) { + dsn, err := ConnectionConfig{ + OrgID: "scenario-org", + SNISuffix: ".dev.example", + HostAddr: "10.0.0.10", + Username: "root", + Password: `pa'ss\word`, + }.DSN() + if err != nil { + t.Fatalf("DSN returned error: %v", err) + } + if !strings.Contains(dsn, "dbname=ducklake") { + t.Fatalf("DSN %q missing default ducklake catalog", dsn) + } + if !strings.Contains(dsn, `password='pa\'ss\\word'`) { + t.Fatalf("DSN %q did not escape password", dsn) + } +} + +func TestConnectionConfigDSNRejectsSuffixWithoutLeadingDot(t *testing.T) { + _, err := ConnectionConfig{ + OrgID: "scenario-org", + SNISuffix: "dev.example", + HostAddr: "10.0.0.10", + Username: "root", + Password: "root-password", + }.DSN() + if err == nil { + t.Fatal("expected SNI suffix without leading dot to fail") + } +} diff --git a/tests/scenario/sql/driver.go b/tests/scenario/sql/driver.go new file mode 100644 index 00000000..bfa84a94 --- /dev/null +++ b/tests/scenario/sql/driver.go @@ -0,0 +1,82 @@ +package sql + +import ( + "context" + stdsql "database/sql" + "fmt" + "time" + + _ "github.com/lib/pq" +) + +type Driver interface { + Execute(context.Context, QueryRequest) (QueryResult, error) +} + +type QueryRequest struct { + StepID string + QueryID string + OrgID string + Catalog string + SQL string + DSN string +} + +type QueryResult struct { + Rows int64 + Duration time.Duration +} + +type DatabaseDriver struct{} + +func NewDatabaseDriver() *DatabaseDriver { + return &DatabaseDriver{} +} + +func (d *DatabaseDriver) Execute(ctx context.Context, req QueryRequest) (QueryResult, error) { + started := time.Now() + db, err := stdsql.Open("postgres", req.DSN) + if err != nil { + return QueryResult{}, fmt.Errorf("open pgwire connection: %w", err) + } + defer func() { + _ = db.Close() + }() + + rows, err := db.QueryContext(ctx, req.SQL) + if err == nil { + defer func() { + _ = rows.Close() + }() + var count int64 + cols, err := rows.Columns() + if err != nil { + return QueryResult{}, err + } + values := make([]any, len(cols)) + ptrs := make([]any, len(cols)) + for i := range values { + ptrs[i] = &values[i] + } + for rows.Next() { + if err := rows.Scan(ptrs...); err != nil { + return QueryResult{}, err + } + count++ + } + if err := rows.Err(); err != nil { + return QueryResult{}, err + } + return QueryResult{Rows: count, Duration: time.Since(started)}, nil + } + + res, execErr := db.ExecContext(ctx, req.SQL) + if execErr != nil { + return QueryResult{}, execErr + } + affected, affectedErr := res.RowsAffected() + if affectedErr != nil { + affected = 0 + } + return QueryResult{Rows: affected, Duration: time.Since(started)}, nil +} diff --git a/tests/scenario/sql/errors.go b/tests/scenario/sql/errors.go new file mode 100644 index 00000000..5d9a4921 --- /dev/null +++ b/tests/scenario/sql/errors.go @@ -0,0 +1,44 @@ +package sql + +import ( + "errors" + "fmt" +) + +const ( + ErrorClassConfig = "sql_configuration_error" + ErrorClassInvalidStepConfig = "invalid_step_config" + ErrorClassSQL = "sql_error" + ErrorClassTransientTimeout = "transient_sql_timeout" + ErrorClassUnsupportedStep = "unsupported_step" +) + +var ErrTransientRetriesExhausted = errors.New("transient SQL retries exhausted") + +type classifiedError struct { + class string + err error +} + +func (e classifiedError) Error() string { + return e.err.Error() +} + +func (e classifiedError) Unwrap() error { + return e.err +} + +func (e classifiedError) ErrorClass() string { + return e.class +} + +func classified(class string, err error) error { + if err == nil { + return nil + } + return classifiedError{class: class, err: err} +} + +func invalidStep(stepID, format string, args ...any) error { + return classified(ErrorClassInvalidStepConfig, fmt.Errorf("step %s: %s", stepID, fmt.Sprintf(format, args...))) +} diff --git a/tests/scenario/sql/steps.go b/tests/scenario/sql/steps.go new file mode 100644 index 00000000..3e9dfe00 --- /dev/null +++ b/tests/scenario/sql/steps.go @@ -0,0 +1,407 @@ +package sql + +import ( + "context" + "fmt" + "os" + "strconv" + "strings" + "sync" + "time" + + "github.com/posthog/duckgres/tests/scenario/core" + "github.com/posthog/duckgres/tests/scenario/provision" +) + +const ( + StepTypeSQL = "sql" + StepTypeSQLCatalog = "sql_catalog" +) + +type RetryConfig struct { + MaxAttempts int + RetryBackoff time.Duration + Sleep func(context.Context, time.Duration) error +} + +type ExecutorConfig struct { + ProvisionState *provision.State + Connection ConnectionConfig + Driver Driver + Retry RetryConfig + State *State +} + +type Executor struct { + provisionState *provision.State + connection ConnectionConfig + driver Driver + retry RetryConfig + state *State +} + +type State struct { + mu sync.Mutex + results map[string]StepResult +} + +type StepResult struct { + StepID string + QueryID string + Rows int64 + Attempts int + Duration time.Duration +} + +type querySpec struct { + ID string + SQL string + Catalog string +} + +func NewExecutor(cfg ExecutorConfig) *Executor { + driver := cfg.Driver + if driver == nil { + driver = NewDatabaseDriver() + } + state := cfg.State + if state == nil { + state = NewState() + } + retry := cfg.Retry + if retry.MaxAttempts == 0 { + retry.MaxAttempts = 12 + } + if retry.RetryBackoff == 0 { + retry.RetryBackoff = 10 * time.Second + } + if retry.Sleep == nil { + retry.Sleep = sleepContext + } + return &Executor{ + provisionState: cfg.ProvisionState, + connection: cfg.Connection, + driver: driver, + retry: retry, + state: state, + } +} + +func NewState() *State { + return &State{results: make(map[string]StepResult)} +} + +func (e *Executor) State() *State { + return e.state +} + +func (s *State) StoreResult(result StepResult) { + s.mu.Lock() + defer s.mu.Unlock() + s.results[result.StepID] = result +} + +func (s *State) Result(stepID string) (StepResult, bool) { + s.mu.Lock() + defer s.mu.Unlock() + result, ok := s.results[stepID] + return result, ok +} + +func (e *Executor) ExecuteStep(ctx context.Context, step core.Step) error { + switch step.Type { + case StepTypeSQL: + spec, err := parseSQLStep(step) + if err != nil { + return err + } + return e.executeQuery(ctx, step, spec, step.ID) + case StepTypeSQLCatalog: + specs, err := parseCatalogStep(step) + if err != nil { + return err + } + for _, spec := range specs { + resultID := step.ID + "/" + spec.ID + if err := e.executeQuery(ctx, step, spec, resultID); err != nil { + return err + } + } + return nil + default: + return classified(ErrorClassUnsupportedStep, fmt.Errorf("unsupported SQL step type %q", step.Type)) + } +} + +func (e *Executor) executeQuery(ctx context.Context, step core.Step, spec querySpec, resultID string) error { + req, err := e.queryRequest(step, spec, resultID) + if err != nil { + return err + } + retry, err := retryConfigForStep(step, e.retry) + if err != nil { + return err + } + + var attempts int + var lastErr error + for attempts = 1; attempts <= retry.MaxAttempts; attempts++ { + result, err := e.driver.Execute(ctx, req) + if err == nil { + e.state.StoreResult(StepResult{ + StepID: resultID, + QueryID: spec.ID, + Rows: result.Rows, + Attempts: attempts, + Duration: result.Duration, + }) + return nil + } + lastErr = err + if !IsTransientStartupError(err) { + return classified(ErrorClassSQL, fmt.Errorf("execute SQL step %s query %s: %w", step.ID, spec.ID, err)) + } + if attempts == retry.MaxAttempts { + break + } + if sleepErr := retry.Sleep(ctx, retry.RetryBackoff); sleepErr != nil { + return sleepErr + } + } + return classified(ErrorClassTransientTimeout, fmt.Errorf("%w for SQL step %s query %s after %d attempts: %w", ErrTransientRetriesExhausted, step.ID, spec.ID, attempts, lastErr)) +} + +func (e *Executor) queryRequest(step core.Step, spec querySpec, resultID string) (QueryRequest, error) { + orgID, err := requiredString(step, "org_id") + if err != nil { + return QueryRequest{}, err + } + username := stringFromWith(step, "username", "root") + password := stringFromWith(step, "password", "") + if password == "" { + if e.provisionState == nil { + return QueryRequest{}, invalidStep(step.ID, "provision state is required when with.password is omitted") + } + resp, ok := e.provisionState.ProvisionResponse(orgID) + if !ok { + return QueryRequest{}, invalidStep(step.ID, "no provision response found for org %q", orgID) + } + if resp.Username != "" { + username = resp.Username + } + password = resp.Password + } + + cfg := e.connection + cfg.OrgID = orgID + cfg.Database = spec.Catalog + cfg.Username = username + cfg.Password = password + dsn, err := cfg.DSN() + if err != nil { + return QueryRequest{}, err + } + return QueryRequest{ + StepID: resultID, + QueryID: spec.ID, + OrgID: orgID, + Catalog: spec.Catalog, + SQL: spec.SQL, + DSN: dsn, + }, nil +} + +func parseSQLStep(step core.Step) (querySpec, error) { + sqlText, err := sqlFromStep(step) + if err != nil { + return querySpec{}, err + } + return querySpec{ + ID: stringFromWith(step, "query_id", step.ID), + SQL: sqlText, + Catalog: stringFromWith(step, "catalog", "ducklake"), + }, nil +} + +func parseCatalogStep(step core.Step) ([]querySpec, error) { + raw, ok := step.With["queries"] + if !ok { + return nil, invalidStep(step.ID, "with.queries is required") + } + items, ok := raw.([]any) + if !ok || len(items) == 0 { + return nil, invalidStep(step.ID, "with.queries must be a non-empty list") + } + specs := make([]querySpec, 0, len(items)) + for i, item := range items { + m, ok := item.(map[string]any) + if !ok { + return nil, invalidStep(step.ID, "with.queries[%d] must be a map", i) + } + id, ok := m["id"].(string) + if !ok || id == "" { + return nil, invalidStep(step.ID, "with.queries[%d].id must be a non-empty string", i) + } + sqlText, ok := m["sql"].(string) + if !ok || strings.TrimSpace(sqlText) == "" { + return nil, invalidStep(step.ID, "with.queries[%d].sql must be a non-empty string", i) + } + catalog := stringFromMap(m, "catalog", stringFromWith(step, "catalog", "ducklake")) + specs = append(specs, querySpec{ID: id, SQL: sqlText, Catalog: catalog}) + } + return specs, nil +} + +func sqlFromStep(step core.Step) (string, error) { + if sqlText := stringFromWith(step, "sql", ""); strings.TrimSpace(sqlText) != "" { + return sqlText, nil + } + file := stringFromWith(step, "file", "") + if file == "" { + return "", invalidStep(step.ID, "with.sql or with.file is required") + } + raw, err := os.ReadFile(file) + if err != nil { + return "", invalidStep(step.ID, "read SQL file %s: %v", file, err) + } + if strings.TrimSpace(string(raw)) == "" { + return "", invalidStep(step.ID, "SQL file %s is empty", file) + } + return string(raw), nil +} + +func retryConfigForStep(step core.Step, base RetryConfig) (RetryConfig, error) { + if maxAttempts, ok, err := intFromWith(step, "max_attempts"); err != nil { + return RetryConfig{}, err + } else if ok { + base.MaxAttempts = maxAttempts + } + if retryInterval, ok, err := durationFromWith(step, "retry_interval"); err != nil { + return RetryConfig{}, err + } else if ok { + base.RetryBackoff = retryInterval + } + if base.MaxAttempts <= 0 { + return RetryConfig{}, invalidStep(step.ID, "max_attempts must be greater than zero") + } + return base, nil +} + +func IsTransientStartupError(err error) bool { + if err == nil { + return false + } + msg := strings.ToLower(err.Error()) + for _, marker := range []string{ + "capacity exhausted", + "no duckgres worker", + "still provisioning", + "failed to initialize session", + "timed out waiting for an available worker", + "failed to start", + "spawn sized worker", + "failed to detect attached catalogs", + } { + if strings.Contains(msg, marker) { + return true + } + } + return false +} + +func requiredString(step core.Step, key string) (string, error) { + value, ok := step.With[key] + if !ok { + return "", invalidStep(step.ID, "with.%s is required", key) + } + text, ok := value.(string) + if !ok || text == "" { + return "", invalidStep(step.ID, "with.%s must be a non-empty string", key) + } + return text, nil +} + +func stringFromWith(step core.Step, key, fallback string) string { + value, ok := step.With[key] + if !ok { + return fallback + } + text, ok := value.(string) + if !ok || text == "" { + return fallback + } + return text +} + +func stringFromMap(m map[string]any, key, fallback string) string { + value, ok := m[key] + if !ok { + return fallback + } + text, ok := value.(string) + if !ok || text == "" { + return fallback + } + return text +} + +func durationFromWith(step core.Step, key string) (time.Duration, bool, error) { + value, ok := step.With[key] + if !ok { + return 0, false, nil + } + text, ok := value.(string) + if !ok { + return 0, false, invalidStep(step.ID, "with.%s must be a Go duration string", key) + } + parsed, err := time.ParseDuration(text) + if err != nil { + return 0, false, invalidStep(step.ID, "with.%s must be a Go duration: %v", key, err) + } + if parsed < 0 { + return 0, false, invalidStep(step.ID, "with.%s must not be negative", key) + } + return parsed, true, nil +} + +func intFromWith(step core.Step, key string) (int, bool, error) { + value, ok := step.With[key] + if !ok { + return 0, false, nil + } + var parsed int + switch typed := value.(type) { + case int: + parsed = typed + case int64: + parsed = int(typed) + case float64: + if typed != float64(int(typed)) { + return 0, false, invalidStep(step.ID, "with.%s must be an integer", key) + } + parsed = int(typed) + case string: + var err error + parsed, err = strconv.Atoi(typed) + if err != nil { + return 0, false, invalidStep(step.ID, "with.%s must be an integer: %v", key, err) + } + default: + return 0, false, invalidStep(step.ID, "with.%s must be an integer", key) + } + if parsed < 0 { + return 0, false, invalidStep(step.ID, "with.%s must not be negative", key) + } + return parsed, true, nil +} + +func sleepContext(ctx context.Context, d time.Duration) error { + timer := time.NewTimer(d) + defer timer.Stop() + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + return nil + } +} diff --git a/tests/scenario/sql/steps_test.go b/tests/scenario/sql/steps_test.go new file mode 100644 index 00000000..a378a0ec --- /dev/null +++ b/tests/scenario/sql/steps_test.go @@ -0,0 +1,260 @@ +package sql + +import ( + "context" + "errors" + "strings" + "testing" + "time" + + "github.com/posthog/duckgres/tests/scenario/core" + "github.com/posthog/duckgres/tests/scenario/provision" +) + +func TestExecutorRetriesTransientStartupErrors(t *testing.T) { + provisionState := provision.NewState() + provisionState.StoreProvisionResponse("scenario-org", provision.ProvisionResponse{ + Org: "scenario-org", + Username: "root", + Password: "root-password", + }) + var attempts int + driver := &fakeDriver{ + executeFunc: func(context.Context, QueryRequest) (QueryResult, error) { + attempts++ + if attempts < 3 { + return QueryResult{}, errors.New("no Duckgres worker is currently available; retry in about 45 seconds") + } + return QueryResult{Rows: 1}, nil + }, + } + var sleeps []time.Duration + executor := NewExecutor(ExecutorConfig{ + ProvisionState: provisionState, + Connection: ConnectionConfig{ + HostAddr: "10.0.0.10", + SNISuffix: ".dev.example", + Port: 5432, + SSLMode: "require", + }, + Driver: driver, + Retry: RetryConfig{ + MaxAttempts: 3, + RetryBackoff: 50 * time.Millisecond, + Sleep: func(_ context.Context, d time.Duration) error { + sleeps = append(sleeps, d) + return nil + }, + }, + }) + + err := executor.ExecuteStep(context.Background(), core.Step{ + ID: "select_one", + Type: StepTypeSQL, + With: map[string]any{ + "org_id": "scenario-org", + "catalog": "ducklake", + "sql": "SELECT 1", + }, + }) + if err != nil { + t.Fatalf("ExecuteStep returned error: %v", err) + } + if attempts != 3 { + t.Fatalf("attempts = %d, want 3", attempts) + } + if len(sleeps) != 2 || sleeps[0] != 50*time.Millisecond || sleeps[1] != 50*time.Millisecond { + t.Fatalf("sleeps = %#v, want two 50ms sleeps", sleeps) + } + result, ok := executor.State().Result("select_one") + if !ok { + t.Fatal("expected SQL result to be recorded") + } + if result.Rows != 1 || result.Attempts != 3 { + t.Fatalf("result = %+v, want rows=1 attempts=3", result) + } +} + +func TestExecutorDoesNotRetryNonTransientSQLErrors(t *testing.T) { + provisionState := provision.NewState() + provisionState.StoreProvisionResponse("scenario-org", provision.ProvisionResponse{ + Org: "scenario-org", + Username: "root", + Password: "root-password", + }) + var attempts int + driver := &fakeDriver{ + executeFunc: func(context.Context, QueryRequest) (QueryResult, error) { + attempts++ + return QueryResult{}, errors.New("syntax error at or near \"SELEC\"") + }, + } + executor := NewExecutor(ExecutorConfig{ + ProvisionState: provisionState, + Connection: ConnectionConfig{ + HostAddr: "10.0.0.10", + SNISuffix: ".dev.example", + Port: 5432, + SSLMode: "require", + }, + Driver: driver, + Retry: RetryConfig{MaxAttempts: 5}, + }) + + err := executor.ExecuteStep(context.Background(), core.Step{ + ID: "bad_sql", + Type: StepTypeSQL, + With: map[string]any{ + "org_id": "scenario-org", + "catalog": "ducklake", + "sql": "SELEC 1", + }, + }) + if err == nil { + t.Fatal("expected non-transient SQL error") + } + if attempts != 1 { + t.Fatalf("attempts = %d, want 1", attempts) + } + var classified core.ClassifiedError + if !errors.As(err, &classified) || classified.ErrorClass() != ErrorClassSQL { + t.Fatalf("error = %T %v, want class %q", err, err, ErrorClassSQL) + } +} + +func TestExecutorRunsInlineSQLCatalog(t *testing.T) { + provisionState := provision.NewState() + provisionState.StoreProvisionResponse("scenario-org", provision.ProvisionResponse{ + Org: "scenario-org", + Username: "root", + Password: "root-password", + }) + var queries []string + driver := &fakeDriver{ + executeFunc: func(_ context.Context, req QueryRequest) (QueryResult, error) { + queries = append(queries, req.SQL) + return QueryResult{Rows: 1}, nil + }, + } + executor := NewExecutor(ExecutorConfig{ + ProvisionState: provisionState, + Connection: ConnectionConfig{ + HostAddr: "10.0.0.10", + SNISuffix: ".dev.example", + Port: 5432, + SSLMode: "require", + }, + Driver: driver, + }) + + err := executor.ExecuteStep(context.Background(), core.Step{ + ID: "catalog", + Type: StepTypeSQLCatalog, + With: map[string]any{ + "org_id": "scenario-org", + "catalog": "ducklake", + "queries": []any{ + map[string]any{"id": "one", "sql": "SELECT 1"}, + map[string]any{"id": "two", "sql": "SELECT 2"}, + }, + }, + }) + if err != nil { + t.Fatalf("ExecuteStep returned error: %v", err) + } + if len(queries) != 2 || queries[0] != "SELECT 1" || queries[1] != "SELECT 2" { + t.Fatalf("queries = %#v, want SELECT 1 and SELECT 2", queries) + } + if _, ok := executor.State().Result("catalog/one"); !ok { + t.Fatal("expected first catalog query result") + } + if _, ok := executor.State().Result("catalog/two"); !ok { + t.Fatal("expected second catalog query result") + } +} + +func TestExecutorUsesProvisionCredentialsInDSN(t *testing.T) { + provisionState := provision.NewState() + provisionState.StoreProvisionResponse("scenario-org", provision.ProvisionResponse{ + Org: "scenario-org", + Username: "custom-root", + Password: "root password", + }) + var gotDSN string + driver := &fakeDriver{ + executeFunc: func(_ context.Context, req QueryRequest) (QueryResult, error) { + gotDSN = req.DSN + return QueryResult{Rows: 1}, nil + }, + } + executor := NewExecutor(ExecutorConfig{ + ProvisionState: provisionState, + Connection: ConnectionConfig{ + HostAddr: "10.0.0.10", + SNISuffix: ".dev.example", + Port: 5432, + SSLMode: "require", + }, + Driver: driver, + }) + + err := executor.ExecuteStep(context.Background(), core.Step{ + ID: "select_one", + Type: StepTypeSQL, + With: map[string]any{ + "org_id": "scenario-org", + "catalog": "iceberg", + "sql": "SELECT 1", + }, + }) + if err != nil { + t.Fatalf("ExecuteStep returned error: %v", err) + } + for _, want := range []string{"user=custom-root", "password='root password'", "dbname=iceberg"} { + if !strings.Contains(gotDSN, want) { + t.Fatalf("DSN %q missing %q", gotDSN, want) + } + } +} + +func TestExecutorFailsWithoutProvisionState(t *testing.T) { + executor := NewExecutor(ExecutorConfig{ + Connection: ConnectionConfig{ + HostAddr: "10.0.0.10", + SNISuffix: ".dev.example", + Port: 5432, + SSLMode: "require", + }, + Driver: &fakeDriver{ + executeFunc: func(context.Context, QueryRequest) (QueryResult, error) { + t.Fatal("driver should not run without provision state") + return QueryResult{}, nil + }, + }, + }) + + err := executor.ExecuteStep(context.Background(), core.Step{ + ID: "select_one", + Type: StepTypeSQL, + With: map[string]any{ + "org_id": "scenario-org", + "catalog": "ducklake", + "sql": "SELECT 1", + }, + }) + if err == nil { + t.Fatal("expected missing provision state to fail") + } + var classified core.ClassifiedError + if !errors.As(err, &classified) || classified.ErrorClass() != ErrorClassInvalidStepConfig { + t.Fatalf("error = %T %v, want class %q", err, err, ErrorClassInvalidStepConfig) + } +} + +type fakeDriver struct { + executeFunc func(context.Context, QueryRequest) (QueryResult, error) +} + +func (d *fakeDriver) Execute(ctx context.Context, req QueryRequest) (QueryResult, error) { + return d.executeFunc(ctx, req) +}