1818import static com .google .common .base .Preconditions .checkNotNull ;
1919
2020import com .google .common .collect .ImmutableList ;
21+ import com .google .common .collect .ImmutableMap ;
2122import com .google .common .collect .ImmutableSet ;
23+ import dev .cel .checker .CelCheckerBuilder ;
24+ import dev .cel .common .CelFunctionDecl ;
2225import dev .cel .common .CelIssue ;
26+ import dev .cel .common .CelOptions ;
27+ import dev .cel .common .CelOverloadDecl ;
2328import dev .cel .common .ast .CelExpr ;
29+ import dev .cel .common .types .MapType ;
30+ import dev .cel .common .types .TypeParamType ;
2431import dev .cel .compiler .CelCompilerLibrary ;
2532import dev .cel .parser .CelMacro ;
2633import dev .cel .parser .CelMacroExprFactory ;
2734import dev .cel .parser .CelParserBuilder ;
2835import dev .cel .parser .Operator ;
36+ import dev .cel .runtime .CelFunctionBinding ;
37+ import dev .cel .runtime .CelInternalRuntimeLibrary ;
38+ import dev .cel .runtime .CelRuntimeBuilder ;
39+ import dev .cel .runtime .RuntimeEquality ;
40+ import java .util .Map ;
2941import java .util .Optional ;
3042
3143/** Internal implementation of CEL comprehensions extensions. */
32- public final class CelComprehensionsExtensions implements CelCompilerLibrary {
44+ final class CelComprehensionsExtensions
45+ implements CelCompilerLibrary , CelInternalRuntimeLibrary , CelExtensionLibrary .FeatureSet {
3346
34- private static final String TRANSFORM_LIST = "transformList" ;
47+ private static final String MAP_INSERT_FUNCTION = "cel.@mapInsert" ;
48+ private static final String MAP_INSERT_OVERLOAD_MAP_MAP = "cel_@mapInsert_map_map" ;
49+ private static final String MAP_INSERT_OVERLOAD_KEY_VALUE = "cel_@mapInsert_map_key_value" ;
50+ private static final TypeParamType TYPE_PARAM_K = TypeParamType .create ("K" );
51+ private static final TypeParamType TYPE_PARAM_V = TypeParamType .create ("V" );
52+ private static final MapType MAP_KV_TYPE = MapType .create (TYPE_PARAM_K , TYPE_PARAM_V );
3553
54+ enum Function {
55+ MAP_INSERT (
56+ CelFunctionDecl .newFunctionDeclaration (
57+ MAP_INSERT_FUNCTION ,
58+ CelOverloadDecl .newGlobalOverload (
59+ MAP_INSERT_OVERLOAD_MAP_MAP ,
60+ "Returns a map that's the result of merging given two maps." ,
61+ MAP_KV_TYPE ,
62+ MAP_KV_TYPE ,
63+ MAP_KV_TYPE ),
64+ CelOverloadDecl .newGlobalOverload (
65+ MAP_INSERT_OVERLOAD_KEY_VALUE ,
66+ "Adds the given key-value pair to the map." ,
67+ MAP_KV_TYPE ,
68+ MAP_KV_TYPE ,
69+ TYPE_PARAM_K ,
70+ TYPE_PARAM_V )));
71+
72+ private final CelFunctionDecl functionDecl ;
73+
74+ String getFunction () {
75+ return functionDecl .name ();
76+ }
77+
78+ Function (CelFunctionDecl functionDecl ) {
79+ this .functionDecl = functionDecl ;
80+ }
81+ }
82+
83+ private static final CelExtensionLibrary <CelComprehensionsExtensions > LIBRARY =
84+ new CelExtensionLibrary <CelComprehensionsExtensions >() {
85+ private final CelComprehensionsExtensions version0 = new CelComprehensionsExtensions ();
86+
87+ @ Override
88+ public String name () {
89+ return "cel.lib.ext.comprev2" ;
90+ }
91+
92+ @ Override
93+ public ImmutableSet <CelComprehensionsExtensions > versions () {
94+ return ImmutableSet .of (version0 );
95+ }
96+ };
97+
98+ static CelExtensionLibrary <CelComprehensionsExtensions > library () {
99+ return LIBRARY ;
100+ }
101+
102+ private final ImmutableSet <Function > functions ;
103+
104+ CelComprehensionsExtensions () {
105+ this .functions = ImmutableSet .copyOf (Function .values ());
106+ }
107+
108+ @ Override
109+ public void setCheckerOptions (CelCheckerBuilder checkerBuilder ) {
110+ functions .forEach (function -> checkerBuilder .addFunctionDeclarations (function .functionDecl ));
111+ }
112+
113+ @ Override
114+ public void setRuntimeOptions (CelRuntimeBuilder runtimeBuilder ) {
115+ throw new UnsupportedOperationException ("Unsupported" );
116+ }
117+
118+ @ Override
119+ public void setRuntimeOptions (
120+ CelRuntimeBuilder runtimeBuilder , RuntimeEquality runtimeEquality , CelOptions celOptions ) {
121+ for (Function function : functions ) {
122+ for (CelOverloadDecl overload : function .functionDecl .overloads ()) {
123+ switch (overload .overloadId ()) {
124+ case MAP_INSERT_OVERLOAD_MAP_MAP :
125+ runtimeBuilder .addFunctionBindings (
126+ CelFunctionBinding .from (
127+ MAP_INSERT_OVERLOAD_MAP_MAP ,
128+ Map .class ,
129+ Map .class ,
130+ (map1 , map2 ) -> mapInsertMap (map1 , map2 , runtimeEquality )));
131+ break ;
132+ case MAP_INSERT_OVERLOAD_KEY_VALUE :
133+ runtimeBuilder .addFunctionBindings (
134+ CelFunctionBinding .from (
135+ MAP_INSERT_OVERLOAD_KEY_VALUE ,
136+ ImmutableList .of (Map .class , Object .class , Object .class ),
137+ args -> mapInsertKeyValue (args , runtimeEquality )));
138+ break ;
139+ default :
140+ // Nothing to add.
141+ }
142+ }
143+ }
144+ }
145+
146+ @ Override
147+ public int version () {
148+ return 0 ;
149+ }
150+
151+ @ Override
36152 public ImmutableSet <CelMacro > macros () {
37153 return ImmutableSet .of (
38154 CelMacro .newReceiverMacro (
@@ -44,16 +160,62 @@ public ImmutableSet<CelMacro> macros() {
44160 3 ,
45161 CelComprehensionsExtensions ::expandExistsOneMacro ),
46162 CelMacro .newReceiverMacro (
47- TRANSFORM_LIST , 3 , CelComprehensionsExtensions ::transformListMacro ),
163+ "transformList" , 3 , CelComprehensionsExtensions ::transformListMacro ),
164+ CelMacro .newReceiverMacro (
165+ "transformList" , 4 , CelComprehensionsExtensions ::transformListMacro ),
166+ CelMacro .newReceiverMacro (
167+ "transformMap" , 3 , CelComprehensionsExtensions ::transformMapMacro ),
168+ CelMacro .newReceiverMacro (
169+ "transformMap" , 4 , CelComprehensionsExtensions ::transformMapMacro ),
170+ CelMacro .newReceiverMacro (
171+ "transformMapEntry" , 3 , CelComprehensionsExtensions ::transformMapEntryMacro ),
48172 CelMacro .newReceiverMacro (
49- TRANSFORM_LIST , 4 , CelComprehensionsExtensions ::transformListMacro ));
173+ "transformMapEntry" , 4 , CelComprehensionsExtensions ::transformMapEntryMacro ));
50174 }
51175
52176 @ Override
53177 public void setParserOptions (CelParserBuilder parserBuilder ) {
54178 parserBuilder .addMacros (macros ());
55179 }
56180
181+ // TODO: Implement a more efficient map insertion based on mutability once mutable
182+ // maps are supported in Java stack.
183+ private static ImmutableMap <Object , Object > mapInsertMap (
184+ Map <?, ?> targetMap , Map <?, ?> mapToMerge , RuntimeEquality equality ) {
185+ ImmutableMap .Builder <Object , Object > resultBuilder =
186+ ImmutableMap .builderWithExpectedSize (targetMap .size () + mapToMerge .size ());
187+
188+ for (Map .Entry <?, ?> entry : mapToMerge .entrySet ()) {
189+ if (equality .findInMap (targetMap , entry .getKey ()).isPresent ()) {
190+ throw new IllegalArgumentException (
191+ String .format ("insert failed: key '%s' already exists" , entry .getKey ()));
192+ } else {
193+ resultBuilder .put (entry .getKey (), entry .getValue ());
194+ }
195+ }
196+ resultBuilder .putAll (targetMap );
197+
198+ return resultBuilder .build ();
199+ }
200+
201+ private static ImmutableMap <Object , Object > mapInsertKeyValue (
202+ Object [] args , RuntimeEquality equality ) {
203+ Map <?, ?> map = (Map <?, ?>) args [0 ];
204+ Object key = args [1 ];
205+ Object value = args [2 ];
206+
207+ if (equality .findInMap (map , key ).isPresent ()) {
208+ throw new IllegalArgumentException (
209+ String .format ("insert failed: key '%s' already exists" , key ));
210+ }
211+
212+ ImmutableMap .Builder <Object , Object > builder =
213+ ImmutableMap .builderWithExpectedSize (map .size () + 1 );
214+ builder .putAll (map );
215+ builder .put (key , value );
216+ return builder .build ();
217+ }
218+
57219 private static Optional <CelExpr > expandAllMacro (
58220 CelMacroExprFactory exprFactory , CelExpr target , ImmutableList <CelExpr > arguments ) {
59221 checkNotNull (exprFactory );
@@ -220,6 +382,103 @@ private static Optional<CelExpr> transformListMacro(
220382 exprFactory .newIdentifier (exprFactory .getAccumulatorVarName ())));
221383 }
222384
385+ private static Optional <CelExpr > transformMapMacro (
386+ CelMacroExprFactory exprFactory , CelExpr target , ImmutableList <CelExpr > arguments ) {
387+ checkNotNull (exprFactory );
388+ checkNotNull (target );
389+ checkArgument (arguments .size () == 3 || arguments .size () == 4 );
390+ CelExpr arg0 = validatedIterationVariable (exprFactory , arguments .get (0 ));
391+ if (arg0 .exprKind ().getKind () != CelExpr .ExprKind .Kind .IDENT ) {
392+ return Optional .of (arg0 );
393+ }
394+ CelExpr arg1 = validatedIterationVariable (exprFactory , arguments .get (1 ));
395+ if (arg1 .exprKind ().getKind () != CelExpr .ExprKind .Kind .IDENT ) {
396+ return Optional .of (arg1 );
397+ }
398+ CelExpr transform ;
399+ CelExpr filter = null ;
400+ if (arguments .size () == 4 ) {
401+ filter = checkNotNull (arguments .get (2 ));
402+ transform = checkNotNull (arguments .get (3 ));
403+ } else {
404+ transform = checkNotNull (arguments .get (2 ));
405+ }
406+ CelExpr accuInit = exprFactory .newMap ();
407+ CelExpr condition = exprFactory .newBoolLiteral (true );
408+ CelExpr step =
409+ exprFactory .newGlobalCall (
410+ MAP_INSERT_FUNCTION ,
411+ exprFactory .newIdentifier (exprFactory .getAccumulatorVarName ()),
412+ arg0 ,
413+ transform );
414+ if (filter != null ) {
415+ step =
416+ exprFactory .newGlobalCall (
417+ Operator .CONDITIONAL .getFunction (),
418+ filter ,
419+ step ,
420+ exprFactory .newIdentifier (exprFactory .getAccumulatorVarName ()));
421+ }
422+ return Optional .of (
423+ exprFactory .fold (
424+ arg0 .ident ().name (),
425+ arg1 .ident ().name (),
426+ target ,
427+ exprFactory .getAccumulatorVarName (),
428+ accuInit ,
429+ condition ,
430+ step ,
431+ exprFactory .newIdentifier (exprFactory .getAccumulatorVarName ())));
432+ }
433+
434+ private static Optional <CelExpr > transformMapEntryMacro (
435+ CelMacroExprFactory exprFactory , CelExpr target , ImmutableList <CelExpr > arguments ) {
436+ checkNotNull (exprFactory );
437+ checkNotNull (target );
438+ checkArgument (arguments .size () == 3 || arguments .size () == 4 );
439+ CelExpr arg0 = validatedIterationVariable (exprFactory , arguments .get (0 ));
440+ if (arg0 .exprKind ().getKind () != CelExpr .ExprKind .Kind .IDENT ) {
441+ return Optional .of (arg0 );
442+ }
443+ CelExpr arg1 = validatedIterationVariable (exprFactory , arguments .get (1 ));
444+ if (arg1 .exprKind ().getKind () != CelExpr .ExprKind .Kind .IDENT ) {
445+ return Optional .of (arg1 );
446+ }
447+ CelExpr transform ;
448+ CelExpr filter = null ;
449+ if (arguments .size () == 4 ) {
450+ filter = checkNotNull (arguments .get (2 ));
451+ transform = checkNotNull (arguments .get (3 ));
452+ } else {
453+ transform = checkNotNull (arguments .get (2 ));
454+ }
455+ CelExpr accuInit = exprFactory .newMap ();
456+ CelExpr condition = exprFactory .newBoolLiteral (true );
457+ CelExpr step =
458+ exprFactory .newGlobalCall (
459+ MAP_INSERT_FUNCTION ,
460+ exprFactory .newIdentifier (exprFactory .getAccumulatorVarName ()),
461+ transform );
462+ if (filter != null ) {
463+ step =
464+ exprFactory .newGlobalCall (
465+ Operator .CONDITIONAL .getFunction (),
466+ filter ,
467+ step ,
468+ exprFactory .newIdentifier (exprFactory .getAccumulatorVarName ()));
469+ }
470+ return Optional .of (
471+ exprFactory .fold (
472+ arg0 .ident ().name (),
473+ arg1 .ident ().name (),
474+ target ,
475+ exprFactory .getAccumulatorVarName (),
476+ accuInit ,
477+ condition ,
478+ step ,
479+ exprFactory .newIdentifier (exprFactory .getAccumulatorVarName ())));
480+ }
481+
223482 private static CelExpr validatedIterationVariable (
224483 CelMacroExprFactory exprFactory , CelExpr argument ) {
225484
0 commit comments