Skip to content

Commit 2a91a16

Browse files
dmitriplotnikovcopybara-github
authored andcommitted
Make protobuf dependency optional. Part II
PiperOrigin-RevId: 865503087
1 parent 95915c5 commit 2a91a16

2 files changed

Lines changed: 17 additions & 4 deletions

File tree

py_cel_test.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -812,15 +812,16 @@ def setUp(self):
812812
self.msg = test_all_types_pb.TestAllTypes()
813813
self.msg.single_string = "Hey"
814814

815-
# "Unimport" descriptor_pool if it is already imported.
816-
if "google.protobuf.descriptor_pool" in sys.modules:
817-
del sys.modules["google.protobuf.descriptor_pool"]
815+
# "Unimport" any google.protobuf modules if they are already imported.
816+
for module_name in list(sys.modules):
817+
if module_name.startswith("google.protobuf"):
818+
del sys.modules[module_name]
818819

819820
# Make it impossible to import descriptor_pool.
820821
class UnluckyFinder(importlib.abc.MetaPathFinder):
821822

822823
def find_spec(self, fullname, unused_path, unused_target=None):
823-
if fullname == "google.protobuf.descriptor_pool":
824+
if fullname.startswith("google.protobuf."):
824825
raise ImportError("Not found")
825826
return None
826827

py_message_factory.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,10 @@ namespace cel_python {
2626

2727
PyMessageFactory::PyMessageFactory(PyObject* descriptor_pool) {
2828
py_descriptor_pool_ = descriptor_pool;
29+
if (py_descriptor_pool_ == nullptr) {
30+
return;
31+
}
32+
2933
Py_XINCREF(py_descriptor_pool_);
3034
PyObject* pName =
3135
PyUnicode_DecodeFSDefault("google.protobuf.message_factory");
@@ -44,6 +48,10 @@ PyMessageFactory::PyMessageFactory(PyObject* descriptor_pool) {
4448
}
4549

4650
PyMessageFactory::~PyMessageFactory() {
51+
if (py_descriptor_pool_ == nullptr) {
52+
return;
53+
}
54+
4755
auto gil_state = PyGILState_Ensure();
4856
Py_XDECREF(py_descriptor_pool_);
4957
Py_XDECREF(py_func_GetMessageClass_);
@@ -91,6 +99,10 @@ PyObject* PyMessageFactory::GetMessageClass(const std::string& message_type) {
9199

92100
PyObject* PyMessageFactory::FromString(const std::string& message_type,
93101
const std::string& serialized_proto) {
102+
if (py_descriptor_pool_ == nullptr) {
103+
return nullptr;
104+
}
105+
94106
ABSL_CHECK(PyGILState_Check());
95107
PyObject* message_class = GetMessageClass(message_type);
96108
if (!message_class) {

0 commit comments

Comments
 (0)