diff --git a/tests/scenario/core/runner.go b/tests/scenario/core/runner.go index b5c7db4b..c2c3a4d9 100644 --- a/tests/scenario/core/runner.go +++ b/tests/scenario/core/runner.go @@ -19,6 +19,11 @@ type StepExecutor interface { ExecuteStep(context.Context, Step) error } +type ClassifiedError interface { + error + ErrorClass() string +} + type StepExecutorFunc func(context.Context, Step) error func (f StepExecutorFunc) ExecuteStep(ctx context.Context, step Step) error { @@ -147,7 +152,7 @@ func (r *Runner) runStep(ctx context.Context, runID string, step Step, statusByS } if err != nil { result.Status = StepStatusFailed - result.ErrorClass = "execution_error" + result.ErrorClass = classifyStepError(err) result.Error = err.Error() result.Err = err } @@ -190,6 +195,14 @@ func firstUnsuccessfulDependency(step Step, statusByStep map[string]StepStatus) return "", false } +func classifyStepError(err error) string { + var classified ClassifiedError + if errors.As(err, &classified) && classified.ErrorClass() != "" { + return classified.ErrorClass() + } + return "execution_error" +} + func defaultRunID(s Scenario, startedAt time.Time) string { prefix := s.RunIDPrefix if prefix == "" { diff --git a/tests/scenario/core/runner_test.go b/tests/scenario/core/runner_test.go index f929184c..1d53b9d3 100644 --- a/tests/scenario/core/runner_test.go +++ b/tests/scenario/core/runner_test.go @@ -189,6 +189,40 @@ steps: } } +func TestRunnerRecordsClassifiedExecutorError(t *testing.T) { + sentinel := classifiedTestError{class: "cleanup_timeout", message: "cleanup timed out"} + scenario, err := ParseScenario([]byte(` +name: classified +steps: + - id: cleanup + type: fake +`)) + if err != nil { + t.Fatalf("ParseScenario returned error: %v", err) + } + + runner := NewRunner(RunnerConfig{ + RunID: "run-classified", + Scenario: scenario, + Executor: StepExecutorFunc(func(context.Context, Step) error { + return sentinel + }), + Now: fixedClock(time.Unix(1700000000, 0)), + }) + + _, err = runner.Run(context.Background()) + if !errors.Is(err, sentinel) { + t.Fatalf("runner error = %v, want classified sentinel", err) + } + results := runner.Results() + if len(results) != 1 { + t.Fatalf("results = %+v, want one result", results) + } + if results[0].ErrorClass != "cleanup_timeout" { + t.Fatalf("error class = %q, want cleanup_timeout", results[0].ErrorClass) + } +} + func TestRunnerReturnsSuccessForAllSuccessfulSteps(t *testing.T) { scenario, err := ParseScenario([]byte(` name: success @@ -238,3 +272,16 @@ func equalStrings(a, b []string) bool { } return true } + +type classifiedTestError struct { + class string + message string +} + +func (e classifiedTestError) Error() string { + return e.message +} + +func (e classifiedTestError) ErrorClass() string { + return e.class +} diff --git a/tests/scenario/provision/client.go b/tests/scenario/provision/client.go new file mode 100644 index 00000000..2c62d275 --- /dev/null +++ b/tests/scenario/provision/client.go @@ -0,0 +1,268 @@ +package provision + +import ( + "bytes" + "context" + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "net/url" + "strings" + "time" +) + +const ( + WarehouseStatePending = "pending" + WarehouseStateProvisioning = "provisioning" + WarehouseStateReady = "ready" + WarehouseStateFailed = "failed" + WarehouseStateDeleting = "deleting" + WarehouseStateDeleted = "deleted" + + defaultPollInterval = 10 * time.Second +) + +type Config struct { + BaseURL string + InternalSecret string + HTTPClient *http.Client +} + +type Client struct { + baseURL string + internalSecret string + httpClient *http.Client +} + +type ProvisionResponse struct { + Status string `json:"status"` + Org string `json:"org"` + Username string `json:"username"` + Password string `json:"password"` + Bucket string `json:"bucket,omitempty"` +} + +type DeprovisionResponse struct { + Status string `json:"status"` + Org string `json:"org"` +} + +type WarehouseStatus struct { + OrgID string `json:"org_id"` + State string `json:"state"` + StatusMessage string `json:"status_message"` + S3State string `json:"s3_state"` + MetadataStoreState string `json:"metadata_store_state"` + IdentityState string `json:"identity_state"` + SecretsState string `json:"secrets_state"` + ReadyAt *time.Time `json:"ready_at,omitempty"` + FailedAt *time.Time `json:"failed_at,omitempty"` + Connection *ConnectionDetails `json:"connection,omitempty"` + Bucket string `json:"bucket,omitempty"` +} + +type ConnectionDetails struct { + Host string `json:"host"` + Port int `json:"port"` + Database string `json:"database"` + Username string `json:"username"` +} + +type WaitOptions struct { + PollInterval time.Duration + Timeout time.Duration + MaxAttempts int + AcceptNotFound bool + Sleep func(context.Context, time.Duration) error +} + +func NewClient(cfg Config) (*Client, error) { + baseURL := strings.TrimRight(strings.TrimSpace(cfg.BaseURL), "/") + if baseURL == "" { + return nil, classified(ErrorClassConfig, fmt.Errorf("provisioning API base URL is required")) + } + if _, err := url.ParseRequestURI(baseURL); err != nil { + return nil, classified(ErrorClassConfig, fmt.Errorf("parse provisioning API base URL: %w", err)) + } + httpClient := cfg.HTTPClient + if httpClient == nil { + httpClient = http.DefaultClient + } + return &Client{ + baseURL: baseURL, + internalSecret: cfg.InternalSecret, + httpClient: httpClient, + }, nil +} + +func (c *Client) Provision(ctx context.Context, orgID string, request map[string]any) (ProvisionResponse, error) { + var resp ProvisionResponse + path := orgPath(orgID, "provision") + if err := c.doJSON(ctx, http.MethodPost, path, request, &resp, http.StatusAccepted); err != nil { + return ProvisionResponse{}, err + } + return resp, nil +} + +func (c *Client) WarehouseStatus(ctx context.Context, orgID string) (WarehouseStatus, error) { + var resp WarehouseStatus + path := orgPath(orgID, "warehouse/status") + if err := c.doJSON(ctx, http.MethodGet, path, nil, &resp, http.StatusOK); err != nil { + return WarehouseStatus{}, err + } + return resp, nil +} + +func (c *Client) Deprovision(ctx context.Context, orgID string) (DeprovisionResponse, error) { + var resp DeprovisionResponse + path := orgPath(orgID, "deprovision") + if err := c.doJSON(ctx, http.MethodPost, path, nil, &resp, http.StatusAccepted); err != nil { + return DeprovisionResponse{}, err + } + return resp, nil +} + +func (c *Client) WaitWarehouseReady(ctx context.Context, orgID string, opts WaitOptions) (WarehouseStatus, error) { + return c.waitForState(ctx, orgID, WarehouseStateReady, opts) +} + +func (c *Client) WaitWarehouseDeleted(ctx context.Context, orgID string, opts WaitOptions) (WarehouseStatus, error) { + opts.AcceptNotFound = true + return c.waitForState(ctx, orgID, WarehouseStateDeleted, opts) +} + +func (c *Client) doJSON(ctx context.Context, method, path string, body any, out any, expectedStatus int) error { + var bodyReader io.Reader + if body != nil { + raw, err := json.Marshal(body) + if err != nil { + return classified(ErrorClassInvalidStepConfig, fmt.Errorf("encode request for %s %s: %w", method, path, err)) + } + bodyReader = bytes.NewReader(raw) + } + + req, err := http.NewRequestWithContext(ctx, method, c.baseURL+path, bodyReader) + if err != nil { + return classified(ErrorClassProvisionAPI, fmt.Errorf("create request for %s %s: %w", method, path, err)) + } + if body != nil { + req.Header.Set("Content-Type", "application/json") + } + req.Header.Set("Accept", "application/json") + if c.internalSecret != "" { + req.Header.Set("X-Duckgres-Internal-Secret", c.internalSecret) + } + + resp, err := c.httpClient.Do(req) + if err != nil { + return classified(ErrorClassProvisionAPI, fmt.Errorf("%s %s failed: %w", method, path, err)) + } + defer func() { + _ = resp.Body.Close() + }() + + if resp.StatusCode != expectedStatus { + raw, _ := io.ReadAll(io.LimitReader(resp.Body, 16*1024)) + return &APIError{ + Method: method, + Path: path, + StatusCode: resp.StatusCode, + Body: redactHTTPBody(raw), + } + } + if out == nil { + return nil + } + if err := json.NewDecoder(resp.Body).Decode(out); err != nil { + return classified(ErrorClassProvisionAPI, fmt.Errorf("decode response for %s %s: %w", method, path, err)) + } + return nil +} + +func (c *Client) waitForState(ctx context.Context, orgID, target string, opts WaitOptions) (WarehouseStatus, error) { + waitCtx, cancel := contextWithOptionalTimeout(ctx, opts.Timeout) + defer cancel() + + interval := opts.PollInterval + if interval <= 0 { + interval = defaultPollInterval + } + sleep := opts.Sleep + if sleep == nil { + sleep = func(ctx context.Context, d time.Duration) error { + timer := time.NewTimer(d) + defer timer.Stop() + select { + case <-ctx.Done(): + return ctx.Err() + case <-timer.C: + return nil + } + } + } + + var last WarehouseStatus + for attempts := 1; ; attempts++ { + status, err := c.WarehouseStatus(waitCtx, orgID) + if err != nil { + var apiErr *APIError + if opts.AcceptNotFound && errorsAs(err, &apiErr) && apiErr.NotFound() { + return WarehouseStatus{OrgID: orgID, State: WarehouseStateDeleted}, nil + } + if waitCtx.Err() != nil { + return last, waitContextError(waitCtx, target, orgID) + } + return last, err + } + last = status + if status.State == target { + return status, nil + } + if status.State == WarehouseStateFailed { + return status, classified(ErrorClassProvisionFailed, fmt.Errorf("%w while waiting for %s to reach %s: %s", ErrWarehouseFailed, orgID, target, status.StatusMessage)) + } + if opts.MaxAttempts > 0 && attempts >= opts.MaxAttempts { + return status, waitTimeoutError(target, orgID) + } + if err := sleep(waitCtx, interval); err != nil { + if waitCtx.Err() != nil { + return status, waitContextError(waitCtx, target, orgID) + } + if errors.Is(err, context.Canceled) { + return status, err + } + if errors.Is(err, context.DeadlineExceeded) { + return status, waitTimeoutError(target, orgID) + } + return status, classified(ErrorClassProvisionAPI, fmt.Errorf("sleep while waiting for %s to reach %s: %w", orgID, target, err)) + } + } +} + +func contextWithOptionalTimeout(ctx context.Context, timeout time.Duration) (context.Context, context.CancelFunc) { + if timeout <= 0 { + return ctx, func() {} + } + return context.WithTimeout(ctx, timeout) +} + +func waitTimeoutError(target, orgID string) error { + return classified(ErrorClassWaitTimeout, fmt.Errorf("%w: warehouse %s did not reach %s", ErrWaitTimeout, orgID, target)) +} + +func waitContextError(ctx context.Context, target, orgID string) error { + if errors.Is(ctx.Err(), context.DeadlineExceeded) { + return waitTimeoutError(target, orgID) + } + return ctx.Err() +} + +func orgPath(orgID, suffix string) string { + return "/api/v1/orgs/" + url.PathEscape(orgID) + "/" + suffix +} + +func errorsAs(err error, target any) bool { + return err != nil && errors.As(err, target) +} diff --git a/tests/scenario/provision/client_test.go b/tests/scenario/provision/client_test.go new file mode 100644 index 00000000..442f4d6f --- /dev/null +++ b/tests/scenario/provision/client_test.go @@ -0,0 +1,328 @@ +package provision + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "strings" + "testing" + "time" +) + +func TestClientProvisionSendsExpectedRequestAndPreservesPassword(t *testing.T) { + var gotRequest map[string]any + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodPost { + t.Fatalf("method = %s, want POST", r.Method) + } + if r.URL.Path != "/api/v1/orgs/scenario-org/provision" { + t.Fatalf("path = %s, want provision path", r.URL.Path) + } + if got := r.Header.Get("X-Duckgres-Internal-Secret"); got != "internal-secret" { + t.Fatalf("internal secret header = %q, want internal-secret", got) + } + if err := json.NewDecoder(r.Body).Decode(&gotRequest); err != nil { + t.Fatalf("decode request: %v", err) + } + w.WriteHeader(http.StatusAccepted) + _ = json.NewEncoder(w).Encode(map[string]any{ + "status": "provisioning started", + "org": "scenario-org", + "username": "root", + "password": "root-password", + "bucket": "scenario-bucket", + }) + })) + defer server.Close() + + client, err := NewClient(Config{ + BaseURL: server.URL, + InternalSecret: "internal-secret", + HTTPClient: server.Client(), + }) + if err != nil { + t.Fatalf("NewClient returned error: %v", err) + } + + resp, err := client.Provision(context.Background(), "scenario-org", map[string]any{ + "database_name": "scenario_db", + "metadata_store": map[string]any{ + "type": "cnpg-shard", + }, + "ducklake": map[string]any{ + "enabled": true, + }, + }) + if err != nil { + t.Fatalf("Provision returned error: %v", err) + } + if gotRequest["database_name"] != "scenario_db" { + t.Fatalf("database_name = %#v, want scenario_db", gotRequest["database_name"]) + } + metadataStore, ok := gotRequest["metadata_store"].(map[string]any) + if !ok || metadataStore["type"] != "cnpg-shard" { + t.Fatalf("metadata_store = %#v, want type cnpg-shard", gotRequest["metadata_store"]) + } + if resp.Password != "root-password" { + t.Fatalf("password = %q, want root-password", resp.Password) + } + if resp.Bucket != "scenario-bucket" { + t.Fatalf("bucket = %q, want scenario-bucket", resp.Bucket) + } +} + +func TestRedactForArtifactRemovesPasswordsAndSecrets(t *testing.T) { + payload := map[string]any{ + "provision_response": ProvisionResponse{ + Org: "scenario-org", + Username: "root", + Password: "root-password", + }, + "nested": map[string]any{ + "internal_secret": "internal-secret", + "password_aws_secret": "aws-secret-name", + "safe": "keep-me", + }, + "tokens": []any{ + map[string]any{"token": "api-token"}, + }, + } + + redacted := RedactForArtifact(payload) + raw, err := json.Marshal(redacted) + if err != nil { + t.Fatalf("marshal redacted payload: %v", err) + } + text := string(raw) + for _, secret := range []string{"root-password", "internal-secret", "aws-secret-name", "api-token"} { + if strings.Contains(text, secret) { + t.Fatalf("redacted payload still contains %q: %s", secret, text) + } + } + if !strings.Contains(text, "keep-me") { + t.Fatalf("redacted payload should retain safe values: %s", text) + } + if !strings.Contains(text, "redacted") { + t.Fatalf("redacted payload should include redaction marker %q: %s", RedactedValue, text) + } +} + +func TestClientWaitWarehouseReadyPollsUntilReady(t *testing.T) { + states := []string{WarehouseStatePending, WarehouseStateProvisioning, WarehouseStateReady} + polls := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.Method != http.MethodGet { + t.Fatalf("method = %s, want GET", r.Method) + } + if r.URL.Path != "/api/v1/orgs/scenario-org/warehouse/status" { + t.Fatalf("path = %s, want status path", r.URL.Path) + } + state := states[polls] + polls++ + response := map[string]any{ + "org_id": "scenario-org", + "state": state, + "status_message": "status " + state, + "s3_state": state, + "metadata_store_state": state, + "identity_state": state, + "secrets_state": state, + } + if state == WarehouseStateReady { + response["connection"] = map[string]any{ + "host": "warehouse.example", + "port": 5432, + "database": "scenario_db", + "username": "root", + } + } + _ = json.NewEncoder(w).Encode(response) + })) + defer server.Close() + + client, err := NewClient(Config{BaseURL: server.URL, HTTPClient: server.Client()}) + if err != nil { + t.Fatalf("NewClient returned error: %v", err) + } + var sleeps []time.Duration + status, err := client.WaitWarehouseReady(context.Background(), "scenario-org", WaitOptions{ + PollInterval: 25 * time.Millisecond, + Sleep: func(_ context.Context, d time.Duration) error { + sleeps = append(sleeps, d) + return nil + }, + }) + if err != nil { + t.Fatalf("WaitWarehouseReady returned error: %v", err) + } + if status.State != WarehouseStateReady { + t.Fatalf("state = %q, want ready", status.State) + } + if status.Connection == nil || status.Connection.Host != "warehouse.example" { + t.Fatalf("connection = %+v, want warehouse.example", status.Connection) + } + if polls != 3 { + t.Fatalf("polls = %d, want 3", polls) + } + if len(sleeps) != 2 || sleeps[0] != 25*time.Millisecond || sleeps[1] != 25*time.Millisecond { + t.Fatalf("sleeps = %#v, want two 25ms sleeps", sleeps) + } +} + +func TestClientWaitWarehouseReadyFailsFastOnFailed(t *testing.T) { + polls := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + polls++ + _ = json.NewEncoder(w).Encode(map[string]any{ + "org_id": "scenario-org", + "state": WarehouseStateFailed, + "status_message": "composition failed", + }) + })) + defer server.Close() + + client, err := NewClient(Config{BaseURL: server.URL, HTTPClient: server.Client()}) + if err != nil { + t.Fatalf("NewClient returned error: %v", err) + } + _, err = client.WaitWarehouseReady(context.Background(), "scenario-org", WaitOptions{ + Sleep: func(context.Context, time.Duration) error { + t.Fatal("sleep should not be called for failed state") + return nil + }, + }) + if !errors.Is(err, ErrWarehouseFailed) { + t.Fatalf("WaitWarehouseReady error = %v, want ErrWarehouseFailed", err) + } + if polls != 1 { + t.Fatalf("polls = %d, want 1", polls) + } + if !strings.Contains(err.Error(), "composition failed") { + t.Fatalf("error = %v, want status message", err) + } +} + +func TestClientDeprovisionSendsExpectedRequest(t *testing.T) { + called := false + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + called = true + if r.Method != http.MethodPost { + t.Fatalf("method = %s, want POST", r.Method) + } + if r.URL.Path != "/api/v1/orgs/scenario-org/deprovision" { + t.Fatalf("path = %s, want deprovision path", r.URL.Path) + } + w.WriteHeader(http.StatusAccepted) + _ = json.NewEncoder(w).Encode(map[string]any{ + "status": "deprovisioning started", + "org": "scenario-org", + }) + })) + defer server.Close() + + client, err := NewClient(Config{BaseURL: server.URL, HTTPClient: server.Client()}) + if err != nil { + t.Fatalf("NewClient returned error: %v", err) + } + resp, err := client.Deprovision(context.Background(), "scenario-org") + if err != nil { + t.Fatalf("Deprovision returned error: %v", err) + } + if !called { + t.Fatal("expected deprovision endpoint to be called") + } + if resp.Org != "scenario-org" || resp.Status != "deprovisioning started" { + t.Fatalf("deprovision response = %+v", resp) + } +} + +func TestClientClassifiesHTTPFailuresAndRedactsBody(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + w.WriteHeader(http.StatusConflict) + _ = json.NewEncoder(w).Encode(map[string]any{ + "error": "password root-password token api-token Authorization: Bearer bearer-token", + "password": "root-password", + }) + })) + defer server.Close() + + client, err := NewClient(Config{BaseURL: server.URL, HTTPClient: server.Client()}) + if err != nil { + t.Fatalf("NewClient returned error: %v", err) + } + _, err = client.Provision(context.Background(), "scenario-org", map[string]any{ + "database_name": "scenario_db", + }) + if !errors.Is(err, ErrUnexpectedStatus) { + t.Fatalf("Provision error = %v, want ErrUnexpectedStatus", err) + } + var apiErr *APIError + if !errors.As(err, &apiErr) { + t.Fatalf("Provision error = %T %v, want APIError", err, err) + } + if apiErr.ErrorClass() != ErrorClassProvisionAPI { + t.Fatalf("error class = %q, want %q", apiErr.ErrorClass(), ErrorClassProvisionAPI) + } + for _, secret := range []string{"root-password", "api-token", "bearer-token"} { + if strings.Contains(err.Error(), secret) { + t.Fatalf("error leaked %q: %v", secret, err) + } + } +} + +func TestClientWaitPreservesContextCancellation(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _ = json.NewEncoder(w).Encode(map[string]any{ + "org_id": "scenario-org", + "state": WarehouseStateProvisioning, + }) + })) + defer server.Close() + + client, err := NewClient(Config{BaseURL: server.URL, HTTPClient: server.Client()}) + if err != nil { + t.Fatalf("NewClient returned error: %v", err) + } + ctx, cancel := context.WithCancel(context.Background()) + _, err = client.WaitWarehouseReady(ctx, "scenario-org", WaitOptions{ + Sleep: func(context.Context, time.Duration) error { + cancel() + return context.Canceled + }, + }) + if !errors.Is(err, context.Canceled) { + t.Fatalf("WaitWarehouseReady error = %v, want context.Canceled", err) + } + if errors.Is(err, ErrWaitTimeout) { + t.Fatalf("WaitWarehouseReady error = %v, should not be ErrWaitTimeout", err) + } +} + +func TestClientWaitWarehouseDeletedReportsTimeout(t *testing.T) { + polls := 0 + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + polls++ + _ = json.NewEncoder(w).Encode(map[string]any{ + "org_id": "scenario-org", + "state": WarehouseStateDeleting, + }) + })) + defer server.Close() + + client, err := NewClient(Config{BaseURL: server.URL, HTTPClient: server.Client()}) + if err != nil { + t.Fatalf("NewClient returned error: %v", err) + } + _, err = client.WaitWarehouseDeleted(context.Background(), "scenario-org", WaitOptions{ + MaxAttempts: 2, + Sleep: func(context.Context, time.Duration) error { return nil }, + }) + if !errors.Is(err, ErrWaitTimeout) { + t.Fatalf("WaitWarehouseDeleted error = %v, want ErrWaitTimeout", err) + } + if polls != 2 { + t.Fatalf("polls = %d, want 2", polls) + } +} diff --git a/tests/scenario/provision/errors.go b/tests/scenario/provision/errors.go new file mode 100644 index 00000000..86ea52ae --- /dev/null +++ b/tests/scenario/provision/errors.go @@ -0,0 +1,75 @@ +package provision + +import ( + "errors" + "fmt" + "net/http" +) + +const ( + ErrorClassConfig = "configuration_error" + ErrorClassProvisionAPI = "provision_api_error" + ErrorClassProvisionFailed = "provision_failed" + ErrorClassWaitTimeout = "wait_timeout" + ErrorClassCleanupError = "cleanup_error" + ErrorClassCleanupTimeout = "cleanup_timeout" + ErrorClassUnsupportedStep = "unsupported_step" + ErrorClassInvalidStepConfig = "invalid_step_config" + ErrorClassProvisionStepError = "provision_step_error" +) + +var ( + ErrUnexpectedStatus = errors.New("unexpected provisioning api status") + ErrWarehouseFailed = errors.New("warehouse failed") + ErrWaitTimeout = errors.New("warehouse wait timeout") +) + +type classifiedError struct { + class string + err error +} + +func (e classifiedError) Error() string { + return e.err.Error() +} + +func (e classifiedError) Unwrap() error { + return e.err +} + +func (e classifiedError) ErrorClass() string { + return e.class +} + +func classified(class string, err error) error { + if err == nil { + return nil + } + return classifiedError{class: class, err: err} +} + +type APIError struct { + Method string + Path string + StatusCode int + Body string +} + +func (e *APIError) Error() string { + if e.Body == "" { + return fmt.Sprintf("%s %s: %v: HTTP %d", e.Method, e.Path, ErrUnexpectedStatus, e.StatusCode) + } + return fmt.Sprintf("%s %s: %v: HTTP %d: %s", e.Method, e.Path, ErrUnexpectedStatus, e.StatusCode, e.Body) +} + +func (e *APIError) Unwrap() error { + return ErrUnexpectedStatus +} + +func (e *APIError) ErrorClass() string { + return ErrorClassProvisionAPI +} + +func (e *APIError) NotFound() bool { + return e.StatusCode == http.StatusNotFound +} diff --git a/tests/scenario/provision/redact.go b/tests/scenario/provision/redact.go new file mode 100644 index 00000000..2ced23ed --- /dev/null +++ b/tests/scenario/provision/redact.go @@ -0,0 +1,80 @@ +package provision + +import ( + "encoding/json" + "regexp" + "strings" +) + +const RedactedValue = "" + +var ( + authorizationPattern = regexp.MustCompile(`(?i)(authorization\s*[:=]\s*)(bearer\s+)?([^&\s",}]+)`) + sensitiveTextPattern = regexp.MustCompile(`(?i)(password|secret|token|authorization)(["'=:\s]+)([^&\s",}]+)`) +) + +func RedactForArtifact(v any) any { + raw, err := json.Marshal(v) + if err != nil { + return RedactedValue + } + var decoded any + if err := json.Unmarshal(raw, &decoded); err != nil { + return RedactedValue + } + return redactValue(decoded) +} + +func redactHTTPBody(raw []byte) string { + if len(raw) == 0 { + return "" + } + var decoded any + if err := json.Unmarshal(raw, &decoded); err == nil { + redacted, err := json.Marshal(RedactForArtifact(decoded)) + if err == nil { + return string(redacted) + } + } + return redactSensitiveText(string(raw)) +} + +func redactValue(v any) any { + switch typed := v.(type) { + case map[string]any: + out := make(map[string]any, len(typed)) + for k, value := range typed { + if isSensitiveKey(k) { + out[k] = RedactedValue + continue + } + out[k] = redactValue(value) + } + return out + case []any: + out := make([]any, len(typed)) + for i, value := range typed { + out[i] = redactValue(value) + } + return out + case string: + return redactSensitiveText(typed) + default: + return typed + } +} + +func isSensitiveKey(key string) bool { + normalized := strings.ToLower(strings.ReplaceAll(key, "-", "_")) + for _, marker := range []string{"password", "secret", "token", "authorization"} { + if strings.Contains(normalized, marker) { + return true + } + } + return false +} + +func redactSensitiveText(value string) string { + value = authorizationPattern.ReplaceAllString(value, `${1}${2}`+RedactedValue) + return sensitiveTextPattern.ReplaceAllString(value, `${1}${2}`+RedactedValue) +} diff --git a/tests/scenario/provision/steps.go b/tests/scenario/provision/steps.go new file mode 100644 index 00000000..13b8d296 --- /dev/null +++ b/tests/scenario/provision/steps.go @@ -0,0 +1,298 @@ +package provision + +import ( + "context" + "errors" + "fmt" + "strconv" + "sync" + "time" + + "github.com/posthog/duckgres/tests/scenario/core" +) + +const ( + StepTypeProvisionWarehouse = "provision_warehouse" + StepTypeWaitWarehouseReady = "wait_warehouse_ready" + StepTypeDeprovisionWarehouse = "deprovision_warehouse" +) + +type ExecutorConfig struct { + Client *Client + State *State + WaitOptions WaitOptions +} + +type Executor struct { + client *Client + state *State + waitOptions WaitOptions +} + +type State struct { + mu sync.Mutex + provisionResponses map[string]ProvisionResponse + statuses map[string]WarehouseStatus +} + +func NewExecutor(cfg ExecutorConfig) *Executor { + state := cfg.State + if state == nil { + state = NewState() + } + return &Executor{ + client: cfg.Client, + state: state, + waitOptions: cfg.WaitOptions, + } +} + +func NewState() *State { + return &State{ + provisionResponses: make(map[string]ProvisionResponse), + statuses: make(map[string]WarehouseStatus), + } +} + +func (s *State) StoreProvisionResponse(orgID string, resp ProvisionResponse) { + s.mu.Lock() + defer s.mu.Unlock() + s.provisionResponses[orgID] = resp +} + +func (s *State) ProvisionResponse(orgID string) (ProvisionResponse, bool) { + s.mu.Lock() + defer s.mu.Unlock() + resp, ok := s.provisionResponses[orgID] + return resp, ok +} + +func (s *State) StoreStatus(orgID string, status WarehouseStatus) { + s.mu.Lock() + defer s.mu.Unlock() + s.statuses[orgID] = status +} + +func (s *State) Status(orgID string) (WarehouseStatus, bool) { + s.mu.Lock() + defer s.mu.Unlock() + status, ok := s.statuses[orgID] + return status, ok +} + +func (e *Executor) ExecuteStep(ctx context.Context, step core.Step) error { + if e.client == nil { + return classified(ErrorClassConfig, fmt.Errorf("provision executor client is required")) + } + switch step.Type { + case StepTypeProvisionWarehouse: + return e.executeProvision(ctx, step) + case StepTypeWaitWarehouseReady: + return e.executeWaitReady(ctx, step) + case StepTypeDeprovisionWarehouse: + return e.executeDeprovision(ctx, step) + default: + return classified(ErrorClassUnsupportedStep, fmt.Errorf("unsupported provision step type %q", step.Type)) + } +} + +func (e *Executor) executeProvision(ctx context.Context, step core.Step) error { + orgID, err := requiredString(step, "org_id") + if err != nil { + return err + } + request, err := requestMap(step) + if err != nil { + return err + } + resp, err := e.client.Provision(ctx, orgID, request) + if err != nil { + return classified(ErrorClassProvisionStepError, err) + } + e.state.StoreProvisionResponse(orgID, resp) + return nil +} + +func (e *Executor) executeWaitReady(ctx context.Context, step core.Step) error { + orgID, err := requiredString(step, "org_id") + if err != nil { + return err + } + opts, err := e.waitOptionsForStep(step) + if err != nil { + return err + } + status, err := e.client.WaitWarehouseReady(ctx, orgID, opts) + if err != nil { + return err + } + e.state.StoreStatus(orgID, status) + return nil +} + +func (e *Executor) executeDeprovision(ctx context.Context, step core.Step) error { + orgID, err := requiredString(step, "org_id") + if err != nil { + return err + } + verifyDeleted, err := boolFromWith(step, "verify_deleted") + if err != nil { + return err + } + opts, err := e.waitOptionsForStep(step) + if err != nil { + return err + } + if _, err := e.client.Deprovision(ctx, orgID); err != nil { + var apiErr *APIError + if errors.As(err, &apiErr) && apiErr.NotFound() { + return nil + } + return classified(ErrorClassCleanupError, err) + } + if !verifyDeleted { + return nil + } + status, err := e.client.WaitWarehouseDeleted(ctx, orgID, opts) + if err != nil { + if errorsIs(err, ErrWaitTimeout) { + return classified(ErrorClassCleanupTimeout, err) + } + return classified(ErrorClassCleanupError, err) + } + e.state.StoreStatus(orgID, status) + return nil +} + +func (e *Executor) waitOptionsForStep(step core.Step) (WaitOptions, error) { + opts := e.waitOptions + if timeout, ok, err := durationFromWith(step, "timeout"); err != nil { + return WaitOptions{}, err + } else if ok { + opts.Timeout = timeout + } + if cleanupTimeout, ok, err := durationFromWith(step, "cleanup_timeout"); err != nil { + return WaitOptions{}, err + } else if ok { + opts.Timeout = cleanupTimeout + } + if interval, ok, err := durationFromWith(step, "poll_interval"); err != nil { + return WaitOptions{}, err + } else if ok { + opts.PollInterval = interval + } + if interval, ok, err := durationFromWith(step, "interval"); err != nil { + return WaitOptions{}, err + } else if ok { + opts.PollInterval = interval + } + if maxAttempts, ok, err := intFromWith(step, "max_attempts"); err != nil { + return WaitOptions{}, err + } else if ok { + opts.MaxAttempts = maxAttempts + } + return opts, nil +} + +func requiredString(step core.Step, key string) (string, error) { + value, ok := step.With[key] + if !ok { + return "", classified(ErrorClassInvalidStepConfig, fmt.Errorf("step %s requires with.%s", step.ID, key)) + } + text, ok := value.(string) + if !ok || text == "" { + return "", classified(ErrorClassInvalidStepConfig, fmt.Errorf("step %s with.%s must be a non-empty string", step.ID, key)) + } + return text, nil +} + +func requestMap(step core.Step) (map[string]any, error) { + value, ok := step.With["request"] + if !ok { + return nil, classified(ErrorClassInvalidStepConfig, fmt.Errorf("step %s requires with.request", step.ID)) + } + request, ok := value.(map[string]any) + if !ok { + return nil, classified(ErrorClassInvalidStepConfig, fmt.Errorf("step %s with.request must be a map", step.ID)) + } + return request, nil +} + +func boolFromWith(step core.Step, key string) (bool, error) { + value, ok := step.With[key] + if !ok { + return false, nil + } + switch typed := value.(type) { + case bool: + return typed, nil + case string: + parsed, err := strconv.ParseBool(typed) + if err != nil { + return false, classified(ErrorClassInvalidStepConfig, fmt.Errorf("step %s with.%s must be a boolean", step.ID, key)) + } + return parsed, nil + default: + return false, classified(ErrorClassInvalidStepConfig, fmt.Errorf("step %s with.%s must be a boolean", step.ID, key)) + } +} + +func durationFromWith(step core.Step, key string) (time.Duration, bool, error) { + value, ok := step.With[key] + if !ok { + return 0, false, nil + } + switch typed := value.(type) { + case time.Duration: + if typed < 0 { + return 0, false, classified(ErrorClassInvalidStepConfig, fmt.Errorf("step %s with.%s must not be negative", step.ID, key)) + } + return typed, true, nil + case string: + parsed, err := time.ParseDuration(typed) + if err != nil { + return 0, false, classified(ErrorClassInvalidStepConfig, fmt.Errorf("step %s with.%s must be a Go duration: %w", step.ID, key, err)) + } + if parsed < 0 { + return 0, false, classified(ErrorClassInvalidStepConfig, fmt.Errorf("step %s with.%s must not be negative", step.ID, key)) + } + return parsed, true, nil + default: + return 0, false, classified(ErrorClassInvalidStepConfig, fmt.Errorf("step %s with.%s must be a Go duration string", step.ID, key)) + } +} + +func intFromWith(step core.Step, key string) (int, bool, error) { + value, ok := step.With[key] + if !ok { + return 0, false, nil + } + var parsed int + switch typed := value.(type) { + case int: + parsed = typed + case int64: + parsed = int(typed) + case float64: + if typed != float64(int(typed)) { + return 0, false, classified(ErrorClassInvalidStepConfig, fmt.Errorf("step %s with.%s must be an integer", step.ID, key)) + } + parsed = int(typed) + case string: + var err error + parsed, err = strconv.Atoi(typed) + if err != nil { + return 0, false, classified(ErrorClassInvalidStepConfig, fmt.Errorf("step %s with.%s must be an integer: %w", step.ID, key, err)) + } + default: + return 0, false, classified(ErrorClassInvalidStepConfig, fmt.Errorf("step %s with.%s must be an integer", step.ID, key)) + } + if parsed < 0 { + return 0, false, classified(ErrorClassInvalidStepConfig, fmt.Errorf("step %s with.%s must not be negative", step.ID, key)) + } + return parsed, true, nil +} + +func errorsIs(err, target error) bool { + return err != nil && target != nil && errors.Is(err, target) +} diff --git a/tests/scenario/provision/steps_test.go b/tests/scenario/provision/steps_test.go new file mode 100644 index 00000000..b7c2d0cc --- /dev/null +++ b/tests/scenario/provision/steps_test.go @@ -0,0 +1,307 @@ +package provision + +import ( + "context" + "encoding/json" + "errors" + "net/http" + "net/http/httptest" + "testing" + "time" + + "github.com/posthog/duckgres/tests/scenario/core" +) + +func TestExecutorProvisionStepStoresPasswordForLaterSteps(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/api/v1/orgs/scenario-org/provision" { + t.Fatalf("path = %s, want provision path", r.URL.Path) + } + w.WriteHeader(http.StatusAccepted) + _ = json.NewEncoder(w).Encode(map[string]any{ + "status": "provisioning started", + "org": "scenario-org", + "username": "root", + "password": "root-password", + }) + })) + defer server.Close() + + client, err := NewClient(Config{BaseURL: server.URL, HTTPClient: server.Client()}) + if err != nil { + t.Fatalf("NewClient returned error: %v", err) + } + state := NewState() + executor := NewExecutor(ExecutorConfig{Client: client, State: state}) + + err = executor.ExecuteStep(context.Background(), core.Step{ + ID: "provision", + Type: StepTypeProvisionWarehouse, + With: map[string]any{ + "org_id": "scenario-org", + "request": map[string]any{ + "database_name": "scenario_db", + "metadata_store": map[string]any{"type": "cnpg-shard"}, + "ducklake": map[string]any{"enabled": true}, + }, + }, + }) + if err != nil { + t.Fatalf("ExecuteStep returned error: %v", err) + } + resp, ok := state.ProvisionResponse("scenario-org") + if !ok { + t.Fatal("expected provision response in state") + } + if resp.Password != "root-password" { + t.Fatalf("stored password = %q, want root-password", resp.Password) + } +} + +func TestExecutorRunsDeprovisionAfterWorkloadFailure(t *testing.T) { + deprovisionCalled := false + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/v1/orgs/scenario-org/provision": + w.WriteHeader(http.StatusAccepted) + _ = json.NewEncoder(w).Encode(map[string]any{ + "status": "provisioning started", + "org": "scenario-org", + "username": "root", + "password": "root-password", + }) + case "/api/v1/orgs/scenario-org/deprovision": + deprovisionCalled = true + w.WriteHeader(http.StatusAccepted) + _ = json.NewEncoder(w).Encode(map[string]any{ + "status": "deprovisioning started", + "org": "scenario-org", + }) + default: + t.Fatalf("unexpected path %s", r.URL.Path) + } + })) + defer server.Close() + + client, err := NewClient(Config{BaseURL: server.URL, HTTPClient: server.Client()}) + if err != nil { + t.Fatalf("NewClient returned error: %v", err) + } + provisionExecutor := NewExecutor(ExecutorConfig{Client: client}) + workloadErr := errors.New("workload failed") + scenario, err := core.ParseScenario([]byte(` +name: cleanup-after-failure +steps: + - id: provision + type: provision_warehouse + with: + org_id: scenario-org + request: + database_name: scenario_db + metadata_store: + type: cnpg-shard + ducklake: + enabled: true + - id: workload + type: fake_workload + - id: deprovision + type: deprovision_warehouse + depends_on: [workload] + always_run: true + with: + org_id: scenario-org +`)) + if err != nil { + t.Fatalf("ParseScenario returned error: %v", err) + } + runner := core.NewRunner(core.RunnerConfig{ + RunID: "run-cleanup", + Scenario: scenario, + Executor: core.StepExecutorFunc(func(ctx context.Context, step core.Step) error { + if step.Type == "fake_workload" { + return workloadErr + } + return provisionExecutor.ExecuteStep(ctx, step) + }), + }) + + _, err = runner.Run(context.Background()) + if !errors.Is(err, workloadErr) { + t.Fatalf("runner error = %v, want workload error", err) + } + if !deprovisionCalled { + t.Fatal("expected deprovision to run after workload failure") + } +} + +func TestExecutorTreatsCleanupNotFoundAsAlreadyCleanedUp(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/v1/orgs/scenario-org/provision": + w.WriteHeader(http.StatusInternalServerError) + _ = json.NewEncoder(w).Encode(map[string]any{ + "error": "provision failed before warehouse row was created", + }) + case "/api/v1/orgs/scenario-org/deprovision": + w.WriteHeader(http.StatusNotFound) + _ = json.NewEncoder(w).Encode(map[string]any{ + "error": "warehouse not found", + }) + default: + t.Fatalf("unexpected path %s", r.URL.Path) + } + })) + defer server.Close() + + client, err := NewClient(Config{BaseURL: server.URL, HTTPClient: server.Client()}) + if err != nil { + t.Fatalf("NewClient returned error: %v", err) + } + provisionExecutor := NewExecutor(ExecutorConfig{Client: client}) + scenario, err := core.ParseScenario([]byte(` +name: cleanup-after-partial-provision +steps: + - id: provision + type: provision_warehouse + with: + org_id: scenario-org + request: + database_name: scenario_db + - id: deprovision + type: deprovision_warehouse + depends_on: [provision] + always_run: true + with: + org_id: scenario-org +`)) + if err != nil { + t.Fatalf("ParseScenario returned error: %v", err) + } + runner := core.NewRunner(core.RunnerConfig{ + RunID: "run-partial-provision", + Scenario: scenario, + Executor: provisionExecutor, + }) + + _, err = runner.Run(context.Background()) + if !errors.Is(err, ErrUnexpectedStatus) { + t.Fatalf("runner error = %v, want provision API failure", err) + } + results := runner.Results() + if len(results) != 2 { + t.Fatalf("results = %+v, want 2 results", results) + } + if results[1].Status != core.StepStatusOK { + t.Fatalf("cleanup result = %+v, want success", results[1]) + } +} + +func TestExecutorCleanupTimeoutIsDistinctFromWorkloadFailure(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + switch r.URL.Path { + case "/api/v1/orgs/scenario-org/deprovision": + w.WriteHeader(http.StatusAccepted) + _ = json.NewEncoder(w).Encode(map[string]any{ + "status": "deprovisioning started", + "org": "scenario-org", + }) + case "/api/v1/orgs/scenario-org/warehouse/status": + _ = json.NewEncoder(w).Encode(map[string]any{ + "org_id": "scenario-org", + "state": WarehouseStateDeleting, + }) + default: + t.Fatalf("unexpected path %s", r.URL.Path) + } + })) + defer server.Close() + + client, err := NewClient(Config{BaseURL: server.URL, HTTPClient: server.Client()}) + if err != nil { + t.Fatalf("NewClient returned error: %v", err) + } + provisionExecutor := NewExecutor(ExecutorConfig{ + Client: client, + WaitOptions: WaitOptions{ + MaxAttempts: 2, + Sleep: func(context.Context, time.Duration) error { return nil }, + }, + }) + workloadErr := errors.New("workload failed") + scenario, err := core.ParseScenario([]byte(` +name: cleanup-timeout +steps: + - id: workload + type: fake_workload + - id: deprovision + type: deprovision_warehouse + always_run: true + with: + org_id: scenario-org + verify_deleted: true +`)) + if err != nil { + t.Fatalf("ParseScenario returned error: %v", err) + } + runner := core.NewRunner(core.RunnerConfig{ + RunID: "run-cleanup-timeout", + Scenario: scenario, + Executor: core.StepExecutorFunc(func(ctx context.Context, step core.Step) error { + if step.Type == "fake_workload" { + return workloadErr + } + return provisionExecutor.ExecuteStep(ctx, step) + }), + }) + + _, err = runner.Run(context.Background()) + if !errors.Is(err, workloadErr) { + t.Fatalf("runner error = %v, want workload failure", err) + } + if !errors.Is(err, ErrWaitTimeout) { + t.Fatalf("runner error = %v, want cleanup timeout", err) + } + results := runner.Results() + if len(results) != 2 || results[1].Status != core.StepStatusFailed { + t.Fatalf("deprovision result = %+v", results) + } + if results[1].ErrorClass != ErrorClassCleanupTimeout { + t.Fatalf("cleanup error class = %q, want %q", results[1].ErrorClass, ErrorClassCleanupTimeout) + } +} + +func TestExecutorRejectsInvalidWaitOptions(t *testing.T) { + server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, _ *http.Request) { + _ = json.NewEncoder(w).Encode(map[string]any{ + "org_id": "scenario-org", + "state": WarehouseStateReady, + }) + })) + defer server.Close() + + client, err := NewClient(Config{BaseURL: server.URL, HTTPClient: server.Client()}) + if err != nil { + t.Fatalf("NewClient returned error: %v", err) + } + executor := NewExecutor(ExecutorConfig{Client: client}) + err = executor.ExecuteStep(context.Background(), core.Step{ + ID: "wait", + Type: StepTypeWaitWarehouseReady, + With: map[string]any{ + "org_id": "scenario-org", + "timeout": "not-a-duration", + "max_attempts": "not-an-int", + "verify_deleted": "maybe", + }, + }) + if err == nil { + t.Fatal("expected invalid wait options to fail") + } + var classified core.ClassifiedError + if !errors.As(err, &classified) { + t.Fatalf("error = %T %v, want classified error", err, err) + } + if classified.ErrorClass() != ErrorClassInvalidStepConfig { + t.Fatalf("error class = %q, want %q", classified.ErrorClass(), ErrorClassInvalidStepConfig) + } +}