diff --git a/Cargo.lock b/Cargo.lock index 94f5c8b..d434907 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -221,9 +221,9 @@ checksum = "009994f150cc0cd50ff54917d5bc8bffe8cad10ca10d81c34da2ec421ae61782" [[package]] name = "tree-sitter-postgres" -version = "1.2.1" +version = "1.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f162cad52646c3f61a49a539486841d67f23fcb6ca5c01958403fc28b985e551" +checksum = "81b6b5beaa4ef890a0ccf39eee2644752cef73669b336b8723899c08c382ed0c" dependencies = [ "cc", "tree-sitter-language", diff --git a/Cargo.toml b/Cargo.toml index 6d6ef92..516f26d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -11,7 +11,7 @@ categories = ["text-processing", "development-tools"] [dependencies] tree-sitter = "0.26" -tree-sitter-postgres = "1.2" +tree-sitter-postgres = "1.2.2" [dev-dependencies] pretty_assertions = "1" diff --git a/examples/dump_plpgsql.rs b/examples/dump_plpgsql.rs new file mode 100644 index 0000000..a7dcd8b --- /dev/null +++ b/examples/dump_plpgsql.rs @@ -0,0 +1,28 @@ +use tree_sitter::Parser; +use tree_sitter_postgres::LANGUAGE_PLPGSQL; + +fn print_tree(node: tree_sitter::Node, source: &str, indent: usize) { + let kind = node.kind(); + let text = &source[node.byte_range()]; + let short = if text.len() > 60 { &text[..60] } else { text }; + let short = short.replace('\n', "\\n"); + let pad = " ".repeat(indent); + if node.is_named() { + println!("{pad}{kind}: {short:?}"); + } else { + println!("{pad}[{kind}]: {short:?}"); + } + let mut cursor = node.walk(); + for child in node.children(&mut cursor) { + print_tree(child, source, indent + 1); + } +} + +fn main() { + let path = std::env::args().nth(1).expect("usage: dump_plpgsql "); + let sql = std::fs::read_to_string(&path).unwrap(); + let mut parser = Parser::new(); + parser.set_language(&LANGUAGE_PLPGSQL.into()).unwrap(); + let tree = parser.parse(sql.trim(), None).unwrap(); + print_tree(tree.root_node(), sql.trim(), 0); +} diff --git a/examples/format_plpgsql_test.rs b/examples/format_plpgsql_test.rs new file mode 100644 index 0000000..9953d60 --- /dev/null +++ b/examples/format_plpgsql_test.rs @@ -0,0 +1,16 @@ +use libpgfmt::{format_plpgsql, style::Style}; +fn main() { + let sql = std::fs::read_to_string(std::env::args().nth(1).unwrap()).unwrap(); + let style: Style = std::env::args() + .nth(2) + .unwrap_or("aweber".to_string()) + .parse() + .unwrap(); + match format_plpgsql(sql.trim(), style) { + Ok(f) => println!("{f}"), + Err(e) => { + eprintln!("Error: {e}"); + std::process::exit(1); + } + } +} diff --git a/src/formatter/plpgsql.rs b/src/formatter/plpgsql.rs index 31fd8d2..42db297 100644 --- a/src/formatter/plpgsql.rs +++ b/src/formatter/plpgsql.rs @@ -169,8 +169,9 @@ impl<'a> Formatter<'a> { self.kw("THEN") )); i += 3; // skip cond and THEN + } else { + i += 1; // closing IF (END IF) } - // Closing IF (END IF). } "sql_expression" => { i += 1; // handled with IF/ELSIF diff --git a/tests/plpgsql_test.rs b/tests/plpgsql_test.rs new file mode 100644 index 0000000..9a50fdb --- /dev/null +++ b/tests/plpgsql_test.rs @@ -0,0 +1,33 @@ +use libpgfmt::{format_plpgsql, style::Style}; + +// Regression: a plain `IF ... THEN ... END IF` previously hung forever because +// format_stmt_if never advanced past the closing `kw_if` of `END IF`. +#[test] +fn if_then_end_if_terminates() { + let body = "BEGIN\n IF x = 1\n THEN\n v := y;\n END IF;\nEND"; + let result = format_plpgsql(body, Style::Aweber).unwrap(); + let expected = "\ +BEGIN + IF x = 1 THEN + v := y; + END IF; +END;"; + assert_eq!(result, expected, "\nGot:\n{result}"); +} + +#[test] +fn if_elsif_else_terminates() { + let body = "BEGIN\n IF a THEN\n v := 1;\n ELSIF b THEN\n v := 2;\n ELSE\n v := 3;\n END IF;\nEND"; + let result = format_plpgsql(body, Style::Aweber).unwrap(); + let expected = "\ +BEGIN + IF a THEN + v := 1; + ELSIF b THEN + v := 2; + ELSE + v := 3; + END IF; +END;"; + assert_eq!(result, expected, "\nGot:\n{result}"); +}