Skip to content

Commit 063d0ac

Browse files
authored
feat(tesseract): auto-parenthesize compound SqlCall substitutions (#10724)
1 parent af31ac9 commit 063d0ac

21 files changed

Lines changed: 1601 additions & 131 deletions

rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_call.rs

Lines changed: 35 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ use crate::planner::sql_evaluator::sql_nodes::SqlNodesFactory;
66
use crate::planner::sql_evaluator::{CubeNameSymbol, CubeTableSymbol};
77
use crate::planner::sql_templates::PlanSqlTemplates;
88
use crate::planner::VisitorContext;
9+
use crate::utils::sql_expression_scanner::analyze_template_arg_contexts;
910
use cubenativeutils::CubeError;
1011
use itertools::Itertools;
1112
use std::collections::HashMap;
@@ -117,6 +118,10 @@ pub struct SqlCall {
117118
filter_params: Vec<SqlCallFilterParamsItem>,
118119
filter_groups: Vec<SqlCallFilterGroupItem>,
119120
security_context: SecutityContextProps,
121+
/// Per `{arg:N}` index: whether the surrounding context in the template
122+
/// would make a compound substitution unsafe (requiring parentheses).
123+
/// Computed once at construction from the template.
124+
arg_paren_contexts: HashMap<usize, bool>,
120125
}
121126

122127
impl SqlCall {
@@ -127,12 +132,26 @@ impl SqlCall {
127132
filter_groups: Vec<SqlCallFilterGroupItem>,
128133
security_context: SecutityContextProps,
129134
) -> Self {
135+
let arg_paren_contexts = match &template {
136+
SqlTemplate::String(s) => analyze_template_arg_contexts(s),
137+
SqlTemplate::StringVec(strings) => {
138+
let mut merged: HashMap<usize, bool> = HashMap::new();
139+
for s in strings {
140+
for (idx, needs_safe) in analyze_template_arg_contexts(s) {
141+
let entry = merged.entry(idx).or_insert(false);
142+
*entry = *entry || needs_safe;
143+
}
144+
}
145+
merged
146+
}
147+
};
130148
Self {
131149
template,
132150
deps,
133151
filter_params,
134152
filter_groups,
135153
security_context,
154+
arg_paren_contexts,
136155
}
137156
}
138157

@@ -254,10 +273,22 @@ impl SqlCall {
254273
let deps = self
255274
.deps
256275
.iter()
257-
.map(|dep| match dep {
258-
SqlDependency::Symbol(m) => visitor.apply(m, node_processor.clone(), templates),
259-
SqlDependency::CubeRef(cr) => {
260-
visitor.evaluate_cube_ref(cr, node_processor.clone(), templates)
276+
.enumerate()
277+
.map(|(i, dep)| {
278+
// Each arg's `arg_needs_paren_safe` flag is set by this call's
279+
// template context, overriding whatever the caller's visitor
280+
// carried. The caller's flag only governs wrapping of this
281+
// whole SqlCall's output, handled by an enclosing Parenthesize
282+
// node up the processor chain.
283+
let needs_safe = *self.arg_paren_contexts.get(&i).unwrap_or(&false);
284+
let arg_visitor = visitor.with_arg_needs_paren_safe(needs_safe);
285+
match dep {
286+
SqlDependency::Symbol(m) => {
287+
arg_visitor.apply(m, node_processor.clone(), templates)
288+
}
289+
SqlDependency::CubeRef(cr) => {
290+
arg_visitor.evaluate_cube_ref(cr, node_processor.clone(), templates)
291+
}
261292
}
262293
})
263294
.collect::<Result<Vec<_>, _>>()?;

rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/calendar_time_shift.rs

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -36,13 +36,6 @@ impl SqlNode for CalendarTimeShiftSqlNode {
3636
node_processor: Rc<dyn SqlNode>,
3737
templates: &PlanSqlTemplates,
3838
) -> Result<String, CubeError> {
39-
let input = self.input.to_sql(
40-
visitor,
41-
node,
42-
query_tools.clone(),
43-
node_processor.clone(),
44-
templates,
45-
)?;
4639
let res = match node.as_ref() {
4740
MemberSymbol::Dimension(ev) => {
4841
if !ev.is_reference() {
@@ -55,20 +48,52 @@ impl SqlNode for CalendarTimeShiftSqlNode {
5548
templates,
5649
)?
5750
} else if let Some(interval) = &shift.interval {
51+
let inner_visitor = visitor.with_arg_needs_paren_safe(false);
52+
let input = self.input.to_sql(
53+
&inner_visitor,
54+
node,
55+
query_tools.clone(),
56+
node_processor.clone(),
57+
templates,
58+
)?;
5859
let res = templates
5960
.add_timestamp_interval(input, interval.inverse().to_sql())?;
6061
format!("({})", res)
6162
} else {
62-
input
63+
self.input.to_sql(
64+
visitor,
65+
node,
66+
query_tools.clone(),
67+
node_processor.clone(),
68+
templates,
69+
)?
6370
}
6471
} else {
65-
input
72+
self.input.to_sql(
73+
visitor,
74+
node,
75+
query_tools.clone(),
76+
node_processor.clone(),
77+
templates,
78+
)?
6679
}
6780
} else {
68-
input
81+
self.input.to_sql(
82+
visitor,
83+
node,
84+
query_tools.clone(),
85+
node_processor.clone(),
86+
templates,
87+
)?
6988
}
7089
}
71-
_ => input,
90+
_ => self.input.to_sql(
91+
visitor,
92+
node,
93+
query_tools.clone(),
94+
node_processor.clone(),
95+
templates,
96+
)?,
7297
};
7398
Ok(res)
7499
}

rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/case.rs

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,18 +31,20 @@ impl CaseSqlNode {
3131
node_processor: Rc<dyn SqlNode>,
3232
templates: &PlanSqlTemplates,
3333
) -> Result<String, CubeError> {
34+
// All sub-SQLs end up inside `CASE … END` — a safe wrap.
35+
let inner_visitor = visitor.with_arg_needs_paren_safe(false);
3436
let mut when_then = Vec::new();
3537
for itm in case.items.iter() {
3638
let when = itm.sql.eval(
37-
visitor,
39+
&inner_visitor,
3840
node_processor.clone(),
3941
query_tools.clone(),
4042
templates,
4143
)?;
4244
let then = match &itm.label {
4345
CaseLabel::String(s) => templates.quote_string(&s)?,
4446
CaseLabel::Sql(sql) => sql.eval(
45-
visitor,
47+
&inner_visitor,
4648
node_processor.clone(),
4749
query_tools.clone(),
4850
templates,
@@ -53,7 +55,7 @@ impl CaseSqlNode {
5355
let else_label = match &case.else_label {
5456
CaseLabel::String(s) => templates.quote_string(&s)?,
5557
CaseLabel::Sql(sql) => sql.eval(
56-
visitor,
58+
&inner_visitor,
5759
node_processor.clone(),
5860
query_tools.clone(),
5961
templates,
@@ -69,6 +71,9 @@ impl CaseSqlNode {
6971
node_processor: Rc<dyn SqlNode>,
7072
templates: &PlanSqlTemplates,
7173
) -> Result<String, CubeError> {
74+
// Degenerate shortcuts return the inner SQL as-is — propagate the outer
75+
// visitor so an enclosing ParenthesizeSqlNode still sees the compound
76+
// flag.
7277
if case.items.len() == 1 && case.else_sql.is_none() {
7378
return case.items[0].sql.eval(
7479
visitor,
@@ -85,22 +90,23 @@ impl CaseSqlNode {
8590
templates,
8691
);
8792
}
93+
let inner_visitor = visitor.with_arg_needs_paren_safe(false);
8894
let expr = match &case.switch {
8995
CaseSwitchItem::Sql(sql_call) => sql_call.eval(
90-
visitor,
96+
&inner_visitor,
9197
node_processor.clone(),
9298
query_tools.clone(),
9399
templates,
94100
)?,
95101
CaseSwitchItem::Member(member_symbol) => {
96-
visitor.apply(&member_symbol, node_processor.clone(), templates)?
102+
inner_visitor.apply(&member_symbol, node_processor.clone(), templates)?
97103
}
98104
};
99105
let mut when_then = Vec::new();
100106
for itm in case.items.iter() {
101107
let when = templates.quote_string(&itm.value)?;
102108
let then = itm.sql.eval(
103-
visitor,
109+
&inner_visitor,
104110
node_processor.clone(),
105111
query_tools.clone(),
106112
templates,
@@ -109,7 +115,7 @@ impl CaseSqlNode {
109115
}
110116
let else_label = if let Some(else_sql) = &case.else_sql {
111117
Some(else_sql.eval(
112-
visitor,
118+
&inner_visitor,
113119
node_processor.clone(),
114120
query_tools.clone(),
115121
templates,

rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/factory.rs

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
use super::{
22
AutoPrefixSqlNode, CaseSqlNode, EvaluateSqlNode, FinalMeasureSqlNode,
33
FinalPreAggregationMeasureSqlNode, GeoDimensionSqlNode, MaskedSqlNode, MeasureFilterSqlNode,
4-
MultiStageRankNode, MultiStageWindowNode, RenderReferencesSqlNode, RenderReferencesType,
5-
RollingWindowNode, RootSqlNode, SqlNode, TimeDimensionNode, TimeShiftSqlNode,
6-
UngroupedMeasureSqlNode, UngroupedQueryFinalMeasureSqlNode,
4+
MultiStageRankNode, MultiStageWindowNode, ParenthesizeSqlNode, RenderReferencesSqlNode,
5+
RenderReferencesType, RollingWindowNode, RootSqlNode, SqlNode, TimeDimensionNode,
6+
TimeShiftSqlNode, UngroupedMeasureSqlNode, UngroupedQueryFinalMeasureSqlNode,
77
};
88
use crate::planner::planners::multi_stage::TimeShiftState;
99
use crate::planner::sql_evaluator::cube_ref_evaluator::CubeRefEvaluator;
@@ -156,8 +156,10 @@ impl SqlNodesFactory {
156156
evaluate_sql_processor.clone(),
157157
self.cube_name_references.clone(),
158158
);
159+
let parenthesize_processor: Rc<dyn SqlNode> =
160+
ParenthesizeSqlNode::new(auto_prefix_processor.clone());
159161

160-
let measure_filter_processor = MeasureFilterSqlNode::new(auto_prefix_processor.clone());
162+
let measure_filter_processor = MeasureFilterSqlNode::new(parenthesize_processor.clone());
161163
let measure_processor = CaseSqlNode::new(measure_filter_processor.clone());
162164

163165
let measure_processor = self.add_ungrouped_measure_reference_if_needed(measure_processor);
@@ -182,10 +184,11 @@ impl SqlNodesFactory {
182184
} else {
183185
evaluate_sql_processor.clone()
184186
};
187+
let default_processor: Rc<dyn SqlNode> = ParenthesizeSqlNode::new(default_processor);
185188

186189
let root_node = RootSqlNode::new(
187190
self.dimension_processor(evaluate_sql_processor.clone()),
188-
self.time_dimension_processor(evaluate_sql_processor.clone()),
191+
self.time_dimension_processor(ParenthesizeSqlNode::new(evaluate_sql_processor.clone())),
189192
measure_processor.clone(),
190193
default_processor,
191194
);
@@ -261,6 +264,8 @@ impl SqlNodesFactory {
261264
let input: Rc<dyn SqlNode> =
262265
AutoPrefixSqlNode::new(input, self.cube_name_references.clone());
263266

267+
let input: Rc<dyn SqlNode> = ParenthesizeSqlNode::new(input);
268+
264269
let input: Rc<dyn SqlNode> =
265270
TimeDimensionNode::new(self.dimensions_with_ignored_timezone.clone(), input);
266271

rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/final_measure.rs

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
use super::SqlNode;
22
use crate::planner::query_tools::QueryTools;
3-
use crate::planner::sql_evaluator::symbols::{AggregateWrap, MeasureSymbol};
3+
use crate::planner::sql_evaluator::symbols::AggregateWrap;
44
use crate::planner::sql_evaluator::MemberSymbol;
55
use crate::planner::sql_evaluator::SqlEvaluatorVisitor;
66
use crate::planner::sql_templates::PlanSqlTemplates;
@@ -32,16 +32,13 @@ impl FinalMeasureSqlNode {
3232
&self.input
3333
}
3434

35-
fn wrap_aggregate(
35+
fn apply_wrap(
3636
&self,
37-
ev: &MeasureSymbol,
37+
wrap: AggregateWrap,
3838
input: String,
3939
templates: &PlanSqlTemplates,
4040
) -> Result<String, CubeError> {
41-
let is_multiplied = self
42-
.rendered_as_multiplied_measures
43-
.contains(&ev.full_name());
44-
match ev.kind().aggregate_wrap(is_multiplied) {
41+
match wrap {
4542
AggregateWrap::PassThrough => Ok(input),
4643
AggregateWrap::Function(name) => Ok(format!("{}({})", name, input)),
4744
AggregateWrap::CountDistinct => templates.count_distinct(&input),
@@ -67,14 +64,22 @@ impl SqlNode for FinalMeasureSqlNode {
6764
) -> Result<String, CubeError> {
6865
let res = match node.as_ref() {
6966
MemberSymbol::Measure(ev) => {
67+
let is_multiplied = self
68+
.rendered_as_multiplied_measures
69+
.contains(&ev.full_name());
70+
let wrap = ev.kind().aggregate_wrap(is_multiplied);
71+
let child_visitor = match wrap {
72+
AggregateWrap::PassThrough => visitor.clone(),
73+
_ => visitor.with_arg_needs_paren_safe(false),
74+
};
7075
let input = self.input.to_sql(
71-
visitor,
76+
&child_visitor,
7277
node,
7378
query_tools.clone(),
7479
node_processor.clone(),
7580
templates,
7681
)?;
77-
self.wrap_aggregate(ev, input, templates)?
82+
self.apply_wrap(wrap, input, templates)?
7883
}
7984
_ => {
8085
return Err(CubeError::internal(format!(

rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/geo_dimension.rs

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,14 +34,15 @@ impl SqlNode for GeoDimensionSqlNode {
3434
let res = match node.as_ref() {
3535
MemberSymbol::Dimension(ev) => {
3636
if let DimensionKind::Geo(geo) = ev.kind() {
37+
let inner_visitor = visitor.with_arg_needs_paren_safe(false);
3738
let latitude_str = geo.latitude().eval(
38-
visitor,
39+
&inner_visitor,
3940
node_processor.clone(),
4041
query_tools.clone(),
4142
templates,
4243
)?;
4344
let longitude_str = geo.longitude().eval(
44-
visitor,
45+
&inner_visitor,
4546
node_processor.clone(),
4647
query_tools.clone(),
4748
templates,

rust/cubesqlplanner/cubesqlplanner/src/planner/sql_evaluator/sql_nodes/measure_filter.rs

Lines changed: 17 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,24 +30,25 @@ impl SqlNode for MeasureFilterSqlNode {
3030
node_processor: Rc<dyn SqlNode>,
3131
templates: &PlanSqlTemplates,
3232
) -> Result<String, CubeError> {
33-
let input = self.input.to_sql(
34-
visitor,
35-
node,
36-
query_tools.clone(),
37-
node_processor.clone(),
38-
templates,
39-
)?;
4033
let res = match node.as_ref() {
4134
MemberSymbol::Measure(ev) => {
4235
let measure_filters = ev.measure_filters();
4336
if !measure_filters.is_empty() {
37+
let inner_visitor = visitor.with_arg_needs_paren_safe(false);
38+
let input = self.input.to_sql(
39+
&inner_visitor,
40+
node,
41+
query_tools.clone(),
42+
node_processor.clone(),
43+
templates,
44+
)?;
4445
let filters = measure_filters
4546
.iter()
4647
.map(|filter| -> Result<String, CubeError> {
4748
Ok(format!(
4849
"({})",
4950
filter.eval(
50-
&visitor,
51+
&inner_visitor,
5152
node_processor.clone(),
5253
query_tools.clone(),
5354
templates
@@ -63,7 +64,14 @@ impl SqlNode for MeasureFilterSqlNode {
6364
};
6465
format!("CASE WHEN {} THEN {} END", filters, result)
6566
} else {
66-
input
67+
// Passthrough — propagate visitor unchanged.
68+
self.input.to_sql(
69+
visitor,
70+
node,
71+
query_tools.clone(),
72+
node_processor.clone(),
73+
templates,
74+
)?
6775
}
6876
}
6977
_ => {

0 commit comments

Comments
 (0)