Skip to content

Commit 902c044

Browse files
dmitriplotnikovcopybara-github
authored andcommitted
Allow adding overrides for existing functions
PiperOrigin-RevId: 863746992
1 parent 1e4a924 commit 902c044

2 files changed

Lines changed: 49 additions & 12 deletions

File tree

custom_ext/custom_ext_test.py

Lines changed: 48 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,6 @@
1414

1515
"""Test for CEL extensions."""
1616

17-
from typing import Callable
18-
1917
from google.protobuf import descriptor_pool
2018

2119
from absl.testing import absltest
@@ -30,20 +28,19 @@ class CustomExtTest(parameterized.TestCase):
3028
# Execute the same test for both C++ and Python implementations, which
3129
# are expected to produce identical results.
3230
EXT_IMPLEMENTATIONS = [
33-
("py_ext", sample_cel_ext.SampleCelExtension),
34-
("cc_ext", sample_cel_ext_cc.SampleCelExtension),
31+
("py_ext", sample_cel_ext.SampleCelExtension()),
32+
("cc_ext", sample_cel_ext_cc.SampleCelExtension()),
3533
]
3634

37-
# TODO(b/462745713): reuse the same extension instance for all tests.
3835
def _compile_expr(
39-
self, ext: Callable[[], cel.CelExtension], expression: str
36+
self, ext: cel.CelExtension, expression: str
4037
) -> cel.Expression:
4138
"""Creates a CEL expression for the given extension and compiles the expression."""
4239
self.descriptor_pool = descriptor_pool.Default()
4340
self.env = cel.NewEnv(
4441
self.descriptor_pool,
4542
variables={},
46-
extensions=[ext()],
43+
extensions=[ext],
4744
)
4845
return self.env.compile(expression)
4946

@@ -161,7 +158,7 @@ def test_error_propagation(self, ext):
161158
def test_bad_extension_type(self):
162159
with self.assertRaises(Exception) as e:
163160
self._compile_expr(
164-
lambda: "Not a CelExtension", "'Hello, world!'.translate('en', 'es')"
161+
"Not a CelExtension", "'Hello, world!'.translate('en', 'es')"
165162
)
166163
assert "Failed to cast str either as a Python CelExtension instance" in str(
167164
e.exception
@@ -170,9 +167,7 @@ def test_bad_extension_type(self):
170167

171168
def test_none_extension(self):
172169
with self.assertRaises(Exception) as e:
173-
self._compile_expr(
174-
lambda: None, "'Hello, world!'.translate('en', 'es')"
175-
)
170+
self._compile_expr(None, "'Hello, world!'.translate('en', 'es')")
176171
assert "Provided extension is None" in str(e.exception)
177172

178173

@@ -191,6 +186,7 @@ def _lost_in_translation_return_none(arg1: str) -> str: # pylint: disable=unuse
191186
def _lost_in_translation_raising_error(text: str) -> str: # pylint: disable=unused-argument
192187
raise LookupError("Lost in translation")
193188

189+
194190
TEST_EXPRESSIONS = [
195191
("getOrDefaultReceiver", "{'a': 1, 'b': 2}.getOrDefault('c', 3) == 3"),
196192
("getOrDefault", "getOrDefault({'a': 'z', 'b': 'y'}, 'a', 'x') == 'z'"),
@@ -225,5 +221,46 @@ def test_lerp_error_out_of_bounds(self):
225221
self.assertIn("t must be between 0.0 and 1.0", res.value())
226222

227223

224+
class OverloadExistingFunctionTest(absltest.TestCase):
225+
226+
def test_overload_existing_function(self):
227+
env = cel.NewEnv(
228+
variables={"var_map": cel.Type.Map(cel.Type.STRING, cel.Type.DYN)},
229+
extensions=[
230+
cel.CelExtension(
231+
"custom_map_functions",
232+
functions=[
233+
cel.FunctionDecl(
234+
"contains",
235+
[
236+
cel.Overload(
237+
"contains_key_value",
238+
return_type=cel.Type.BOOL,
239+
parameters=[
240+
cel.Type.Map(cel.Type.STRING, cel.Type.DYN),
241+
cel.Type.STRING,
242+
cel.Type.DYN,
243+
],
244+
is_member=True,
245+
impl=contains_key_value,
246+
)
247+
],
248+
)
249+
],
250+
)
251+
],
252+
)
253+
expr = env.compile("var_map.contains('foo', 'bar')")
254+
255+
res = expr.eval(data={"var_map": {"foo": "bar"}})
256+
self.assertTrue(res.value())
257+
res = expr.eval(data={"var_map": {"foo": "baz"}})
258+
self.assertFalse(res.value())
259+
260+
261+
def contains_key_value(cel_map, key, value):
262+
return key in cel_map and cel_map[key] == value
263+
264+
228265
if __name__ == "__main__":
229266
absltest.main()

py_cel_python_extension.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ absl::Status PyCelPythonExtension::ConfigureCompiler(
8888
function_decl.AddOverload(std::move(overload_decl)));
8989
}
9090
PY_CEL_RETURN_IF_ERROR(
91-
compiler_builder.GetCheckerBuilder().AddFunction(function_decl));
91+
compiler_builder.GetCheckerBuilder().MergeFunction(function_decl));
9292
}
9393

9494
return absl::OkStatus();

0 commit comments

Comments
 (0)