Skip to content

Commit 58ec38a

Browse files
Enhance approx_percentile with strict weight validation, improved error handling and IT coverage (#17368)
1 parent fd3726f commit 58ec38a

File tree

4 files changed

+135
-9
lines changed

4 files changed

+135
-9
lines changed

integration-test/src/test/java/org/apache/iotdb/relational/it/query/recent/IoTDBTableAggregationIT.java

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4352,6 +4352,18 @@ public void approxPercentileTest() {
43524352
"2024-09-24T06:15:55.000Z,shanghai,55,null,",
43534353
},
43544354
DATABASE_NAME);
4355+
4356+
tableResultSetEqualTest(
4357+
"select approx_percentile(s1,null,0.5) from table1",
4358+
new String[] {"_col0"},
4359+
new String[] {"null,"},
4360+
DATABASE_NAME);
4361+
4362+
tableResultSetEqualTest(
4363+
"select 1 as g, approx_percentile(s1,null,0.5) from table1 group by 1",
4364+
new String[] {"g", "_col1"},
4365+
new String[] {"1,null,"},
4366+
DATABASE_NAME);
43554367
}
43564368

43574369
@Test
@@ -4432,6 +4444,18 @@ public void exceptionTest() {
44324444
"select approx_percentile(s5,0.5) from table1",
44334445
"701: Aggregation functions [approx_percentile] should have value column as numeric type [INT32, INT64, FLOAT, DOUBLE, TIMESTAMP]",
44344446
DATABASE_NAME);
4447+
tableAssertTestFail(
4448+
"select approx_percentile(s1,-1,0.5) from table1",
4449+
"701: weight must be >= 1, was -1",
4450+
DATABASE_NAME);
4451+
tableAssertTestFail(
4452+
"select approx_percentile(s1,s2,0.5) from table1",
4453+
"701: Aggregation functions [approx_percentile] do not support weight as INT64 type",
4454+
DATABASE_NAME);
4455+
tableAssertTestFail(
4456+
"select 1 as g, approx_percentile(s1,s2,0.5) from table1 group by 1",
4457+
"701: Aggregation functions [approx_percentile] do not support weight as INT64 type",
4458+
DATABASE_NAME);
44354459
}
44364460

44374461
// ==================================================================

iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/ApproxPercentileWithWeightAccumulator.java

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,8 @@
1414

1515
package org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation;
1616

17+
import org.apache.iotdb.db.exception.sql.SemanticException;
18+
1719
import org.apache.tsfile.block.column.Column;
1820
import org.apache.tsfile.enums.TSDataType;
1921

