@@ -18,6 +18,45 @@ use tree_sitter::{Language, Node, Parser, Range, Tree};
1818
1919pub mod simple;
2020
21+ /// Trait abstracting over tree-sitter and yeast node types for extraction.
22+ trait AstNode {
23+ fn kind ( & self ) -> & str ;
24+ fn is_named ( & self ) -> bool ;
25+ fn is_missing ( & self ) -> bool ;
26+ fn is_error ( & self ) -> bool ;
27+ fn is_extra ( & self ) -> bool ;
28+ fn start_position ( & self ) -> tree_sitter:: Point ;
29+ fn end_position ( & self ) -> tree_sitter:: Point ;
30+ fn byte_range ( & self ) -> std:: ops:: Range < usize > ;
31+ fn start_byte ( & self ) -> usize { self . byte_range ( ) . start }
32+ fn end_byte ( & self ) -> usize { self . byte_range ( ) . end }
33+ /// For yeast nodes with synthetic content, return it. Otherwise None.
34+ fn opt_string_content ( & self ) -> Option < String > { None }
35+ }
36+
37+ impl < ' a > AstNode for Node < ' a > {
38+ fn kind ( & self ) -> & str { Node :: kind ( self ) }
39+ fn is_named ( & self ) -> bool { Node :: is_named ( self ) }
40+ fn is_missing ( & self ) -> bool { Node :: is_missing ( self ) }
41+ fn is_error ( & self ) -> bool { Node :: is_error ( self ) }
42+ fn is_extra ( & self ) -> bool { Node :: is_extra ( self ) }
43+ fn start_position ( & self ) -> tree_sitter:: Point { Node :: start_position ( self ) }
44+ fn end_position ( & self ) -> tree_sitter:: Point { Node :: end_position ( self ) }
45+ fn byte_range ( & self ) -> std:: ops:: Range < usize > { Node :: byte_range ( self ) }
46+ }
47+
48+ impl AstNode for yeast:: Node {
49+ fn kind ( & self ) -> & str { yeast:: Node :: kind ( self ) }
50+ fn is_named ( & self ) -> bool { yeast:: Node :: is_named ( self ) }
51+ fn is_missing ( & self ) -> bool { yeast:: Node :: is_missing ( self ) }
52+ fn is_error ( & self ) -> bool { yeast:: Node :: is_error ( self ) }
53+ fn is_extra ( & self ) -> bool { yeast:: Node :: is_extra ( self ) }
54+ fn start_position ( & self ) -> tree_sitter:: Point { yeast:: Node :: start_position ( self ) }
55+ fn end_position ( & self ) -> tree_sitter:: Point { yeast:: Node :: end_position ( self ) }
56+ fn byte_range ( & self ) -> std:: ops:: Range < usize > { yeast:: Node :: byte_range ( self ) }
57+ fn opt_string_content ( & self ) -> Option < String > { yeast:: Node :: opt_string_content ( self ) }
58+ }
59+
2160/// Sets the tracing level based on the environment variables
2261/// `RUST_LOG` and `CODEQL_VERBOSITY` (prioritized in that order),
2362/// falling back to `warn` if neither is set.
@@ -204,6 +243,9 @@ pub fn location_label(writer: &mut trap::Writer, location: trap::Location) -> tr
204243}
205244
206245/// Extracts the source file at `path`, which is assumed to be canonicalized.
246+ /// When `rules` is non-empty, the parsed tree is first transformed through
247+ /// yeast before TRAP extraction. If `output_schema` is provided, it is used
248+ /// by the yeast runner to resolve output-only node kinds and fields.
207249pub fn extract (
208250 language : & Language ,
209251 language_prefix : & str ,
@@ -214,6 +256,8 @@ pub fn extract(
214256 path : & Path ,
215257 source : & [ u8 ] ,
216258 ranges : & [ Range ] ,
259+ rules : Vec < yeast:: Rule > ,
260+ output_schema : Option < yeast:: schema:: Schema > ,
217261) {
218262 let path_str = file_paths:: normalize_and_transform_path ( path, transformer) ;
219263 let span = tracing:: span!(
@@ -236,13 +280,24 @@ pub fn extract(
236280 source,
237281 diagnostics_writer,
238282 trap_writer,
239- // TODO: should we handle path strings that are not valid UTF8 better?
240283 & path_str,
241284 file_label,
242285 language_prefix,
243286 schema,
244287 ) ;
245- traverse ( & tree, & mut visitor) ;
288+
289+ if rules. is_empty ( ) {
290+ traverse ( & tree, & mut visitor) ;
291+ } else {
292+ let runner = match output_schema {
293+ Some ( schema) => yeast:: Runner :: with_schema ( language. clone ( ) , schema, rules) ,
294+ None => yeast:: Runner :: new ( language. clone ( ) , rules) ,
295+ } ;
296+ let ast = runner
297+ . run_from_tree ( & tree)
298+ . unwrap_or_else ( |e| panic ! ( "Desugaring failed for {path_str}: {e}" ) ) ;
299+ traverse_yeast ( & ast, & mut visitor) ;
300+ }
246301
247302 parser. reset ( ) ;
248303}
@@ -329,11 +384,11 @@ impl<'a> Visitor<'a> {
329384 ) ;
330385 }
331386
332- fn record_parse_error_for_node (
387+ fn record_parse_error_for_node < N : AstNode > (
333388 & mut self ,
334389 message : & str ,
335390 args : & [ diagnostics:: MessageArg ] ,
336- node : Node ,
391+ node : & N ,
337392 status_page : bool ,
338393 ) {
339394 let loc = location_for ( self , self . file_label , node) ;
@@ -357,7 +412,7 @@ impl<'a> Visitor<'a> {
357412 self . record_parse_error ( loc_label, & mesg) ;
358413 }
359414
360- fn enter_node ( & mut self , node : Node ) -> bool {
415+ fn enter_node < N : AstNode > ( & mut self , node : & N ) -> bool {
361416 if node. is_missing ( ) {
362417 self . record_parse_error_for_node (
363418 "A parse error occurred (expected {} symbol). Check the syntax of the file. If the file is invalid, correct the error or {} the file from analysis." ,
@@ -383,7 +438,7 @@ impl<'a> Visitor<'a> {
383438 true
384439 }
385440
386- fn leave_node ( & mut self , field_name : Option < & ' static str > , node : Node ) {
441+ fn leave_node < N : AstNode > ( & mut self , field_name : Option < & ' static str > , node : & N ) {
387442 if node. is_error ( ) || node. is_missing ( ) {
388443 return ;
389444 }
@@ -434,7 +489,7 @@ impl<'a> Visitor<'a> {
434489 fields,
435490 name : table_name,
436491 } => {
437- if let Some ( args) = self . complex_node ( & node, fields, & child_nodes, id) {
492+ if let Some ( args) = self . complex_node ( node, fields, & child_nodes, id) {
438493 self . trap_writer . add_tuple (
439494 & self . ast_node_location_table_name ,
440495 vec ! [ trap:: Arg :: Label ( id) , trap:: Arg :: Label ( loc_label) ] ,
@@ -495,9 +550,9 @@ impl<'a> Visitor<'a> {
495550 }
496551 }
497552
498- fn complex_node (
553+ fn complex_node < N : AstNode > (
499554 & mut self ,
500- node : & Node ,
555+ node : & N ,
501556 fields : & [ Field ] ,
502557 child_nodes : & [ ChildNode ] ,
503558 parent_id : trap:: Label ,
@@ -529,7 +584,7 @@ impl<'a> Visitor<'a> {
529584 diagnostics:: MessageArg :: Code ( & format ! ( "{:?}" , child_node. type_name) ) ,
530585 diagnostics:: MessageArg :: Code ( & format ! ( "{:?}" , field. type_info) ) ,
531586 ] ,
532- * node,
587+ node,
533588 false ,
534589 ) ;
535590 }
@@ -541,7 +596,7 @@ impl<'a> Visitor<'a> {
541596 diagnostics:: MessageArg :: Code ( child_node. field_name . unwrap_or ( "child" ) ) ,
542597 diagnostics:: MessageArg :: Code ( & format ! ( "{:?}" , child_node. type_name) ) ,
543598 ] ,
544- * node,
599+ node,
545600 false ,
546601 ) ;
547602 }
@@ -566,7 +621,7 @@ impl<'a> Visitor<'a> {
566621 node. kind( ) ,
567622 column_name
568623 ) ;
569- self . record_parse_error_for_node ( & error_message, & [ ] , * node, false ) ;
624+ self . record_parse_error_for_node ( & error_message, & [ ] , node, false ) ;
570625 }
571626 }
572627 Storage :: Table {
@@ -582,7 +637,7 @@ impl<'a> Visitor<'a> {
582637 diagnostics:: MessageArg :: Code ( node. kind ( ) ) ,
583638 diagnostics:: MessageArg :: Code ( table_name) ,
584639 ] ,
585- * node,
640+ node,
586641 false ,
587642 ) ;
588643 break ;
@@ -639,15 +694,17 @@ impl<'a> Visitor<'a> {
639694}
640695
641696// Emit a slice of a source file as an Arg.
642- fn sliced_source_arg ( source : & [ u8 ] , n : Node ) -> trap:: Arg {
643- let range = n. byte_range ( ) ;
644- trap:: Arg :: String ( String :: from_utf8_lossy ( & source[ range. start ..range. end ] ) . into_owned ( ) )
697+ fn sliced_source_arg < N : AstNode > ( source : & [ u8 ] , n : & N ) -> trap:: Arg {
698+ trap:: Arg :: String ( n. opt_string_content ( ) . unwrap_or_else ( || {
699+ let range = n. byte_range ( ) ;
700+ String :: from_utf8_lossy ( & source[ range. start ..range. end ] ) . into_owned ( )
701+ } ) )
645702}
646703
647704// Emit a pair of `TrapEntry`s for the provided node, appropriately calibrated.
648705// The first is the location and label definition, and the second is the
649706// 'Located' entry.
650- fn location_for ( visitor : & mut Visitor , file_label : trap:: Label , n : Node ) -> trap:: Location {
707+ fn location_for < N : AstNode > ( visitor : & mut Visitor , file_label : trap:: Label , n : & N ) -> trap:: Location {
651708 // Tree-sitter row, column values are 0-based while CodeQL starts
652709 // counting at 1. In addition Tree-sitter's row and column for the
653710 // end position are exclusive while CodeQL's end positions are inclusive.
@@ -715,6 +772,28 @@ fn location_for(visitor: &mut Visitor, file_label: trap::Label, n: Node) -> trap
715772
716773fn traverse ( tree : & Tree , visitor : & mut Visitor ) {
717774 let cursor = & mut tree. walk ( ) ;
775+ visitor. enter_node ( & cursor. node ( ) ) ;
776+ let mut recurse = true ;
777+ loop {
778+ if recurse && cursor. goto_first_child ( ) {
779+ recurse = visitor. enter_node ( & cursor. node ( ) ) ;
780+ } else {
781+ visitor. leave_node ( cursor. field_name ( ) , & cursor. node ( ) ) ;
782+
783+ if cursor. goto_next_sibling ( ) {
784+ recurse = visitor. enter_node ( & cursor. node ( ) ) ;
785+ } else if cursor. goto_parent ( ) {
786+ recurse = false ;
787+ } else {
788+ break ;
789+ }
790+ }
791+ }
792+ }
793+
794+ fn traverse_yeast ( tree : & yeast:: Ast , visitor : & mut Visitor ) {
795+ use yeast:: Cursor ;
796+ let mut cursor = tree. walk ( ) ;
718797 visitor. enter_node ( cursor. node ( ) ) ;
719798 let mut recurse = true ;
720799 loop {
0 commit comments