@@ -18,6 +18,82 @@ use tree_sitter::{Language, Node, Parser, Range, Tree};
1818
1919pub mod simple;
2020
21+ /// Trait abstracting over tree-sitter and yeast node types for extraction.
22+ trait AstNode {
23+ fn kind ( & self ) -> & str ;
24+ fn is_named ( & self ) -> bool ;
25+ fn is_missing ( & self ) -> bool ;
26+ fn is_error ( & self ) -> bool ;
27+ fn is_extra ( & self ) -> bool ;
28+ fn start_position ( & self ) -> tree_sitter:: Point ;
29+ fn end_position ( & self ) -> tree_sitter:: Point ;
30+ fn byte_range ( & self ) -> std:: ops:: Range < usize > ;
31+ fn end_byte ( & self ) -> usize {
32+ self . byte_range ( ) . end
33+ }
34+ /// For yeast nodes with synthetic content, return it. Otherwise None.
35+ fn opt_string_content ( & self ) -> Option < String > {
36+ None
37+ }
38+ }
39+
40+ impl < ' a > AstNode for Node < ' a > {
41+ fn kind ( & self ) -> & str {
42+ Node :: kind ( self )
43+ }
44+ fn is_named ( & self ) -> bool {
45+ Node :: is_named ( self )
46+ }
47+ fn is_missing ( & self ) -> bool {
48+ Node :: is_missing ( self )
49+ }
50+ fn is_error ( & self ) -> bool {
51+ Node :: is_error ( self )
52+ }
53+ fn is_extra ( & self ) -> bool {
54+ Node :: is_extra ( self )
55+ }
56+ fn start_position ( & self ) -> tree_sitter:: Point {
57+ Node :: start_position ( self )
58+ }
59+ fn end_position ( & self ) -> tree_sitter:: Point {
60+ Node :: end_position ( self )
61+ }
62+ fn byte_range ( & self ) -> std:: ops:: Range < usize > {
63+ Node :: byte_range ( self )
64+ }
65+ }
66+
67+ impl AstNode for yeast:: Node {
68+ fn kind ( & self ) -> & str {
69+ yeast:: Node :: kind ( self )
70+ }
71+ fn is_named ( & self ) -> bool {
72+ yeast:: Node :: is_named ( self )
73+ }
74+ fn is_missing ( & self ) -> bool {
75+ yeast:: Node :: is_missing ( self )
76+ }
77+ fn is_error ( & self ) -> bool {
78+ yeast:: Node :: is_error ( self )
79+ }
80+ fn is_extra ( & self ) -> bool {
81+ yeast:: Node :: is_extra ( self )
82+ }
83+ fn start_position ( & self ) -> tree_sitter:: Point {
84+ yeast:: Node :: start_position ( self )
85+ }
86+ fn end_position ( & self ) -> tree_sitter:: Point {
87+ yeast:: Node :: end_position ( self )
88+ }
89+ fn byte_range ( & self ) -> std:: ops:: Range < usize > {
90+ yeast:: Node :: byte_range ( self )
91+ }
92+ fn opt_string_content ( & self ) -> Option < String > {
93+ yeast:: Node :: opt_string_content ( self )
94+ }
95+ }
96+
2197/// Sets the tracing level based on the environment variables
2298/// `RUST_LOG` and `CODEQL_VERBOSITY` (prioritized in that order),
2399/// falling back to `warn` if neither is set.
@@ -204,6 +280,10 @@ pub fn location_label(writer: &mut trap::Writer, location: trap::Location) -> tr
204280}
205281
206282/// Extracts the source file at `path`, which is assumed to be canonicalized.
283+ /// When `desugar` is `Some`, the parsed tree is first transformed through
284+ /// yeast before TRAP extraction, using the rules and (optional) output
285+ /// schema from the [`yeast::DesugaringConfig`].
286+ #[ allow( clippy:: too_many_arguments) ]
207287pub fn extract (
208288 language : & Language ,
209289 language_prefix : & str ,
@@ -214,6 +294,7 @@ pub fn extract(
214294 path : & Path ,
215295 source : & [ u8 ] ,
216296 ranges : & [ Range ] ,
297+ desugar : Option < & yeast:: DesugaringConfig > ,
217298) {
218299 let path_str = file_paths:: normalize_and_transform_path ( path, transformer) ;
219300 let span = tracing:: span!(
@@ -236,13 +317,22 @@ pub fn extract(
236317 source,
237318 diagnostics_writer,
238319 trap_writer,
239- // TODO: should we handle path strings that are not valid UTF8 better?
240320 & path_str,
241321 file_label,
242322 language_prefix,
243323 schema,
244324 ) ;
245- traverse ( & tree, & mut visitor) ;
325+
326+ if let Some ( config) = desugar {
327+ let runner = yeast:: Runner :: from_config ( language. clone ( ) , config)
328+ . unwrap_or_else ( |e| panic ! ( "Failed to build desugaring runner for {path_str}: {e}" ) ) ;
329+ let ast = runner
330+ . run_from_tree ( & tree)
331+ . unwrap_or_else ( |e| panic ! ( "Desugaring failed for {path_str}: {e}" ) ) ;
332+ traverse_yeast ( & ast, & mut visitor) ;
333+ } else {
334+ traverse ( & tree, & mut visitor) ;
335+ }
246336
247337 parser. reset ( ) ;
248338}
@@ -329,11 +419,11 @@ impl<'a> Visitor<'a> {
329419 ) ;
330420 }
331421
332- fn record_parse_error_for_node (
422+ fn record_parse_error_for_node < N : AstNode > (
333423 & mut self ,
334424 message : & str ,
335425 args : & [ diagnostics:: MessageArg ] ,
336- node : Node ,
426+ node : & N ,
337427 status_page : bool ,
338428 ) {
339429 let loc = location_for ( self , self . file_label , node) ;
@@ -357,7 +447,7 @@ impl<'a> Visitor<'a> {
357447 self . record_parse_error ( loc_label, & mesg) ;
358448 }
359449
360- fn enter_node ( & mut self , node : Node ) -> bool {
450+ fn enter_node < N : AstNode > ( & mut self , node : & N ) -> bool {
361451 if node. is_missing ( ) {
362452 self . record_parse_error_for_node (
363453 "A parse error occurred (expected {} symbol). Check the syntax of the file. If the file is invalid, correct the error or {} the file from analysis." ,
@@ -383,7 +473,7 @@ impl<'a> Visitor<'a> {
383473 true
384474 }
385475
386- fn leave_node ( & mut self , field_name : Option < & ' static str > , node : Node ) {
476+ fn leave_node < N : AstNode > ( & mut self , field_name : Option < & ' static str > , node : & N ) {
387477 if node. is_error ( ) || node. is_missing ( ) {
388478 return ;
389479 }
@@ -434,7 +524,7 @@ impl<'a> Visitor<'a> {
434524 fields,
435525 name : table_name,
436526 } => {
437- if let Some ( args) = self . complex_node ( & node, fields, & child_nodes, id) {
527+ if let Some ( args) = self . complex_node ( node, fields, & child_nodes, id) {
438528 self . trap_writer . add_tuple (
439529 & self . ast_node_location_table_name ,
440530 vec ! [ trap:: Arg :: Label ( id) , trap:: Arg :: Label ( loc_label) ] ,
@@ -495,9 +585,9 @@ impl<'a> Visitor<'a> {
495585 }
496586 }
497587
498- fn complex_node (
588+ fn complex_node < N : AstNode > (
499589 & mut self ,
500- node : & Node ,
590+ node : & N ,
501591 fields : & [ Field ] ,
502592 child_nodes : & [ ChildNode ] ,
503593 parent_id : trap:: Label ,
@@ -529,7 +619,7 @@ impl<'a> Visitor<'a> {
529619 diagnostics:: MessageArg :: Code ( & format ! ( "{:?}" , child_node. type_name) ) ,
530620 diagnostics:: MessageArg :: Code ( & format ! ( "{:?}" , field. type_info) ) ,
531621 ] ,
532- * node,
622+ node,
533623 false ,
534624 ) ;
535625 }
@@ -541,7 +631,7 @@ impl<'a> Visitor<'a> {
541631 diagnostics:: MessageArg :: Code ( child_node. field_name . unwrap_or ( "child" ) ) ,
542632 diagnostics:: MessageArg :: Code ( & format ! ( "{:?}" , child_node. type_name) ) ,
543633 ] ,
544- * node,
634+ node,
545635 false ,
546636 ) ;
547637 }
@@ -566,7 +656,7 @@ impl<'a> Visitor<'a> {
566656 node. kind( ) ,
567657 column_name
568658 ) ;
569- self . record_parse_error_for_node ( & error_message, & [ ] , * node, false ) ;
659+ self . record_parse_error_for_node ( & error_message, & [ ] , node, false ) ;
570660 }
571661 }
572662 Storage :: Table {
@@ -582,7 +672,7 @@ impl<'a> Visitor<'a> {
582672 diagnostics:: MessageArg :: Code ( node. kind ( ) ) ,
583673 diagnostics:: MessageArg :: Code ( table_name) ,
584674 ] ,
585- * node,
675+ node,
586676 false ,
587677 ) ;
588678 break ;
@@ -639,15 +729,21 @@ impl<'a> Visitor<'a> {
639729}
640730
641731// Emit a slice of a source file as an Arg.
642- fn sliced_source_arg ( source : & [ u8 ] , n : Node ) -> trap:: Arg {
643- let range = n. byte_range ( ) ;
644- trap:: Arg :: String ( String :: from_utf8_lossy ( & source[ range. start ..range. end ] ) . into_owned ( ) )
732+ fn sliced_source_arg < N : AstNode > ( source : & [ u8 ] , n : & N ) -> trap:: Arg {
733+ trap:: Arg :: String ( n. opt_string_content ( ) . unwrap_or_else ( || {
734+ let range = n. byte_range ( ) ;
735+ String :: from_utf8_lossy ( & source[ range. start ..range. end ] ) . into_owned ( )
736+ } ) )
645737}
646738
647739// Emit a pair of `TrapEntry`s for the provided node, appropriately calibrated.
648740// The first is the location and label definition, and the second is the
649741// 'Located' entry.
650- fn location_for ( visitor : & mut Visitor , file_label : trap:: Label , n : Node ) -> trap:: Location {
742+ fn location_for < N : AstNode > (
743+ visitor : & mut Visitor ,
744+ file_label : trap:: Label ,
745+ n : & N ,
746+ ) -> trap:: Location {
651747 // Tree-sitter row, column values are 0-based while CodeQL starts
652748 // counting at 1. In addition Tree-sitter's row and column for the
653749 // end position are exclusive while CodeQL's end positions are inclusive.
@@ -715,6 +811,28 @@ fn location_for(visitor: &mut Visitor, file_label: trap::Label, n: Node) -> trap
715811
716812fn traverse ( tree : & Tree , visitor : & mut Visitor ) {
717813 let cursor = & mut tree. walk ( ) ;
814+ visitor. enter_node ( & cursor. node ( ) ) ;
815+ let mut recurse = true ;
816+ loop {
817+ if recurse && cursor. goto_first_child ( ) {
818+ recurse = visitor. enter_node ( & cursor. node ( ) ) ;
819+ } else {
820+ visitor. leave_node ( cursor. field_name ( ) , & cursor. node ( ) ) ;
821+
822+ if cursor. goto_next_sibling ( ) {
823+ recurse = visitor. enter_node ( & cursor. node ( ) ) ;
824+ } else if cursor. goto_parent ( ) {
825+ recurse = false ;
826+ } else {
827+ break ;
828+ }
829+ }
830+ }
831+ }
832+
833+ fn traverse_yeast ( tree : & yeast:: Ast , visitor : & mut Visitor ) {
834+ use yeast:: Cursor ;
835+ let mut cursor = tree. walk ( ) ;
718836 visitor. enter_node ( cursor. node ( ) ) ;
719837 let mut recurse = true ;
720838 loop {
0 commit comments