Skip to content

Commit 95915c5

Browse files
dmitriplotnikovcopybara-github
authored andcommitted
Make protobuf dependency optional
PiperOrigin-RevId: 864950333
1 parent e16a44c commit 95915c5

4 files changed

Lines changed: 139 additions & 10 deletions

File tree

py_cel_env.cc

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
#include <Python.h> // IWYU pragma: keep - Needed for PyObject
1818

19+
#include <exception>
1920
#include <memory>
2021
#include <optional>
2122
#include <string>
@@ -45,12 +46,18 @@ void PyCelEnv::DefinePythonBindings(pybind11::module& m) {
4546
std::optional<std::unordered_map<std::string, PyCelType>> variables,
4647
std::optional<std::vector<py::object>> extensions,
4748
const std::optional<std::string>& container) {
48-
PyObject* pool_ptr = nullptr;
49+
PyObject* pool_ptr;
4950
if (descriptor_pool.is_none()) {
5051
// Replicates python's `descriptor_pool.Default()`
51-
pool_ptr = py::module::import("google.protobuf.descriptor_pool")
52-
.attr("Default")()
53-
.ptr();
52+
try {
53+
pool_ptr = py::module::import("google.protobuf.descriptor_pool")
54+
.attr("Default")()
55+
.ptr();
56+
} catch (const std::exception& e) {
57+
// google.protobuf.descriptor_pool is not available.
58+
pool_ptr = nullptr;
59+
PyErr_Clear(); // Clear the Python error state.
60+
}
5461
} else {
5562
pool_ptr = descriptor_pool.ptr();
5663
}

py_cel_test.py

Lines changed: 105 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@
1616

1717
import datetime
1818
import gc
19+
import importlib
20+
import importlib.abc
21+
import sys
1922

2023
from google.protobuf import duration_pb2 as duration_pb
2124
from google.protobuf import timestamp_pb2 as timestamp_pb
@@ -730,8 +733,10 @@ def testActivationAndOtherArgs(self):
730733
self.env.Activation(data={"var_str": "World!"}),
731734
data={"var_str": "World!"},
732735
)
733-
self.assertIn("Cannot provide both activation and any other arguments",
734-
str(e.exception))
736+
self.assertIn(
737+
"Cannot provide both activation and any other arguments",
738+
str(e.exception),
739+
)
735740

736741
def testCompilationErrorHandling(self):
737742
# Check parser error.
@@ -799,5 +804,103 @@ def FindFileContainingSymbol(self, symbol_name: str): # pylint: disable=invalid
799804
raise LookupError("Could not find file containing symbol: %s" % symbol_name)
800805

801806

807+
class PyCelWithoutProtoSupportTest(absltest.TestCase):
808+
"""Test that the environment can be created without proto support."""
809+
810+
def setUp(self):
811+
super().setUp()
812+
self.msg = test_all_types_pb.TestAllTypes()
813+
self.msg.single_string = "Hey"
814+
815+
# "Unimport" descriptor_pool if it is already imported.
816+
if "google.protobuf.descriptor_pool" in sys.modules:
817+
del sys.modules["google.protobuf.descriptor_pool"]
818+
819+
# Make it impossible to import descriptor_pool.
820+
class UnluckyFinder(importlib.abc.MetaPathFinder):
821+
822+
def find_spec(self, fullname, unused_path, unused_target=None):
823+
if fullname == "google.protobuf.descriptor_pool":
824+
raise ImportError("Not found")
825+
return None
826+
827+
sys.meta_path.insert(0, UnluckyFinder())
828+
829+
def tearDown(self):
830+
# Remove the unlucky finder from the meta path.
831+
sys.meta_path.pop(0)
832+
super().tearDown()
833+
834+
def testEvalWithNonProtoTypes(self):
835+
cel_env = cel.NewEnv(
836+
descriptor_pool=None,
837+
variables={
838+
"var_str": cel.Type.STRING,
839+
"var_map": cel.Type.Map(cel.Type.STRING, cel.Type.STRING),
840+
"var_list": cel.Type.List(cel.Type.STRING),
841+
},
842+
)
843+
data = {
844+
"var_str": "foo",
845+
"var_map": {"key": "bar"},
846+
"var_list": ["foo", "bar", "baz"],
847+
}
848+
res = cel_env.compile("var_str").eval(data=data)
849+
self.assertEqual(res.value(), "foo")
850+
851+
res = cel_env.compile("var_map['key']").eval(data=data)
852+
self.assertEqual(res.value(), "bar")
853+
854+
res = cel_env.compile("var_list[2]").eval(data=data)
855+
self.assertEqual(res.value(), "baz")
856+
857+
def testErrorOnProtoAccess(self):
858+
cel_env = cel.NewEnv(
859+
descriptor_pool=None,
860+
variables={
861+
"var_proto": cel.Type.DYN,
862+
},
863+
)
864+
res = cel_env.compile("var_proto.single_string").eval(
865+
data={"var_proto": self.msg}
866+
)
867+
self.assertEqual(res.type(), cel.Type.ERROR)
868+
self.assertIn(
869+
"Descriptor not found for message type"
870+
" 'cel.expr.conformance.proto2.TestAllTypes'",
871+
str(res.value()),
872+
)
873+
874+
with self.assertRaises(Exception) as e:
875+
cel_env.compile(
876+
"cel.expr.conformance.proto2.TestAllTypes{single_string: 'hello'}"
877+
).eval()
878+
self.assertIn(
879+
"undeclared reference to 'cel.expr.conformance.proto2.TestAllTypes'",
880+
str(e.exception),
881+
)
882+
883+
def testErrorOnProtoCreation(self):
884+
cel_env = cel.NewEnv(
885+
descriptor_pool=None,
886+
variables={
887+
"var_proto": cel.Type.DYN,
888+
},
889+
)
890+
# Disable type checking to allow the compilation to succeed.
891+
expr = cel_env.compile(
892+
"cel.expr.conformance.proto2.TestAllTypes{single_string: 'hello'}",
893+
disable_check=True,
894+
)
895+
896+
with self.assertRaises(Exception) as e:
897+
expr.eval()
898+
self.assertIn(
899+
"Invalid struct creation: missing type info for"
900+
" 'cel.expr.conformance.proto2.TestAllTypes'",
901+
str(e.exception),
902+
)
903+
904+
802905
if __name__ == "__main__":
803906
absltest.main()

