Skip to content

Commit 1c40730

Browse files
jnthntatumcopybara-github
authored andcommitted
Fix type mapping for return types of python-backed extension functions.
- support setting an expected return type - handle mapping dyn as an argument kind PiperOrigin-RevId: 858753278
1 parent 3419a99 commit 1c40730

8 files changed

Lines changed: 140 additions & 17 deletions

File tree

BUILD

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,6 @@ pybind_extension(
4949
":status_macros",
5050
"@com_google_absl//absl/base",
5151
"@com_google_absl//absl/base:no_destructor",
52-
"@com_google_absl//absl/base:nullability",
5352
"@com_google_absl//absl/container:flat_hash_map",
5453
"@com_google_absl//absl/functional:function_ref",
5554
"@com_google_absl//absl/log:absl_check",

custom_ext/custom_ext_test.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,7 @@ def _create_activation(self, impl) -> cel.Activation:
5757
[cel.Type.STRING],
5858
False,
5959
impl,
60+
return_type=cel.Type.STRING,
6061
)
6162
],
6263
)
@@ -93,6 +94,7 @@ def test_error_no_matching_overload(self, ext):
9394
[cel.Type.STRING, cel.Type.INT],
9495
False,
9596
lambda _: "¡Hola Mundo!",
97+
return_type=cel.Type.STRING,
9698
)
9799
],
98100
)
@@ -189,6 +191,42 @@ def _lost_in_translation_return_none(arg1: str) -> str: # pylint: disable=unuse
189191
def _lost_in_translation_raising_error(text: str) -> str: # pylint: disable=unused-argument
190192
raise LookupError("Lost in translation")
191193

194+
TEST_EXPRESSIONS = [
195+
("getOrDefaultReceiver", "{'a': 1, 'b': 2}.getOrDefault('c', 3) == 3"),
196+
("getOrDefault", "getOrDefault({'a': 'z', 'b': 'y'}, 'a', 'x') == 'z'"),
197+
("lerp_int", "lerp(1, 2, 0.5) == 1.5"),
198+
("lerp_uint", "lerp(1u, 2u, 0.5) == 1.5"),
199+
]
200+
201+
202+
class PythonTypeMappingsTest(parameterized.TestCase):
203+
204+
def setUp(self):
205+
super().setUp()
206+
self.descriptor_pool = descriptor_pool.Default()
207+
self.env = cel.NewEnv(
208+
self.descriptor_pool,
209+
variables={},
210+
extensions=[sample_cel_ext.SampleCelExtension()],
211+
)
212+
213+
def _compile_expr(self, expression: str) -> cel.Expression:
214+
return self.env.compile(expression)
215+
216+
@parameterized.named_parameters(TEST_EXPRESSIONS)
217+
def test_expression(self, expr):
218+
compiled_expr = self.env.compile(expr)
219+
act = self.env.Activation()
220+
res = compiled_expr.eval(act)
221+
self.assertEqual(res.value(), True)
222+
223+
def test_lerp_error_out_of_bounds(self):
224+
compiled_expr = self.env.compile("lerp(1, 2, 1.5)")
225+
act = self.env.Activation()
226+
res = compiled_expr.eval(act)
227+
self.assertEqual(res.type(), cel.Type.ERROR)
228+
self.assertIn("t must be between 0.0 and 1.0", res.value())
229+
192230

193231
if __name__ == "__main__":
194232
absltest.main()

custom_ext/sample_cel_ext.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,9 @@
1313
# limitations under the License.
1414

1515
"""A sample CEL extension implemented entirely in Python."""
16+
17+
from typing import Any
18+
1619
import py_cel as cel
1720

1821

