Skip to content

Commit 257fd24

Browse files
maskri17copybara-github
authored andcommitted
Adding transformMap and transformMapEntry macros
PiperOrigin-RevId: 799752266
1 parent e73283b commit 257fd24

7 files changed

Lines changed: 403 additions & 19 deletions

File tree

checker/src/test/resources/standardEnvDump.baseline

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@ Source: 'redundant expression so the env is constructed and can be printed'
44

55

66
Standard environment:
7+
declare cel.@mapInsert {
8+
function cel_@mapInsert_map_map (map(K, V), map(K, V)) -> map(K, V)
9+
function cel_@mapInsert_map_key_value (map(K, V), K, V) -> map(K, V)
10+
}
711
declare !_ {
812
function logical_not (bool) -> bool
913
}

extensions/src/main/java/dev/cel/extensions/BUILD.bazel

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,12 +299,19 @@ java_library(
299299
name = "comprehensions",
300300
srcs = ["CelComprehensionsExtensions.java"],
301301
deps = [
302+
"//checker:checker_builder",
302303
"//common:compiler_common",
304+
"//common:options",
303305
"//common/ast",
306+
"//common/types",
304307
"//compiler:compiler_builder",
308+
"//extensions:extension_library",
305309
"//parser:macro",
306310
"//parser:operator",
307311
"//parser:parser_builder",
312+
"//runtime",
313+
"//runtime:function_binding",
314+
"//runtime:runtime_equality",
308315
"@maven//:com_google_guava_guava",
309316
],
310317
)

extensions/src/main/java/dev/cel/extensions/CelComprehensionsExtensions.java

Lines changed: 263 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,21 +18,137 @@
1818
import static com.google.common.base.Preconditions.checkNotNull;
1919

2020
import com.google.common.collect.ImmutableList;
21+
import com.google.common.collect.ImmutableMap;
2122
import com.google.common.collect.ImmutableSet;
23+
import dev.cel.checker.CelCheckerBuilder;
24+
import dev.cel.common.CelFunctionDecl;
2225
import dev.cel.common.CelIssue;
26+
import dev.cel.common.CelOptions;
27+
import dev.cel.common.CelOverloadDecl;
2328
import dev.cel.common.ast.CelExpr;
29+
import dev.cel.common.types.MapType;
30+
import dev.cel.common.types.TypeParamType;
2431
import dev.cel.compiler.CelCompilerLibrary;
2532
import dev.cel.parser.CelMacro;
2633
import dev.cel.parser.CelMacroExprFactory;
2734
import dev.cel.parser.CelParserBuilder;
2835
import dev.cel.parser.Operator;
36+
import dev.cel.runtime.CelFunctionBinding;
37+
import dev.cel.runtime.CelInternalRuntimeLibrary;
38+
import dev.cel.runtime.CelRuntimeBuilder;
39+
import dev.cel.runtime.RuntimeEquality;
40+
import java.util.Map;
2941
import java.util.Optional;
3042

