|
16 | 16 | import static com.google.common.base.Preconditions.checkNotNull; |
17 | 17 | import static com.google.common.collect.ImmutableList.toImmutableList; |
18 | 18 | import static com.google.common.collect.MoreCollectors.onlyElement; |
| 19 | +import static dev.cel.checker.CelStandardDeclarations.StandardFunction.DURATION; |
| 20 | +import static dev.cel.checker.CelStandardDeclarations.StandardFunction.TIMESTAMP; |
19 | 21 |
|
20 | 22 | import com.google.auto.value.AutoValue; |
21 | 23 | import com.google.common.collect.ImmutableList; |
|
45 | 47 | import dev.cel.optimizer.CelOptimizationException; |
46 | 48 | import dev.cel.parser.Operator; |
47 | 49 | import dev.cel.runtime.CelEvaluationException; |
| 50 | +import java.math.BigDecimal; |
| 51 | +import java.time.Duration; |
| 52 | +import java.time.Instant; |
48 | 53 | import java.util.ArrayList; |
49 | 54 | import java.util.Arrays; |
50 | 55 | import java.util.Collection; |
@@ -142,6 +147,14 @@ private boolean canFold(CelNavigableMutableExpr navigableExpr) { |
142 | 147 | return false; |
143 | 148 | } |
144 | 149 |
|
| 150 | + // Timestamps/durations in CEL are calls, but they are effectively treated as literals. |
| 151 | + // Expressions like timestamp(123) cannot be folded directly, but arithmetics involving timestamps |
| 152 | + // can be folded. |
| 153 | + // Ex: timestamp(123) - timestamp(100) = duration("23s") |
| 154 | + if (isExprTimestampOrDuration(navigableExpr)) { |
| 155 | + return false; |
| 156 | + } |
| 157 | + |
145 | 158 | CelMutableCall mutableCall = navigableExpr.expr().call(); |
146 | 159 | String functionName = mutableCall.function(); |
147 | 160 |
|
@@ -197,14 +210,19 @@ private boolean canFold(CelNavigableMutableExpr navigableExpr) { |
197 | 210 | private boolean containsFoldableFunctionOnly(CelNavigableMutableExpr navigableExpr) { |
198 | 211 | return navigableExpr |
199 | 212 | .allNodes() |
| 213 | + .filter(node -> node.getKind().equals(Kind.CALL)) |
| 214 | + .map(node -> node.expr().call()) |
200 | 215 | .allMatch( |
201 | | - node -> { |
202 | | - if (node.getKind().equals(Kind.CALL)) { |
203 | | - return foldableFunctions.contains(node.expr().call().function()); |
204 | | - } |
| 216 | + call -> foldableFunctions.contains(call.function())); |
| 217 | + } |
| 218 | + |
| 219 | + private static boolean isExprTimestampOrDuration(CelNavigableMutableExpr navigableExpr) { |
| 220 | + if (!navigableExpr.getKind().equals(Kind.CALL)) { |
| 221 | + return true; |
| 222 | + } |
205 | 223 |
|
206 | | - return true; |
207 | | - }); |
| 224 | + CelMutableCall call = navigableExpr.expr().call(); |
| 225 | + return call.function().equals(TIMESTAMP.functionName()) || call.function().equals(DURATION.functionName()); |
208 | 226 | } |
209 | 227 |
|
210 | 228 | private static boolean canFoldInOperator(CelNavigableMutableExpr navigableExpr) { |
@@ -318,12 +336,32 @@ private Optional<CelMutableExpr> maybeAdaptEvaluatedResult(Object result) { |
318 | 336 | } |
319 | 337 |
|
320 | 338 | return Optional.of(CelMutableExpr.ofMap(CelMutableMap.create(mapEntries))); |
| 339 | + } else if (result instanceof Duration) { |
| 340 | + String durationStrArg = formatDurationArgument((Duration) result); |
| 341 | + CelMutableCall durationCall = CelMutableCall.create(DURATION.functionName(), CelMutableExpr.ofConstant(CelConstant.ofValue(durationStrArg))); |
| 342 | + return Optional.of(CelMutableExpr.ofCall(durationCall)); |
| 343 | + } else if (result instanceof Instant) { |
| 344 | + String timestampStrArg = ((Instant) result).toString(); |
| 345 | + CelMutableCall timestampCall = CelMutableCall.create(TIMESTAMP.functionName(), CelMutableExpr.ofConstant(CelConstant.ofValue(timestampStrArg))); |
| 346 | + return Optional.of(CelMutableExpr.ofCall(timestampCall)); |
321 | 347 | } |
322 | 348 |
|
323 | 349 | // Evaluated result cannot be folded (e.g: unknowns) |
324 | 350 | return Optional.empty(); |
325 | 351 | } |
326 | 352 |
|
| 353 | + private static String formatDurationArgument(Duration duration) { |
| 354 | + if (duration.isZero()) { |
| 355 | + return "0"; |
| 356 | + } |
| 357 | + |
| 358 | + BigDecimal seconds = BigDecimal.valueOf(duration.getSeconds()); |
| 359 | + BigDecimal nanos = BigDecimal.valueOf(duration.getNano(), 9); |
| 360 | + BigDecimal totalSeconds = seconds.add(nanos); |
| 361 | + |
| 362 | + return totalSeconds.stripTrailingZeros().toPlainString() + "s"; |
| 363 | + } |
| 364 | + |
327 | 365 | private Optional<CelMutableAst> maybeRewriteOptional( |
328 | 366 | Optional<?> optResult, CelMutableAst mutableAst, CelMutableExpr expr) { |
329 | 367 | if (!optResult.isPresent()) { |
@@ -352,6 +390,8 @@ private Optional<CelMutableAst> maybeRewriteOptional( |
352 | 390 | return Optional.empty(); |
353 | 391 | } |
354 | 392 |
|
| 393 | + |
| 394 | + |
355 | 395 | /** Inspects the non-strict calls to determine whether a branch can be removed. */ |
356 | 396 | private Optional<CelMutableAst> maybePruneBranches( |
357 | 397 | CelMutableAst mutableAst, CelMutableExpr expr) { |
|
0 commit comments