@@ -51,6 +54,61 @@ def __init__(self):
5154
)
5255
],
5356
),
57+
# Tests for type adaptation.
58+
cel.FunctionDecl(
59+
"lerp",
60+
[
61+
cel.Overload(
62+
"lerp_int_int_double",
63+
return_type=cel.Type.DOUBLE,
64+
parameters=[
65+
cel.Type.INT,
66+
cel.Type.INT,
67+
cel.Type.DOUBLE,
68+
],
69+
is_member=False,
70+
impl=self.lerp,
71+
),
72+
cel.Overload(
73+
"lerp_uint_uint_double",
74+
return_type=cel.Type.DOUBLE,
75+
parameters=[
76+
cel.Type.UINT,
77+
cel.Type.UINT,
78+
cel.Type.DOUBLE,
79+
],
80+
is_member=False,
81+
impl=self.lerp,
82+
),
83+
],
84+
),
85+
cel.FunctionDecl(
86+
"getOrDefault",
87+
[
88+
cel.Overload(
89+
"map_get_or_default_string_dyn",
90+
return_type=cel.Type.DYN,
91+
parameters=[
92+
cel.Type.Map(cel.Type.STRING, cel.Type.DYN),
93+
cel.Type.STRING,
94+
cel.Type.DYN,
95+
],
96+
is_member=True,
97+
impl=self.map_get_or_default,
98+
),
99+
cel.Overload(
100+
"get_or_default_map_string_dyn",
101+
return_type=cel.Type.DYN,
102+
parameters=[
103+
cel.Type.Map(cel.Type.STRING, cel.Type.DYN),
104+
cel.Type.STRING,
105+
cel.Type.DYN,
106+
],
107+
is_member=False,
108+
impl=self.map_get_or_default,
109+
),
110+
],
111+
),
54112
],
55113
)
56114

@@ -60,3 +118,15 @@ def translate(self, text: str, from_lang: str, to_lang: str) -> str:
60118
if text != "Hello, world!":
61119
raise ValueError("Come on, this is just 'Hello, world!'")
62120
return "¡Hola Mundo!"
121+
122+
def lerp(self, a: int, b: int, t: float) -> float:
123+
"""Linearly interpolate between a and b using t."""
124+
if t < 0.0 or t > 1.0:
125+
raise ValueError("t must be between 0.0 and 1.0")
126+
return a + (b - a) * t
127+
128+
def map_get_or_default(
129+
self, m: dict[str, Any], key: str, default: Any
130+
) -> Any:
131+
"""Get the value for the key from the map, or return the default value."""
132+
return m.get(key, default)

py_cel_activation.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,8 @@ PyCelActivation::PyCelActivation(
7474
kFunctionDescriptorOptions);
7575
activation_.InsertFunction(
7676
func_descriptor, std::make_unique<PyCelFunctionAdapter>(
77-
function->function_name(), function->impl()));
77+
function->function_name(), function->return_type(),
78+
function->impl()));
7879
}
7980
};
8081

