@@ -18,6 +18,82 @@ use tree_sitter::{Language, Node, Parser, Range, Tree};
1818
1919pub mod simple;
2020
21+ /// Trait abstracting over tree-sitter and yeast node types for extraction.
22+ trait AstNode {
23+ fn kind ( & self ) -> & str ;
24+ fn is_named ( & self ) -> bool ;
25+ fn is_missing ( & self ) -> bool ;
26+ fn is_error ( & self ) -> bool ;
27+ fn is_extra ( & self ) -> bool ;
28+ fn start_position ( & self ) -> tree_sitter:: Point ;
29+ fn end_position ( & self ) -> tree_sitter:: Point ;
30+ fn byte_range ( & self ) -> std:: ops:: Range < usize > ;
31+ fn end_byte ( & self ) -> usize {
32+ self . byte_range ( ) . end
33+ }
34+ /// For yeast nodes with synthetic content, return it. Otherwise None.
35+ fn opt_string_content ( & self ) -> Option < String > {
36+ None
37+ }
38+ }
39+
40+ impl < ' a > AstNode for Node < ' a > {
41+ fn kind ( & self ) -> & str {
42+ Node :: kind ( self )
43+ }
44+ fn is_named ( & self ) -> bool {
45+ Node :: is_named ( self )
46+ }
47+ fn is_missing ( & self ) -> bool {
48+ Node :: is_missing ( self )
49+ }
50+ fn is_error ( & self ) -> bool {
51+ Node :: is_error ( self )
52+ }
53+ fn is_extra ( & self ) -> bool {
54+ Node :: is_extra ( self )
55+ }
56+ fn start_position ( & self ) -> tree_sitter:: Point {
57+ Node :: start_position ( self )
58+ }
59+ fn end_position ( & self ) -> tree_sitter:: Point {
60+ Node :: end_position ( self )
61+ }
62+ fn byte_range ( & self ) -> std:: ops:: Range < usize > {
63+ Node :: byte_range ( self )
64+ }
65+ }
66+
67+ impl AstNode for yeast:: Node {
68+ fn kind ( & self ) -> & str {
69+ yeast:: Node :: kind ( self )
70+ }
71+ fn is_named ( & self ) -> bool {
72+ yeast:: Node :: is_named ( self )
73+ }
74+ fn is_missing ( & self ) -> bool {
75+ yeast:: Node :: is_missing ( self )
76+ }
77+ fn is_error ( & self ) -> bool {
78+ yeast:: Node :: is_error ( self )
79+ }
80+ fn is_extra ( & self ) -> bool {
81+ yeast:: Node :: is_extra ( self )
82+ }
83+ fn start_position ( & self ) -> tree_sitter:: Point {
84+ yeast:: Node :: start_position ( self )
85+ }
86+ fn end_position ( & self ) -> tree_sitter:: Point {
87+ yeast:: Node :: end_position ( self )
88+ }
89+ fn byte_range ( & self ) -> std:: ops:: Range < usize > {
90+ yeast:: Node :: byte_range ( self )
91+ }
92+ fn opt_string_content ( & self ) -> Option < String > {
93+ yeast:: Node :: opt_string_content ( self )
94+ }
95+ }
96+
2197/// Sets the tracing level based on the environment variables
2298/// `RUST_LOG` and `CODEQL_VERBOSITY` (prioritized in that order),
2399/// falling back to `warn` if neither is set.
@@ -204,6 +280,11 @@ pub fn location_label(writer: &mut trap::Writer, location: trap::Location) -> tr
204280}
205281
206282/// Extracts the source file at `path`, which is assumed to be canonicalized.
283+ /// When `yeast_runner` is `Some`, the parsed tree is first transformed
284+ /// through the supplied yeast `Runner` before TRAP extraction. Building the
285+ /// `Runner` (which parses YAML and constructs the schema) is the caller's
286+ /// responsibility, allowing it to be done once and shared across files.
287+ #[ allow( clippy:: too_many_arguments) ]
207288pub fn extract (
208289 language : & Language ,
209290 language_prefix : & str ,
@@ -214,6 +295,7 @@ pub fn extract(
214295 path : & Path ,
215296 source : & [ u8 ] ,
216297 ranges : & [ Range ] ,
298+ yeast_runner : Option < & yeast:: Runner < ' _ > > ,
217299) {
218300 let path_str = file_paths:: normalize_and_transform_path ( path, transformer) ;
219301 let span = tracing:: span!(
@@ -236,13 +318,20 @@ pub fn extract(
236318 source,
237319 diagnostics_writer,
238320 trap_writer,
239- // TODO: should we handle path strings that are not valid UTF8 better?
240321 & path_str,
241322 file_label,
242323 language_prefix,
243324 schema,
244325 ) ;
245- traverse ( & tree, & mut visitor) ;
326+
327+ if let Some ( yeast_runner) = yeast_runner {
328+ let ast = yeast_runner
329+ . run_from_tree ( & tree)
330+ . unwrap_or_else ( |e| panic ! ( "Desugaring failed for {path_str}: {e}" ) ) ;
331+ traverse_yeast ( & ast, & mut visitor) ;
332+ } else {
333+ traverse ( & tree, & mut visitor) ;
334+ }
246335
247336 parser. reset ( ) ;
248337}
@@ -329,11 +418,11 @@ impl<'a> Visitor<'a> {
329418 ) ;
330419 }
331420
332- fn record_parse_error_for_node (
421+ fn record_parse_error_for_node < N : AstNode > (
333422 & mut self ,
334423 message : & str ,
335424 args : & [ diagnostics:: MessageArg ] ,
336- node : Node ,
425+ node : & N ,
337426 status_page : bool ,
338427 ) {
339428 let loc = location_for ( self , self . file_label , node) ;
@@ -357,7 +446,7 @@ impl<'a> Visitor<'a> {
357446 self . record_parse_error ( loc_label, & mesg) ;
358447 }
359448
360- fn enter_node ( & mut self , node : Node ) -> bool {
449+ fn enter_node < N : AstNode > ( & mut self , node : & N ) -> bool {
361450 if node. is_missing ( ) {
362451 self . record_parse_error_for_node (
363452 "A parse error occurred (expected {} symbol). Check the syntax of the file. If the file is invalid, correct the error or {} the file from analysis." ,
@@ -383,7 +472,7 @@ impl<'a> Visitor<'a> {
383472 true
384473 }
385474
386- fn leave_node ( & mut self , field_name : Option < & ' static str > , node : Node ) {
475+ fn leave_node < N : AstNode > ( & mut self , field_name : Option < & ' static str > , node : & N ) {
387476 if node. is_error ( ) || node. is_missing ( ) {
388477 return ;
389478 }
@@ -434,7 +523,7 @@ impl<'a> Visitor<'a> {
434523 fields,
435524 name : table_name,
436525 } => {
437- if let Some ( args) = self . complex_node ( & node, fields, & child_nodes, id) {
526+ if let Some ( args) = self . complex_node ( node, fields, & child_nodes, id) {
438527 self . trap_writer . add_tuple (
439528 & self . ast_node_location_table_name ,
440529 vec ! [ trap:: Arg :: Label ( id) , trap:: Arg :: Label ( loc_label) ] ,
@@ -495,9 +584,9 @@ impl<'a> Visitor<'a> {
495584 }
496585 }
497586
498- fn complex_node (
587+ fn complex_node < N : AstNode > (
499588 & mut self ,
500- node : & Node ,
589+ node : & N ,
501590 fields : & [ Field ] ,
502591 child_nodes : & [ ChildNode ] ,
503592 parent_id : trap:: Label ,
@@ -529,7 +618,7 @@ impl<'a> Visitor<'a> {
529618 diagnostics:: MessageArg :: Code ( & format ! ( "{:?}" , child_node. type_name) ) ,
530619 diagnostics:: MessageArg :: Code ( & format ! ( "{:?}" , field. type_info) ) ,
531620 ] ,
532- * node,
621+ node,
533622 false ,
534623 ) ;
535624 }
@@ -541,7 +630,7 @@ impl<'a> Visitor<'a> {
541630 diagnostics:: MessageArg :: Code ( child_node. field_name . unwrap_or ( "child" ) ) ,
542631 diagnostics:: MessageArg :: Code ( & format ! ( "{:?}" , child_node. type_name) ) ,
543632 ] ,
544- * node,
633+ node,
545634 false ,
546635 ) ;
547636 }
@@ -566,7 +655,7 @@ impl<'a> Visitor<'a> {
566655 node. kind( ) ,
567656 column_name
568657 ) ;
569- self . record_parse_error_for_node ( & error_message, & [ ] , * node, false ) ;
658+ self . record_parse_error_for_node ( & error_message, & [ ] , node, false ) ;
570659 }
571660 }
572661 Storage :: Table {
@@ -582,7 +671,7 @@ impl<'a> Visitor<'a> {
582671 diagnostics:: MessageArg :: Code ( node. kind ( ) ) ,
583672 diagnostics:: MessageArg :: Code ( table_name) ,
584673 ] ,
585- * node,
674+ node,
586675 false ,
587676 ) ;
588677 break ;
@@ -639,15 +728,21 @@ impl<'a> Visitor<'a> {
639728}
640729
641730// Emit a slice of a source file as an Arg.
642- fn sliced_source_arg ( source : & [ u8 ] , n : Node ) -> trap:: Arg {
643- let range = n. byte_range ( ) ;
644- trap:: Arg :: String ( String :: from_utf8_lossy ( & source[ range. start ..range. end ] ) . into_owned ( ) )
731+ fn sliced_source_arg < N : AstNode > ( source : & [ u8 ] , n : & N ) -> trap:: Arg {
732+ trap:: Arg :: String ( n. opt_string_content ( ) . unwrap_or_else ( || {
733+ let range = n. byte_range ( ) ;
734+ String :: from_utf8_lossy ( & source[ range. start ..range. end ] ) . into_owned ( )
735+ } ) )
645736}
646737
647738// Emit a pair of `TrapEntry`s for the provided node, appropriately calibrated.
648739// The first is the location and label definition, and the second is the
649740// 'Located' entry.
650- fn location_for ( visitor : & mut Visitor , file_label : trap:: Label , n : Node ) -> trap:: Location {
741+ fn location_for < N : AstNode > (
742+ visitor : & mut Visitor ,
743+ file_label : trap:: Label ,
744+ n : & N ,
745+ ) -> trap:: Location {
651746 // Tree-sitter row, column values are 0-based while CodeQL starts
652747 // counting at 1. In addition Tree-sitter's row and column for the
653748 // end position are exclusive while CodeQL's end positions are inclusive.
@@ -715,6 +810,28 @@ fn location_for(visitor: &mut Visitor, file_label: trap::Label, n: Node) -> trap
715810
716811fn traverse ( tree : & Tree , visitor : & mut Visitor ) {
717812 let cursor = & mut tree. walk ( ) ;
813+ visitor. enter_node ( & cursor. node ( ) ) ;
814+ let mut recurse = true ;
815+ loop {
816+ if recurse && cursor. goto_first_child ( ) {
817+ recurse = visitor. enter_node ( & cursor. node ( ) ) ;
818+ } else {
819+ visitor. leave_node ( cursor. field_name ( ) , & cursor. node ( ) ) ;
820+
821+ if cursor. goto_next_sibling ( ) {
822+ recurse = visitor. enter_node ( & cursor. node ( ) ) ;
823+ } else if cursor. goto_parent ( ) {
824+ recurse = false ;
825+ } else {
826+ break ;
827+ }
828+ }
829+ }
830+ }
831+
832+ fn traverse_yeast ( tree : & yeast:: Ast , visitor : & mut Visitor ) {
833+ use yeast:: Cursor ;
834+ let mut cursor = tree. walk ( ) ;
718835 visitor. enter_node ( cursor. node ( ) ) ;
719836 let mut recurse = true ;
720837 loop {
0 commit comments