Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions gigl/common/utils/proto_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from tempfile import NamedTemporaryFile
from typing import Optional, Type, TypeVar
from typing import Optional, Type, TypeVar, cast

import yaml
from google.protobuf import message
Expand Down Expand Up @@ -28,7 +28,12 @@ def read_proto_from_yaml(self, uri: Uri, proto_cls: Type[T]) -> T:
omega_conf_obj = OmegaConf.create(raw_data)
tfh.close()
obj_dict = OmegaConf.to_object(omega_conf_obj)
proto = ParseDict(js_dict=obj_dict, message=proto_cls())
if not isinstance(obj_dict, dict):
raise TypeError(
f"ProtoUtils.read_proto_from_yaml expected a mapping at the YAML root for "
f"{uri}, got {type(obj_dict).__name__}."
)
proto = ParseDict(js_dict=cast(dict, obj_dict), message=proto_cls())
return proto

def read_proto_from_binary(self, uri: Uri, proto_cls: Type[T]) -> T:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import annotations

from dataclasses import dataclass, field
from typing import cast
from typing import Any, cast

import torch
from google.protobuf.json_format import ParseDict
Expand Down Expand Up @@ -70,8 +70,15 @@ def from_omegaconf(config: DictConfig) -> HeterogeneousGraphSparseEmbeddingConfi
assert graph_metadata is not None, "Graph metadata is required in the config."

graph_metadata_dict = OmegaConf.to_container(graph_metadata, resolve=True)
if not isinstance(graph_metadata_dict, dict):
raise TypeError(
f"HeterogeneousGraphSparseEmbeddingConfig.from_omegaconf expected "
f"dataset.metadata to resolve to a mapping, got "
f"{type(graph_metadata_dict).__name__}."
)
pb = ParseDict(
js_dict=graph_metadata_dict, message=graph_schema_pb2.GraphMetadata()
js_dict=cast(dict[str, Any], graph_metadata_dict),
message=graph_schema_pb2.GraphMetadata(),
)
Comment thread
kmontemayor2-sc marked this conversation as resolved.
graph_metadata = GraphMetadataPbWrapper(graph_metadata_pb=pb)

Expand Down
14 changes: 14 additions & 0 deletions tests/unit/common/utils/proto_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,20 @@ def test_can_read_gbml_config_from_yaml(self):
f"{expected_positive_label_date_range_start}:{expected_positive_label_date_range_end}",
)

def test_read_proto_from_yaml_raises_typeerror_when_root_is_not_a_mapping(self):
list_yaml = "- a\n- b\n- c\n"
tmp_file = NamedTemporaryFile(delete=False)
tmp_file.write(list_yaml.encode())
tmp_file.close()
try:
with self.assertRaises(TypeError):
self.proto_utils.read_proto_from_yaml(
uri=LocalUri(tmp_file.name),
proto_cls=gbml_config_pb2.GbmlConfig,
)
finally:
os.remove(tmp_file.name)


if __name__ == "__main__":
absltest.main()