Skip to content

Commit 707d886

Browse files
dmitriplotnikovcopybara-github
authored andcommitted
Make extensions reusable across different environments
PiperOrigin-RevId: 852967383
1 parent a24efe8 commit 707d886

8 files changed

Lines changed: 37 additions & 41 deletions

BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ pybind_extension(
8181
"@com_google_cel_cpp//parser:parser_interface",
8282
"@com_google_cel_cpp//runtime",
8383
"@com_google_cel_cpp//runtime:activation",
84+
"@com_google_cel_cpp//runtime:embedder_context",
8485
"@com_google_cel_cpp//runtime:function",
8586
"@com_google_cel_cpp//runtime:reference_resolver",
8687
"@com_google_cel_cpp//runtime:runtime_builder",

py_cel_activation.cc

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,9 @@ namespace cel_python {
3737

3838
namespace py = ::pybind11;
3939

40+
static const cel::FunctionDescriptorOptions kFunctionDescriptorOptions = {
41+
.is_strict = true, .is_contextual = true};
42+
4043
void PyCelActivation::DefinePythonBindings(py::module& m) {
4144
py::class_<PyCelActivation, std::shared_ptr<PyCelActivation>>(m,
4245
"Activation");
@@ -67,10 +70,10 @@ PyCelActivation::PyCelActivation(
6770
}
6871
cel::FunctionDescriptor func_descriptor(function->function_name(),
6972
function->is_member(), parameters,
70-
/*is_strict=*/true);
73+
kFunctionDescriptorOptions);
7174
activation_.InsertFunction(
7275
func_descriptor, std::make_unique<PyCelFunctionAdapter>(
73-
env, function->function_name(), function->impl()));
76+
function->function_name(), function->impl()));
7477
}
7578
};
7679

py_cel_env.cc

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -172,10 +172,7 @@ absl::StatusOr<PyCelExtension*> PyCelExtensionHandle::GetExtension(
172172
absl::Status status_py_cel_extension;
173173
try {
174174
pybind11::handle handle = pybind11::handle(py_extension_);
175-
PyCelPythonExtension* py_cel_extension =
176-
handle.cast<PyCelPythonExtension*>();
177-
status_py_cel_extension = py_cel_extension->SetEnv(env);
178-
return py_cel_extension;
175+
return handle.cast<PyCelPythonExtension*>();
179176
} catch (const pybind11::cast_error& e) {
180177
status_py_cel_extension = absl::InvalidArgumentError(e.what());
181178
}

py_cel_expression.cc

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@
3838
#include "compiler/compiler.h"
3939
#include "extensions/protobuf/runtime_adapter.h"
4040
#include "parser/parser_interface.h"
41+
#include "runtime/embedder_context.h"
4142
#include "runtime/runtime.h"
4243
#include "py_cel_activation.h"
4344
#include "py_cel_arena.h"
@@ -135,10 +136,14 @@ absl::StatusOr<PyCelValue> PyCelExpression::Eval(
135136
}
136137
std::shared_ptr<PyCelArena> arena = activation.GetArena();
137138
std::shared_ptr<PyCelEnv> env = activation.GetEnv();
139+
cel::EmbedderContext embedder_context = cel::EmbedderContext::From(&env);
140+
cel::EvaluateOptions options;
141+
options.message_factory = env->GetMessageFactory();
142+
options.embedder_context = &embedder_context;
138143
CEL_PYTHON_ASSIGN_OR_RETURN(
139144
cel::Value result,
140-
cel_program_->Evaluate(arena->GetArena(), env->GetMessageFactory(),
141-
*activation.GetActivation()));
145+
cel_program_->Evaluate(arena->GetArena(), *activation.GetActivation(),
146+
std::move(options)));
142147
return PyCelValue(result, arena, std::move(env));
143148
}
144149

