Skip to content

Commit 55d374d

Browse files
tausbnCopilot
andcommitted
yeast: Integrate yeast with shared tree-sitter extractor
extract() gains a rules parameter. When empty, uses tree-sitter native traversal (no behavior change). When non-empty, runs yeast desugaring and extracts via traverse_yeast. Adds AstNode trait abstracting over tree_sitter::Node and yeast::Node, with minimal changes to existing Visitor methods (Node -> &N in 6 signatures). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
1 parent 5eb6cfd commit 55d374d

6 files changed

Lines changed: 112 additions & 18 deletions

File tree

ruby/extractor/src/extractor.rs

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -123,6 +123,8 @@ pub fn run(options: Options) -> std::io::Result<()> {
123123
&path,
124124
&source,
125125
&[],
126+
vec![],
127+
None,
126128
);
127129

128130
let (ranges, line_breaks) = scan_erb(
@@ -211,6 +213,8 @@ pub fn run(options: Options) -> std::io::Result<()> {
211213
&path,
212214
&source,
213215
&code_ranges,
216+
vec![],
217+
None,
214218
);
215219
std::fs::create_dir_all(src_archive_file.parent().unwrap())?;
216220
if needs_conversion {

shared/tree-sitter-extractor/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ serde_json = "1.0"
2020
chrono = { version = "0.4.42", features = ["serde"] }
2121
num_cpus = "1.17.0"
2222
zstd = "0.13.3"
23+
yeast = { path = "../yeast" }
2324

2425
[dev-dependencies]
2526
tree-sitter-ql = "0.23.1"

shared/tree-sitter-extractor/src/extractor/mod.rs

Lines changed: 96 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,45 @@ use tree_sitter::{Language, Node, Parser, Range, Tree};
1818

1919
pub mod simple;
2020

21+
/// Trait abstracting over tree-sitter and yeast node types for extraction.
22+
trait AstNode {
23+
fn kind(&self) -> &str;
24+
fn is_named(&self) -> bool;
25+
fn is_missing(&self) -> bool;
26+
fn is_error(&self) -> bool;
27+
fn is_extra(&self) -> bool;
28+
fn start_position(&self) -> tree_sitter::Point;
29+
fn end_position(&self) -> tree_sitter::Point;
30+
fn byte_range(&self) -> std::ops::Range<usize>;
31+
fn start_byte(&self) -> usize { self.byte_range().start }
32+
fn end_byte(&self) -> usize { self.byte_range().end }
33+
/// For yeast nodes with synthetic content, return it. Otherwise None.
34+
fn opt_string_content(&self) -> Option<String> { None }
35+
}
36+
37+
impl<'a> AstNode for Node<'a> {
38+
fn kind(&self) -> &str { Node::kind(self) }
39+
fn is_named(&self) -> bool { Node::is_named(self) }
40+
fn is_missing(&self) -> bool { Node::is_missing(self) }
41+
fn is_error(&self) -> bool { Node::is_error(self) }
42+
fn is_extra(&self) -> bool { Node::is_extra(self) }
43+
fn start_position(&self) -> tree_sitter::Point { Node::start_position(self) }
44+
fn end_position(&self) -> tree_sitter::Point { Node::end_position(self) }
45+
fn byte_range(&self) -> std::ops::Range<usize> { Node::byte_range(self) }
46+
}
47+
48+
impl AstNode for yeast::Node {
49+
fn kind(&self) -> &str { yeast::Node::kind(self) }
50+
fn is_named(&self) -> bool { yeast::Node::is_named(self) }
51+
fn is_missing(&self) -> bool { yeast::Node::is_missing(self) }
52+
fn is_error(&self) -> bool { yeast::Node::is_error(self) }
53+
fn is_extra(&self) -> bool { yeast::Node::is_extra(self) }
54+
fn start_position(&self) -> tree_sitter::Point { yeast::Node::start_position(self) }
55+
fn end_position(&self) -> tree_sitter::Point { yeast::Node::end_position(self) }
56+
fn byte_range(&self) -> std::ops::Range<usize> { yeast::Node::byte_range(self) }
57+
fn opt_string_content(&self) -> Option<String> { yeast::Node::opt_string_content(self) }
58+
}
59+
2160
/// Sets the tracing level based on the environment variables
2261
/// `RUST_LOG` and `CODEQL_VERBOSITY` (prioritized in that order),
2362
/// falling back to `warn` if neither is set.
@@ -204,6 +243,9 @@ pub fn location_label(writer: &mut trap::Writer, location: trap::Location) -> tr
204243
}
205244

206245
/// Extracts the source file at `path`, which is assumed to be canonicalized.
246+
/// When `rules` is non-empty, the parsed tree is first transformed through
247+
/// yeast before TRAP extraction. If `output_schema` is provided, it is used
248+
/// by the yeast runner to resolve output-only node kinds and fields.
207249
pub fn extract(
208250
language: &Language,
209251
language_prefix: &str,
@@ -214,6 +256,8 @@ pub fn extract(
214256
path: &Path,
215257
source: &[u8],
216258
ranges: &[Range],
259+
rules: Vec<yeast::Rule>,
260+
output_schema: Option<yeast::schema::Schema>,
217261
) {
218262
let path_str = file_paths::normalize_and_transform_path(path, transformer);
219263
let span = tracing::span!(
@@ -236,13 +280,24 @@ pub fn extract(
236280
source,
237281
diagnostics_writer,
238282
trap_writer,
239-
// TODO: should we handle path strings that are not valid UTF8 better?
240283
&path_str,
241284
file_label,
242285
language_prefix,
243286
schema,
244287
);
245-
traverse(&tree, &mut visitor);
288+
289+
if rules.is_empty() {
290+
traverse(&tree, &mut visitor);
291+
} else {
292+
let runner = match output_schema {
293+
Some(schema) => yeast::Runner::with_schema(language.clone(), schema, rules),
294+
None => yeast::Runner::new(language.clone(), rules),
295+
};
296+
let ast = runner
297+
.run_from_tree(&tree)
298+
.unwrap_or_else(|e| panic!("Desugaring failed for {path_str}: {e}"));
299+
traverse_yeast(&ast, &mut visitor);
300+
}
246301

247302
parser.reset();
248303
}
@@ -329,11 +384,11 @@ impl<'a> Visitor<'a> {
329384
);
330385
}
331386

332-
fn record_parse_error_for_node(
387+
fn record_parse_error_for_node<N: AstNode>(
333388
&mut self,
334389
message: &str,
335390
args: &[diagnostics::MessageArg],
336-
node: Node,
391+
node: &N,
337392
status_page: bool,
338393
) {
339394
let loc = location_for(self, self.file_label, node);
@@ -357,7 +412,7 @@ impl<'a> Visitor<'a> {
357412
self.record_parse_error(loc_label, &mesg);
358413
}
359414

360-
fn enter_node(&mut self, node: Node) -> bool {
415+
fn enter_node<N: AstNode>(&mut self, node: &N) -> bool {
361416
if node.is_missing() {
362417
self.record_parse_error_for_node(
363418
"A parse error occurred (expected {} symbol). Check the syntax of the file. If the file is invalid, correct the error or {} the file from analysis.",
@@ -383,7 +438,7 @@ impl<'a> Visitor<'a> {
383438
true
384439
}
385440

386-
fn leave_node(&mut self, field_name: Option<&'static str>, node: Node) {
441+
fn leave_node<N: AstNode>(&mut self, field_name: Option<&'static str>, node: &N) {
387442
if node.is_error() || node.is_missing() {
388443
return;
389444
}
@@ -434,7 +489,7 @@ impl<'a> Visitor<'a> {
434489
fields,
435490
name: table_name,
436491
} => {
437-
if let Some(args) = self.complex_node(&node, fields, &child_nodes, id) {
492+
if let Some(args) = self.complex_node(node, fields, &child_nodes, id) {
438493
self.trap_writer.add_tuple(
439494
&self.ast_node_location_table_name,
440495
vec![trap::Arg::Label(id), trap::Arg::Label(loc_label)],
@@ -495,9 +550,9 @@ impl<'a> Visitor<'a> {
495550
}
496551
}
497552

498-
fn complex_node(
553+
fn complex_node<N: AstNode>(
499554
&mut self,
500-
node: &Node,
555+
node: &N,
501556
fields: &[Field],
502557
child_nodes: &[ChildNode],
503558
parent_id: trap::Label,
@@ -529,7 +584,7 @@ impl<'a> Visitor<'a> {
529584
diagnostics::MessageArg::Code(&format!("{:?}", child_node.type_name)),
530585
diagnostics::MessageArg::Code(&format!("{:?}", field.type_info)),
531586
],
532-
*node,
587+
node,
533588
false,
534589
);
535590
}
@@ -541,7 +596,7 @@ impl<'a> Visitor<'a> {
541596
diagnostics::MessageArg::Code(child_node.field_name.unwrap_or("child")),
542597
diagnostics::MessageArg::Code(&format!("{:?}", child_node.type_name)),
543598
],
544-
*node,
599+
node,
545600
false,
546601
);
547602
}
@@ -566,7 +621,7 @@ impl<'a> Visitor<'a> {
566621
node.kind(),
567622
column_name
568623
);
569-
self.record_parse_error_for_node(&error_message, &[], *node, false);
624+
self.record_parse_error_for_node(&error_message, &[], node, false);
570625
}
571626
}
572627
Storage::Table {
@@ -582,7 +637,7 @@ impl<'a> Visitor<'a> {
582637
diagnostics::MessageArg::Code(node.kind()),
583638
diagnostics::MessageArg::Code(table_name),
584639
],
585-
*node,
640+
node,
586641
false,
587642
);
588643
break;
@@ -639,15 +694,17 @@ impl<'a> Visitor<'a> {
639694
}
640695

641696
// Emit a slice of a source file as an Arg.
642-
fn sliced_source_arg(source: &[u8], n: Node) -> trap::Arg {
643-
let range = n.byte_range();
644-
trap::Arg::String(String::from_utf8_lossy(&source[range.start..range.end]).into_owned())
697+
fn sliced_source_arg<N: AstNode>(source: &[u8], n: &N) -> trap::Arg {
698+
trap::Arg::String(n.opt_string_content().unwrap_or_else(|| {
699+
let range = n.byte_range();
700+
String::from_utf8_lossy(&source[range.start..range.end]).into_owned()
701+
}))
645702
}
646703

647704
// Emit a pair of `TrapEntry`s for the provided node, appropriately calibrated.
648705
// The first is the location and label definition, and the second is the
649706
// 'Located' entry.
650-
fn location_for(visitor: &mut Visitor, file_label: trap::Label, n: Node) -> trap::Location {
707+
fn location_for<N: AstNode>(visitor: &mut Visitor, file_label: trap::Label, n: &N) -> trap::Location {
651708
// Tree-sitter row, column values are 0-based while CodeQL starts
652709
// counting at 1. In addition Tree-sitter's row and column for the
653710
// end position are exclusive while CodeQL's end positions are inclusive.
@@ -715,6 +772,28 @@ fn location_for(visitor: &mut Visitor, file_label: trap::Label, n: Node) -> trap
715772

716773
fn traverse(tree: &Tree, visitor: &mut Visitor) {
717774
let cursor = &mut tree.walk();
775+
visitor.enter_node(&cursor.node());
776+
let mut recurse = true;
777+
loop {
778+
if recurse && cursor.goto_first_child() {
779+
recurse = visitor.enter_node(&cursor.node());
780+
} else {
781+
visitor.leave_node(cursor.field_name(), &cursor.node());
782+
783+
if cursor.goto_next_sibling() {
784+
recurse = visitor.enter_node(&cursor.node());
785+
} else if cursor.goto_parent() {
786+
recurse = false;
787+
} else {
788+
break;
789+
}
790+
}
791+
}
792+
}
793+
794+
fn traverse_yeast(tree: &yeast::Ast, visitor: &mut Visitor) {
795+
use yeast::Cursor;
796+
let mut cursor = tree.walk();
718797
visitor.enter_node(cursor.node());
719798
let mut recurse = true;
720799
loop {

shared/tree-sitter-extractor/src/extractor/simple.rs

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,10 @@ pub struct LanguageSpec {
1212
pub prefix: &'static str,
1313
pub ts_language: tree_sitter::Language,
1414
pub node_types: &'static str,
15+
/// If set, the extractor validates TRAP output against these node types
16+
/// instead of `node_types`. Use when desugaring produces an AST that
17+
/// differs from the tree-sitter grammar.
18+
pub output_node_types: Option<&'static str>,
1519
pub file_globs: Vec<String>,
1620
}
1721

@@ -86,7 +90,8 @@ impl Extractor {
8690

8791
let mut schemas = vec![];
8892
for lang in &self.languages {
89-
let schema = node_types::read_node_types_str(lang.prefix, lang.node_types)?;
93+
let effective_node_types = lang.output_node_types.unwrap_or(lang.node_types);
94+
let schema = node_types::read_node_types_str(lang.prefix, effective_node_types)?;
9095
schemas.push(schema);
9196
}
9297

@@ -162,6 +167,8 @@ impl Extractor {
162167
&path,
163168
&source,
164169
&[],
170+
vec![],
171+
None,
165172
);
166173
std::fs::create_dir_all(src_archive_file.parent().unwrap())?;
167174
std::fs::copy(&path, &src_archive_file)?;

shared/tree-sitter-extractor/tests/integration_test.rs

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ fn simple_extractor() {
1313
prefix: "ql",
1414
ts_language: tree_sitter_ql::LANGUAGE.into(),
1515
node_types: tree_sitter_ql::NODE_TYPES,
16+
output_node_types: None,
1617
file_globs: vec!["*.qll".into()],
1718
};
1819

shared/tree-sitter-extractor/tests/multiple_languages.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,14 @@ fn multiple_language_extractor() {
1313
prefix: "ql",
1414
ts_language: tree_sitter_ql::LANGUAGE.into(),
1515
node_types: tree_sitter_ql::NODE_TYPES,
16+
output_node_types: None,
1617
file_globs: vec!["*.qll".into()],
1718
};
1819
let lang_json = simple::LanguageSpec {
1920
prefix: "json",
2021
ts_language: tree_sitter_json::LANGUAGE.into(),
2122
node_types: tree_sitter_json::NODE_TYPES,
23+
output_node_types: None,
2224
file_globs: vec!["*.json".into(), "*Jsonfile".into()],
2325
};
2426

0 commit comments

Comments
 (0)