3143
/** Internal implementation of CEL comprehensions extensions. */
32-
public final class CelComprehensionsExtensions implements CelCompilerLibrary {
44+
final class CelComprehensionsExtensions
45+
implements CelCompilerLibrary, CelInternalRuntimeLibrary, CelExtensionLibrary.FeatureSet {
3346

34-
private static final String TRANSFORM_LIST = "transformList";
47+
private static final String MAP_INSERT_FUNCTION = "cel.@mapInsert";
48+
private static final String MAP_INSERT_OVERLOAD_MAP_MAP = "cel_@mapInsert_map_map";
49+
private static final String MAP_INSERT_OVERLOAD_KEY_VALUE = "cel_@mapInsert_map_key_value";
50+
private static final TypeParamType TYPE_PARAM_K = TypeParamType.create("K");
51+
private static final TypeParamType TYPE_PARAM_V = TypeParamType.create("V");
52+
private static final MapType MAP_KV_TYPE = MapType.create(TYPE_PARAM_K, TYPE_PARAM_V);
3553

54+
enum Function {
55+
MAP_INSERT(
56+
CelFunctionDecl.newFunctionDeclaration(
57+
MAP_INSERT_FUNCTION,
58+
CelOverloadDecl.newGlobalOverload(
59+
MAP_INSERT_OVERLOAD_MAP_MAP,
60+
"Returns a map that's the result of merging given two maps.",
61+
MAP_KV_TYPE,
62+
MAP_KV_TYPE,
63+
MAP_KV_TYPE),
64+
CelOverloadDecl.newGlobalOverload(
65+
MAP_INSERT_OVERLOAD_KEY_VALUE,
66+
"Adds the given key-value pair to the map.",
67+
MAP_KV_TYPE,
68+
MAP_KV_TYPE,
69+
TYPE_PARAM_K,
70+
TYPE_PARAM_V)));
71+
72+
private final CelFunctionDecl functionDecl;
73+
74+
String getFunction() {
75+
return functionDecl.name();
76+
}
77+
78+
Function(CelFunctionDecl functionDecl) {
79+
this.functionDecl = functionDecl;
80+
}
81+
}
82+
83+
private static final CelExtensionLibrary<CelComprehensionsExtensions> LIBRARY =
84+
new CelExtensionLibrary<CelComprehensionsExtensions>() {
85+
private final CelComprehensionsExtensions version0 = new CelComprehensionsExtensions();
86+
87+
@Override
88+
public String name() {
89+
return "cel.lib.ext.comprev2";
90+
}
91+
92+
@Override
93+
public ImmutableSet<CelComprehensionsExtensions> versions() {
94+
return ImmutableSet.of(version0);
95+
}
96+
};
97+
98+
static CelExtensionLibrary<CelComprehensionsExtensions> library() {
99+
return LIBRARY;
100+
}
101+
102+
private final ImmutableSet<Function> functions;
103+
104+
CelComprehensionsExtensions() {
105+
this.functions = ImmutableSet.copyOf(Function.values());
106+
}
107+
108+
@Override
109+
public void setCheckerOptions(CelCheckerBuilder checkerBuilder) {
110+
functions.forEach(function -> checkerBuilder.addFunctionDeclarations(function.functionDecl));
111+
}
112+
113+
@Override
114+
public void setRuntimeOptions(CelRuntimeBuilder runtimeBuilder) {
115+
throw new UnsupportedOperationException("Unsupported");
116+
}
117+
118+
@Override
119+
public void setRuntimeOptions(
120+
CelRuntimeBuilder runtimeBuilder, RuntimeEquality runtimeEquality, CelOptions celOptions) {
121+
for (Function function : functions) {
122+
for (CelOverloadDecl overload : function.functionDecl.overloads()) {
123+
switch (overload.overloadId()) {
124+
case MAP_INSERT_OVERLOAD_MAP_MAP:
125+
runtimeBuilder.addFunctionBindings(
126+
CelFunctionBinding.from(
127+
MAP_INSERT_OVERLOAD_MAP_MAP,
128+
Map.class,
129+
Map.class,
130+
(map1, map2) -> mapInsertMap(map1, map2, runtimeEquality)));
131+
break;
132+
case MAP_INSERT_OVERLOAD_KEY_VALUE:
133+
runtimeBuilder.addFunctionBindings(
134+
CelFunctionBinding.from(
135+
MAP_INSERT_OVERLOAD_KEY_VALUE,
136+
ImmutableList.of(Map.class, Object.class, Object.class),
137+
args -> mapInsertKeyValue(args, runtimeEquality)));
138+
break;
139+
default:
140+
// Nothing to add.
141+
}
142+
}
143+
}
144+
}
145+
146+
@Override
147+
public int version() {
148+
return 0;
149+
}
150+
151+
@Override
36152
public ImmutableSet<CelMacro> macros() {
37153
return ImmutableSet.of(
38154
CelMacro.newReceiverMacro(
@@ -44,16 +160,62 @@ public ImmutableSet<CelMacro> macros() {
44160
3,
45161
CelComprehensionsExtensions::expandExistsOneMacro),
46162
CelMacro.newReceiverMacro(
47-
TRANSFORM_LIST, 3, CelComprehensionsExtensions::transformListMacro),
163+
"transformList", 3, CelComprehensionsExtensions::transformListMacro),
164+
CelMacro.newReceiverMacro(
165+
"transformList", 4, CelComprehensionsExtensions::transformListMacro),
166+
CelMacro.newReceiverMacro(
167+
"transformMap", 3, CelComprehensionsExtensions::transformMapMacro),
168+
CelMacro.newReceiverMacro(
169+
"transformMap", 4, CelComprehensionsExtensions::transformMapMacro),
170+
CelMacro.newReceiverMacro(
171+
"transformMapEntry", 3, CelComprehensionsExtensions::transformMapEntryMacro),
48172
CelMacro.newReceiverMacro(
49-
TRANSFORM_LIST, 4, CelComprehensionsExtensions::transformListMacro));
173+
"transformMapEntry", 4, CelComprehensionsExtensions::transformMapEntryMacro));
50174
}
51175

