Skip to content

Commit 37de13d

Browse files
maskri17copybara-github
authored andcommitted
Adding transformMap and transformMapEntry macros
PiperOrigin-RevId: 799752266
1 parent e73283b commit 37de13d

7 files changed

Lines changed: 372 additions & 19 deletions

File tree

checker/src/test/resources/standardEnvDump.baseline

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,10 @@ Source: 'redundant expression so the env is constructed and can be printed'
44

55

66
Standard environment:
7+
declare cel.@mapInsert {
8+
function cel_@mapInsert_map_map (map(K, V), map(K, V)) -> map(K, V)
9+
function cel_@mapInsert_map_key_value (map(K, V), K, V) -> map(K, V)
10+
}
711
declare !_ {
812
function logical_not (bool) -> bool
913
}

extensions/src/main/java/dev/cel/extensions/BUILD.bazel

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -299,12 +299,17 @@ java_library(
299299
name = "comprehensions",
300300
srcs = ["CelComprehensionsExtensions.java"],
301301
deps = [
302+
"//checker:checker_builder",
302303
"//common:compiler_common",
303304
"//common/ast",
305+
"//common/types",
304306
"//compiler:compiler_builder",
307+
"//extensions:extension_library",
305308
"//parser:macro",
306309
"//parser:operator",
307310
"//parser:parser_builder",
311+
"//runtime",
312+
"//runtime:function_binding",
308313
"@maven//:com_google_guava_guava",
309314
],
310315
)

extensions/src/main/java/dev/cel/extensions/CelComprehensionsExtensions.java

Lines changed: 239 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,20 +19,118 @@
1919

2020
import com.google.common.collect.ImmutableList;
2121
import com.google.common.collect.ImmutableSet;
22+
import com.google.common.collect.Maps;
23+
import dev.cel.checker.CelCheckerBuilder;
24+
import dev.cel.common.CelFunctionDecl;
2225
import dev.cel.common.CelIssue;
26+
import dev.cel.common.CelOverloadDecl;
2327
import dev.cel.common.ast.CelExpr;
28+
import dev.cel.common.types.MapType;
29+
import dev.cel.common.types.TypeParamType;
2430
import dev.cel.compiler.CelCompilerLibrary;
2531
import dev.cel.parser.CelMacro;
2632
import dev.cel.parser.CelMacroExprFactory;
2733
import dev.cel.parser.CelParserBuilder;
2834
import dev.cel.parser.Operator;
35+
import dev.cel.runtime.CelFunctionBinding;
36+
import dev.cel.runtime.CelRuntimeBuilder;
37+
import dev.cel.runtime.CelRuntimeLibrary;
38+
import java.util.Map;
2939
import java.util.Optional;
3040