@@ -32,6 +34,12 @@ public void addIntInput(Column[] arguments, AggregationMask mask) {
3234

3335
if (mask.isSelectAll()) {
3436
for (int i = 0; i < valueColumn.getPositionCount(); i++) {
37+
if (weightColumn.isNull(i)) {
38+
continue;
39+
}
40+
if (weightColumn.getInt(i) < 1) {
41+
throw new SemanticException("weight must be >= 1, was " + weightColumn.getInt(i));
42+
}
3543
if (!valueColumn.isNull(i)) {
3644
tDigest.add(valueColumn.getInt(i), weightColumn.getInt(i));
3745
}
@@ -41,6 +49,12 @@ public void addIntInput(Column[] arguments, AggregationMask mask) {
4149
int position;
4250
for (int i = 0; i < positionCount; i++) {
4351
position = selectedPositions[i];
52+
if (weightColumn.isNull(position)) {
53+
continue;
54+
}
55+
if (weightColumn.getInt(position) < 1) {
56+
throw new SemanticException("weight must be >= 1, was " + weightColumn.getInt(position));
57+
}
4458
if (!valueColumn.isNull(position)) {
4559
tDigest.add(valueColumn.getInt(position), weightColumn.getInt(position));
4660
}
@@ -57,6 +71,12 @@ public void addLongInput(Column[] arguments, AggregationMask mask) {
5771

5872
if (mask.isSelectAll()) {
5973
for (int i = 0; i < valueColumn.getPositionCount(); i++) {
74+
if (weightColumn.isNull(i)) {
75+
continue;
76+
}
77+
if (weightColumn.getInt(i) < 1) {
78+
throw new SemanticException("weight must be >= 1, was " + weightColumn.getInt(i));
79+
}
6080
if (!valueColumn.isNull(i)) {
6181
tDigest.add(toDoubleExact(valueColumn.getLong(i)), weightColumn.getInt(i));
6282
}
@@ -66,6 +86,12 @@ public void addLongInput(Column[] arguments, AggregationMask mask) {
6686
int position;
6787
for (int i = 0; i < positionCount; i++) {
6888
position = selectedPositions[i];
89+
if (weightColumn.isNull(position)) {
90+
continue;
91+
}
92+
if (weightColumn.getInt(position) < 1) {
93+
throw new SemanticException("weight must be >= 1, was " + weightColumn.getInt(position));
94+
}
6995
if (!valueColumn.isNull(position)) {
7096
tDigest.add(toDoubleExact(valueColumn.getLong(position)), weightColumn.getInt(position));
7197
}
@@ -82,6 +108,12 @@ public void addFloatInput(Column[] arguments, AggregationMask mask) {
82108

83109
if (mask.isSelectAll()) {
84110
for (int i = 0; i < valueColumn.getPositionCount(); i++) {
111+
if (weightColumn.isNull(i)) {
112+
continue;
113+
}
114+
if (weightColumn.getInt(i) < 1) {
115+
throw new SemanticException("weight must be >= 1, was " + weightColumn.getInt(i));
116+
}
85117
if (!valueColumn.isNull(i)) {
86118
tDigest.add(valueColumn.getFloat(i), weightColumn.getInt(i));
87119
}
@@ -91,6 +123,12 @@ public void addFloatInput(Column[] arguments, AggregationMask mask) {
91123
int position;
92124
for (int i = 0; i < positionCount; i++) {
93125
position = selectedPositions[i];
126+
if (weightColumn.isNull(position)) {
127+
continue;
128+
}
129+
if (weightColumn.getInt(position) < 1) {
130+
throw new SemanticException("weight must be >= 1, was " + weightColumn.getInt(position));
131+
}
94132
if (!valueColumn.isNull(position)) {
95133
tDigest.add(valueColumn.getFloat(position), weightColumn.getInt(position));
96134
}
@@ -107,6 +145,12 @@ public void addDoubleInput(Column[] arguments, AggregationMask mask) {
107145

108146
if (mask.isSelectAll()) {
109147
for (int i = 0; i < valueColumn.getPositionCount(); i++) {
148+
if (weightColumn.isNull(i)) {
149+
continue;
150+
}
151+
if (weightColumn.getInt(i) < 1) {
152+
throw new SemanticException("weight must be >= 1, was " + weightColumn.getInt(i));
153+
}
110154
if (!valueColumn.isNull(i)) {
111155
tDigest.add(valueColumn.getDouble(i), weightColumn.getInt(i));
112156
}
@@ -116,6 +160,12 @@ public void addDoubleInput(Column[] arguments, AggregationMask mask) {
116160
int position;
117161
for (int i = 0; i < positionCount; i++) {
118162
position = selectedPositions[i];
163+
if (weightColumn.isNull(position)) {
164+
continue;
165+
}
166+
if (weightColumn.getInt(position) < 1) {
167+
throw new SemanticException("weight must be >= 1, was " + weightColumn.getInt(position));
168+
}
119169
if (!valueColumn.isNull(position)) {
120170
tDigest.add(valueColumn.getDouble(position), weightColumn.getInt(position));
121171
}

iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/execution/operator/source/relational/aggregation/grouped/GroupedApproxPercentileWithWeightAccumulator.java

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414

1515
package org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.grouped;
1616

17+
import org.apache.iotdb.db.exception.sql.SemanticException;
1718
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.AggregationMask;
1819
import org.apache.iotdb.db.queryengine.execution.operator.source.relational.aggregation.approximate.TDigest;
1920

@@ -36,6 +37,12 @@ public void addIntInput(int[] groupIds, Column[] arguments, AggregationMask mask
3637

3738
if (mask.isSelectAll()) {
3839
for (int i = 0; i < positionCount; i++) {
40+
if (weightColumn.isNull(i)) {
41+
continue;
42+
}
43+
if (weightColumn.getInt(i) < 1) {
44+
throw new SemanticException("weight must be >= 1, was " + weightColumn.getInt(i));
45+
}
3946
int groupId = groupIds[i];
4047
TDigest tDigest = array.get(groupId);
4148
if (!valueColumn.isNull(i)) {
@@ -48,6 +55,12 @@ public void addIntInput(int[] groupIds, Column[] arguments, AggregationMask mask
4855
int groupId;
4956
for (int i = 0; i < positionCount; i++) {
5057
position = selectedPositions[i];
58+
if (weightColumn.isNull(position)) {
59+
continue;
60+
}
61+
if (weightColumn.getInt(position) < 1) {
62+
throw new SemanticException("weight must be >= 1, was " + weightColumn.getInt(position));
63+
}
5164
groupId = groupIds[position];
5265
TDigest tDigest = array.get(groupId);
5366
if (!valueColumn.isNull(position)) {
@@ -66,6 +79,12 @@ public void addLongInput(int[] groupIds, Column[] arguments, AggregationMask mas
6679

6780
if (mask.isSelectAll()) {
6881
for (int i = 0; i < positionCount; i++) {
82+
if (weightColumn.isNull(i)) {
83+
continue;
84+
}
85+
if (weightColumn.getInt(i) < 1) {
86+
throw new SemanticException("weight must be >= 1, was " + weightColumn.getInt(i));
87+
}
6988
int groupId = groupIds[i];
7089
TDigest tDigest = array.get(groupId);
7190
if (!valueColumn.isNull(i)) {
@@ -78,6 +97,12 @@ public void addLongInput(int[] groupIds, Column[] arguments, AggregationMask mas
7897
int groupId;
7998
for (int i = 0; i < positionCount; i++) {
8099
position = selectedPositions[i];
100+
if (weightColumn.isNull(position)) {
101+
continue;
102+
}
103+
if (weightColumn.getInt(position) < 1) {
104+
throw new SemanticException("weight must be >= 1, was " + weightColumn.getInt(position));
105+
}
81106
groupId = groupIds[position];
82107
TDigest tDigest = array.get(groupId);
83108
if (!valueColumn.isNull(position)) {
@@ -96,6 +121,12 @@ public void addFloatInput(int[] groupIds, Column[] arguments, AggregationMask ma
96121

97122
if (mask.isSelectAll()) {
98123
for (int i = 0; i < positionCount; i++) {
124+
if (weightColumn.isNull(i)) {
125+
continue;
126+
}
127+
if (weightColumn.getInt(i) < 1) {
128+
throw new SemanticException("weight must be >= 1, was " + weightColumn.getInt(i));
129+
}
99130
int groupId = groupIds[i];
100131
TDigest tDigest = array.get(groupId);
101132
if (!valueColumn.isNull(i)) {
@@ -108,6 +139,12 @@ public void addFloatInput(int[] groupIds, Column[] arguments, AggregationMask ma
108139
int groupId;
109140
for (int i = 0; i < positionCount; i++) {
110141
position = selectedPositions[i];
142+
if (weightColumn.isNull(position)) {
143+
continue;
144+
}
145+
if (weightColumn.getInt(position) < 1) {
146+
throw new SemanticException("weight must be >= 1, was " + weightColumn.getInt(position));
147+
}
111148
groupId = groupIds[position];
112149
TDigest tDigest = array.get(groupId);
113150
if (!valueColumn.isNull(position)) {
@@ -126,6 +163,12 @@ public void addDoubleInput(int[] groupIds, Column[] arguments, AggregationMask m
126163

127164
if (mask.isSelectAll()) {
128165
for (int i = 0; i < positionCount; i++) {
166+
if (weightColumn.isNull(i)) {
167+
continue;
168+
}
169+
if (weightColumn.getInt(i) < 1) {
170+
throw new SemanticException("weight must be >= 1, was " + weightColumn.getInt(i));
171+
}
129172
int groupId = groupIds[i];
130173
TDigest tDigest = array.get(groupId);
131174
if (!valueColumn.isNull(i)) {
@@ -138,6 +181,12 @@ public void addDoubleInput(int[] groupIds, Column[] arguments, AggregationMask m
138181
int groupId;
139182
for (int i = 0; i < positionCount; i++) {
140183
position = selectedPositions[i];
184+
if (weightColumn.isNull(position)) {
185+
continue;
186+
}
187+
if (weightColumn.getInt(position) < 1) {
188+
throw new SemanticException("weight must be >= 1, was " + weightColumn.getInt(position));
189+
}
141190
groupId = groupIds[position];
142191
TDigest tDigest = array.get(groupId);
143192
if (!valueColumn.isNull(position)) {

iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/metadata/TableMetadataImpl.java

Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1194,19 +1194,22 @@ && isIntegerNumber(argumentTypes.get(2)))) {
11941194
functionName));
11951195
}
11961196

1197-
// Validate percentage and weight parameters
1198-
boolean hasInvalidTypes =
1199-
(argumentSize == 2 && !isDecimalType(argumentTypes.get(1)))
1200-
|| (argumentSize == 3
1201-
&& (!isIntegerNumber(argumentTypes.get(1))
1202-
|| !isDecimalType(argumentTypes.get(2))));
1203-
1204-
if (hasInvalidTypes) {
1197+
Type percentageType = argumentTypes.get(argumentSize - 1);
1198+
if (!isDecimalType(percentageType)) {
12051199
throw new SemanticException(
12061200
String.format(
1207-
"Aggregation functions [%s] should have weight as integer type and percentage as decimal type",
1201+
"Aggregation functions [%s] should have percentage as decimal type",
12081202
functionName));
12091203
}
1204+
if (argumentSize == 3) {
1205+
Type weightType = argumentTypes.get(1);
1206+
if (!INT32.equals(weightType) && !isUnknownType(weightType)) {
1207+
throw new SemanticException(
1208+
String.format(
1209+
"Aggregation functions [%s] do not support weight as %s type",
1210+
functionName, weightType.getDisplayName()));
1211+
}
1212+
}
12101213

12111214
break;
12121215
case SqlConstant.COUNT:

0 commit comments

Comments
 (0)