52176
@Override
53177
public void setParserOptions(CelParserBuilder parserBuilder) {
54178
parserBuilder.addMacros(macros());
55179
}
56180

181+
// TODO: Implement a more efficient map insertion based on mutability once mutable
182+
// maps are supported in Java stack.
183+
private static ImmutableMap<Object, Object> mapInsertMap(
184+
Map<?, ?> targetMap, Map<?, ?> mapToMerge, RuntimeEquality equality) {
185+
ImmutableMap.Builder<Object, Object> resultBuilder =
186+
ImmutableMap.builderWithExpectedSize(targetMap.size() + mapToMerge.size());
187+
188+
for (Map.Entry<?, ?> entry : mapToMerge.entrySet()) {
189+
if (equality.findInMap(targetMap, entry.getKey()).isPresent()) {
190+
throw new IllegalArgumentException(
191+
String.format("insert failed: key '%s' already exists", entry.getKey()));
192+
} else {
193+
resultBuilder.put(entry.getKey(), entry.getValue());
194+
}
195+
}
196+
resultBuilder.putAll(targetMap);
197+
198+
return resultBuilder.build();
199+
}
200+
201+
private static ImmutableMap<Object, Object> mapInsertKeyValue(
202+
Object[] args, RuntimeEquality equality) {
203+
Map<?, ?> map = (Map<?, ?>) args[0];
204+
Object key = args[1];
205+
Object value = args[2];
206+
207+
if (equality.findInMap(map, key).isPresent()) {
208+
throw new IllegalArgumentException(
209+
String.format("insert failed: key '%s' already exists", key));
210+
}
211+
212+
ImmutableMap.Builder<Object, Object> builder =
213+
ImmutableMap.builderWithExpectedSize(map.size() + 1);
214+
builder.putAll(map);
215+
builder.put(key, value);
216+
return builder.build();
217+
}
218+
57219
private static Optional<CelExpr> expandAllMacro(
58220
CelMacroExprFactory exprFactory, CelExpr target, ImmutableList<CelExpr> arguments) {
59221
checkNotNull(exprFactory);
@@ -220,6 +382,103 @@ private static Optional<CelExpr> transformListMacro(
220382
exprFactory.newIdentifier(exprFactory.getAccumulatorVarName())));
221383
}
222384

