@@ -18,6 +18,82 @@ use tree_sitter::{Language, Node, Parser, Range, Tree};
1818
1919pub mod simple;
2020
21+ /// Trait abstracting over tree-sitter and yeast node types for extraction.
22+ trait AstNode {
23+ fn kind ( & self ) -> & str ;
24+ fn is_named ( & self ) -> bool ;
25+ fn is_missing ( & self ) -> bool ;
26+ fn is_error ( & self ) -> bool ;
27+ fn is_extra ( & self ) -> bool ;
28+ fn start_position ( & self ) -> tree_sitter:: Point ;
29+ fn end_position ( & self ) -> tree_sitter:: Point ;
30+ fn byte_range ( & self ) -> std:: ops:: Range < usize > ;
31+ fn end_byte ( & self ) -> usize {
32+ self . byte_range ( ) . end
33+ }
34+ /// For yeast nodes with synthetic content, return it. Otherwise None.
35+ fn opt_string_content ( & self ) -> Option < String > {
36+ None
37+ }
38+ }
39+
40+ impl < ' a > AstNode for Node < ' a > {
41+ fn kind ( & self ) -> & str {
42+ Node :: kind ( self )
43+ }
44+ fn is_named ( & self ) -> bool {
45+ Node :: is_named ( self )
46+ }
47+ fn is_missing ( & self ) -> bool {
48+ Node :: is_missing ( self )
49+ }
50+ fn is_error ( & self ) -> bool {
51+ Node :: is_error ( self )
52+ }
53+ fn is_extra ( & self ) -> bool {
54+ Node :: is_extra ( self )
55+ }
56+ fn start_position ( & self ) -> tree_sitter:: Point {
57+ Node :: start_position ( self )
58+ }
59+ fn end_position ( & self ) -> tree_sitter:: Point {
60+ Node :: end_position ( self )
61+ }
62+ fn byte_range ( & self ) -> std:: ops:: Range < usize > {
63+ Node :: byte_range ( self )
64+ }
65+ }
66+
67+ impl AstNode for yeast:: Node {
68+ fn kind ( & self ) -> & str {
69+ yeast:: Node :: kind ( self )
70+ }
71+ fn is_named ( & self ) -> bool {
72+ yeast:: Node :: is_named ( self )
73+ }
74+ fn is_missing ( & self ) -> bool {
75+ yeast:: Node :: is_missing ( self )
76+ }
77+ fn is_error ( & self ) -> bool {
78+ yeast:: Node :: is_error ( self )
79+ }
80+ fn is_extra ( & self ) -> bool {
81+ yeast:: Node :: is_extra ( self )
82+ }
83+ fn start_position ( & self ) -> tree_sitter:: Point {
84+ yeast:: Node :: start_position ( self )
85+ }
86+ fn end_position ( & self ) -> tree_sitter:: Point {
87+ yeast:: Node :: end_position ( self )
88+ }
89+ fn byte_range ( & self ) -> std:: ops:: Range < usize > {
90+ yeast:: Node :: byte_range ( self )
91+ }
92+ fn opt_string_content ( & self ) -> Option < String > {
93+ yeast:: Node :: opt_string_content ( self )
94+ }
95+ }
96+
2197/// Sets the tracing level based on the environment variables
2298/// `RUST_LOG` and `CODEQL_VERBOSITY` (prioritized in that order),
2399/// falling back to `warn` if neither is set.
@@ -204,6 +280,10 @@ pub fn location_label(writer: &mut trap::Writer, location: trap::Location) -> tr
204280}
205281
206282/// Extracts the source file at `path`, which is assumed to be canonicalized.
283+ /// When `rules` is non-empty, the parsed tree is first transformed through
284+ /// yeast before TRAP extraction. If `output_schema` is provided, it is used
285+ /// by the yeast runner to resolve output-only node kinds and fields.
286+ #[ allow( clippy:: too_many_arguments) ]
207287pub fn extract (
208288 language : & Language ,
209289 language_prefix : & str ,
@@ -214,6 +294,8 @@ pub fn extract(
214294 path : & Path ,
215295 source : & [ u8 ] ,
216296 ranges : & [ Range ] ,
297+ rules : Vec < yeast:: Rule > ,
298+ output_schema : Option < yeast:: schema:: Schema > ,
217299) {
218300 let path_str = file_paths:: normalize_and_transform_path ( path, transformer) ;
219301 let span = tracing:: span!(
@@ -236,13 +318,24 @@ pub fn extract(
236318 source,
237319 diagnostics_writer,
238320 trap_writer,
239- // TODO: should we handle path strings that are not valid UTF8 better?
240321 & path_str,
241322 file_label,
242323 language_prefix,
243324 schema,
244325 ) ;
245- traverse ( & tree, & mut visitor) ;
326+
327+ if rules. is_empty ( ) {
328+ traverse ( & tree, & mut visitor) ;
329+ } else {
330+ let runner = match output_schema {
331+ Some ( schema) => yeast:: Runner :: with_schema ( language. clone ( ) , schema, rules) ,
332+ None => yeast:: Runner :: new ( language. clone ( ) , rules) ,
333+ } ;
334+ let ast = runner
335+ . run_from_tree ( & tree)
336+ . unwrap_or_else ( |e| panic ! ( "Desugaring failed for {path_str}: {e}" ) ) ;
337+ traverse_yeast ( & ast, & mut visitor) ;
338+ }
246339
247340 parser. reset ( ) ;
248341}
@@ -329,11 +422,11 @@ impl<'a> Visitor<'a> {
329422 ) ;
330423 }
331424
332- fn record_parse_error_for_node (
425+ fn record_parse_error_for_node < N : AstNode > (
333426 & mut self ,
334427 message : & str ,
335428 args : & [ diagnostics:: MessageArg ] ,
336- node : Node ,
429+ node : & N ,
337430 status_page : bool ,
338431 ) {
339432 let loc = location_for ( self , self . file_label , node) ;
@@ -357,7 +450,7 @@ impl<'a> Visitor<'a> {
357450 self . record_parse_error ( loc_label, & mesg) ;
358451 }
359452
360- fn enter_node ( & mut self , node : Node ) -> bool {
453+ fn enter_node < N : AstNode > ( & mut self , node : & N ) -> bool {
361454 if node. is_missing ( ) {
362455 self . record_parse_error_for_node (
363456 "A parse error occurred (expected {} symbol). Check the syntax of the file. If the file is invalid, correct the error or {} the file from analysis." ,
@@ -383,7 +476,7 @@ impl<'a> Visitor<'a> {
383476 true
384477 }
385478
386- fn leave_node ( & mut self , field_name : Option < & ' static str > , node : Node ) {
479+ fn leave_node < N : AstNode > ( & mut self , field_name : Option < & ' static str > , node : & N ) {
387480 if node. is_error ( ) || node. is_missing ( ) {
388481 return ;
389482 }
@@ -434,7 +527,7 @@ impl<'a> Visitor<'a> {
434527 fields,
435528 name : table_name,
436529 } => {
437- if let Some ( args) = self . complex_node ( & node, fields, & child_nodes, id) {
530+ if let Some ( args) = self . complex_node ( node, fields, & child_nodes, id) {
438531 self . trap_writer . add_tuple (
439532 & self . ast_node_location_table_name ,
440533 vec ! [ trap:: Arg :: Label ( id) , trap:: Arg :: Label ( loc_label) ] ,
@@ -495,9 +588,9 @@ impl<'a> Visitor<'a> {
495588 }
496589 }
497590
498- fn complex_node (
591+ fn complex_node < N : AstNode > (
499592 & mut self ,
500- node : & Node ,
593+ node : & N ,
501594 fields : & [ Field ] ,
502595 child_nodes : & [ ChildNode ] ,
503596 parent_id : trap:: Label ,
@@ -529,7 +622,7 @@ impl<'a> Visitor<'a> {
529622 diagnostics:: MessageArg :: Code ( & format ! ( "{:?}" , child_node. type_name) ) ,
530623 diagnostics:: MessageArg :: Code ( & format ! ( "{:?}" , field. type_info) ) ,
531624 ] ,
532- * node,
625+ node,
533626 false ,
534627 ) ;
535628 }
@@ -541,7 +634,7 @@ impl<'a> Visitor<'a> {
541634 diagnostics:: MessageArg :: Code ( child_node. field_name . unwrap_or ( "child" ) ) ,
542635 diagnostics:: MessageArg :: Code ( & format ! ( "{:?}" , child_node. type_name) ) ,
543636 ] ,
544- * node,
637+ node,
545638 false ,
546639 ) ;
547640 }
@@ -566,7 +659,7 @@ impl<'a> Visitor<'a> {
566659 node. kind( ) ,
567660 column_name
568661 ) ;
569- self . record_parse_error_for_node ( & error_message, & [ ] , * node, false ) ;
662+ self . record_parse_error_for_node ( & error_message, & [ ] , node, false ) ;
570663 }
571664 }
572665 Storage :: Table {
@@ -582,7 +675,7 @@ impl<'a> Visitor<'a> {
582675 diagnostics:: MessageArg :: Code ( node. kind ( ) ) ,
583676 diagnostics:: MessageArg :: Code ( table_name) ,
584677 ] ,
585- * node,
678+ node,
586679 false ,
587680 ) ;
588681 break ;
@@ -639,15 +732,21 @@ impl<'a> Visitor<'a> {
639732}
640733
641734// Emit a slice of a source file as an Arg.
642- fn sliced_source_arg ( source : & [ u8 ] , n : Node ) -> trap:: Arg {
643- let range = n. byte_range ( ) ;
644- trap:: Arg :: String ( String :: from_utf8_lossy ( & source[ range. start ..range. end ] ) . into_owned ( ) )
735+ fn sliced_source_arg < N : AstNode > ( source : & [ u8 ] , n : & N ) -> trap:: Arg {
736+ trap:: Arg :: String ( n. opt_string_content ( ) . unwrap_or_else ( || {
737+ let range = n. byte_range ( ) ;
738+ String :: from_utf8_lossy ( & source[ range. start ..range. end ] ) . into_owned ( )
739+ } ) )
645740}
646741
647742// Emit a pair of `TrapEntry`s for the provided node, appropriately calibrated.
648743// The first is the location and label definition, and the second is the
649744// 'Located' entry.
650- fn location_for ( visitor : & mut Visitor , file_label : trap:: Label , n : Node ) -> trap:: Location {
745+ fn location_for < N : AstNode > (
746+ visitor : & mut Visitor ,
747+ file_label : trap:: Label ,
748+ n : & N ,
749+ ) -> trap:: Location {
651750 // Tree-sitter row, column values are 0-based while CodeQL starts
652751 // counting at 1. In addition Tree-sitter's row and column for the
653752 // end position are exclusive while CodeQL's end positions are inclusive.
@@ -715,6 +814,28 @@ fn location_for(visitor: &mut Visitor, file_label: trap::Label, n: Node) -> trap
715814
716815fn traverse ( tree : & Tree , visitor : & mut Visitor ) {
717816 let cursor = & mut tree. walk ( ) ;
817+ visitor. enter_node ( & cursor. node ( ) ) ;
818+ let mut recurse = true ;
819+ loop {
820+ if recurse && cursor. goto_first_child ( ) {
821+ recurse = visitor. enter_node ( & cursor. node ( ) ) ;
822+ } else {
823+ visitor. leave_node ( cursor. field_name ( ) , & cursor. node ( ) ) ;
824+
825+ if cursor. goto_next_sibling ( ) {
826+ recurse = visitor. enter_node ( & cursor. node ( ) ) ;
827+ } else if cursor. goto_parent ( ) {
828+ recurse = false ;
829+ } else {
830+ break ;
831+ }
832+ }
833+ }
834+ }
835+
836+ fn traverse_yeast ( tree : & yeast:: Ast , visitor : & mut Visitor ) {
837+ use yeast:: Cursor ;
838+ let mut cursor = tree. walk ( ) ;
718839 visitor. enter_node ( cursor. node ( ) ) ;
719840 let mut recurse = true ;
720841 loop {
0 commit comments