Skip to content

Commit ed29a53

Browse files
committed
Handle mixed aggregate/non-aggregate expressions in GROUP BY ALL and add IT
- Extract non-aggregate column references from mixed expressions (e.g. s1 + avg(s2) -> GROUP BY s1) instead of skipping them - Add extractNonAggregateColumnReferences in ExpressionTreeUtils - Add unit tests for mixed expression scenarios - Add integration tests for GROUP BY ALL
1 parent 08911fb commit ed29a53

4 files changed

Lines changed: 385 additions & 0 deletions

File tree

Lines changed: 279 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,279 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.iotdb.relational.it.query.recent;
21+
22+
import org.apache.iotdb.it.env.EnvFactory;
23+
import org.apache.iotdb.it.framework.IoTDBTestRunner;
24+
import org.apache.iotdb.itbase.category.TableClusterIT;
25+
import org.apache.iotdb.itbase.category.TableLocalStandaloneIT;
26+
27+
import org.junit.AfterClass;
28+
import org.junit.BeforeClass;
29+
import org.junit.Test;
30+
import org.junit.experimental.categories.Category;
31+
import org.junit.runner.RunWith;
32+
33+
import static org.apache.iotdb.db.it.utils.TestUtils.prepareTableData;
34+
import static org.apache.iotdb.db.it.utils.TestUtils.tableResultSetEqualTest;
35+
36+
@RunWith(IoTDBTestRunner.class)
37+
@Category({TableLocalStandaloneIT.class, TableClusterIT.class})
38+
public class IoTDBGroupByAllTableIT {
39+
private static final String DATABASE_NAME = "test";
40+
private static final String[] createSqls =
41+
new String[] {
42+
"CREATE DATABASE " + DATABASE_NAME,
43+
"USE " + DATABASE_NAME,
44+
"CREATE TABLE t1(device_id STRING TAG, s1 INT32 FIELD, s2 INT64 FIELD, s3 DOUBLE FIELD)",
45+
"INSERT INTO t1(time, device_id, s1, s2, s3) VALUES (1, 'a', 10, 100, 1.0)",
46+
"INSERT INTO t1(time, device_id, s1, s2, s3) VALUES (2, 'a', 20, 200, 2.0)",
47+
"INSERT INTO t1(time, device_id, s1, s2, s3) VALUES (3, 'a', 30, 300, 3.0)",
48+
"INSERT INTO t1(time, device_id, s1, s2, s3) VALUES (4, 'b', 40, 400, 4.0)",
49+
"INSERT INTO t1(time, device_id, s1, s2, s3) VALUES (5, 'b', 50, 500, 5.0)",
50+
"INSERT INTO t1(time, device_id, s1, s2, s3) VALUES (6, 'c', 60, 600, 6.0)",
51+
"FLUSH",
52+
};
53+
54+
@BeforeClass
55+
public static void setUp() throws Exception {
56+
EnvFactory.getEnv().initClusterEnvironment();
57+
prepareTableData(createSqls);
58+
}
59+
60+
@AfterClass
61+
public static void tearDown() throws Exception {
62+
EnvFactory.getEnv().cleanClusterEnvironment();
63+
}
64+
65+
@Test
66+
public void groupByAllSingleColumnTest() {
67+
// GROUP BY ALL should infer device_id as the grouping key
68+
String[] expectedHeader = new String[] {"device_id", "_col1"};
69+
String[] retArray =
70+
new String[] {
71+
"a,3,", "b,2,", "c,1,",
72+
};
73+
tableResultSetEqualTest(
74+
"SELECT device_id, count(s1) FROM t1 GROUP BY ALL ORDER BY device_id",
75+
expectedHeader,
76+
retArray,
77+
DATABASE_NAME);
78+
}
79+
80+
@Test
81+
public void groupByAllMultipleColumnsTest() {
82+
// GROUP BY ALL should infer device_id and s1 as grouping keys
83+
String[] expectedHeader = new String[] {"device_id", "s1", "_col2"};
84+
String[] retArray =
85+
new String[] {
86+
"a,10,100,", "a,20,200,", "a,30,300,", "b,40,400,", "b,50,500,", "c,60,600,",
87+
};
88+
tableResultSetEqualTest(
89+
"SELECT device_id, s1, sum(s2) FROM t1 GROUP BY ALL ORDER BY device_id, s1",
90+
expectedHeader,
91+
retArray,
92+
DATABASE_NAME);
93+
}
94+
95+
@Test
96+
public void groupByAllEquivalenceTest() {
97+
// GROUP BY ALL and explicit GROUP BY should produce the same results
98+
String[] expectedHeader = new String[] {"device_id", "_col1"};
99+
String[] retArray =
100+
new String[] {
101+
"a,60,", "b,90,", "c,60,",
102+
};
103+
tableResultSetEqualTest(
104+
"SELECT device_id, sum(s1) FROM t1 GROUP BY ALL ORDER BY device_id",
105+
expectedHeader,
106+
retArray,
107+
DATABASE_NAME);
108+
tableResultSetEqualTest(
109+
"SELECT device_id, sum(s1) FROM t1 GROUP BY device_id ORDER BY device_id",
110+
expectedHeader,
111+
retArray,
112+
DATABASE_NAME);
113+
}
114+
115+
@Test
116+
public void groupByAllGlobalAggregationTest() {
117+
// All SELECT items are aggregates, GROUP BY ALL => global aggregation
118+
String[] expectedHeader = new String[] {"_col0", "_col1"};
119+
String[] retArray = new String[] {"6,2100,"};
120+
tableResultSetEqualTest(
121+
"SELECT count(s1), sum(s2) FROM t1 GROUP BY ALL",
122+
expectedHeader,
123+
retArray,
124+
DATABASE_NAME);
125+
}
126+
127+
@Test
128+
public void groupByAllWithExpressionTest() {
129+
// GROUP BY ALL with a computed expression (s1 + 1)
130+
String[] expectedHeader = new String[] {"_col0", "_col1"};
131+
String[] retArray =
132+
new String[] {
133+
"11,100,", "21,200,", "31,300,", "41,400,", "51,500,", "61,600,",
134+
};
135+
tableResultSetEqualTest(
136+
"SELECT s1 + 1, sum(s2) FROM t1 GROUP BY ALL ORDER BY s1 + 1",
137+
expectedHeader,
138+
retArray,
139+
DATABASE_NAME);
140+
}
141+
142+
@Test
143+
public void groupByAllMixedExpressionTest() {
144+
// SELECT s1 + avg(s2) FROM t1 GROUP BY ALL
145+
// should be equivalent to GROUP BY s1
146+
String[] expectedHeader = new String[] {"_col0"};
147+
String[] retArray =
148+
new String[] {
149+
"110.0,", "220.0,", "330.0,", "440.0,", "550.0,", "660.0,",
150+
};
151+
tableResultSetEqualTest(
152+
"SELECT s1 + avg(s2) FROM t1 GROUP BY ALL ORDER BY s1 + avg(s2)",
153+
expectedHeader,
154+
retArray,
155+
DATABASE_NAME);
156+
}
157+
158+
@Test
159+
public void groupByAllMixedExpressionEquivalenceTest() {
160+
// GROUP BY ALL with s1 + avg(s2) should equal GROUP BY s1
161+
String[] expectedHeader = new String[] {"_col0"};
162+
String[] retArray =
163+
new String[] {
164+
"110.0,", "220.0,", "330.0,", "440.0,", "550.0,", "660.0,",
165+
};
166+
tableResultSetEqualTest(
167+
"SELECT s1 + avg(s2) FROM t1 GROUP BY ALL ORDER BY s1 + avg(s2)",
168+
expectedHeader,
169+
retArray,
170+
DATABASE_NAME);
171+
tableResultSetEqualTest(
172+
"SELECT s1 + avg(s2) FROM t1 GROUP BY s1 ORDER BY s1 + avg(s2)",
173+
expectedHeader,
174+
retArray,
175+
DATABASE_NAME);
176+
}
177+
178+
@Test
179+
public void groupByAllMixedWithPureColumnTest() {
180+
// SELECT device_id, s1 + avg(s2) FROM t1 GROUP BY ALL
181+
// should be equivalent to GROUP BY device_id, s1
182+
String[] expectedHeader = new String[] {"device_id", "_col1"};
183+
String[] retArray =
184+
new String[] {
185+
"a,110.0,", "a,220.0,", "a,330.0,", "b,440.0,", "b,550.0,", "c,660.0,",
186+
};
187+
tableResultSetEqualTest(
188+
"SELECT device_id, s1 + avg(s2) FROM t1 GROUP BY ALL ORDER BY device_id, s1 + avg(s2)",
189+
expectedHeader,
190+
retArray,
191+
DATABASE_NAME);
192+
tableResultSetEqualTest(
193+
"SELECT device_id, s1 + avg(s2) FROM t1 GROUP BY device_id, s1 ORDER BY device_id, s1 + avg(s2)",
194+
expectedHeader,
195+
retArray,
196+
DATABASE_NAME);
197+
}
198+
199+
@Test
200+
public void groupByAllMixedMultipleSubExpressionsTest() {
201+
// SELECT s1 + s3 + avg(s2) FROM t1 GROUP BY ALL
202+
// should be equivalent to GROUP BY s1, s3
203+
String[] expectedHeader = new String[] {"_col0"};
204+
String[] retArray =
205+
new String[] {
206+
"111.0,", "222.0,", "333.0,", "444.0,", "555.0,", "666.0,",
207+
};
208+
tableResultSetEqualTest(
209+
"SELECT s1 + s3 + avg(s2) FROM t1 GROUP BY ALL ORDER BY s1 + s3 + avg(s2)",
210+
expectedHeader,
211+
retArray,
212+
DATABASE_NAME);
213+
tableResultSetEqualTest(
214+
"SELECT s1 + s3 + avg(s2) FROM t1 GROUP BY s1, s3 ORDER BY s1 + s3 + avg(s2)",
215+
expectedHeader,
216+
retArray,
217+
DATABASE_NAME);
218+
}
219+
220+
@Test
221+
public void groupByAllWithHavingTest() {
222+
// GROUP BY ALL with HAVING clause
223+
String[] expectedHeader = new String[] {"device_id", "_col1"};
224+
String[] retArray =
225+
new String[] {
226+
"a,3,",
227+
};
228+
tableResultSetEqualTest(
229+
"SELECT device_id, count(s1) FROM t1 GROUP BY ALL HAVING count(s1) >= 3 ORDER BY device_id",
230+
expectedHeader,
231+
retArray,
232+
DATABASE_NAME);
233+
}
234+
235+
@Test
236+
public void groupByAllWithWhereTest() {
237+
// GROUP BY ALL with WHERE clause
238+
String[] expectedHeader = new String[] {"device_id", "_col1"};
239+
String[] retArray =
240+
new String[] {
241+
"a,2,", "b,2,",
242+
};
243+
tableResultSetEqualTest(
244+
"SELECT device_id, count(s1) FROM t1 WHERE s1 >= 20 GROUP BY ALL ORDER BY device_id",
245+
expectedHeader,
246+
retArray,
247+
DATABASE_NAME);
248+
}
249+
250+
@Test
251+
public void groupByAllQuantifierBackwardCompatibilityTest() {
252+
// GROUP BY ALL s1 (ALL as set-quantifier, not GROUP BY ALL feature)
253+
String[] expectedHeader = new String[] {"s1", "_col1"};
254+
String[] retArray =
255+
new String[] {
256+
"10,100,", "20,200,", "30,300,", "40,400,", "50,500,", "60,600,",
257+
};
258+
tableResultSetEqualTest(
259+
"SELECT s1, sum(s2) FROM t1 GROUP BY ALL s1 ORDER BY s1",
260+
expectedHeader,
261+
retArray,
262+
DATABASE_NAME);
263+
}
264+
265+
@Test
266+
public void groupByAllMultipleAggregatesTest() {
267+
// Multiple aggregation functions with GROUP BY ALL
268+
String[] expectedHeader = new String[] {"device_id", "_col1", "_col2", "_col3"};
269+
String[] retArray =
270+
new String[] {
271+
"a,3,60,200.0,", "b,2,90,450.0,", "c,1,60,600.0,",
272+
};
273+
tableResultSetEqualTest(
274+
"SELECT device_id, count(s1), sum(s1), avg(s2) FROM t1 GROUP BY ALL ORDER BY device_id",
275+
expectedHeader,
276+
retArray,
277+
DATABASE_NAME);
278+
}
279+
}

iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/ExpressionTreeUtils.java

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -105,6 +105,37 @@ public static QualifiedName asQualifiedName(Expression expression) {
105105
return name;
106106
}
107107

108+
/**
109+
* Extracts column references (Identifiers and DereferenceExpressions) that appear outside of
110+
* aggregate and window function call boundaries. For example, in {@code s1 + avg(s2)}, this
111+
* returns {@code [s1]}, skipping {@code s2} inside the aggregate.
112+
*/
113+
static List<Expression> extractNonAggregateColumnReferences(Expression expression) {
114+
ImmutableList.Builder<Expression> result = ImmutableList.builder();
115+
new DefaultExpressionTraversalVisitor<Void>() {
116+
@Override
117+
protected Void visitFunctionCall(FunctionCall node, Void context) {
118+
if (isAggregation(node) || isWindowFunction(node)) {
119+
return null;
120+
}
121+
return super.visitFunctionCall(node, context);
122+
}
123+
124+
@Override
125+
protected Void visitIdentifier(Identifier node, Void context) {
126+
result.add(node);
127+
return null;
128+
}
129+
130+
@Override
131+
protected Void visitDereferenceExpression(DereferenceExpression node, Void context) {
132+
result.add(node);
133+
return null;
134+
}
135+
}.process(expression, null);
136+
return result.build();
137+
}
138+
108139
static boolean isAggregationFunction(String functionName) {
109140
return TableBuiltinAggregationFunction.getBuiltInAggregateFunctionName()
110141
.contains(functionName.toLowerCase())

iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/analyzer/StatementAnalyzer.java

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2738,6 +2738,23 @@ private Analysis.GroupingSetAnalysis analyzeGroupByAll(
27382738
List<FunctionCall> windowFunctions =
27392739
extractWindowFunctions(ImmutableList.of(outputExpression));
27402740
if (!aggregates.isEmpty() || !windowFunctions.isEmpty()) {
2741+
// Extract non-aggregate sub-expressions (column references outside aggregate
2742+
// boundaries).
2743+
// e.g. in `s1 + avg(s2)`, extract `s1` as a grouping key.
2744+
List<Expression> nonAggColumns =
2745+
ExpressionTreeUtils.extractNonAggregateColumnReferences(outputExpression);
2746+
for (Expression colRef : nonAggColumns) {
2747+
analyzeExpression(colRef, scope);
2748+
ResolvedField field =
2749+
analysis.getColumnReferenceFields().get(NodeRef.of(colRef));
2750+
if (field != null) {
2751+
sets.add(ImmutableList.of(ImmutableSet.of(field.getFieldId())));
2752+
} else {
2753+
complexExpressions.add(colRef);
2754+
}
2755+
gapFillGroupingExpressions.add(colRef);
2756+
groupingExpressions.add(colRef);
2757+
}
27412758
continue;
27422759
}
27432760

0 commit comments

Comments
 (0)