diff --git a/Cargo.lock b/Cargo.lock index 9bb9f12..5440545 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -221,9 +221,9 @@ checksum = "009994f150cc0cd50ff54917d5bc8bffe8cad10ca10d81c34da2ec421ae61782" [[package]] name = "tree-sitter-postgres" -version = "1.2.2" +version = "1.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81b6b5beaa4ef890a0ccf39eee2644752cef73669b336b8723899c08c382ed0c" +checksum = "a0d4968176b02e91b0521a1bf115bce33106343afc85af58c4bee7f1216099ac" dependencies = [ "cc", "tree-sitter-language", diff --git a/Cargo.toml b/Cargo.toml index 5e30d7e..2dd302c 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,7 +11,7 @@ categories = ["text-processing", "development-tools"] [dependencies] tree-sitter = "0.26" -tree-sitter-postgres = "1.2.2" +tree-sitter-postgres = "1.2.3" [dev-dependencies] pretty_assertions = "1" diff --git a/src/formatter/plpgsql.rs b/src/formatter/plpgsql.rs index 42db297..dfa4b7d 100644 --- a/src/formatter/plpgsql.rs +++ b/src/formatter/plpgsql.rs @@ -54,6 +54,21 @@ impl<'a> Formatter<'a> { .map(|n| self.text(n).trim().to_string()) .unwrap_or_default(); + // Alias declaration: name ALIAS FOR target ; + if decl.has_child("kw_alias") { + let target = decl + .named_children(&mut decl.walk()) + .last() + .map(|n| self.text(n).trim().to_string()) + .unwrap_or_default(); + lines.push(format!( + "{indent}{var_name} {} {} {target};", + self.kw("ALIAS"), + self.kw("FOR") + )); + return; + } + let mut parts = vec![var_name]; // Constant? @@ -108,15 +123,24 @@ impl<'a> Formatter<'a> { "stmt_foreach_a" => self.format_stmt_foreach(child, indent_level, lines), "stmt_case" => self.format_stmt_case(child, indent_level, lines), "stmt_return" => { + let mut parts = vec![self.kw("RETURN")]; + // RETURN NEXT [expr] / RETURN QUERY [EXECUTE] ... + if child.has_child("kw_next") { + parts.push(self.kw("NEXT")); + } else if child.has_child("kw_query") { + parts.push(self.kw("QUERY")); + if child.has_child("kw_execute") { + parts.push(self.kw("EXECUTE")); + } + } let expr = child .find_child("sql_expression") .map(|n| self.text(n).trim()) .unwrap_or(""); - if expr.is_empty() { - lines.push(format!("{indent}{};", self.kw("RETURN"))); - } else { - lines.push(format!("{indent}{} {expr};", self.kw("RETURN"))); + if !expr.is_empty() { + parts.push(expr.to_string()); } + lines.push(format!("{indent}{};", parts.join(" "))); } "stmt_raise" => self.format_stmt_raise(child, indent_level, lines), "stmt_null" => { @@ -261,24 +285,25 @@ impl<'a> Formatter<'a> { .map(|n| self.text(n).trim()) .unwrap_or(""); - // Determine if it's a FOR ... IN range or FOR ... IN query. - let in_clause = if let Some(range) = node.find_child("for_integer_range") { - self.text(range).trim().to_string() - } else if let Some(query) = node.find_child("for_control") { - self.text(query).trim().to_string() - } else { - // Fallback: reconstruct from source. - let text = self.text(node); - if let Some(start) = text.find("IN") { - if let Some(end) = text.find("LOOP") { - text[start + 2..end].trim().to_string() - } else { - String::new() - } - } else { - String::new() - } - }; + // The IN clause is one of the for_* variants (integer range, query, + // cursor, or dynamic EXECUTE). Each ends with a nested LOOP keyword, so + // take the variant's text up to that keyword. + let mut cursor2 = node.walk(); + let in_clause = node + .named_children(&mut cursor2) + .find(|c| { + matches!( + c.kind(), + "for_integer_range" | "for_query" | "for_cursor" | "for_dynamic" + ) + }) + .map(|variant| match variant.find_child("kw_loop") { + Some(loop_kw) => self.source[variant.start_byte()..loop_kw.start_byte()] + .trim() + .to_string(), + None => self.text(variant).trim().to_string(), + }) + .unwrap_or_default(); let for_kw = self.kw("FOR"); let in_kw = self.kw("IN"); diff --git a/tests/plpgsql_test.rs b/tests/plpgsql_test.rs index 9a50fdb..bd22e4a 100644 --- a/tests/plpgsql_test.rs +++ b/tests/plpgsql_test.rs @@ -31,3 +31,44 @@ BEGIN END;"; assert_eq!(result, expected, "\nGot:\n{result}"); } + +// Regression: declarations using multi-word type names, DEFAULT, and ALIAS FOR +// a positional parameter previously failed to parse (grammar gaps), and the +// ALIAS form dropped its target when formatting. +#[test] +fn declarations_types_default_alias() { + let body = "DECLARE\n a character varying(50);\n b double precision;\n c timestamp with time zone;\n d integer DEFAULT 0;\n username ALIAS FOR $1;\nBEGIN\n NULL;\nEND"; + let result = format_plpgsql(body, Style::Aweber).unwrap(); + let expected = "\ +DECLARE + a character varying(50); + b double precision; + c timestamp with time zone; + d integer DEFAULT 0; + username ALIAS FOR $1; +BEGIN + NULL; +END;"; + assert_eq!(result, expected, "\nGot:\n{result}"); +} + +// Regression: `RETURN NEXT` (bare) failed to parse, and the formatter dropped +// the NEXT keyword. +#[test] +fn return_next_bare() { + let body = "BEGIN\n RETURN NEXT;\nEND"; + let result = format_plpgsql(body, Style::Aweber).unwrap(); + assert_eq!(result, "BEGIN\n RETURN NEXT;\nEND;", "\nGot:\n{result}"); +} + +// Regression: a FOR loop over a query dropped the query text after IN. +#[test] +fn for_over_query_keeps_query() { + let body = "BEGIN\n FOR r IN SELECT id FROM t LOOP\n RETURN NEXT r;\n END LOOP;\nEND"; + let result = format_plpgsql(body, Style::Aweber).unwrap(); + assert!( + result.contains("FOR r IN SELECT id FROM t LOOP"), + "query dropped from FOR clause:\n{result}" + ); + assert!(result.contains("RETURN NEXT r;"), "\nGot:\n{result}"); +}