385+
private static Optional<CelExpr> transformMapMacro(
386+
CelMacroExprFactory exprFactory, CelExpr target, ImmutableList<CelExpr> arguments) {
387+
checkNotNull(exprFactory);
388+
checkNotNull(target);
389+
checkArgument(arguments.size() == 3 || arguments.size() == 4);
390+
CelExpr arg0 = validatedIterationVariable(exprFactory, arguments.get(0));
391+
if (arg0.exprKind().getKind() != CelExpr.ExprKind.Kind.IDENT) {
392+
return Optional.of(arg0);
393+
}
394+
CelExpr arg1 = validatedIterationVariable(exprFactory, arguments.get(1));
395+
if (arg1.exprKind().getKind() != CelExpr.ExprKind.Kind.IDENT) {
396+
return Optional.of(arg1);
397+
}
398+
CelExpr transform;
399+
CelExpr filter = null;
400+
if (arguments.size() == 4) {
401+
filter = checkNotNull(arguments.get(2));
402+
transform = checkNotNull(arguments.get(3));
403+
} else {
404+
transform = checkNotNull(arguments.get(2));
405+
}
406+
CelExpr accuInit = exprFactory.newMap();
407+
CelExpr condition = exprFactory.newBoolLiteral(true);
408+
CelExpr step =
409+
exprFactory.newGlobalCall(
410+
MAP_INSERT_FUNCTION,
411+
exprFactory.newIdentifier(exprFactory.getAccumulatorVarName()),
412+
arg0,
413+
transform);
414+
if (filter != null) {
415+
step =
416+
exprFactory.newGlobalCall(
417+
Operator.CONDITIONAL.getFunction(),
418+
filter,
419+
step,
420+
exprFactory.newIdentifier(exprFactory.getAccumulatorVarName()));
421+
}
422+
return Optional.of(
423+
exprFactory.fold(
424+
arg0.ident().name(),
425+
arg1.ident().name(),
426+
target,
427+
exprFactory.getAccumulatorVarName(),
428+
accuInit,
429+
condition,
430+
step,
431+
exprFactory.newIdentifier(exprFactory.getAccumulatorVarName())));
432+
}
433+
434+
private static Optional<CelExpr> transformMapEntryMacro(
435+
CelMacroExprFactory exprFactory, CelExpr target, ImmutableList<CelExpr> arguments) {
436+
checkNotNull(exprFactory);
437+
checkNotNull(target);
438+
checkArgument(arguments.size() == 3 || arguments.size() == 4);
439+
CelExpr arg0 = validatedIterationVariable(exprFactory, arguments.get(0));
440+
if (arg0.exprKind().getKind() != CelExpr.ExprKind.Kind.IDENT) {
441+
return Optional.of(arg0);
442+
}
443+
CelExpr arg1 = validatedIterationVariable(exprFactory, arguments.get(1));
444+
if (arg1.exprKind().getKind() != CelExpr.ExprKind.Kind.IDENT) {
445+
return Optional.of(arg1);
446+
}
447+
CelExpr transform;
448+
CelExpr filter = null;
449+
if (arguments.size() == 4) {
450+
filter = checkNotNull(arguments.get(2));
451+
transform = checkNotNull(arguments.get(3));
452+
} else {
453+
transform = checkNotNull(arguments.get(2));
454+
}
455+
CelExpr accuInit = exprFactory.newMap();
456+
CelExpr condition = exprFactory.newBoolLiteral(true);
457+
CelExpr step =
458+
exprFactory.newGlobalCall(
459+
MAP_INSERT_FUNCTION,
460+
exprFactory.newIdentifier(exprFactory.getAccumulatorVarName()),
461+
transform);
462+
if (filter != null) {
463+
step =
464+
exprFactory.newGlobalCall(
465+
Operator.CONDITIONAL.getFunction(),
466+
filter,
467+
step,
468+
exprFactory.newIdentifier(exprFactory.getAccumulatorVarName()));
469+
}
470+
return Optional.of(
471+
exprFactory.fold(
472+
arg0.ident().name(),
473+
arg1.ident().name(),
474+
target,
475+
exprFactory.getAccumulatorVarName(),
476+
accuInit,
477+
condition,
478+
step,
479+
exprFactory.newIdentifier(exprFactory.getAccumulatorVarName())));
480+
}
481+
223482
private static CelExpr validatedIterationVariable(
224483
CelMacroExprFactory exprFactory, CelExpr argument) {
225484

extensions/src/main/java/dev/cel/extensions/CelExtensions.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,8 @@ public static CelExtensionLibrary<? extends CelExtensionLibrary.FeatureSet> getE
363363
return CelSetsExtensions.library(options);
364364
case "strings":
365365
return CelStringExtensions.library();
366+
case "comprehensions":
367+
return CelComprehensionsExtensions.library();
366368
// TODO: add support for remaining standard extensions
367369
default:
368370
throw new IllegalArgumentException("Unknown standard extension '" + name + "'");

0 commit comments

Comments
 (0)