Skip to content

Commit 3931c05

Browse files
authored
Resolve type mismatch when WHEN result type differs from ELSE (INT32 vs INT64) (#17415)
1 parent dce76a6 commit 3931c05

3 files changed

Lines changed: 146 additions & 29 deletions

File tree

integration-test/src/test/java/org/apache/iotdb/relational/it/db/it/IoTDBCaseWhenThenTableIT.java

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,16 @@ public static void tearDown() throws Exception {
8888
EnvFactory.getEnv().cleanClusterEnvironment();
8989
}
9090

91+
@Test
92+
public void testIfWithCastedDefaultType() {
93+
String[] retArray = new String[] {"0,", "0,", "2,", "3,"};
94+
tableResultSetEqualTest(
95+
"select if(s2 > 1, s2, cast(0 as int64)) from table3 limit 4",
96+
expectedHeader,
97+
retArray,
98+
DATABASE);
99+
}
100+
91101
@Test
92102
public void testKind1Basic() {
93103
String[] retArray = new String[] {"99,", "9999,", "9999,", "999,"};
@@ -161,7 +171,7 @@ public void testKind1OutputTypeRestrict() {
161171
DATABASE);
162172
tableAssertTestFail(
163173
"select case when s1<=0 then 0 when s1>1 then null end from table1",
164-
"701: All result types must be the same:",
174+
"701: All result types and default result type must be the same:",
165175
DATABASE);
166176

167177
// TEXT and other types cannot exist at the same time

iotdb-core/datanode/src/main/java/org/apache/iotdb/db/queryengine/plan/relational/planner/IrTypeAnalyzer.java

Lines changed: 89 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -61,9 +61,12 @@
6161
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.SimpleCaseExpression;
6262
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.StringLiteral;
6363
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.SymbolReference;
64+
import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.WhenClause;
65+
import org.apache.iotdb.db.queryengine.plan.relational.type.TypeCoercionUtils;
6466

6567
import com.google.common.collect.ImmutableList;
6668
import com.google.common.collect.ImmutableMap;
69+
import com.google.common.collect.LinkedHashMultimap;
6770
import org.apache.tsfile.read.common.type.BlobType;
6871
import org.apache.tsfile.read.common.type.DateType;
6972
import org.apache.tsfile.read.common.type.RowType;
@@ -72,10 +75,14 @@
7275
import org.apache.tsfile.read.common.type.Type;
7376

7477
import java.util.ArrayList;
78+
import java.util.Collection;
79+
import java.util.Iterator;
7580
import java.util.LinkedHashMap;
7681
import java.util.LinkedHashSet;
7782
import java.util.List;
7883
import java.util.Map;
84+
import java.util.Set;
85+
import java.util.function.Function;
7986
import java.util.stream.Collectors;
8087

8188
import static com.google.common.base.Preconditions.checkArgument;
@@ -230,37 +237,31 @@ protected Type visitIfExpression(IfExpression node, Context context) {
230237

231238
@Override
232239
protected Type visitSearchedCaseExpression(SearchedCaseExpression node, Context context) {
233-
LinkedHashSet<Type> resultTypes =
234-
node.getWhenClauses().stream()
235-
.map(
236-
clause -> {
237-
Type operandType = process(clause.getOperand(), context);
238-
if (!operandType.equals(BOOLEAN)) {
239-
throw new SemanticException(
240-
String.format("When clause operand must be boolean: %s", operandType));
241-
}
242-
return setExpressionType(clause, process(clause.getResult(), context));
243-
})
244-
.collect(Collectors.toCollection(LinkedHashSet::new));
240+
for (WhenClause whenClause : node.getWhenClauses()) {
241+
coerceType(
242+
context,
243+
whenClause.getOperand(),
244+
BOOLEAN,
245+
(actualType) -> String.format("When clause operand must be boolean: %s", actualType));
246+
}
245247

246-
if (resultTypes.size() != 1) {
247-
throw new SemanticException(
248-
String.format("All result types must be the same: %s", resultTypes));
248+
List<Expression> expressions = new ArrayList<>();
249+
for (WhenClause whenClause : node.getWhenClauses()) {
250+
expressions.add(whenClause.getResult());
249251
}
250-
Type resultType = resultTypes.iterator().next();
251-
node.getDefaultValue()
252-
.ifPresent(
253-
defaultValue -> {
254-
Type defaultType = process(defaultValue, context);
255-
if (!defaultType.equals(resultType)) {
256-
throw new SemanticException(
257-
String.format(
258-
"Default result type must be the same as WHEN result types: %s vs %s",
259-
defaultType, resultType));
260-
}
261-
});
252+
node.getDefaultValue().ifPresent(expressions::add);
262253

263-
return setExpressionType(node, resultType);
254+
Type type =
255+
coerceToSingleType(
256+
context, expressions, "All result types and default result type must be the same");
257+
setExpressionType(node, type);
258+
259+
for (WhenClause whenClause : node.getWhenClauses()) {
260+
Type whenClauseType = process(whenClause.getResult(), context);
261+
setExpressionType(whenClause, whenClauseType);
262+
}
263+
264+
return type;
264265
}
265266

266267
@Override
@@ -485,6 +486,66 @@ protected Type visitNode(Node node, Context context) {
485486
throw new UnsupportedOperationException(
486487
"Not a valid IR expression: " + node.getClass().getName());
487488
}
489+
490+
// Only allow INT32 -> INT64 coercion to suppress some related bugs for now
491+
private void coerceType(
492+
Context context, Expression expression, Type expectedType, Function<Type, String> message) {
493+
Type actualType = process(expression, context);
494+
coerceType(expression, expectedType, actualType, message);
495+
}
496+
497+
private Type coerceToSingleType(
498+
Context context, List<Expression> expressions, String description) {
499+
LinkedHashMultimap<Type, NodeRef<Expression>> typeExpressions = LinkedHashMultimap.create();
500+
501+
for (Expression expression : expressions) {
502+
Type type = process(expression, context);
503+
typeExpressions.put(type, NodeRef.of(expression));
504+
}
505+
Set<Type> types = typeExpressions.keySet();
506+
Iterator<Type> iterator = types.iterator();
507+
Type superType = iterator.next();
508+
if (types.size() == 1) {
509+
return superType;
510+
}
511+
while (iterator.hasNext()) {
512+
Type current = iterator.next();
513+
if (TypeCoercionUtils.canCoerceTo(current, superType)) {
514+
continue;
515+
}
516+
if (TypeCoercionUtils.canCoerceTo(superType, current)) {
517+
superType = current;
518+
}
519+
throw new SemanticException(String.format(description + ": %s vs %s", superType, current));
520+
}
521+
for (Type type : types) {
522+
Set<NodeRef<Expression>> nodeRefs = typeExpressions.get(type);
523+
if (type.equals(superType)) {
524+
continue;
525+
}
526+
if (!TypeCoercionUtils.canCoerceTo(type, superType)) {
527+
throw new SemanticException("Cannot coerce type " + type + " to " + superType);
528+
}
529+
addOrReplaceExpressionType(nodeRefs, superType);
530+
}
531+
return superType;
532+
}
533+
534+
private void coerceType(
535+
Expression expression,
536+
Type actualType,
537+
Type expectedType,
538+
Function<Type, String> errorMsg) {
539+
if (!TypeCoercionUtils.canCoerceTo(actualType, expectedType)) {
540+
throw new SemanticException(errorMsg.apply(actualType));
541+
}
542+
setExpressionType(expression, actualType);
543+
}
544+
545+
private void addOrReplaceExpressionType(
546+
Collection<NodeRef<Expression>> expressions, Type superType) {
547+
expressions.forEach(expression -> expressionTypes.put(expression, superType));
548+
}
488549
}
489550

490551
private static class Context {
Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one
3+
* or more contributor license agreements. See the NOTICE file
4+
* distributed with this work for additional information
5+
* regarding copyright ownership. The ASF licenses this file
6+
* to you under the Apache License, Version 2.0 (the
7+
* "License"); you may not use this file except in compliance
8+
* with the License. You may obtain a copy of the License at
9+
*
10+
* http://www.apache.org/licenses/LICENSE-2.0
11+
*
12+
* Unless required by applicable law or agreed to in writing,
13+
* software distributed under the License is distributed on an
14+
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15+
* KIND, either express or implied. See the License for the
16+
* specific language governing permissions and limitations
17+
* under the License.
18+
*/
19+
20+
package org.apache.iotdb.db.queryengine.plan.relational.type;
21+
22+
import com.google.common.collect.ImmutableMap;
23+
import com.google.common.collect.ImmutableSet;
24+
import org.apache.tsfile.read.common.type.IntType;
25+
import org.apache.tsfile.read.common.type.LongType;
26+
import org.apache.tsfile.read.common.type.Type;
27+
28+
import java.util.Collections;
29+
import java.util.Map;
30+
import java.util.Set;
31+
32+
public class TypeCoercionUtils {
33+
34+
private static final Map<Type, Set<Type>> typeCoercionMap;
35+
36+
static {
37+
typeCoercionMap = ImmutableMap.of(IntType.INT32, ImmutableSet.of(LongType.INT64));
38+
}
39+
40+
public static boolean canCoerceTo(Type from, Type to) {
41+
if (from.equals(to)) {
42+
return true;
43+
}
44+
return typeCoercionMap.getOrDefault(from, Collections.emptySet()).contains(to);
45+
}
46+
}

0 commit comments

Comments
 (0)