|
61 | 61 | import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.SimpleCaseExpression; |
62 | 62 | import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.StringLiteral; |
63 | 63 | import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.SymbolReference; |
| 64 | +import org.apache.iotdb.db.queryengine.plan.relational.sql.ast.WhenClause; |
| 65 | +import org.apache.iotdb.db.queryengine.plan.relational.type.TypeCoercionUtils; |
64 | 66 |
|
65 | 67 | import com.google.common.collect.ImmutableList; |
66 | 68 | import com.google.common.collect.ImmutableMap; |
| 69 | +import com.google.common.collect.LinkedHashMultimap; |
67 | 70 | import org.apache.tsfile.read.common.type.BlobType; |
68 | 71 | import org.apache.tsfile.read.common.type.DateType; |
69 | 72 | import org.apache.tsfile.read.common.type.RowType; |
|
72 | 75 | import org.apache.tsfile.read.common.type.Type; |
73 | 76 |
|
74 | 77 | import java.util.ArrayList; |
| 78 | +import java.util.Collection; |
| 79 | +import java.util.Iterator; |
75 | 80 | import java.util.LinkedHashMap; |
76 | 81 | import java.util.LinkedHashSet; |
77 | 82 | import java.util.List; |
78 | 83 | import java.util.Map; |
| 84 | +import java.util.Set; |
| 85 | +import java.util.function.Function; |
79 | 86 | import java.util.stream.Collectors; |
80 | 87 |
|
81 | 88 | import static com.google.common.base.Preconditions.checkArgument; |
@@ -230,37 +237,31 @@ protected Type visitIfExpression(IfExpression node, Context context) { |
230 | 237 |
|
231 | 238 | @Override |
232 | 239 | protected Type visitSearchedCaseExpression(SearchedCaseExpression node, Context context) { |
233 | | - LinkedHashSet<Type> resultTypes = |
234 | | - node.getWhenClauses().stream() |
235 | | - .map( |
236 | | - clause -> { |
237 | | - Type operandType = process(clause.getOperand(), context); |
238 | | - if (!operandType.equals(BOOLEAN)) { |
239 | | - throw new SemanticException( |
240 | | - String.format("When clause operand must be boolean: %s", operandType)); |
241 | | - } |
242 | | - return setExpressionType(clause, process(clause.getResult(), context)); |
243 | | - }) |
244 | | - .collect(Collectors.toCollection(LinkedHashSet::new)); |
| 240 | + for (WhenClause whenClause : node.getWhenClauses()) { |
| 241 | + coerceType( |
| 242 | + context, |
| 243 | + whenClause.getOperand(), |
| 244 | + BOOLEAN, |
| 245 | + (actualType) -> String.format("When clause operand must be boolean: %s", actualType)); |
| 246 | + } |
245 | 247 |
|
246 | | - if (resultTypes.size() != 1) { |
247 | | - throw new SemanticException( |
248 | | - String.format("All result types must be the same: %s", resultTypes)); |
| 248 | + List<Expression> expressions = new ArrayList<>(); |
| 249 | + for (WhenClause whenClause : node.getWhenClauses()) { |
| 250 | + expressions.add(whenClause.getResult()); |
249 | 251 | } |
250 | | - Type resultType = resultTypes.iterator().next(); |
251 | | - node.getDefaultValue() |
252 | | - .ifPresent( |
253 | | - defaultValue -> { |
254 | | - Type defaultType = process(defaultValue, context); |
255 | | - if (!defaultType.equals(resultType)) { |
256 | | - throw new SemanticException( |
257 | | - String.format( |
258 | | - "Default result type must be the same as WHEN result types: %s vs %s", |
259 | | - defaultType, resultType)); |
260 | | - } |
261 | | - }); |
| 252 | + node.getDefaultValue().ifPresent(expressions::add); |
262 | 253 |
|
263 | | - return setExpressionType(node, resultType); |
| 254 | + Type type = |
| 255 | + coerceToSingleType( |
| 256 | + context, expressions, "All result types and default result type must be the same"); |
| 257 | + setExpressionType(node, type); |
| 258 | + |
| 259 | + for (WhenClause whenClause : node.getWhenClauses()) { |
| 260 | + Type whenClauseType = process(whenClause.getResult(), context); |
| 261 | + setExpressionType(whenClause, whenClauseType); |
| 262 | + } |
| 263 | + |
| 264 | + return type; |
264 | 265 | } |
265 | 266 |
|
266 | 267 | @Override |
@@ -485,6 +486,66 @@ protected Type visitNode(Node node, Context context) { |
485 | 486 | throw new UnsupportedOperationException( |
486 | 487 | "Not a valid IR expression: " + node.getClass().getName()); |
487 | 488 | } |
| 489 | + |
| 490 | + // Only allow INT32 -> INT64 coercion to suppress some related bugs for now |
| 491 | + private void coerceType( |
| 492 | + Context context, Expression expression, Type expectedType, Function<Type, String> message) { |
| 493 | + Type actualType = process(expression, context); |
| 494 | + coerceType(expression, expectedType, actualType, message); |
| 495 | + } |
| 496 | + |
| 497 | + private Type coerceToSingleType( |
| 498 | + Context context, List<Expression> expressions, String description) { |
| 499 | + LinkedHashMultimap<Type, NodeRef<Expression>> typeExpressions = LinkedHashMultimap.create(); |
| 500 | + |
| 501 | + for (Expression expression : expressions) { |
| 502 | + Type type = process(expression, context); |
| 503 | + typeExpressions.put(type, NodeRef.of(expression)); |
| 504 | + } |
| 505 | + Set<Type> types = typeExpressions.keySet(); |
| 506 | + Iterator<Type> iterator = types.iterator(); |
| 507 | + Type superType = iterator.next(); |
| 508 | + if (types.size() == 1) { |
| 509 | + return superType; |
| 510 | + } |
| 511 | + while (iterator.hasNext()) { |
| 512 | + Type current = iterator.next(); |
| 513 | + if (TypeCoercionUtils.canCoerceTo(current, superType)) { |
| 514 | + continue; |
| 515 | + } |
| 516 | + if (TypeCoercionUtils.canCoerceTo(superType, current)) { |
| 517 | + superType = current; |
| 518 | + } |
| 519 | + throw new SemanticException(String.format(description + ": %s vs %s", superType, current)); |
| 520 | + } |
| 521 | + for (Type type : types) { |
| 522 | + Set<NodeRef<Expression>> nodeRefs = typeExpressions.get(type); |
| 523 | + if (type.equals(superType)) { |
| 524 | + continue; |
| 525 | + } |
| 526 | + if (!TypeCoercionUtils.canCoerceTo(type, superType)) { |
| 527 | + throw new SemanticException("Cannot coerce type " + type + " to " + superType); |
| 528 | + } |
| 529 | + addOrReplaceExpressionType(nodeRefs, superType); |
| 530 | + } |
| 531 | + return superType; |
| 532 | + } |
| 533 | + |
| 534 | + private void coerceType( |
| 535 | + Expression expression, |
| 536 | + Type actualType, |
| 537 | + Type expectedType, |
| 538 | + Function<Type, String> errorMsg) { |
| 539 | + if (!TypeCoercionUtils.canCoerceTo(actualType, expectedType)) { |
| 540 | + throw new SemanticException(errorMsg.apply(actualType)); |
| 541 | + } |
| 542 | + setExpressionType(expression, actualType); |
| 543 | + } |
| 544 | + |
| 545 | + private void addOrReplaceExpressionType( |
| 546 | + Collection<NodeRef<Expression>> expressions, Type superType) { |
| 547 | + expressions.forEach(expression -> expressionTypes.put(expression, superType)); |
| 548 | + } |
488 | 549 | } |
489 | 550 |
|
490 | 551 | private static class Context { |
|
0 commit comments