py_cel_function.cc

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "absl/strings/str_format.h"
2828
#include "absl/types/span.h"
2929
#include "common/value.h"
30+
#include "runtime/embedder_context.h"
3031
#include "runtime/function.h"
3132
#include "py_cel_env.h"
3233
#include "py_cel_type.h"
@@ -41,6 +42,17 @@ namespace cel_python {
4142

4243
namespace py = ::pybind11;
4344

45+
namespace {
46+
47+
static std::shared_ptr<PyCelEnv> GetEnvFromContext(
48+
const cel::Function::InvokeContext& context) {
49+
ABSL_CHECK(context.embedder_context()); // Crash OK: all call sites are local
50+
// to the library.
51+
return *context.embedder_context()->Get<std::shared_ptr<PyCelEnv>*>();
52+
}
53+
54+
} // namespace
55+
4456
void PyCelFunction::DefinePythonBindings(pybind11::module& m) {
4557
py::class_<PyCelFunction, std::shared_ptr<PyCelFunction>>(m, "Function")
4658
.def(py::init<std::string, std::vector<PyCelType>, bool, PyObject*>(),
@@ -65,12 +77,9 @@ PyCelFunction::~PyCelFunction() {
6577
PyGILState_Release(gil_state);
6678
};
6779

68-
PyCelFunctionAdapter::PyCelFunctionAdapter(const std::shared_ptr<PyCelEnv>& env,
69-
std::string function_name,
80+
PyCelFunctionAdapter::PyCelFunctionAdapter(std::string function_name,
7081
PyObject* py_function)
71-
: env_(env),
72-
function_name_(std::move(function_name)),
73-
py_function_(py_function) {
82+
: function_name_(std::move(function_name)), py_function_(py_function) {
7483
Py_XINCREF(py_function_);
7584
}
7685

@@ -85,12 +94,13 @@ absl::StatusOr<cel::Value> PyCelFunctionAdapter::Invoke(
8594
const cel::Function::InvokeContext& context) const {
8695
ABSL_CHECK(PyGILState_Check());
8796

97+
std::shared_ptr<PyCelEnv> env = GetEnvFromContext(context);
8898
PY_CEL_ASSIGN_OR_RETURN(auto py_arena,
8999
PyCelArena::FromProtoArena(context.arena()));
90100
PyObject* py_args = PyTuple_New(args.size());
91101
for (int i = 0; i < args.size(); ++i) {
92102
PyTuple_SetItem(py_args, i,
93-
CelValueToPyObject(args[i], env_, py_arena,
103+
CelValueToPyObject(args[i], env, py_arena,
94104
/*plain_value=*/true));
95105
}
96106
PyObject* result = PyObject_CallObject(py_function_, py_args);
@@ -105,7 +115,7 @@ absl::StatusOr<cel::Value> PyCelFunctionAdapter::Invoke(
105115
return absl::StrFormat("Python function '%s'",
106116
PyUnicode_AsUTF8(PyObject_Repr(py_function_)));
107117
},
108-
env_, context.arena());
118+
env, context.arena());
109119
};
110120

111121
} // namespace cel_python

py_cel_function.h

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -62,16 +62,14 @@ class PyCelFunction {
6262
// function.
6363
class PyCelFunctionAdapter : public cel::Function {
6464
public:
65-
PyCelFunctionAdapter(const std::shared_ptr<PyCelEnv>& env,
66-
std::string function_name, PyObject* py_function);
65+
PyCelFunctionAdapter(std::string function_name, PyObject* py_function);
6766
~PyCelFunctionAdapter() override;
6867

6968
absl::StatusOr<cel::Value> Invoke(
7069
absl::Span<const cel::Value> args,
7170
const cel::Function::InvokeContext& context) const final;
7271

7372
private:
74-
std::shared_ptr<PyCelEnv> env_;
7573
std::string function_name_;
7674
PyObject* py_function_;
7775
};

py_cel_python_extension.cc

Lines changed: 5 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,9 @@ namespace cel_python {
4646

4747
namespace py = ::pybind11;
4848

49+
static const cel::FunctionDescriptorOptions kFunctionDescriptorOptions = {
50+
.is_strict = true, .is_contextual = true};
51+
4952
void PyCelPythonExtension::DefinePythonBindings(py::module_& m) {
5053
py::class_<PyCelExtension>(m, "CelExtensionBase")
5154
.def(py::init<std::string>(), py::arg("name"));
@@ -59,18 +62,6 @@ PyCelPythonExtension::PyCelPythonExtension(
5962
std::string name, std::vector<PyCelFunctionDecl> functions)
6063
: PyCelExtension(std::move(name)), functions_(std::move(functions)) {}
6164

62-
// TODO(b/462745713): pass the env to the Invoke method instead of storing it
63-
// as a member variable and remove this method.
64-
absl::Status PyCelPythonExtension::SetEnv(
65-
const std::shared_ptr<PyCelEnv>& env) {
66-
if (env_ != nullptr && env_ != env) {
67-
return absl::FailedPreconditionError(
68-
"PyCelExtension already has an environment");
69-
}
70-
env_ = env;
71-
return absl::OkStatus();
72-
}
73-
7465
absl::Status PyCelPythonExtension::ConfigureCompiler(
7566
cel::CompilerBuilder& compiler_builder,
7667
const google::protobuf::DescriptorPool& descriptor_pool) {
@@ -115,12 +106,11 @@ absl::Status PyCelPythonExtension::ConfigureRuntime(
115106
}
116107

117108
cel::FunctionDescriptor descriptor(function.name(), overload.is_member(),
118-
types,
119-
/*is_strict=*/true);
109+
types, kFunctionDescriptorOptions);
120110
if (overload.py_function()) {
121111
PY_CEL_RETURN_IF_ERROR(runtime_builder.function_registry().Register(
122112
descriptor, std::make_unique<PyCelFunctionAdapter>(
123-
env_, function.name(), overload.py_function())));
113+
function.name(), overload.py_function())));
124114
} else {
125115
PY_CEL_RETURN_IF_ERROR(
126116
runtime_builder.function_registry().RegisterLazyFunction(

py_cel_python_extension.h

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,6 @@
1515
#ifndef THIRD_PARTY_CEL_PYTHON_PY_CEL_PYTHON_EXTENSION_H_
1616
#define THIRD_PARTY_CEL_PYTHON_PY_CEL_PYTHON_EXTENSION_H_
1717

18-
#include <memory>
1918
#include <string>
2019
#include <vector>
2120

@@ -30,18 +29,12 @@
3029

3130
namespace cel_python {
3231

33-
class PyCelEnv;
34-
3532
class PyCelPythonExtension : public PyCelExtension {
3633
public:
3734
static void DefinePythonBindings(pybind11::module& m);
3835
PyCelPythonExtension(std::string name,
3936
std::vector<PyCelFunctionDecl> functions);
4037

41-
// TODO(b/462745713): pass the env to the Invoke method instead of storing it
42-
// as a member variable.
43-
absl::Status SetEnv(const std::shared_ptr<PyCelEnv>& env);
44-
4538
protected:
4639
absl::Status ConfigureCompiler(
4740
cel::CompilerBuilder& compiler_builder,
@@ -52,7 +45,6 @@ class PyCelPythonExtension : public PyCelExtension {
5245

5346
private:
5447
std::vector<PyCelFunctionDecl> functions_;
55-
std::shared_ptr<PyCelEnv> env_;
5648
};
5749

5850
} // namespace cel_python

0 commit comments

Comments
 (0)