Skip to content

Commit 4eb8e78

Browse files
dmitriplotnikovcopybara-github
authored andcommitted
Integrate cel::Config::ExtensionConfig into Python CEL environment
PiperOrigin-RevId: 892660001
1 parent 1ec5a73 commit 4eb8e78

5 files changed

Lines changed: 240 additions & 33 deletions

File tree

cel_expr_python/BUILD

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,6 +63,7 @@ pybind_extension(
6363
"@com_google_absl//absl/types:optional",
6464
"@com_google_absl//absl/types:span",
6565
"@com_google_cel_cpp//checker:type_checker_builder",
66+
"@com_google_cel_cpp//checker:type_checker_builder_factory",
6667
"@com_google_cel_cpp//checker:validation_result",
6768
"@com_google_cel_cpp//common:ast",
6869
"@com_google_cel_cpp//common:ast_proto",
@@ -78,8 +79,13 @@ pybind_extension(
7879
"@com_google_cel_cpp//compiler",
7980
"@com_google_cel_cpp//env",
8081
"@com_google_cel_cpp//env:config",
82+
"@com_google_cel_cpp//env:env_runtime",
83+
"@com_google_cel_cpp//env:env_std_extensions",
8184
"@com_google_cel_cpp//env:env_yaml",
85+
"@com_google_cel_cpp//env:runtime_std_extensions",
8286
"@com_google_cel_cpp//extensions/protobuf:runtime_adapter",
87+
"@com_google_cel_cpp//parser",
88+
"@com_google_cel_cpp//parser:options",
8389
"@com_google_cel_cpp//parser:parser_interface",
8490
"@com_google_cel_cpp//runtime",
8591
"@com_google_cel_cpp//runtime:activation",
@@ -88,7 +94,6 @@ pybind_extension(
8894
"@com_google_cel_cpp//runtime:reference_resolver",
8995
"@com_google_cel_cpp//runtime:runtime_builder",
9096
"@com_google_cel_cpp//runtime:runtime_options",
91-
"@com_google_cel_cpp//runtime:standard_runtime_builder_factory",
9297
"@com_google_cel_spec//proto/cel/expr:checked_cc_proto",
9398
"@com_google_cel_spec//proto/cel/expr:syntax_cc_proto",
9499
"@com_google_protobuf//:protobuf",
@@ -135,6 +140,7 @@ py_test(
135140
srcs = ["cel_env_test.py"],
136141
deps = [
137142
":cel",
143+
"//cel_expr_python/ext:ext_math",
138144
"//testing:proto2_test_all_types_py_pb2",
139145
"@com_google_absl_py//absl/testing:absltest",
140146
],

cel_expr_python/cel_env_test.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020

2121
from absl.testing import absltest
2222
from cel_expr_python import cel
23+
from cel_expr_python.ext import ext_math
2324
from cel.expr.conformance.proto2 import test_all_types_pb2 as test_all_types_pb
2425

2526

@@ -236,6 +237,75 @@ def test_config_variable_types(self):
236237
self.assertEqual(res.type(), cel.Type.INT)
237238
self.assertEqual(res.value(), 42)
238239

240+
def test_config_extensions(self):
241+
config = cel.NewEnvConfigFromYaml("""
242+
extensions:
243+
- name: math
244+
- name: strings
245+
""")
246+
env = cel.NewEnv(
247+
config=config,
248+
extensions=[TestCelExtension()],
249+
)
250+
yaml = env.config().to_yaml()
251+
self.assertEqual(
252+
normalize_yaml(yaml),
253+
normalize_yaml("""
254+
extensions:
255+
- name: "math"
256+
- name: "strings"
257+
- name: "test_cel_extension"
258+
"""),
259+
)
260+
res = env.compile("'%.4f'.format([math.sqrt(2)])").eval()
261+
self.assertEqual(res.value(), "1.4142")
262+
res = env.compile("hello('World')").eval()
263+
self.assertEqual(res.value(), "Hello, World!")
264+
265+
def test_config_extensions_override(self):
266+
# TODO(b/498655870): add assertion based on extension aliases once
267+
# supported.
268+
config = cel.NewEnvConfigFromYaml("""
269+
extensions:
270+
- name: cel.lib.ext.math
271+
version: 0
272+
- name: cel.lib.ext.strings
273+
""")
274+
with self.assertRaises(Exception) as e:
275+
cel.NewEnv(
276+
config=config,
277+
extensions=[ext_math.ExtMath()],
278+
)
279+
self.assertIn(
280+
"Extension 'cel.lib.ext.math' version 0 is already included. Cannot"
281+
" also include version 'latest'",
282+
str(e.exception),
283+
)
284+
285+
286+
class TestCelExtension(cel.CelExtension):
287+
"""An example CEL extension for testing."""
288+
289+
def __init__(self):
290+
super().__init__(
291+
"test_cel_extension",
292+
functions=[
293+
cel.FunctionDecl(
294+
"hello",
295+
[
296+
cel.Overload(
297+
"hello(string)",
298+
return_type=cel.Type.STRING,
299+
parameters=[
300+
cel.Type.STRING,
301+
],
302+
impl=lambda arg: f"Hello, {arg}!",
303+
)
304+
],
305+
),
306+
],
307+
)
308+
239309

240310
def normalize_yaml(yaml: str) -> str:
241311
lines = yaml.split("\n")

cel_expr_python/ext/ext_optional.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ namespace cel_python {
2525

2626
class ExtOptional : public CelExtension {
2727
public:
28-
explicit ExtOptional() : CelExtension("cel.lib.optional") {}
28+
explicit ExtOptional() : CelExtension("optional") {}
2929

3030
absl::Status ConfigureCompiler(
3131
cel::CompilerBuilder& compiler_builder,

cel_expr_python/py_cel_env_internal.cc

Lines changed: 151 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -25,42 +25,169 @@
2525
#include "absl/status/statusor.h"
2626
#include "absl/strings/str_cat.h"
2727
#include "checker/type_checker_builder.h"
28+
#include "checker/type_checker_builder_factory.h"
29+
#include "common/type.h"
2830
#include "compiler/compiler.h"
2931
#include "env/config.h"
3032
#include "env/env.h"
33+
#include "env/env_std_extensions.h"
34+
#include "env/runtime_std_extensions.h"
3135
#include "env/type_info.h"
36+
#include "parser/options.h"
37+
#include "parser/parser.h"
38+
#include "parser/parser_interface.h"
3239
#include "runtime/reference_resolver.h"
3340
#include "runtime/runtime.h"
3441
#include "runtime/runtime_builder.h"
3542
#include "runtime/runtime_options.h"
36-
#include "runtime/standard_runtime_builder_factory.h"
3743
#include "cel_expr_python/cel_extension.h"
3844
#include "cel_expr_python/py_cel_env_config.h"
3945
#include "cel_expr_python/py_cel_python_extension.h"
4046
#include "cel_expr_python/py_cel_type.h"
4147
#include "cel_expr_python/py_descriptor_database.h"
48+
#include "cel_expr_python/py_error_status.h"
49+
#include "cel_expr_python/py_message_factory.h"
4250
#include "cel_expr_python/status_macros.h"
4351
#include "google/protobuf/arena.h"
52+
#include "google/protobuf/descriptor.h"
4453
#include <pybind11/pybind11.h>
4554

4655
namespace cel_python {
4756

48-
PyCelEnvInternal::PyCelEnvInternal(const PyCelEnvConfig& env_config,
49-
PyObject* py_descriptor_pool,
50-
const std::vector<PyObject*>& extensions,
51-
const std::string& container)
57+
// A temporary adapter for cel::CompilerBuilder. It will be removed once the
58+
// CelExtension class is changed to return a cel::CompilerLibrary directly.
59+
//
60+
// This adapter allows `CelExtension::ConfigureCompiler`, which takes a
61+
// `cel::CompilerBuilder` to be used with the `cel::Env` API, which
62+
// is expressed in terms of the `cel::CompilerLibrary` framework. The adapter
63+
// splits the `cel::CompilerLibrary` into its constituent parts and feeds them
64+
// to the `cel::CompilerBuilder` interface.
65+
class CompilerBuilderAdapter : public cel::CompilerBuilder {
66+
public:
67+
CompilerBuilderAdapter(cel::ParserBuilder* parser_builder,
68+
cel::TypeCheckerBuilder* checker_builder)
69+
: parser_builder_(parser_builder), checker_builder_(checker_builder) {}
70+
71+
absl::Status AddLibrary(cel::CompilerLibrary library) override {
72+
if (library.configure_parser != nullptr) {
73+
CEL_PYTHON_RETURN_IF_ERROR(library.configure_parser(*parser_builder_));
74+
}
75+
if (library.configure_checker != nullptr) {
76+
CEL_PYTHON_RETURN_IF_ERROR(library.configure_checker(*checker_builder_));
77+
}
78+
return absl::OkStatus();
79+
}
80+
81+
absl::Status AddLibrarySubset(cel::CompilerLibrarySubset subset) override {
82+
return absl::UnimplementedError("Not implemented");
83+
}
84+
85+
cel::ParserBuilder& GetParserBuilder() override {
86+
ABSL_CHECK(parser_builder_ != nullptr);
87+
return *parser_builder_;
88+
}
89+
90+
cel::TypeCheckerBuilder& GetCheckerBuilder() override {
91+
ABSL_CHECK(checker_builder_ != nullptr);
92+
return *checker_builder_;
93+
}
94+
95+
absl::StatusOr<std::unique_ptr<cel::Compiler>> Build() override {
96+
return absl::UnimplementedError("Not implemented");
97+
}
98+
99+
private:
100+
cel::ParserBuilder* parser_builder_;
101+
cel::TypeCheckerBuilder* checker_builder_;
102+
};
103+
104+
// A temporary class to deal with the interface mismatch between CelExtension
105+
// and cel::CompilerLibrary.
106+
class AdapterCompilerLibrary : public cel::CompilerLibrary {
107+
public:
108+
AdapterCompilerLibrary(
109+
CelExtension* extension,
110+
const std::shared_ptr<google::protobuf::DescriptorPool>& descriptor_pool,
111+
cel::ParserBuilder* passive_parser_builder,
112+
cel::TypeCheckerBuilder* passive_checker_builder)
113+
: cel::CompilerLibrary(
114+
extension->name(),
115+
// Safe to capture passive_checker_builder because it outlives the
116+
// compiler library.
117+
[extension, descriptor_pool, passive_checker_builder](
118+
cel::ParserBuilder& parser_builder) -> absl::Status {
119+
return ConfigureParser(parser_builder, extension, descriptor_pool,
120+
passive_checker_builder);
121+
},
122+
// Safe to capture passive_parser_builder because it outlives the
123+
// compiler library.
124+
[extension, descriptor_pool, passive_parser_builder](
125+
cel::TypeCheckerBuilder& checker_builder) -> absl::Status {
126+
return ConfigureChecker(checker_builder, extension,
127+
descriptor_pool, passive_parser_builder);
128+
}) {};
129+
130+
private:
131+
static absl::Status ConfigureParser(
132+
cel::ParserBuilder& parser_builder, CelExtension* extension,
133+
const std::shared_ptr<google::protobuf::DescriptorPool>& descriptor_pool,
134+
cel::TypeCheckerBuilder* passive_checker_builder) {
135+
CompilerBuilderAdapter compiler_builder(&parser_builder,
136+
passive_checker_builder);
137+
return extension->ConfigureCompiler(compiler_builder, *descriptor_pool);
138+
}
139+
140+
static absl::Status ConfigureChecker(
141+
cel::TypeCheckerBuilder& checker_builder, CelExtension* extension,
142+
const std::shared_ptr<google::protobuf::DescriptorPool>& descriptor_pool,
143+
cel::ParserBuilder* passive_parser_builder) {
144+
CompilerBuilderAdapter compiler_builder(passive_parser_builder,
145+
&checker_builder);
146+
return extension->ConfigureCompiler(compiler_builder, *descriptor_pool);
147+
}
148+
};
149+
150+
PyCelEnvInternal::PyCelEnvInternal(
151+
const PyCelEnvConfig& env_config, PyObject* py_descriptor_pool,
152+
std::vector<std::unique_ptr<CelExtensionHandle>>& extension_handles,
153+
const std::string& container)
52154
: env_config_(env_config),
53155
py_descriptor_database_(py_descriptor_pool),
54156
descriptor_pool_(
55157
std::make_shared<google::protobuf::DescriptorPool>(&py_descriptor_database_)),
56158
message_factory_(descriptor_pool_.get()),
57159
py_message_factory_(
58160
std::make_shared<PyMessageFactory>(py_descriptor_pool)),
161+
extensions_(std::move(extension_handles)),
59162
container_(std::move(container)) {
60163
cel_env_.SetDescriptorPool(descriptor_pool_);
61164
cel_env_.SetConfig(env_config_.GetConfig());
62-
for (PyObject* ext : extensions) {
63-
extensions_.push_back(std::make_unique<CelExtensionHandle>(ext));
165+
cel::RegisterStandardExtensions(cel_env_);
166+
167+
cel_env_runtime_.SetDescriptorPool(descriptor_pool_);
168+
cel_env_runtime_.SetConfig(env_config_.GetConfig());
169+
cel::RegisterStandardExtensions(cel_env_runtime_);
170+
171+
passive_parser_builder_ = cel::NewParserBuilder(cel::ParserOptions());
172+
passive_checker_builder_ =
173+
ThrowIfError(cel::CreateTypeCheckerBuilder(descriptor_pool_.get()));
174+
for (std::unique_ptr<CelExtensionHandle>& extension_handle : extensions_) {
175+
// This should never fail because we have already called GetExtension() once
176+
// before calling this constructor.
177+
CelExtension* extension = ThrowIfError(extension_handle->GetExtension());
178+
cel_env_.RegisterCompilerLibrary(
179+
extension->name(), extension->name(), 0, [this, extension]() {
180+
return AdapterCompilerLibrary(extension, descriptor_pool_,
181+
passive_parser_builder_.get(),
182+
passive_checker_builder_.get());
183+
});
184+
cel_env_runtime_.RegisterExtensionFunctions(
185+
extension->name(), extension->name(), 0,
186+
[extension](
187+
cel::RuntimeBuilder& runtime_builder,
188+
const cel::RuntimeOptions& runtime_options) -> absl::Status {
189+
return extension->ConfigureRuntime(runtime_builder, runtime_options);
190+
});
64191
}
65192
}
66193

@@ -79,8 +206,19 @@ PyCelEnvInternal::NewCelEnvInternal(
79206
}));
80207
}
81208

82-
return std::shared_ptr<PyCelEnvInternal>(new PyCelEnvInternal(
83-
PyCelEnvConfig(config), py_descriptor_pool, extensions, container));
209+
std::vector<std::unique_ptr<CelExtensionHandle>> extension_handles;
210+
for (PyObject* ext : extensions) {
211+
auto extension_handle = std::make_unique<CelExtensionHandle>(ext);
212+
CEL_PYTHON_ASSIGN_OR_RETURN(CelExtension * extension,
213+
extension_handle->GetExtension());
214+
// TODO(b/498655870): support extension version.
215+
CEL_PYTHON_RETURN_IF_ERROR(config.AddExtensionConfig(extension->name()));
216+
extension_handles.push_back(std::move(extension_handle));
217+
}
218+
219+
return std::shared_ptr<PyCelEnvInternal>(
220+
new PyCelEnvInternal(PyCelEnvConfig(config), py_descriptor_pool,
221+
extension_handles, container));
84222
}
85223

86224
absl::StatusOr<const cel::Compiler*> PyCelEnvInternal::GetCompiler(
@@ -95,13 +233,6 @@ absl::StatusOr<const cel::Compiler*> PyCelEnvInternal::GetCompiler(
95233
std::unique_ptr<cel::CompilerBuilder> compiler_builder,
96234
env->cel_env_.NewCompilerBuilder());
97235
compiler_builder->GetCheckerBuilder().set_container(env->container_);
98-
for (std::unique_ptr<CelExtensionHandle>& extension_handle :
99-
env->extensions_) {
100-
CEL_PYTHON_ASSIGN_OR_RETURN(CelExtension * extension,
101-
extension_handle->GetExtension(env));
102-
CEL_PYTHON_RETURN_IF_ERROR(extension->ConfigureCompiler(
103-
*compiler_builder, *(env->descriptor_pool_.get())));
104-
}
105236

106237
// Convert variable types from cel::TypeInfo to PyCelType.
107238
google::protobuf::Arena* arena = compiler_builder->GetCheckerBuilder().arena();
@@ -125,7 +256,7 @@ absl::StatusOr<const cel::Runtime*> PyCelEnvInternal::GetRuntime(
125256
return it->second.get();
126257
}
127258

128-
cel::RuntimeOptions opts;
259+
cel::RuntimeOptions& opts = env->cel_env_runtime_.mutable_runtime_options();
129260
opts.container = env->container_;
130261
opts.enable_empty_wrapper_null_unboxing = true;
131262
opts.enable_qualified_type_identifiers = true;
@@ -137,17 +268,10 @@ absl::StatusOr<const cel::Runtime*> PyCelEnvInternal::GetRuntime(
137268
opts.fail_on_warnings = false;
138269
break;
139270
}
140-
CEL_PYTHON_ASSIGN_OR_RETURN(
141-
cel::RuntimeBuilder builder,
142-
cel::CreateStandardRuntimeBuilder(env->descriptor_pool_.get(), opts));
271+
CEL_PYTHON_ASSIGN_OR_RETURN(cel::RuntimeBuilder builder,
272+
env->cel_env_runtime_.CreateRuntimeBuilder());
143273
CEL_PYTHON_RETURN_IF_ERROR(cel::EnableReferenceResolver(
144274
builder, cel::ReferenceResolverEnabled::kAlways));
145-
for (std::unique_ptr<CelExtensionHandle>& extension_handle :
146-
env->extensions_) {
147-
CEL_PYTHON_ASSIGN_OR_RETURN(CelExtension * extension,
148-
extension_handle->GetExtension(env));
149-
CEL_PYTHON_RETURN_IF_ERROR(extension->ConfigureRuntime(builder, opts));
150-
}
151275
CEL_PYTHON_ASSIGN_OR_RETURN(std::unique_ptr<cel::Runtime> runtime,
152276
std::move(builder).Build());
153277
const cel::Runtime* runtime_ptr = runtime.get();
@@ -177,8 +301,7 @@ CelExtensionHandle::~CelExtensionHandle() {
177301
PyGILState_Release(gil_state);
178302
}
179303

180-
absl::StatusOr<CelExtension*> CelExtensionHandle::GetExtension(
181-
const std::shared_ptr<PyCelEnvInternal>& env) {
304+
absl::StatusOr<CelExtension*> CelExtensionHandle::GetExtension() {
182305
if (cel_extension_) {
183306
return cel_extension_;
184307
}

0 commit comments

Comments
 (0)