py_descriptor_database.cc

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,12 +32,12 @@ PyDescriptorDatabase::PyDescriptorDatabase(PyObject* py_descriptor_pool)
3232
: py_descriptor_pool_(py_descriptor_pool),
3333
standard_pool_(cel::GetMinimalDescriptorPool()) {
3434
ABSL_CHECK(PyGILState_Check());
35-
Py_INCREF(py_descriptor_pool_);
35+
Py_XINCREF(py_descriptor_pool_);
3636
}
3737

3838
PyDescriptorDatabase::~PyDescriptorDatabase() {
3939
auto gil_state = PyGILState_Ensure();
40-
Py_DECREF(py_descriptor_pool_);
40+
Py_XDECREF(py_descriptor_pool_);
4141
PyGILState_Release(gil_state);
4242
}
4343

@@ -52,6 +52,10 @@ bool PyDescriptorDatabase::FindFileByName(StringViewArg filename,
5252
return true;
5353
}
5454

55+
if (py_descriptor_pool_ == nullptr) {
56+
return false;
57+
}
58+
5559
PyObject* pyfile = PyObject_CallMethod(
5660
py_descriptor_pool_, "FindFileByName", "s#", filename.data(),
5761
static_cast<Py_ssize_t>(filename.size()));
@@ -98,6 +102,10 @@ bool PyDescriptorDatabase::FindFileContainingSymbol(
98102
return true;
99103
}
100104

105+
if (py_descriptor_pool_ == nullptr) {
106+
return false;
107+
}
108+
101109
PyObject* pyfile = PyObject_CallMethod(
102110
py_descriptor_pool_, "FindFileContainingSymbol", "s#", symbol_name.data(),
103111
static_cast<Py_ssize_t>(symbol_name.size()));
@@ -137,6 +145,10 @@ bool PyDescriptorDatabase::FindFileContainingSymbol(
137145
bool PyDescriptorDatabase::FindFileContainingExtension(
138146
StringViewArg containing_type, int field_number,
139147
google::protobuf::FileDescriptorProto* output) {
148+
if (py_descriptor_pool_ == nullptr) {
149+
return false;
150+
}
151+
140152
ABSL_CHECK(PyGILState_Check());
141153
PyObject* py_containing_type = PyObject_CallMethod(
142154
py_descriptor_pool_, "FindMessageTypeByName", "s#",

py_message_factory.cc

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ namespace cel_python {
2626

2727
PyMessageFactory::PyMessageFactory(PyObject* descriptor_pool) {
2828
py_descriptor_pool_ = descriptor_pool;
29-
Py_INCREF(py_descriptor_pool_);
29+
Py_XINCREF(py_descriptor_pool_);
3030
PyObject* pName =
3131
PyUnicode_DecodeFSDefault("google.protobuf.message_factory");
3232
PyObject* pModule = PyImport_Import(pName);
@@ -45,7 +45,7 @@ PyMessageFactory::PyMessageFactory(PyObject* descriptor_pool) {
4545

4646
PyMessageFactory::~PyMessageFactory() {
4747
auto gil_state = PyGILState_Ensure();
48-
Py_DECREF(py_descriptor_pool_);
48+
Py_XDECREF(py_descriptor_pool_);
4949
Py_XDECREF(py_func_GetMessageClass_);
5050
Py_XDECREF(py_func_MergeFromString_);
5151
for (auto const& [key, py_obj] : message_classes_) {
@@ -55,6 +55,13 @@ PyMessageFactory::~PyMessageFactory() {
5555
}
5656

5757
PyObject* PyMessageFactory::GetMessageClass(const std::string& message_type) {
58+
if (py_descriptor_pool_ == nullptr) {
59+
PyErr_Format(PyExc_TypeError,
60+
"Message type not found: %s, descriptor pool is unavailable.",
61+
message_type.c_str());
62+
return nullptr;
63+
}
64+
5865
auto it = message_classes_.find(message_type);
5966
if (it != message_classes_.end()) {
6067
return it->second;

0 commit comments

Comments
 (0)