@@ -790,9 +790,11 @@ impl DefaultPhysicalPlanner {
790790 {
791791 // For UPDATE, the assignments are encoded in the projection of input
792792 // We pass the filters and let the provider handle the projection
793- let filters = extract_dml_filters ( input, table_name) ?;
793+ let filters =
794+ extract_dml_filters_with_analysis ( input, table_name, & analysis) ?;
794795 // Extract assignments from the projection in input plan
795- let assignments = extract_update_assignments ( input, table_name) ?;
796+ let assignments =
797+ extract_update_assignments_with_analysis ( input, & analysis) ?;
796798 provider
797799 . table_provider
798800 . update ( session_state, assignments, filters)
@@ -2107,6 +2109,15 @@ fn extract_dml_filters(
21072109 target : & TableReference ,
21082110) -> Result < Vec < Expr > > {
21092111 let analysis = analyze_dml_input ( input, target) ?;
2112+ extract_dml_filters_with_analysis ( input, target, & analysis)
2113+ }
2114+
2115+ #[ allow( clippy:: allow_attributes, clippy:: mutable_key_type) ] // Expr contains Arc with interior mutability but is intentionally used as hash key
2116+ fn extract_dml_filters_with_analysis (
2117+ input : & Arc < LogicalPlan > ,
2118+ target : & TableReference ,
2119+ analysis : & DmlInputAnalysis ,
2120+ ) -> Result < Vec < Expr > > {
21102121 let mut filters = Vec :: new ( ) ;
21112122
21122123 input. apply ( |node| {
@@ -2204,30 +2215,28 @@ fn analyze_dml_input(
22042215 has_joined_input : false ,
22052216 target_refs : vec ! [ target. clone( ) ] ,
22062217 } ;
2207- analyze_target_branch ( input, & mut analysis) ? ;
2218+ analyze_target_branch ( input, & mut analysis) ;
22082219 Ok ( analysis)
22092220}
22102221
2211- fn analyze_target_branch (
2212- input : & Arc < LogicalPlan > ,
2213- analysis : & mut DmlInputAnalysis ,
2214- ) -> Result < ( ) > {
2215- match input. as_ref ( ) {
2216- LogicalPlan :: Projection ( projection) => {
2217- analyze_target_branch ( & projection. input , analysis)
2218- }
2219- LogicalPlan :: Filter ( filter) => analyze_target_branch ( & filter. input , analysis) ,
2220- LogicalPlan :: SubqueryAlias ( alias) => {
2221- analysis
2222- . target_refs
2223- . push ( TableReference :: bare ( alias. alias . to_string ( ) ) ) ;
2224- analyze_target_branch ( & alias. input , analysis)
2225- }
2226- LogicalPlan :: Join ( join) => {
2227- analysis. has_joined_input = true ;
2228- analyze_target_branch ( & join. left , analysis)
2222+ fn analyze_target_branch ( input : & Arc < LogicalPlan > , analysis : & mut DmlInputAnalysis ) {
2223+ let mut current = input;
2224+ loop {
2225+ match current. as_ref ( ) {
2226+ LogicalPlan :: Projection ( projection) => current = & projection. input ,
2227+ LogicalPlan :: Filter ( filter) => current = & filter. input ,
2228+ LogicalPlan :: SubqueryAlias ( alias) => {
2229+ analysis
2230+ . target_refs
2231+ . push ( TableReference :: bare ( alias. alias . to_string ( ) ) ) ;
2232+ current = & alias. input ;
2233+ }
2234+ LogicalPlan :: Join ( join) => {
2235+ analysis. has_joined_input = true ;
2236+ current = & join. left ;
2237+ }
2238+ _ => return ,
22292239 }
2230- _ => Ok ( ( ) ) ,
22312240 }
22322241}
22332242
@@ -2280,6 +2289,7 @@ fn strip_column_qualifiers(expr: Expr) -> Result<Expr> {
22802289/// from the projection. Column qualifiers are stripped only for single-table
22812290/// updates so provider-facing expressions remain resolvable.
22822291///
2292+ #[ cfg( test) ]
22832293fn extract_update_assignments (
22842294 input : & Arc < LogicalPlan > ,
22852295 target : & TableReference ,
@@ -2291,67 +2301,65 @@ fn extract_update_assignments(
22912301 //
22922302 // Each projected expression has an alias matching the column name
22932303 let analysis = analyze_dml_input ( input, target) ?;
2294- let mut assignments = Vec :: new ( ) ;
2295- let strip_qualifiers = !analysis . has_joined_input ;
2304+ extract_update_assignments_with_analysis ( input , & analysis )
2305+ }
22962306
2297- // Find the top-level projection
2307+ fn extract_update_assignments_with_analysis (
2308+ input : & Arc < LogicalPlan > ,
2309+ analysis : & DmlInputAnalysis ,
2310+ ) -> Result < Vec < ( String , Expr ) > > {
2311+ find_update_projection ( input) ?
2312+ . map ( |projection| {
2313+ append_update_assignments (
2314+ projection,
2315+ & analysis. target_refs ,
2316+ !analysis. has_joined_input ,
2317+ )
2318+ } )
2319+ . transpose ( )
2320+ . map ( |assignments| assignments. unwrap_or_default ( ) )
2321+ }
2322+
2323+ fn find_update_projection ( input : & Arc < LogicalPlan > ) -> Result < Option < & Projection > > {
22982324 if let LogicalPlan :: Projection ( projection) = input. as_ref ( ) {
2299- append_update_assignments (
2300- & mut assignments,
2301- projection,
2302- & analysis. target_refs ,
2303- strip_qualifiers,
2304- ) ?;
2305- } else {
2306- // Try to find projection deeper in the plan
2307- input. apply ( |node| {
2308- if let LogicalPlan :: Projection ( projection) = node {
2309- append_update_assignments (
2310- & mut assignments,
2311- projection,
2312- & analysis. target_refs ,
2313- strip_qualifiers,
2314- ) ?;
2315- return Ok ( TreeNodeRecursion :: Stop ) ;
2316- }
2317- Ok ( TreeNodeRecursion :: Continue )
2318- } ) ?;
2325+ return Ok ( Some ( projection) ) ;
23192326 }
23202327
2321- Ok ( assignments)
2328+ let mut found_projection = None ;
2329+ input. apply ( |node| {
2330+ if let LogicalPlan :: Projection ( projection) = node {
2331+ found_projection = Some ( projection) ;
2332+ return Ok ( TreeNodeRecursion :: Stop ) ;
2333+ }
2334+ Ok ( TreeNodeRecursion :: Continue )
2335+ } ) ?;
2336+ Ok ( found_projection)
23222337}
23232338
23242339fn append_update_assignments (
2325- assignments : & mut Vec < ( String , Expr ) > ,
23262340 projection : & Projection ,
23272341 target_refs : & [ TableReference ] ,
23282342 strip_qualifiers : bool ,
2329- ) -> Result < ( ) > {
2330- for expr in & projection. expr {
2331- if let Expr :: Alias ( alias) = expr {
2332- // The alias name is the column name being updated
2333- // The inner expression is the new value
2334- let column_name = alias. name . clone ( ) ;
2335- // Only include if it's not just a column reference to itself
2336- // (those are columns that aren't being updated)
2337- if !is_identity_assignment ( & alias. expr , & column_name, target_refs) {
2338- let assignment_expr = normalize_update_assignment_expr (
2339- ( * alias. expr ) . clone ( ) ,
2340- strip_qualifiers,
2341- ) ?;
2342- assignments. push ( ( column_name, assignment_expr) ) ;
2343+ ) -> Result < Vec < ( String , Expr ) > > {
2344+ projection
2345+ . expr
2346+ . iter ( )
2347+ . filter_map ( |expr| match expr {
2348+ Expr :: Alias ( alias)
2349+ if !is_identity_assignment ( & alias. expr , & alias. name , target_refs) =>
2350+ {
2351+ Some (
2352+ if strip_qualifiers {
2353+ strip_column_qualifiers ( ( * alias. expr ) . clone ( ) )
2354+ } else {
2355+ Ok ( ( * alias. expr ) . clone ( ) )
2356+ }
2357+ . map ( |assignment_expr| ( alias. name . clone ( ) , assignment_expr) ) ,
2358+ )
23432359 }
2344- }
2345- }
2346- Ok ( ( ) )
2347- }
2348-
2349- fn normalize_update_assignment_expr ( expr : Expr , strip_qualifiers : bool ) -> Result < Expr > {
2350- if strip_qualifiers {
2351- strip_column_qualifiers ( expr)
2352- } else {
2353- Ok ( expr)
2354- }
2360+ _ => None ,
2361+ } )
2362+ . collect ( )
23552363}
23562364
23572365/// Check if an assignment is an identity assignment (column = column)
@@ -3247,20 +3255,19 @@ mod tests {
32473255 Arc :: new ( MemTable :: try_new ( schema, vec ! [ vec![ ] ] ) . unwrap ( ) )
32483256 }
32493257
3250- async fn update_assignments_for_sql ( sql : & str ) -> Result < Vec < ( String , Expr ) > > {
3251- let ctx = SessionContext :: new ( ) ;
3252- let t1_schema = Arc :: new ( Schema :: new ( vec ! [
3253- Field :: new( "id" , DataType :: Int32 , false ) ,
3254- Field :: new( "a" , DataType :: Int32 , false ) ,
3255- Field :: new( "b" , DataType :: Int32 , false ) ,
3256- ] ) ) ;
3257- let t2_schema = Arc :: new ( Schema :: new ( vec ! [
3258+ fn test_update_schema ( ) -> SchemaRef {
3259+ Arc :: new ( Schema :: new ( vec ! [
32583260 Field :: new( "id" , DataType :: Int32 , false ) ,
32593261 Field :: new( "a" , DataType :: Int32 , false ) ,
32603262 Field :: new( "b" , DataType :: Int32 , false ) ,
3261- ] ) ) ;
3262- ctx. register_table ( "t1" , make_test_mem_table ( t1_schema) ) ?;
3263- ctx. register_table ( "t2" , make_test_mem_table ( t2_schema) ) ?;
3263+ ] ) )
3264+ }
3265+
3266+ async fn update_assignments_for_sql ( sql : & str ) -> Result < Vec < ( String , Expr ) > > {
3267+ let ctx = SessionContext :: new ( ) ;
3268+ let schema = test_update_schema ( ) ;
3269+ ctx. register_table ( "t1" , make_test_mem_table ( Arc :: clone ( & schema) ) ) ?;
3270+ ctx. register_table ( "t2" , make_test_mem_table ( schema) ) ?;
32643271
32653272 let df = ctx. sql ( sql) . await ?;
32663273 let ( table_name, input) = match df. logical_plan ( ) {
@@ -3276,6 +3283,16 @@ mod tests {
32763283 extract_update_assignments ( & input, & table_name)
32773284 }
32783285
3286+ async fn assert_update_assignment ( sql : & str , column : & str , expected : & str ) {
3287+ let assignments: HashMap < _ , _ > = update_assignments_for_sql ( sql)
3288+ . await
3289+ . unwrap ( )
3290+ . into_iter ( )
3291+ . map ( |( name, expr) | ( name, expr. to_string ( ) ) )
3292+ . collect ( ) ;
3293+ assert_eq ! ( assignments. get( column) . map( String :: as_str) , Some ( expected) ) ;
3294+ }
3295+
32793296 #[ tokio:: test]
32803297 async fn test_all_operators ( ) -> Result < ( ) > {
32813298 let logical_plan = test_csv_scan ( )
@@ -4861,81 +4878,53 @@ digraph {
48614878
48624879 #[ tokio:: test]
48634880 async fn test_extract_update_assignments_preserves_joined_source_qualifiers ( ) {
4864- let assignments = update_assignments_for_sql (
4881+ assert_update_assignment (
48654882 "UPDATE t1 SET b = t2.b FROM t2 WHERE t1.id = t2.id" ,
4883+ "b" ,
4884+ "t2.b" ,
48664885 )
4867- . await
4868- . unwrap ( ) ;
4869-
4870- let assignments: HashMap < _ , _ > = assignments. into_iter ( ) . collect ( ) ;
4871- assert_eq ! (
4872- assignments. get( "b" ) . map( ToString :: to_string) . as_deref( ) ,
4873- Some ( "t2.b" )
4874- ) ;
4886+ . await ;
48754887 }
48764888
48774889 #[ tokio:: test]
48784890 async fn test_extract_update_assignments_preserves_alias_qualified_sources ( ) {
4879- let assignments = update_assignments_for_sql (
4891+ assert_update_assignment (
48804892 "UPDATE t1 AS target SET b = source.b FROM t2 AS source \
48814893 WHERE target.id = source.id",
4894+ "b" ,
4895+ "source.b" ,
48824896 )
4883- . await
4884- . unwrap ( ) ;
4885-
4886- let assignments: HashMap < _ , _ > = assignments. into_iter ( ) . collect ( ) ;
4887- assert_eq ! (
4888- assignments. get( "b" ) . map( ToString :: to_string) . as_deref( ) ,
4889- Some ( "source.b" )
4890- ) ;
4897+ . await ;
48914898 }
48924899
48934900 #[ tokio:: test]
48944901 async fn test_extract_update_assignments_distinguishes_same_name_join_columns ( ) {
4895- let assignments = update_assignments_for_sql (
4902+ let assignments: HashMap < _ , _ > = update_assignments_for_sql (
48964903 "UPDATE t1 SET a = t2.a, b = t1.a FROM t2 WHERE t1.id = t2.id" ,
48974904 )
48984905 . await
4899- . unwrap ( ) ;
4900-
4901- let assignments: HashMap < _ , _ > = assignments. into_iter ( ) . collect ( ) ;
4902- assert_eq ! (
4903- assignments. get( "a" ) . map( ToString :: to_string) . as_deref( ) ,
4904- Some ( "t2.a" )
4905- ) ;
4906- assert_eq ! (
4907- assignments. get( "b" ) . map( ToString :: to_string) . as_deref( ) ,
4908- Some ( "t1.a" )
4909- ) ;
4906+ . unwrap ( )
4907+ . into_iter ( )
4908+ . map ( |( name, expr) | ( name, expr. to_string ( ) ) )
4909+ . collect ( ) ;
4910+ assert_eq ! ( assignments. get( "a" ) . map( String :: as_str) , Some ( "t2.a" ) ) ;
4911+ assert_eq ! ( assignments. get( "b" ) . map( String :: as_str) , Some ( "t1.a" ) ) ;
49104912 }
49114913
49124914 #[ tokio:: test]
49134915 async fn test_extract_update_assignments_preserves_self_join_source_alias ( ) {
4914- let assignments = update_assignments_for_sql (
4916+ assert_update_assignment (
49154917 "UPDATE t1 AS target SET a = src.a FROM t1 AS src \
49164918 WHERE target.id = src.id",
4919+ "a" ,
4920+ "src.a" ,
49174921 )
4918- . await
4919- . unwrap ( ) ;
4920-
4921- let assignments: HashMap < _ , _ > = assignments. into_iter ( ) . collect ( ) ;
4922- assert_eq ! (
4923- assignments. get( "a" ) . map( ToString :: to_string) . as_deref( ) ,
4924- Some ( "src.a" )
4925- ) ;
4922+ . await ;
49264923 }
49274924
49284925 #[ tokio:: test]
49294926 async fn test_extract_update_assignments_strips_single_table_target_qualifiers ( ) {
4930- let assignments =
4931- update_assignments_for_sql ( "UPDATE t1 SET b = t1.a WHERE t1.id = 1" )
4932- . await
4933- . unwrap ( ) ;
4934-
4935- let assignments: HashMap < _ , _ > = assignments. into_iter ( ) . collect ( ) ;
4936- assert_eq ! (
4937- assignments. get( "b" ) . map( ToString :: to_string) . as_deref( ) ,
4938- Some ( "a" )
4939- ) ;
4927+ assert_update_assignment ( "UPDATE t1 SET b = t1.a WHERE t1.id = 1" , "b" , "a" )
4928+ . await ;
49404929 }
49414930}
0 commit comments