3141
/** Internal implementation of CEL comprehensions extensions. */
32-
public final class CelComprehensionsExtensions implements CelCompilerLibrary {
42+
public final class CelComprehensionsExtensions
43+
implements CelCompilerLibrary, CelRuntimeLibrary, CelExtensionLibrary.FeatureSet {
3344

34-
private static final String TRANSFORM_LIST = "transformList";
45+
private static final String MAP_INSERT_FUNCTION = "cel.@mapInsert";
46+
private static final String MAP_INSERT_OVERLOAD_MAP_MAP = "cel_@mapInsert_map_map";
47+
private static final String MAP_INSERT_OVERLOAD_KEY_VALUE = "cel_@mapInsert_map_key_value";
48+
private static final TypeParamType TYPE_PARAM_K = TypeParamType.create("K");
49+
private static final TypeParamType TYPE_PARAM_V = TypeParamType.create("V");
50+
private static final MapType MAP_KV_TYPE = MapType.create(TYPE_PARAM_K, TYPE_PARAM_V);
3551

52+
enum Function {
53+
MAP_INSERT(
54+
CelFunctionDecl.newFunctionDeclaration(
55+
MAP_INSERT_FUNCTION,
56+
CelOverloadDecl.newGlobalOverload(
57+
MAP_INSERT_OVERLOAD_MAP_MAP,
58+
"Returns a map that's the result of merging given two maps.",
59+
MAP_KV_TYPE,
60+
MAP_KV_TYPE,
61+
MAP_KV_TYPE),
62+
CelOverloadDecl.newGlobalOverload(
63+
MAP_INSERT_OVERLOAD_KEY_VALUE,
64+
"Adds the given key-value pair to the map.",
65+
MAP_KV_TYPE,
66+
MAP_KV_TYPE,
67+
TYPE_PARAM_K,
68+
TYPE_PARAM_V)),
69+
ImmutableSet.of(
70+
CelFunctionBinding.from(
71+
MAP_INSERT_OVERLOAD_MAP_MAP,
72+
Map.class,
73+
Map.class,
74+
CelComprehensionsExtensions::mapInsertMap),
75+
CelFunctionBinding.from(
76+
MAP_INSERT_OVERLOAD_KEY_VALUE,
77+
ImmutableList.of(Map.class, Object.class, Object.class),
78+
CelComprehensionsExtensions::mapInsertKeyValue)));
79+
80+
private final CelFunctionDecl functionDecl;
81+
private final ImmutableSet<CelFunctionBinding> functionBindings;
82+
83+
String getFunction() {
84+
return functionDecl.name();
85+
}
86+
87+
Function(CelFunctionDecl functionDecl, ImmutableSet<CelFunctionBinding> functionBindings) {
88+
this.functionDecl = functionDecl;
89+
this.functionBindings = functionBindings;
90+
}
91+
}
92+
93+
private static final CelExtensionLibrary<CelComprehensionsExtensions> LIBRARY =
94+
new CelExtensionLibrary<CelComprehensionsExtensions>() {
95+
private final CelComprehensionsExtensions version0 = new CelComprehensionsExtensions();
96+
97+
@Override
98+
public String name() {
99+
return "cel.lib.ext.comprev2";
100+
}
101+
102+
@Override
103+
public ImmutableSet<CelComprehensionsExtensions> versions() {
104+
return ImmutableSet.of(version0);
105+
}
106+
};
107+
108+
static CelExtensionLibrary<CelComprehensionsExtensions> library() {
109+
return LIBRARY;
110+
}
111+
112+
private final ImmutableSet<Function> functions;
113+
114+
CelComprehensionsExtensions() {
115+
this.functions = ImmutableSet.copyOf(Function.values());
116+
}
117+
118+
@Override
119+
public void setCheckerOptions(CelCheckerBuilder checkerBuilder) {
120+
functions.forEach(function -> checkerBuilder.addFunctionDeclarations(function.functionDecl));
121+
}
122+
123+
@Override
124+
public void setRuntimeOptions(CelRuntimeBuilder runtimeBuilder) {
125+
functions.forEach(function -> runtimeBuilder.addFunctionBindings(function.functionBindings));
126+
}
127+
128+
@Override
129+
public int version() {
130+
return 0;
131+
}
132+
133+
@Override
36134
public ImmutableSet<CelMacro> macros() {
37135
return ImmutableSet.of(
38136
CelMacro.newReceiverMacro(
@@ -44,16 +142,56 @@ public ImmutableSet<CelMacro> macros() {
44142
3,
45143
CelComprehensionsExtensions::expandExistsOneMacro),
46144
CelMacro.newReceiverMacro(
47-
TRANSFORM_LIST, 3, CelComprehensionsExtensions::transformListMacro),
145+
"transformList", 3, CelComprehensionsExtensions::transformListMacro),
146+
CelMacro.newReceiverMacro(
147+
"transformList", 4, CelComprehensionsExtensions::transformListMacro),
148+
CelMacro.newReceiverMacro(
149+
"transformMap", 3, CelComprehensionsExtensions::transformMapMacro),
150+
CelMacro.newReceiverMacro(
151+
"transformMap", 4, CelComprehensionsExtensions::transformMapMacro),
48152
CelMacro.newReceiverMacro(
49-
TRANSFORM_LIST, 4, CelComprehensionsExtensions::transformListMacro));
153+
"transformMapEntry", 3, CelComprehensionsExtensions::transformMapEntryMacro),
154+
CelMacro.newReceiverMacro(
155+
"transformMapEntry", 4, CelComprehensionsExtensions::transformMapEntryMacro));
50156
}
51157

52158
@Override
53159
public void setParserOptions(CelParserBuilder parserBuilder) {
54160
parserBuilder.addMacros(macros());
55161
}
56162

163+
// TODO: Implement a more efficient map insertion based on mutability once mutable
164+
// maps are supported in Java stack.
165+
private static Map<Object, Object> mapInsertMap(Map<?, ?> targetMap, Map<?, ?> mapToMerge) {
166+
Map<Object, Object> result = Maps.newHashMapWithExpectedSize(targetMap.size());
167+
result.putAll(targetMap);
168+
if (mapToMerge.isEmpty()) {
169+
return result;
170+
}
171+
172+
for (Map.Entry<?, ?> entry : mapToMerge.entrySet()) {
173+
result = mapInsertKeyValue(new Object[] {result, entry.getKey(), entry.getValue()});
174+
}
175+
176+
return result;
177+
}
178+
179+
private static Map<Object, Object> mapInsertKeyValue(Object[] args) {
180+
Map<?, ?> map = (Map<?, ?>) args[0];
181+
Object key = args[1];
182+
Object value = args[2];
183+
184+
if (map.containsKey(key)) {
185+
throw new IllegalArgumentException(
186+
String.format("insert failed: key '%s' already exists", key));
187+
}
188+
189+
Map<Object, Object> result = Maps.newHashMapWithExpectedSize(map.size() + 1);
190+
result.putAll(map);
191+
result.put(key, value);
192+
return result;
193+
}
194+
57195
private static Optional<CelExpr> expandAllMacro(
58196
CelMacroExprFactory exprFactory, CelExpr target, ImmutableList<CelExpr> arguments) {
59197
checkNotNull(exprFactory);
@@ -220,6 +358,103 @@ private static Optional<CelExpr> transformListMacro(
220358
exprFactory.newIdentifier(exprFactory.getAccumulatorVarName())));
221359
}
222360

