diff --git a/gigl/common/utils/proto_utils.py b/gigl/common/utils/proto_utils.py index 56b78b312..e61258756 100644 --- a/gigl/common/utils/proto_utils.py +++ b/gigl/common/utils/proto_utils.py @@ -1,5 +1,5 @@ from tempfile import NamedTemporaryFile -from typing import Optional, Type, TypeVar +from typing import Optional, Type, TypeVar, cast import yaml from google.protobuf import message @@ -28,7 +28,12 @@ def read_proto_from_yaml(self, uri: Uri, proto_cls: Type[T]) -> T: omega_conf_obj = OmegaConf.create(raw_data) tfh.close() obj_dict = OmegaConf.to_object(omega_conf_obj) - proto = ParseDict(js_dict=obj_dict, message=proto_cls()) + if not isinstance(obj_dict, dict): + raise TypeError( + f"ProtoUtils.read_proto_from_yaml expected a mapping at the YAML root for " + f"{uri}, got {type(obj_dict).__name__}." + ) + proto = ParseDict(js_dict=cast(dict, obj_dict), message=proto_cls()) return proto def read_proto_from_binary(self, uri: Uri, proto_cls: Type[T]) -> T: diff --git a/gigl/experimental/knowledge_graph_embedding/lib/config/__init__.py b/gigl/experimental/knowledge_graph_embedding/lib/config/__init__.py index 094f1375b..9caf7fe38 100644 --- a/gigl/experimental/knowledge_graph_embedding/lib/config/__init__.py +++ b/gigl/experimental/knowledge_graph_embedding/lib/config/__init__.py @@ -1,7 +1,7 @@ from __future__ import annotations from dataclasses import dataclass, field -from typing import cast +from typing import Any, cast import torch from google.protobuf.json_format import ParseDict @@ -70,8 +70,15 @@ def from_omegaconf(config: DictConfig) -> HeterogeneousGraphSparseEmbeddingConfi assert graph_metadata is not None, "Graph metadata is required in the config." graph_metadata_dict = OmegaConf.to_container(graph_metadata, resolve=True) + if not isinstance(graph_metadata_dict, dict): + raise TypeError( + f"HeterogeneousGraphSparseEmbeddingConfig.from_omegaconf expected " + f"dataset.metadata to resolve to a mapping, got " + f"{type(graph_metadata_dict).__name__}." + ) pb = ParseDict( - js_dict=graph_metadata_dict, message=graph_schema_pb2.GraphMetadata() + js_dict=cast(dict[str, Any], graph_metadata_dict), + message=graph_schema_pb2.GraphMetadata(), ) graph_metadata = GraphMetadataPbWrapper(graph_metadata_pb=pb) diff --git a/tests/unit/common/utils/proto_utils_test.py b/tests/unit/common/utils/proto_utils_test.py index 1e8dd54a9..793e0e7a6 100644 --- a/tests/unit/common/utils/proto_utils_test.py +++ b/tests/unit/common/utils/proto_utils_test.py @@ -68,6 +68,20 @@ def test_can_read_gbml_config_from_yaml(self): f"{expected_positive_label_date_range_start}:{expected_positive_label_date_range_end}", ) + def test_read_proto_from_yaml_raises_typeerror_when_root_is_not_a_mapping(self): + list_yaml = "- a\n- b\n- c\n" + tmp_file = NamedTemporaryFile(delete=False) + tmp_file.write(list_yaml.encode()) + tmp_file.close() + try: + with self.assertRaises(TypeError): + self.proto_utils.read_proto_from_yaml( + uri=LocalUri(tmp_file.name), + proto_cls=gbml_config_pb2.GbmlConfig, + ) + finally: + os.remove(tmp_file.name) + if __name__ == "__main__": absltest.main()