@@ -18,6 +18,85 @@ use tree_sitter::{Language, Node, Parser, Range, Tree};
1818
1919pub mod simple;
2020
21+ /// Trait abstracting over tree-sitter and yeast node types for extraction.
22+ trait AstNode {
23+ fn kind ( & self ) -> & str ;
24+ fn is_named ( & self ) -> bool ;
25+ fn is_missing ( & self ) -> bool ;
26+ fn is_error ( & self ) -> bool ;
27+ fn is_extra ( & self ) -> bool ;
28+ fn start_position ( & self ) -> tree_sitter:: Point ;
29+ fn end_position ( & self ) -> tree_sitter:: Point ;
30+ fn byte_range ( & self ) -> std:: ops:: Range < usize > ;
31+ fn start_byte ( & self ) -> usize {
32+ self . byte_range ( ) . start
33+ }
34+ fn end_byte ( & self ) -> usize {
35+ self . byte_range ( ) . end
36+ }
37+ /// For yeast nodes with synthetic content, return it. Otherwise None.
38+ fn opt_string_content ( & self ) -> Option < String > {
39+ None
40+ }
41+ }
42+
43+ impl < ' a > AstNode for Node < ' a > {
44+ fn kind ( & self ) -> & str {
45+ Node :: kind ( self )
46+ }
47+ fn is_named ( & self ) -> bool {
48+ Node :: is_named ( self )
49+ }
50+ fn is_missing ( & self ) -> bool {
51+ Node :: is_missing ( self )
52+ }
53+ fn is_error ( & self ) -> bool {
54+ Node :: is_error ( self )
55+ }
56+ fn is_extra ( & self ) -> bool {
57+ Node :: is_extra ( self )
58+ }
59+ fn start_position ( & self ) -> tree_sitter:: Point {
60+ Node :: start_position ( self )
61+ }
62+ fn end_position ( & self ) -> tree_sitter:: Point {
63+ Node :: end_position ( self )
64+ }
65+ fn byte_range ( & self ) -> std:: ops:: Range < usize > {
66+ Node :: byte_range ( self )
67+ }
68+ }
69+
70+ impl AstNode for yeast:: Node {
71+ fn kind ( & self ) -> & str {
72+ yeast:: Node :: kind ( self )
73+ }
74+ fn is_named ( & self ) -> bool {
75+ yeast:: Node :: is_named ( self )
76+ }
77+ fn is_missing ( & self ) -> bool {
78+ yeast:: Node :: is_missing ( self )
79+ }
80+ fn is_error ( & self ) -> bool {
81+ yeast:: Node :: is_error ( self )
82+ }
83+ fn is_extra ( & self ) -> bool {
84+ yeast:: Node :: is_extra ( self )
85+ }
86+ fn start_position ( & self ) -> tree_sitter:: Point {
87+ yeast:: Node :: start_position ( self )
88+ }
89+ fn end_position ( & self ) -> tree_sitter:: Point {
90+ yeast:: Node :: end_position ( self )
91+ }
92+ fn byte_range ( & self ) -> std:: ops:: Range < usize > {
93+ yeast:: Node :: byte_range ( self )
94+ }
95+ fn opt_string_content ( & self ) -> Option < String > {
96+ yeast:: Node :: opt_string_content ( self )
97+ }
98+ }
99+
21100/// Sets the tracing level based on the environment variables
22101/// `RUST_LOG` and `CODEQL_VERBOSITY` (prioritized in that order),
23102/// falling back to `warn` if neither is set.
@@ -204,6 +283,9 @@ pub fn location_label(writer: &mut trap::Writer, location: trap::Location) -> tr
204283}
205284
206285/// Extracts the source file at `path`, which is assumed to be canonicalized.
286+ /// When `rules` is non-empty, the parsed tree is first transformed through
287+ /// yeast before TRAP extraction. If `output_schema` is provided, it is used
288+ /// by the yeast runner to resolve output-only node kinds and fields.
207289pub fn extract (
208290 language : & Language ,
209291 language_prefix : & str ,
@@ -214,6 +296,8 @@ pub fn extract(
214296 path : & Path ,
215297 source : & [ u8 ] ,
216298 ranges : & [ Range ] ,
299+ rules : Vec < yeast:: Rule > ,
300+ output_schema : Option < yeast:: schema:: Schema > ,
217301) {
218302 let path_str = file_paths:: normalize_and_transform_path ( path, transformer) ;
219303 let span = tracing:: span!(
@@ -236,13 +320,24 @@ pub fn extract(
236320 source,
237321 diagnostics_writer,
238322 trap_writer,
239- // TODO: should we handle path strings that are not valid UTF8 better?
240323 & path_str,
241324 file_label,
242325 language_prefix,
243326 schema,
244327 ) ;
245- traverse ( & tree, & mut visitor) ;
328+
329+ if rules. is_empty ( ) {
330+ traverse ( & tree, & mut visitor) ;
331+ } else {
332+ let runner = match output_schema {
333+ Some ( schema) => yeast:: Runner :: with_schema ( language. clone ( ) , schema, rules) ,
334+ None => yeast:: Runner :: new ( language. clone ( ) , rules) ,
335+ } ;
336+ let ast = runner
337+ . run_from_tree ( & tree)
338+ . unwrap_or_else ( |e| panic ! ( "Desugaring failed for {path_str}: {e}" ) ) ;
339+ traverse_yeast ( & ast, & mut visitor) ;
340+ }
246341
247342 parser. reset ( ) ;
248343}
@@ -329,11 +424,11 @@ impl<'a> Visitor<'a> {
329424 ) ;
330425 }
331426
332- fn record_parse_error_for_node (
427+ fn record_parse_error_for_node < N : AstNode > (
333428 & mut self ,
334429 message : & str ,
335430 args : & [ diagnostics:: MessageArg ] ,
336- node : Node ,
431+ node : & N ,
337432 status_page : bool ,
338433 ) {
339434 let loc = location_for ( self , self . file_label , node) ;
@@ -357,7 +452,7 @@ impl<'a> Visitor<'a> {
357452 self . record_parse_error ( loc_label, & mesg) ;
358453 }
359454
360- fn enter_node ( & mut self , node : Node ) -> bool {
455+ fn enter_node < N : AstNode > ( & mut self , node : & N ) -> bool {
361456 if node. is_missing ( ) {
362457 self . record_parse_error_for_node (
363458 "A parse error occurred (expected {} symbol). Check the syntax of the file. If the file is invalid, correct the error or {} the file from analysis." ,
@@ -383,7 +478,7 @@ impl<'a> Visitor<'a> {
383478 true
384479 }
385480
386- fn leave_node ( & mut self , field_name : Option < & ' static str > , node : Node ) {
481+ fn leave_node < N : AstNode > ( & mut self , field_name : Option < & ' static str > , node : & N ) {
387482 if node. is_error ( ) || node. is_missing ( ) {
388483 return ;
389484 }
@@ -434,7 +529,7 @@ impl<'a> Visitor<'a> {
434529 fields,
435530 name : table_name,
436531 } => {
437- if let Some ( args) = self . complex_node ( & node, fields, & child_nodes, id) {
532+ if let Some ( args) = self . complex_node ( node, fields, & child_nodes, id) {
438533 self . trap_writer . add_tuple (
439534 & self . ast_node_location_table_name ,
440535 vec ! [ trap:: Arg :: Label ( id) , trap:: Arg :: Label ( loc_label) ] ,
@@ -495,9 +590,9 @@ impl<'a> Visitor<'a> {
495590 }
496591 }
497592
498- fn complex_node (
593+ fn complex_node < N : AstNode > (
499594 & mut self ,
500- node : & Node ,
595+ node : & N ,
501596 fields : & [ Field ] ,
502597 child_nodes : & [ ChildNode ] ,
503598 parent_id : trap:: Label ,
@@ -529,7 +624,7 @@ impl<'a> Visitor<'a> {
529624 diagnostics:: MessageArg :: Code ( & format ! ( "{:?}" , child_node. type_name) ) ,
530625 diagnostics:: MessageArg :: Code ( & format ! ( "{:?}" , field. type_info) ) ,
531626 ] ,
532- * node,
627+ node,
533628 false ,
534629 ) ;
535630 }
@@ -541,7 +636,7 @@ impl<'a> Visitor<'a> {
541636 diagnostics:: MessageArg :: Code ( child_node. field_name . unwrap_or ( "child" ) ) ,
542637 diagnostics:: MessageArg :: Code ( & format ! ( "{:?}" , child_node. type_name) ) ,
543638 ] ,
544- * node,
639+ node,
545640 false ,
546641 ) ;
547642 }
@@ -566,7 +661,7 @@ impl<'a> Visitor<'a> {
566661 node. kind( ) ,
567662 column_name
568663 ) ;
569- self . record_parse_error_for_node ( & error_message, & [ ] , * node, false ) ;
664+ self . record_parse_error_for_node ( & error_message, & [ ] , node, false ) ;
570665 }
571666 }
572667 Storage :: Table {
@@ -582,7 +677,7 @@ impl<'a> Visitor<'a> {
582677 diagnostics:: MessageArg :: Code ( node. kind ( ) ) ,
583678 diagnostics:: MessageArg :: Code ( table_name) ,
584679 ] ,
585- * node,
680+ node,
586681 false ,
587682 ) ;
588683 break ;
@@ -639,15 +734,21 @@ impl<'a> Visitor<'a> {
639734}
640735
641736// Emit a slice of a source file as an Arg.
642- fn sliced_source_arg ( source : & [ u8 ] , n : Node ) -> trap:: Arg {
643- let range = n. byte_range ( ) ;
644- trap:: Arg :: String ( String :: from_utf8_lossy ( & source[ range. start ..range. end ] ) . into_owned ( ) )
737+ fn sliced_source_arg < N : AstNode > ( source : & [ u8 ] , n : & N ) -> trap:: Arg {
738+ trap:: Arg :: String ( n. opt_string_content ( ) . unwrap_or_else ( || {
739+ let range = n. byte_range ( ) ;
740+ String :: from_utf8_lossy ( & source[ range. start ..range. end ] ) . into_owned ( )
741+ } ) )
645742}
646743
647744// Emit a pair of `TrapEntry`s for the provided node, appropriately calibrated.
648745// The first is the location and label definition, and the second is the
649746// 'Located' entry.
650- fn location_for ( visitor : & mut Visitor , file_label : trap:: Label , n : Node ) -> trap:: Location {
747+ fn location_for < N : AstNode > (
748+ visitor : & mut Visitor ,
749+ file_label : trap:: Label ,
750+ n : & N ,
751+ ) -> trap:: Location {
651752 // Tree-sitter row, column values are 0-based while CodeQL starts
652753 // counting at 1. In addition Tree-sitter's row and column for the
653754 // end position are exclusive while CodeQL's end positions are inclusive.
@@ -715,6 +816,28 @@ fn location_for(visitor: &mut Visitor, file_label: trap::Label, n: Node) -> trap
715816
716817fn traverse ( tree : & Tree , visitor : & mut Visitor ) {
717818 let cursor = & mut tree. walk ( ) ;
819+ visitor. enter_node ( & cursor. node ( ) ) ;
820+ let mut recurse = true ;
821+ loop {
822+ if recurse && cursor. goto_first_child ( ) {
823+ recurse = visitor. enter_node ( & cursor. node ( ) ) ;
824+ } else {
825+ visitor. leave_node ( cursor. field_name ( ) , & cursor. node ( ) ) ;
826+
827+ if cursor. goto_next_sibling ( ) {
828+ recurse = visitor. enter_node ( & cursor. node ( ) ) ;
829+ } else if cursor. goto_parent ( ) {
830+ recurse = false ;
831+ } else {
832+ break ;
833+ }
834+ }
835+ }
836+ }
837+
838+ fn traverse_yeast ( tree : & yeast:: Ast , visitor : & mut Visitor ) {
839+ use yeast:: Cursor ;
840+ let mut cursor = tree. walk ( ) ;
718841 visitor. enter_node ( cursor. node ( ) ) ;
719842 let mut recurse = true ;
720843 loop {
0 commit comments