361+
private static Optional<CelExpr> transformMapMacro(
362+
CelMacroExprFactory exprFactory, CelExpr target, ImmutableList<CelExpr> arguments) {
363+
checkNotNull(exprFactory);
364+
checkNotNull(target);
365+
checkArgument(arguments.size() == 3 || arguments.size() == 4);
366+
CelExpr arg0 = validatedIterationVariable(exprFactory, arguments.get(0));
367+
if (arg0.exprKind().getKind() != CelExpr.ExprKind.Kind.IDENT) {
368+
return Optional.of(arg0);
369+
}
370+
CelExpr arg1 = validatedIterationVariable(exprFactory, arguments.get(1));
371+
if (arg1.exprKind().getKind() != CelExpr.ExprKind.Kind.IDENT) {
372+
return Optional.of(arg1);
373+
}
374+
CelExpr transform;
375+
CelExpr filter = null;
376+
if (arguments.size() == 4) {
377+
filter = checkNotNull(arguments.get(2));
378+
transform = checkNotNull(arguments.get(3));
379+
} else {
380+
transform = checkNotNull(arguments.get(2));
381+
}
382+
CelExpr accuInit = exprFactory.newMap();
383+
CelExpr condition = exprFactory.newBoolLiteral(true);
384+
CelExpr step =
385+
exprFactory.newGlobalCall(
386+
MAP_INSERT_FUNCTION,
387+
exprFactory.newIdentifier(exprFactory.getAccumulatorVarName()),
388+
arg0,
389+
transform);
390+
if (filter != null) {
391+
step =
392+
exprFactory.newGlobalCall(
393+
Operator.CONDITIONAL.getFunction(),
394+
filter,
395+
step,
396+
exprFactory.newIdentifier(exprFactory.getAccumulatorVarName()));
397+
}
398+
return Optional.of(
399+
exprFactory.fold(
400+
arg0.ident().name(),
401+
arg1.ident().name(),
402+
target,
403+
exprFactory.getAccumulatorVarName(),
404+
accuInit,
405+
condition,
406+
step,
407+
exprFactory.newIdentifier(exprFactory.getAccumulatorVarName())));
408+
}
409+
410+
private static Optional<CelExpr> transformMapEntryMacro(
411+
CelMacroExprFactory exprFactory, CelExpr target, ImmutableList<CelExpr> arguments) {
412+
checkNotNull(exprFactory);
413+
checkNotNull(target);
414+
checkArgument(arguments.size() == 3 || arguments.size() == 4);
415+
CelExpr arg0 = validatedIterationVariable(exprFactory, arguments.get(0));
416+
if (arg0.exprKind().getKind() != CelExpr.ExprKind.Kind.IDENT) {
417+
return Optional.of(arg0);
418+
}
419+
CelExpr arg1 = validatedIterationVariable(exprFactory, arguments.get(1));
420+
if (arg1.exprKind().getKind() != CelExpr.ExprKind.Kind.IDENT) {
421+
return Optional.of(arg1);
422+
}
423+
CelExpr transform;
424+
CelExpr filter = null;
425+
if (arguments.size() == 4) {
426+
filter = checkNotNull(arguments.get(2));
427+
transform = checkNotNull(arguments.get(3));
428+
} else {
429+
transform = checkNotNull(arguments.get(2));
430+
}
431+
CelExpr accuInit = exprFactory.newMap();
432+
CelExpr condition = exprFactory.newBoolLiteral(true);
433+
CelExpr step =
434+
exprFactory.newGlobalCall(
435+
MAP_INSERT_FUNCTION,
436+
exprFactory.newIdentifier(exprFactory.getAccumulatorVarName()),
437+
transform);
438+
if (filter != null) {
439+
step =
440+
exprFactory.newGlobalCall(
441+
Operator.CONDITIONAL.getFunction(),
442+
filter,
443+
step,
444+
exprFactory.newIdentifier(exprFactory.getAccumulatorVarName()));
445+
}
446+
return Optional.of(
447+
exprFactory.fold(
448+
arg0.ident().name(),
449+
arg1.ident().name(),
450+
target,
451+
exprFactory.getAccumulatorVarName(),
452+
accuInit,
453+
condition,
454+
step,
455+
exprFactory.newIdentifier(exprFactory.getAccumulatorVarName())));
456+
}
457+
223458
private static CelExpr validatedIterationVariable(
224459
CelMacroExprFactory exprFactory, CelExpr argument) {
225460

extensions/src/main/java/dev/cel/extensions/CelExtensions.java

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -363,6 +363,8 @@ public static CelExtensionLibrary<? extends CelExtensionLibrary.FeatureSet> getE
363363
return CelSetsExtensions.library(options);
364364
case "strings":
365365
return CelStringExtensions.library();
366+
case "comprehensions":
367+
return CelComprehensionsExtensions.library();
366368
// TODO: add support for remaining standard extensions
367369
default:
368370
throw new IllegalArgumentException("Unknown standard extension '" + name + "'");

0 commit comments

Comments
 (0)