py_cel_function.cc

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,15 +56,18 @@ static std::shared_ptr<PyCelEnvInternal> GetEnvFromContext(
5656

5757
void PyCelFunction::DefinePythonBindings(pybind11::module& m) {
5858
py::class_<PyCelFunction, std::shared_ptr<PyCelFunction>>(m, "Function")
59-
.def(py::init<std::string, std::vector<PyCelType>, bool, PyObject*>(),
59+
.def(py::init<std::string, std::vector<PyCelType>, bool, PyObject*,
60+
PyCelType>(),
6061
py::arg("function_name"), py::arg("parameters"),
61-
py::arg("is_member"), py::arg("impl"));
62+
py::arg("is_member"), py::arg("impl"),
63+
py::arg("return_type") = PyCelType::Dyn());
6264
}
6365

6466
PyCelFunction::PyCelFunction(std::string function_name,
6567
std::vector<PyCelType> parameters, bool is_member,
66-
PyObject* impl)
68+
PyObject* impl, PyCelType return_type)
6769
: function_name_(std::move(function_name)),
70+
return_type_(std::move(return_type)),
6871
parameters_(std::move(parameters)),
6972
is_member_(is_member),
7073
impl_(impl) {
@@ -79,8 +82,11 @@ PyCelFunction::~PyCelFunction() {
7982
};
8083

8184
PyCelFunctionAdapter::PyCelFunctionAdapter(std::string function_name,
85+
PyCelType return_type,
8286
PyObject* py_function)
83-
: function_name_(std::move(function_name)), py_function_(py_function) {
87+
: function_name_(std::move(function_name)),
88+
return_type_(std::move(return_type)),
89+
py_function_(py_function) {
8490
Py_XINCREF(py_function_);
8591
}
8692

@@ -111,7 +117,7 @@ absl::StatusOr<cel::Value> PyCelFunctionAdapter::Invoke(
111117
}
112118

113119
return PyObjectToCelValue(
114-
result, PyCelType::String(),
120+
result, return_type_,
115121
[this]() {
116122
return absl::StrFormat("Python function '%s'",
117123
PyUnicode_AsUTF8(PyObject_Repr(py_function_)));

py_cel_function.h

Lines changed: 6 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,19 +17,14 @@
1717

1818
#include <Python.h> // IWYU pragma: keep - Needed for PyObject
1919

20-
#include <memory>
2120
#include <string>
2221
#include <vector>
2322

24-
#include "absl/base/nullability.h"
2523
#include "absl/status/statusor.h"
2624
#include "absl/types/span.h"
2725
#include "common/value.h"
2826
#include "runtime/function.h"
2927
#include "py_cel_type.h"
30-
#include "google/protobuf/arena.h"
31-
#include "google/protobuf/descriptor.h"
32-
#include "google/protobuf/message.h"
3328
#include <pybind11/pybind11.h>
3429

3530
namespace cel_python {
@@ -43,16 +38,18 @@ class PyCelFunction {
4338
static void DefinePythonBindings(pybind11::module& m);
4439

4540
PyCelFunction(std::string function_name, std::vector<PyCelType> parameters,
46-
bool is_member, PyObject* impl);
41+
bool is_member, PyObject* impl, PyCelType return_type);
4742
~PyCelFunction();
4843

4944
std::string function_name() const { return function_name_; }
5045
const std::vector<PyCelType>& parameters() const { return parameters_; }
5146
bool is_member() const { return is_member_; }
5247
PyObject* impl() const { return impl_; }
48+
const PyCelType& return_type() const { return return_type_; }
5349

5450
private:
5551
std::string function_name_;
52+
PyCelType return_type_;
5653
std::vector<PyCelType> parameters_;
5754
bool is_member_;
5855
PyObject* impl_;
@@ -62,7 +59,8 @@ class PyCelFunction {
6259
// function.
6360
class PyCelFunctionAdapter : public cel::Function {
6461
public:
65-
PyCelFunctionAdapter(std::string function_name, PyObject* py_function);
62+
PyCelFunctionAdapter(std::string function_name, PyCelType return_type,
63+
PyObject* py_function);
6664
~PyCelFunctionAdapter() override;
6765

6866
absl::StatusOr<cel::Value> Invoke(
@@ -71,6 +69,7 @@ class PyCelFunctionAdapter : public cel::Function {
7169

7270
private:
7371
std::string function_name_;
72+
PyCelType return_type_;
7473
PyObject* py_function_;
7574
};
7675

py_cel_python_extension.cc

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,6 @@
3030
#include "compiler/compiler.h"
3131
#include "runtime/runtime_builder.h"
3232
#include "runtime/runtime_options.h"
33-
#include "py_cel_env_internal.h"
3433
#include "py_cel_extension.h"
3534
#include "py_cel_function.h"
3635
#include "py_cel_function_decl.h"
@@ -102,15 +101,22 @@ absl::Status PyCelPythonExtension::ConfigureRuntime(
102101
std::vector<cel::Kind> types;
103102
types.reserve(overload.parameters().size());
104103
for (const PyCelType& arg : overload.parameters()) {
105-
types.push_back(arg.GetKind());
104+
if (arg.GetKind() == cel::Kind::kDyn) {
105+
// C++ runtime dispatcher historically uses kAny for wildcard type (
106+
// not distinguishing between dyn and any)
107+
types.push_back(cel::Kind::kAny);
108+
} else {
109+
types.push_back(arg.GetKind());
110+
}
106111
}
107112

108113
cel::FunctionDescriptor descriptor(function.name(), overload.is_member(),
109114
types, kFunctionDescriptorOptions);
110115
if (overload.py_function()) {
111116
PY_CEL_RETURN_IF_ERROR(runtime_builder.function_registry().Register(
112117
descriptor, std::make_unique<PyCelFunctionAdapter>(
113-
function.name(), overload.py_function())));
118+
function.name(), overload.return_type(),
119+
overload.py_function())));
114120
} else {
115121
PY_CEL_RETURN_IF_ERROR(
116122
runtime_builder.function_registry().RegisterLazyFunction(

py_cel_type.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@ class PyCelType {
4141
// Creates a dynamic type.
4242
PyCelType();
4343
PyCelType(const PyCelType& other) = default;
44+
PyCelType& operator=(const PyCelType& other) = default;
45+
PyCelType(PyCelType&& other) = default;
46+
PyCelType& operator=(PyCelType&& other) = default;
47+
4448
// Creates a message type.
4549
explicit PyCelType(const std::string& name);
4650
PyCelType(cel::Kind kind, const std::string& name);

0 commit comments

Comments
 (0)