Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions cel_expr_python/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -105,14 +105,25 @@ pybind_extension(
# For pybind11-based CEL extensions.
pybind_library(
name = "cel_extension",
hdrs = ["cel_extension.h"],
srcs = [
"py_error_status.cc",
],
hdrs = [
"cel_extension.h",
"py_error_status.h",
],
visibility = ["//visibility:public"],
deps = [
":status_macros",
"@com_google_absl//absl/base:no_destructor",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/log:absl_check",
"@com_google_absl//absl/log:absl_log",
"@com_google_absl//absl/status",
"@com_google_absl//absl/status:statusor",
"@com_google_cel_cpp//compiler",
"@com_google_cel_cpp//runtime:runtime_builder",
"@com_google_cel_cpp//runtime:runtime_options",
"@com_google_protobuf//:protobuf",
],
)

Expand Down Expand Up @@ -141,7 +152,10 @@ py_test(
srcs = ["cel_env_test.py"],
deps = [
":cel",
"//cel_expr_python/ext:ext_bindings",
"//cel_expr_python/ext:ext_math",
"//cel_expr_python/ext:ext_optional",
"//cel_expr_python/ext:ext_strings",
"//testing:proto2_test_all_types_py_pb2",
"@com_google_absl_py//absl/testing:absltest",
],
Expand Down
87 changes: 79 additions & 8 deletions cel_expr_python/cel_env_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,10 @@

from absl.testing import absltest
from cel_expr_python import cel
from cel_expr_python.ext import ext_bindings
from cel_expr_python.ext import ext_math
from cel_expr_python.ext import ext_optional
from cel_expr_python.ext import ext_strings
from cel.expr.conformance.proto2 import test_all_types_pb2 as test_all_types_pb


Expand Down Expand Up @@ -95,9 +98,7 @@ def test_invalid_yaml(self):
)

def test_config_export_container(self):
env = cel.NewEnv(
container="test.container"
)
env = cel.NewEnv(container="test.container")
yaml = env.config().to_yaml()
self.assertEqual(
normalize_yaml(yaml),
Expand Down Expand Up @@ -251,6 +252,52 @@ def test_config_variable_types(self):
self.assertEqual(res.type(), cel.Type.INT)
self.assertEqual(res.value(), 42)

def test_config_export_extension_version(self):
env = cel.NewEnv(
extensions=[
ext_math.ExtMath(0),
ext_optional.ExtOptional(1),
ext_strings.ExtStrings(2),
ext_bindings.ExtBindings(),
],
)
yaml = env.config().to_yaml()
self.assertEqual(
normalize_yaml(yaml),
normalize_yaml("""
extensions:
- name: "bindings"
- name: "math"
version: 0
- name: "optional"
version: 1
- name: "strings"
version: 2
"""),
)

def test_config_extension_version_out_of_range(self):
cases = [
[
lambda: ext_math.ExtMath(42),
r"'math' extension version: 42 not in range \[0, \d+\]",
],
[
lambda: ext_optional.ExtOptional(6),
r"'optional' extension version: 6 not in range \[0, \d+\]",
],
[
lambda: ext_strings.ExtStrings(18),
r"'strings' extension version: 18 not in range \[0, \d+\]",
],
]
for test_case in cases:
with self.assertRaises(Exception) as e:
cel.NewEnv(
extensions=[test_case[0]()],
)
self.assertRegex(str(e.exception), test_case[1])

def test_config_extensions(self):
config = cel.NewEnvConfigFromYaml("""
extensions:
Expand All @@ -276,23 +323,47 @@ def test_config_extensions(self):
res = env.compile("hello('World')").eval()
self.assertEqual(res.value(), "Hello, World!")

def test_config_extensions_override(self):
# TODO(b/498655870): add assertion based on extension aliases once
# supported.
def test_config_extension_override_same_version(self):
config = cel.NewEnvConfigFromYaml("""
extensions:
- name: cel.lib.ext.math
version: 1
- name: strings
version: 2
""")
env = cel.NewEnv(
config=config,
extensions=[ext_math.ExtMath(1), ext_strings.ExtStrings(2)],
)
res = env.compile("'%.3f'.format([math.floor(3.14)])").eval()
self.assertEqual(res.value(), "3.000")

def test_config_extension_override_different_version(self):
config = cel.NewEnvConfigFromYaml("""
extensions:
- name: math
version: 0
- name: cel.lib.ext.strings
version: 2
""")
with self.assertRaises(Exception) as e:
cel.NewEnv(
config=config,
extensions=[ext_math.ExtMath()],
)
self.assertIn(
"Extension 'cel.lib.ext.math' version 0 is already included. Cannot"
" also include version 'latest'",
"Extension 'math' version 0 is already included. Cannot"
" also include version 2",
str(e.exception),
)
with self.assertRaises(Exception) as e:
cel.NewEnv(
config=config,
extensions=[ext_strings.ExtStrings(1)],
)
self.assertIn(
"Extension 'cel.lib.ext.strings' version 2 is already included. Cannot"
" also include version 1",
str(e.exception),
)

Expand Down
15 changes: 14 additions & 1 deletion cel_expr_python/cel_extension.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,9 @@ namespace cel_python {
// Python.
class CelExtension {
public:
explicit CelExtension(std::string name) : name_(std::move(name)) {};
explicit CelExtension(std::string name, std::string alias = "",
int version = -1)
: name_(std::move(name)), alias_(std::move(alias)), version_(version) {}
virtual ~CelExtension() = default;

virtual cel::CompilerLibrary GetCompilerLibrary() {
Expand All @@ -51,9 +53,13 @@ class CelExtension {
}

std::string name() const { return name_; }
std::string alias() const { return alias_; }
int version() const { return version_; }

private:
std::string name_;
std::string alias_;
int version_;
};

#define CEL_MODULE_NAME "cel_expr_python.cel"
Expand All @@ -80,6 +86,13 @@ class CelExtension {
.def(pybind11::init<>()); \
}

#define CEL_VERSIONED_EXTENSION_MODULE(module_name, class_name) \
PYBIND11_MODULE(module_name, m) { \
pybind11::module_::import(CEL_MODULE_NAME); \
pybind11::class_<class_name, cel_python::CelExtension>(m, #class_name) \
.def(pybind11::init<>()) \
.def(pybind11::init<int>(), pybind11::arg("version")); \
}
} // namespace cel_python

#endif // THIRD_PARTY_CEL_PYTHON_CEL_EXTENSION_H_
8 changes: 6 additions & 2 deletions cel_expr_python/ext/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@ pybind_extension(
deps = [
"//cel_expr_python:cel_extension",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_cel_cpp//compiler",
"@com_google_cel_cpp//extensions:math_ext",
"@com_google_cel_cpp//extensions:math_ext_decls",
Expand All @@ -69,6 +70,8 @@ pybind_extension(
deps = [
"//cel_expr_python:cel_extension",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_cel_cpp//checker:optional",
"@com_google_cel_cpp//compiler",
"@com_google_cel_cpp//compiler:optional",
"@com_google_cel_cpp//runtime:optional_types",
Expand All @@ -94,9 +97,9 @@ pybind_extension(
)

pybind_extension(
name = "ext_string",
name = "ext_strings",
srcs = [
"ext_string.cc",
"ext_strings.cc",
],
data = [
"//cel_expr_python:cel",
Expand All @@ -105,6 +108,7 @@ pybind_extension(
deps = [
"//cel_expr_python:cel_extension",
"@com_google_absl//absl/status",
"@com_google_absl//absl/strings",
"@com_google_cel_cpp//compiler",
"@com_google_cel_cpp//extensions:strings",
"@com_google_cel_cpp//runtime:runtime_builder",
Expand Down
3 changes: 2 additions & 1 deletion cel_expr_python/ext/ext_bindings.cc
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,8 @@ namespace cel_python {

class ExtBindings : public CelExtension {
public:
explicit ExtBindings() : CelExtension("cel.lib.ext.cel.bindings") {}
explicit ExtBindings()
: CelExtension("cel.lib.ext.cel.bindings", "bindings") {}

cel::CompilerLibrary GetCompilerLibrary() override {
return cel::extensions::BindingsCompilerLibrary();
Expand Down
2 changes: 1 addition & 1 deletion cel_expr_python/ext/ext_encoders.cc
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ namespace cel_python {

class ExtEncoders : public CelExtension {
public:
explicit ExtEncoders() : CelExtension("cel.lib.ext.encoders") {}
explicit ExtEncoders() : CelExtension("cel.lib.ext.encoders", "encoders") {}

cel::CompilerLibrary GetCompilerLibrary() override {
return cel::extensions::EncodersCompilerLibrary();
Expand Down
19 changes: 15 additions & 4 deletions cel_expr_python/ext/ext_math.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,30 +13,41 @@
// limitations under the License.

#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "compiler/compiler.h"
#include "extensions/math_ext.h"
#include "extensions/math_ext_decls.h"
#include "runtime/runtime_builder.h"
#include "runtime/runtime_options.h"
#include "cel_expr_python/cel_extension.h"
#include "cel_expr_python/py_error_status.h"

namespace cel_python {

class ExtMath : public CelExtension {
public:
explicit ExtMath() : CelExtension("cel.lib.ext.math") {}
explicit ExtMath(int version)
: CelExtension("cel.lib.ext.math", "math", version) {
if (version < 0 || version > cel::extensions::kMathExtensionLatestVersion) {
throw StatusToException(absl::InvalidArgumentError(absl::StrCat(
"'math' extension version: ", version, " not in range [0, ",
cel::extensions::kMathExtensionLatestVersion, "]")));
}
}

ExtMath() : ExtMath(cel::extensions::kMathExtensionLatestVersion) {}

cel::CompilerLibrary GetCompilerLibrary() override {
return cel::extensions::MathCompilerLibrary();
return cel::extensions::MathCompilerLibrary(version());
}

absl::Status ConfigureRuntime(cel::RuntimeBuilder& runtime_builder,
const cel::RuntimeOptions& opts) override {
return cel::extensions::RegisterMathExtensionFunctions(
runtime_builder.function_registry(), opts);
runtime_builder.function_registry(), opts, version());
}
};

CEL_EXTENSION_MODULE(ext_math, ExtMath);
CEL_VERSIONED_EXTENSION_MODULE(ext_math, ExtMath);

} // namespace cel_python
17 changes: 14 additions & 3 deletions cel_expr_python/ext/ext_optional.cc
Original file line number Diff line number Diff line change
Expand Up @@ -13,21 +13,32 @@
// limitations under the License.

#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "checker/optional.h"
#include "compiler/compiler.h"
#include "compiler/optional.h"
#include "runtime/optional_types.h"
#include "runtime/runtime_builder.h"
#include "runtime/runtime_options.h"
#include "cel_expr_python/cel_extension.h"
#include "cel_expr_python/py_error_status.h"

namespace cel_python {

class ExtOptional : public CelExtension {
public:
explicit ExtOptional() : CelExtension("optional") {}
explicit ExtOptional(int version) : CelExtension("optional", "", version) {
if (version < 0 || version > cel::kOptionalExtensionLatestVersion) {
throw StatusToException(absl::InvalidArgumentError(absl::StrCat(
"'optional' extension version: ", version, " not in range [0, ",
cel::kOptionalExtensionLatestVersion, "]")));
}
}

ExtOptional() : ExtOptional(cel::kOptionalExtensionLatestVersion) {}

cel::CompilerLibrary GetCompilerLibrary() override {
return cel::OptionalCompilerLibrary();
return cel::OptionalCompilerLibrary(version());
}

absl::Status ConfigureRuntime(cel::RuntimeBuilder& runtime_builder,
Expand All @@ -40,6 +51,6 @@ class ExtOptional : public CelExtension {
}
};

CEL_EXTENSION_MODULE(ext_optional, ExtOptional);
CEL_VERSIONED_EXTENSION_MODULE(ext_optional, ExtOptional);

} // namespace cel_python
Original file line number Diff line number Diff line change
Expand Up @@ -13,29 +13,42 @@
// limitations under the License.

#include "absl/status/status.h"
#include "absl/strings/str_cat.h"
#include "compiler/compiler.h"
#include "extensions/strings.h"
#include "runtime/runtime_builder.h"
#include "runtime/runtime_options.h"
#include "cel_expr_python/cel_extension.h"
#include "cel_expr_python/py_error_status.h"

namespace cel_python {

class ExtString : public CelExtension {
class ExtStrings : public CelExtension {
public:
explicit ExtString() : CelExtension("cel.lib.ext.string") {}
explicit ExtStrings(int version)
: CelExtension("cel.lib.ext.strings", "strings", version) {
if (version < 0 ||
version > cel::extensions::kStringsExtensionLatestVersion) {
throw StatusToException(absl::InvalidArgumentError(absl::StrCat(
"'strings' extension version: ", version, " not in range [0, ",
cel::extensions::kStringsExtensionLatestVersion, "]")));
}
}

ExtStrings() : ExtStrings(cel::extensions::kStringsExtensionLatestVersion) {}

cel::CompilerLibrary GetCompilerLibrary() override {
return cel::extensions::StringsCompilerLibrary();
return cel::extensions::StringsCompilerLibrary(version());
}

absl::Status ConfigureRuntime(cel::RuntimeBuilder& runtime_builder,
const cel::RuntimeOptions& opts) override {
return cel::extensions::RegisterStringsFunctions(
runtime_builder.function_registry(), opts);
runtime_builder.function_registry(), opts,
cel::extensions::StringsExtensionOptions{.version = version()});
}
};

CEL_EXTENSION_MODULE(ext_string, ExtString);
CEL_VERSIONED_EXTENSION_MODULE(ext_strings, ExtStrings);

} // namespace cel_python
Loading
Loading