Skip to content

Commit 2b8b462

Browse files
committed
Refactor planning and testing for updates and SQL
Simplify physical planner by reusing DmlInputAnalysis and centralizing projection lookup. Streamline assignment extraction with iterators. Reduce duplication in SQL planning setup by introducing a shared helper and improve context provider to reuse stored schemas for efficiency. Enhance test scaffolding with shared update schema and new assertion utilities.
1 parent 0ce5ffe commit 2b8b462

2 files changed

Lines changed: 144 additions & 144 deletions

File tree

datafusion/core/src/physical_planner.rs

Lines changed: 121 additions & 132 deletions
Original file line numberDiff line numberDiff line change
@@ -790,9 +790,11 @@ impl DefaultPhysicalPlanner {
790790
{
791791
// For UPDATE, the assignments are encoded in the projection of input
792792
// We pass the filters and let the provider handle the projection
793-
let filters = extract_dml_filters(input, table_name)?;
793+
let filters =
794+
extract_dml_filters_with_analysis(input, table_name, &analysis)?;
794795
// Extract assignments from the projection in input plan
795-
let assignments = extract_update_assignments(input, table_name)?;
796+
let assignments =
797+
extract_update_assignments_with_analysis(input, &analysis)?;
796798
provider
797799
.table_provider
798800
.update(session_state, assignments, filters)
@@ -2107,6 +2109,15 @@ fn extract_dml_filters(
21072109
target: &TableReference,
21082110
) -> Result<Vec<Expr>> {
21092111
let analysis = analyze_dml_input(input, target)?;
2112+
extract_dml_filters_with_analysis(input, target, &analysis)
2113+
}
2114+
2115+
#[allow(clippy::allow_attributes, clippy::mutable_key_type)] // Expr contains Arc with interior mutability but is intentionally used as hash key
2116+
fn extract_dml_filters_with_analysis(
2117+
input: &Arc<LogicalPlan>,
2118+
target: &TableReference,
2119+
analysis: &DmlInputAnalysis,
2120+
) -> Result<Vec<Expr>> {
21102121
let mut filters = Vec::new();
21112122

21122123
input.apply(|node| {
@@ -2204,30 +2215,28 @@ fn analyze_dml_input(
22042215
has_joined_input: false,
22052216
target_refs: vec![target.clone()],
22062217
};
2207-
analyze_target_branch(input, &mut analysis)?;
2218+
analyze_target_branch(input, &mut analysis);
22082219
Ok(analysis)
22092220
}
22102221

2211-
fn analyze_target_branch(
2212-
input: &Arc<LogicalPlan>,
2213-
analysis: &mut DmlInputAnalysis,
2214-
) -> Result<()> {
2215-
match input.as_ref() {
2216-
LogicalPlan::Projection(projection) => {
2217-
analyze_target_branch(&projection.input, analysis)
2218-
}
2219-
LogicalPlan::Filter(filter) => analyze_target_branch(&filter.input, analysis),
2220-
LogicalPlan::SubqueryAlias(alias) => {
2221-
analysis
2222-
.target_refs
2223-
.push(TableReference::bare(alias.alias.to_string()));
2224-
analyze_target_branch(&alias.input, analysis)
2225-
}
2226-
LogicalPlan::Join(join) => {
2227-
analysis.has_joined_input = true;
2228-
analyze_target_branch(&join.left, analysis)
2222+
fn analyze_target_branch(input: &Arc<LogicalPlan>, analysis: &mut DmlInputAnalysis) {
2223+
let mut current = input;
2224+
loop {
2225+
match current.as_ref() {
2226+
LogicalPlan::Projection(projection) => current = &projection.input,
2227+
LogicalPlan::Filter(filter) => current = &filter.input,
2228+
LogicalPlan::SubqueryAlias(alias) => {
2229+
analysis
2230+
.target_refs
2231+
.push(TableReference::bare(alias.alias.to_string()));
2232+
current = &alias.input;
2233+
}
2234+
LogicalPlan::Join(join) => {
2235+
analysis.has_joined_input = true;
2236+
current = &join.left;
2237+
}
2238+
_ => return,
22292239
}
2230-
_ => Ok(()),
22312240
}
22322241
}
22332242

@@ -2280,6 +2289,7 @@ fn strip_column_qualifiers(expr: Expr) -> Result<Expr> {
22802289
/// from the projection. Column qualifiers are stripped only for single-table
22812290
/// updates so provider-facing expressions remain resolvable.
22822291
///
2292+
#[cfg(test)]
22832293
fn extract_update_assignments(
22842294
input: &Arc<LogicalPlan>,
22852295
target: &TableReference,
@@ -2291,67 +2301,65 @@ fn extract_update_assignments(
22912301
//
22922302
// Each projected expression has an alias matching the column name
22932303
let analysis = analyze_dml_input(input, target)?;
2294-
let mut assignments = Vec::new();
2295-
let strip_qualifiers = !analysis.has_joined_input;
2304+
extract_update_assignments_with_analysis(input, &analysis)
2305+
}
22962306

2297-
// Find the top-level projection
2307+
fn extract_update_assignments_with_analysis(
2308+
input: &Arc<LogicalPlan>,
2309+
analysis: &DmlInputAnalysis,
2310+
) -> Result<Vec<(String, Expr)>> {
2311+
find_update_projection(input)?
2312+
.map(|projection| {
2313+
append_update_assignments(
2314+
projection,
2315+
&analysis.target_refs,
2316+
!analysis.has_joined_input,
2317+
)
2318+
})
2319+
.transpose()
2320+
.map(|assignments| assignments.unwrap_or_default())
2321+
}
2322+
2323+
fn find_update_projection(input: &Arc<LogicalPlan>) -> Result<Option<&Projection>> {
22982324
if let LogicalPlan::Projection(projection) = input.as_ref() {
2299-
append_update_assignments(
2300-
&mut assignments,
2301-
projection,
2302-
&analysis.target_refs,
2303-
strip_qualifiers,
2304-
)?;
2305-
} else {
2306-
// Try to find projection deeper in the plan
2307-
input.apply(|node| {
2308-
if let LogicalPlan::Projection(projection) = node {
2309-
append_update_assignments(
2310-
&mut assignments,
2311-
projection,
2312-
&analysis.target_refs,
2313-
strip_qualifiers,
2314-
)?;
2315-
return Ok(TreeNodeRecursion::Stop);
2316-
}
2317-
Ok(TreeNodeRecursion::Continue)
2318-
})?;
2325+
return Ok(Some(projection));
23192326
}
23202327

2321-
Ok(assignments)
2328+
let mut found_projection = None;
2329+
input.apply(|node| {
2330+
if let LogicalPlan::Projection(projection) = node {
2331+
found_projection = Some(projection);
2332+
return Ok(TreeNodeRecursion::Stop);
2333+
}
2334+
Ok(TreeNodeRecursion::Continue)
2335+
})?;
2336+
Ok(found_projection)
23222337
}
23232338

23242339
fn append_update_assignments(
2325-
assignments: &mut Vec<(String, Expr)>,
23262340
projection: &Projection,
23272341
target_refs: &[TableReference],
23282342
strip_qualifiers: bool,
2329-
) -> Result<()> {
2330-
for expr in &projection.expr {
2331-
if let Expr::Alias(alias) = expr {
2332-
// The alias name is the column name being updated
2333-
// The inner expression is the new value
2334-
let column_name = alias.name.clone();
2335-
// Only include if it's not just a column reference to itself
2336-
// (those are columns that aren't being updated)
2337-
if !is_identity_assignment(&alias.expr, &column_name, target_refs) {
2338-
let assignment_expr = normalize_update_assignment_expr(
2339-
(*alias.expr).clone(),
2340-
strip_qualifiers,
2341-
)?;
2342-
assignments.push((column_name, assignment_expr));
2343+
) -> Result<Vec<(String, Expr)>> {
2344+
projection
2345+
.expr
2346+
.iter()
2347+
.filter_map(|expr| match expr {
2348+
Expr::Alias(alias)
2349+
if !is_identity_assignment(&alias.expr, &alias.name, target_refs) =>
2350+
{
2351+
Some(
2352+
if strip_qualifiers {
2353+
strip_column_qualifiers((*alias.expr).clone())
2354+
} else {
2355+
Ok((*alias.expr).clone())
2356+
}
2357+
.map(|assignment_expr| (alias.name.clone(), assignment_expr)),
2358+
)
23432359
}
2344-
}
2345-
}
2346-
Ok(())
2347-
}
2348-
2349-
fn normalize_update_assignment_expr(expr: Expr, strip_qualifiers: bool) -> Result<Expr> {
2350-
if strip_qualifiers {
2351-
strip_column_qualifiers(expr)
2352-
} else {
2353-
Ok(expr)
2354-
}
2360+
_ => None,
2361+
})
2362+
.collect()
23552363
}
23562364

23572365
/// Check if an assignment is an identity assignment (column = column)
@@ -3247,20 +3255,19 @@ mod tests {
32473255
Arc::new(MemTable::try_new(schema, vec![vec![]]).unwrap())
32483256
}
32493257

3250-
async fn update_assignments_for_sql(sql: &str) -> Result<Vec<(String, Expr)>> {
3251-
let ctx = SessionContext::new();
3252-
let t1_schema = Arc::new(Schema::new(vec![
3253-
Field::new("id", DataType::Int32, false),
3254-
Field::new("a", DataType::Int32, false),
3255-
Field::new("b", DataType::Int32, false),
3256-
]));
3257-
let t2_schema = Arc::new(Schema::new(vec![
3258+
fn test_update_schema() -> SchemaRef {
3259+
Arc::new(Schema::new(vec![
32583260
Field::new("id", DataType::Int32, false),
32593261
Field::new("a", DataType::Int32, false),
32603262
Field::new("b", DataType::Int32, false),
3261-
]));
3262-
ctx.register_table("t1", make_test_mem_table(t1_schema))?;
3263-
ctx.register_table("t2", make_test_mem_table(t2_schema))?;
3263+
]))
3264+
}
3265+
3266+
async fn update_assignments_for_sql(sql: &str) -> Result<Vec<(String, Expr)>> {
3267+
let ctx = SessionContext::new();
3268+
let schema = test_update_schema();
3269+
ctx.register_table("t1", make_test_mem_table(Arc::clone(&schema)))?;
3270+
ctx.register_table("t2", make_test_mem_table(schema))?;
32643271

32653272
let df = ctx.sql(sql).await?;
32663273
let (table_name, input) = match df.logical_plan() {
@@ -3276,6 +3283,16 @@ mod tests {
32763283
extract_update_assignments(&input, &table_name)
32773284
}
32783285

3286+
async fn assert_update_assignment(sql: &str, column: &str, expected: &str) {
3287+
let assignments: HashMap<_, _> = update_assignments_for_sql(sql)
3288+
.await
3289+
.unwrap()
3290+
.into_iter()
3291+
.map(|(name, expr)| (name, expr.to_string()))
3292+
.collect();
3293+
assert_eq!(assignments.get(column).map(String::as_str), Some(expected));
3294+
}
3295+
32793296
#[tokio::test]
32803297
async fn test_all_operators() -> Result<()> {
32813298
let logical_plan = test_csv_scan()
@@ -4861,81 +4878,53 @@ digraph {
48614878

48624879
#[tokio::test]
48634880
async fn test_extract_update_assignments_preserves_joined_source_qualifiers() {
4864-
let assignments = update_assignments_for_sql(
4881+
assert_update_assignment(
48654882
"UPDATE t1 SET b = t2.b FROM t2 WHERE t1.id = t2.id",
4883+
"b",
4884+
"t2.b",
48664885
)
4867-
.await
4868-
.unwrap();
4869-
4870-
let assignments: HashMap<_, _> = assignments.into_iter().collect();
4871-
assert_eq!(
4872-
assignments.get("b").map(ToString::to_string).as_deref(),
4873-
Some("t2.b")
4874-
);
4886+
.await;
48754887
}
48764888

48774889
#[tokio::test]
48784890
async fn test_extract_update_assignments_preserves_alias_qualified_sources() {
4879-
let assignments = update_assignments_for_sql(
4891+
assert_update_assignment(
48804892
"UPDATE t1 AS target SET b = source.b FROM t2 AS source \
48814893
WHERE target.id = source.id",
4894+
"b",
4895+
"source.b",
48824896
)
4883-
.await
4884-
.unwrap();
4885-
4886-
let assignments: HashMap<_, _> = assignments.into_iter().collect();
4887-
assert_eq!(
4888-
assignments.get("b").map(ToString::to_string).as_deref(),
4889-
Some("source.b")
4890-
);
4897+
.await;
48914898
}
48924899

48934900
#[tokio::test]
48944901
async fn test_extract_update_assignments_distinguishes_same_name_join_columns() {
4895-
let assignments = update_assignments_for_sql(
4902+
let assignments: HashMap<_, _> = update_assignments_for_sql(
48964903
"UPDATE t1 SET a = t2.a, b = t1.a FROM t2 WHERE t1.id = t2.id",
48974904
)
48984905
.await
4899-
.unwrap();
4900-
4901-
let assignments: HashMap<_, _> = assignments.into_iter().collect();
4902-
assert_eq!(
4903-
assignments.get("a").map(ToString::to_string).as_deref(),
4904-
Some("t2.a")
4905-
);
4906-
assert_eq!(
4907-
assignments.get("b").map(ToString::to_string).as_deref(),
4908-
Some("t1.a")
4909-
);
4906+
.unwrap()
4907+
.into_iter()
4908+
.map(|(name, expr)| (name, expr.to_string()))
4909+
.collect();
4910+
assert_eq!(assignments.get("a").map(String::as_str), Some("t2.a"));
4911+
assert_eq!(assignments.get("b").map(String::as_str), Some("t1.a"));
49104912
}
49114913

49124914
#[tokio::test]
49134915
async fn test_extract_update_assignments_preserves_self_join_source_alias() {
4914-
let assignments = update_assignments_for_sql(
4916+
assert_update_assignment(
49154917
"UPDATE t1 AS target SET a = src.a FROM t1 AS src \
49164918
WHERE target.id = src.id",
4919+
"a",
4920+
"src.a",
49174921
)
4918-
.await
4919-
.unwrap();
4920-
4921-
let assignments: HashMap<_, _> = assignments.into_iter().collect();
4922-
assert_eq!(
4923-
assignments.get("a").map(ToString::to_string).as_deref(),
4924-
Some("src.a")
4925-
);
4922+
.await;
49264923
}
49274924

49284925
#[tokio::test]
49294926
async fn test_extract_update_assignments_strips_single_table_target_qualifiers() {
4930-
let assignments =
4931-
update_assignments_for_sql("UPDATE t1 SET b = t1.a WHERE t1.id = 1")
4932-
.await
4933-
.unwrap();
4934-
4935-
let assignments: HashMap<_, _> = assignments.into_iter().collect();
4936-
assert_eq!(
4937-
assignments.get("b").map(ToString::to_string).as_deref(),
4938-
Some("a")
4939-
);
4927+
assert_update_assignment("UPDATE t1 SET b = t1.a WHERE t1.id = 1", "b", "a")
4928+
.await;
49404929
}
49414930
}

0 commit comments

Comments
 (0)