1414
1515"""Test for CEL extensions."""
1616
17- from typing import Callable
18-
1917from google .protobuf import descriptor_pool
2018
2119from absl .testing import absltest
@@ -30,20 +28,19 @@ class CustomExtTest(parameterized.TestCase):
3028 # Execute the same test for both C++ and Python implementations, which
3129 # are expected to produce identical results.
3230 EXT_IMPLEMENTATIONS = [
33- ("py_ext" , sample_cel_ext .SampleCelExtension ),
34- ("cc_ext" , sample_cel_ext_cc .SampleCelExtension ),
31+ ("py_ext" , sample_cel_ext .SampleCelExtension () ),
32+ ("cc_ext" , sample_cel_ext_cc .SampleCelExtension () ),
3533 ]
3634
37- # TODO(b/462745713): reuse the same extension instance for all tests.
3835 def _compile_expr (
39- self , ext : Callable [[], cel .CelExtension ] , expression : str
36+ self , ext : cel .CelExtension , expression : str
4037 ) -> cel .Expression :
4138 """Creates a CEL expression for the given extension and compiles the expression."""
4239 self .descriptor_pool = descriptor_pool .Default ()
4340 self .env = cel .NewEnv (
4441 self .descriptor_pool ,
4542 variables = {},
46- extensions = [ext () ],
43+ extensions = [ext ],
4744 )
4845 return self .env .compile (expression )
4946
@@ -161,7 +158,7 @@ def test_error_propagation(self, ext):
161158 def test_bad_extension_type (self ):
162159 with self .assertRaises (Exception ) as e :
163160 self ._compile_expr (
164- lambda : "Not a CelExtension" , "'Hello, world!'.translate('en', 'es')"
161+ "Not a CelExtension" , "'Hello, world!'.translate('en', 'es')"
165162 )
166163 assert "Failed to cast str either as a Python CelExtension instance" in str (
167164 e .exception
@@ -170,9 +167,7 @@ def test_bad_extension_type(self):
170167
171168 def test_none_extension (self ):
172169 with self .assertRaises (Exception ) as e :
173- self ._compile_expr (
174- lambda : None , "'Hello, world!'.translate('en', 'es')"
175- )
170+ self ._compile_expr (None , "'Hello, world!'.translate('en', 'es')" )
176171 assert "Provided extension is None" in str (e .exception )
177172
178173
@@ -191,6 +186,7 @@ def _lost_in_translation_return_none(arg1: str) -> str: # pylint: disable=unuse
191186def _lost_in_translation_raising_error (text : str ) -> str : # pylint: disable=unused-argument
192187 raise LookupError ("Lost in translation" )
193188
189+
194190TEST_EXPRESSIONS = [
195191 ("getOrDefaultReceiver" , "{'a': 1, 'b': 2}.getOrDefault('c', 3) == 3" ),
196192 ("getOrDefault" , "getOrDefault({'a': 'z', 'b': 'y'}, 'a', 'x') == 'z'" ),
@@ -225,5 +221,46 @@ def test_lerp_error_out_of_bounds(self):
225221 self .assertIn ("t must be between 0.0 and 1.0" , res .value ())
226222
227223
224+ class OverloadExistingFunctionTest (absltest .TestCase ):
225+
226+ def test_overload_existing_function (self ):
227+ env = cel .NewEnv (
228+ variables = {"var_map" : cel .Type .Map (cel .Type .STRING , cel .Type .DYN )},
229+ extensions = [
230+ cel .CelExtension (
231+ "custom_map_functions" ,
232+ functions = [
233+ cel .FunctionDecl (
234+ "contains" ,
235+ [
236+ cel .Overload (
237+ "contains_key_value" ,
238+ return_type = cel .Type .BOOL ,
239+ parameters = [
240+ cel .Type .Map (cel .Type .STRING , cel .Type .DYN ),
241+ cel .Type .STRING ,
242+ cel .Type .DYN ,
243+ ],
244+ is_member = True ,
245+ impl = contains_key_value ,
246+ )
247+ ],
248+ )
249+ ],
250+ )
251+ ],
252+ )
253+ expr = env .compile ("var_map.contains('foo', 'bar')" )
254+
255+ res = expr .eval (data = {"var_map" : {"foo" : "bar" }})
256+ self .assertTrue (res .value ())
257+ res = expr .eval (data = {"var_map" : {"foo" : "baz" }})
258+ self .assertFalse (res .value ())
259+
260+
261+ def contains_key_value (cel_map , key , value ):
262+ return key in cel_map and cel_map [key ] == value
263+
264+
228265if __name__ == "__main__" :
229266 absltest .main ()
0 commit comments