diff --git a/.github/cloud_builder/run_command_on_active_checkout.yaml b/.github/cloud_builder/run_command_on_active_checkout.yaml index 91135a88e..4bb02a915 100644 --- a/.github/cloud_builder/run_command_on_active_checkout.yaml +++ b/.github/cloud_builder/run_command_on_active_checkout.yaml @@ -5,6 +5,10 @@ options: steps: - name: us-central1-docker.pkg.dev/external-snap-ci-github-gigl/gigl-base-images/gigl-builder:51af343c1c298ab465a96ecffd4e50ea6dffacb7.88.1 entrypoint: /bin/bash + env: + # This is used to determine if the test is running on Google Cloud Build. + # See: tests/test_assets/distributed/utils.py + - "IS_GIGL_CLOUD_BUILD=true" args: - -c - | diff --git a/python/tests/integration/distributed/graph_store/graph_store_integration_test.py b/python/tests/integration/distributed/graph_store/graph_store_integration_test.py index 4d66c761b..ff4c66dd7 100644 --- a/python/tests/integration/distributed/graph_store/graph_store_integration_test.py +++ b/python/tests/integration/distributed/graph_store/graph_store_integration_test.py @@ -2,12 +2,12 @@ import os import socket import unittest -from typing import Optional +from typing import Optional, Union from unittest import mock import torch import torch.multiprocessing as mp -from torch_geometric.data import Data +from torch_geometric.data import Data, HeteroData from gigl.common import Uri from gigl.common.logger import Logger @@ -25,12 +25,16 @@ COMPUTE_CLUSTER_LOCAL_WORLD_SIZE_ENV_KEY, GraphStoreInfo, ) -from gigl.src.common.types.graph_data import EdgeType +from gigl.src.common.types.graph_data import EdgeType, NodeType, Relation from gigl.src.mocking.lib.versioning import get_mocked_dataset_artifact_metadata from gigl.src.mocking.mocking_assets.mocked_datasets_for_pipeline_tests import ( CORA_USER_DEFINED_NODE_ANCHOR_MOCKED_DATASET_INFO, + DBLP_GRAPH_NODE_ANCHOR_MOCKED_DATASET_INFO, +) +from tests.test_assets.distributed.utils import ( + assert_tensor_equality, + on_google_cloud_build, ) -from tests.test_assets.distributed.utils import assert_tensor_equality logger = Logger() @@ -62,6 +66,7 @@ def _run_client_process( client_rank: int, cluster_info: GraphStoreInfo, mp_sharing_dict: dict[str, torch.Tensor], + node_type: Optional[NodeType], expected_sampler_input: dict[int, list[torch.Tensor]], expected_edge_types: Optional[list[EdgeType]], ) -> None: @@ -100,7 +105,7 @@ def _run_client_process( torch.distributed.barrier() logger.info("Verified that all ranks received the same free ports") - sampler_input = remote_dist_dataset.get_node_ids() + sampler_input = remote_dist_dataset.get_node_ids(node_type=node_type) _assert_sampler_input(cluster_info, sampler_input, expected_sampler_input) # test "simple" case where we don't have mp sharing dict too @@ -108,28 +113,36 @@ def _run_client_process( cluster_info=cluster_info, local_rank=client_rank, mp_sharing_dict=None, - ).get_node_ids() + ).get_node_ids(node_type=node_type) _assert_sampler_input(cluster_info, simple_sampler_input, expected_sampler_input) - # Check that the edge types are correct assert ( remote_dist_dataset.get_edge_types() == expected_edge_types ), f"Expected edge types {expected_edge_types}, got {remote_dist_dataset.get_edge_types()}" torch.distributed.barrier() - + if node_type is not None: + input_nodes: Union[list[torch.Tensor], tuple[NodeType, list[torch.Tensor]]] = ( + node_type, + sampler_input, + ) + else: + input_nodes = sampler_input # Test the DistNeighborLoader loader = DistNeighborLoader( dataset=remote_dist_dataset, num_neighbors=[2, 2], pin_memory_device=torch.device("cpu"), - input_nodes=sampler_input, + input_nodes=input_nodes, num_workers=2, worker_concurrency=2, ) count = 0 for datum in loader: - assert isinstance(datum, Data) + if node_type is not None: + assert isinstance(datum, HeteroData) + else: + assert isinstance(datum, Data) count += 1 torch.distributed.barrier() logger.info(f"Rank {torch.distributed.get_rank()} loaded {count} batches") @@ -148,6 +161,7 @@ def _run_client_process( def _client_process( client_rank: int, cluster_info: GraphStoreInfo, + node_type: Optional[NodeType], expected_sampler_input: dict[int, list[torch.Tensor]], expected_edge_types: Optional[list[EdgeType]], ) -> None: @@ -166,6 +180,7 @@ def _client_process( i, # client_rank cluster_info, # cluster_info mp_sharing_dict, # mp_sharing_dict + node_type, # node_type expected_sampler_input, # expected_sampler_input expected_edge_types, # expected_edge_types ], @@ -239,7 +254,7 @@ def _get_expected_input_nodes_by_rank( class GraphStoreIntegrationTest(unittest.TestCase): - def test_graph_store_locally(self): + def _test_graph_store_homogeneous(self): # Simulating two server machine, two compute machines. # Each machine has one process. cora_supervised_info = get_mocked_dataset_artifact_metadata()[ @@ -295,6 +310,7 @@ def test_graph_store_locally(self): args=[ i, # client_rank cluster_info, # cluster_info + None, # node_type - None for homogeneous dataset expected_sampler_input, # expected_sampler_input None, # expected_edge_types - None for homogeneous dataset ], @@ -332,3 +348,106 @@ def test_graph_store_locally(self): client_process.join() for server_process in server_processes: server_process.join() + + # TODO: (mkolodner-sc) - Figure out why this test is failing on Google Cloud Build + @unittest.skipIf( + on_google_cloud_build(), "Failing on Google Cloud Build - skiping for now" + ) + def test_graph_store_heterogeneous(self): + # Simulating two server machine, two compute machines. + # Each machine has one process. + dblp_supervised_info = get_mocked_dataset_artifact_metadata()[ + DBLP_GRAPH_NODE_ANCHOR_MOCKED_DATASET_INFO.name + ] + task_config_uri = dblp_supervised_info.frozen_gbml_config_uri + ( + cluster_master_port, + storage_cluster_master_port, + compute_cluster_master_port, + master_port, + rpc_master_port, + rpc_wait_port, + ) = get_free_ports(num_ports=6) + host_ip = socket.gethostbyname(socket.gethostname()) + cluster_info = GraphStoreInfo( + num_storage_nodes=2, + num_compute_nodes=2, + num_processes_per_compute=2, + cluster_master_ip=host_ip, + storage_cluster_master_ip=host_ip, + compute_cluster_master_ip=host_ip, + cluster_master_port=cluster_master_port, + storage_cluster_master_port=storage_cluster_master_port, + compute_cluster_master_port=compute_cluster_master_port, + rpc_master_port=rpc_master_port, + rpc_wait_port=rpc_wait_port, + ) + + num_dblp_nodes = 4057 + expected_sampler_input = _get_expected_input_nodes_by_rank( + num_dblp_nodes, cluster_info + ) + expected_edge_types = [ + EdgeType(NodeType("author"), Relation("to"), NodeType("paper")), + EdgeType(NodeType("paper"), Relation("to"), NodeType("author")), + EdgeType(NodeType("term"), Relation("to"), NodeType("paper")), + ] + ctx = mp.get_context("spawn") + client_processes: list = [] + for i in range(cluster_info.num_compute_nodes): + with mock.patch.dict( + os.environ, + { + "MASTER_ADDR": host_ip, + "MASTER_PORT": str(master_port), + "RANK": str(i), + "WORLD_SIZE": str(cluster_info.compute_cluster_world_size), + COMPUTE_CLUSTER_LOCAL_WORLD_SIZE_ENV_KEY: str( + cluster_info.num_processes_per_compute + ), + }, + clear=False, + ): + client_process = ctx.Process( + target=_client_process, + args=[ + i, # client_rank + cluster_info, # cluster_info + NodeType("author"), # node_type + expected_sampler_input, # expected_sampler_input + expected_edge_types, # expected_edge_types + ], + ) + client_process.start() + client_processes.append(client_process) + # Start server process + server_processes = [] + for i in range(cluster_info.num_storage_nodes): + with mock.patch.dict( + os.environ, + { + "MASTER_ADDR": host_ip, + "MASTER_PORT": str(master_port), + "RANK": str(i + cluster_info.num_compute_nodes), + "WORLD_SIZE": str(cluster_info.compute_cluster_world_size), + COMPUTE_CLUSTER_LOCAL_WORLD_SIZE_ENV_KEY: str( + cluster_info.num_processes_per_compute + ), + }, + clear=False, + ): + server_process = ctx.Process( + target=_run_server_processes, + args=[ + cluster_info, # cluster_info + task_config_uri, # task_config_uri + True, # is_inference + ], + ) + server_process.start() + server_processes.append(server_process) + + for client_process in client_processes: + client_process.join() + for server_process in server_processes: + server_process.join() diff --git a/python/tests/test_assets/distributed/utils.py b/python/tests/test_assets/distributed/utils.py index 15a6f3a95..a7f56b338 100644 --- a/python/tests/test_assets/distributed/utils.py +++ b/python/tests/test_assets/distributed/utils.py @@ -1,3 +1,4 @@ +import os from typing import Callable, Optional import torch @@ -67,3 +68,10 @@ def create_test_process_group() -> None: world_size=1, init_method=get_process_group_init_method(), ) + + +def on_google_cloud_build() -> bool: + """ + Returns True if the test is running on Google Cloud Build. + """ + return os.environ.get("IS_GIGL_CLOUD_BUILD", "false").lower() == "true" diff --git a/python/tests/unit/distributed/dist_ablp_neighborloader_test.py b/python/tests/unit/distributed/dist_ablp_neighborloader_test.py index 791dee288..e6e24aa2b 100644 --- a/python/tests/unit/distributed/dist_ablp_neighborloader_test.py +++ b/python/tests/unit/distributed/dist_ablp_neighborloader_test.py @@ -38,6 +38,7 @@ from tests.test_assets.distributed.utils import ( assert_tensor_equality, create_test_process_group, + on_google_cloud_build, ) _POSITIVE_EDGE_TYPE = message_passing_to_positive_label(DEFAULT_HOMOGENEOUS_EDGE_TYPE) @@ -565,7 +566,9 @@ def test_cora_supervised(self): ) # TODO: (mkolodner-sc) - Figure out why this test is failing on Google Cloud Build - @unittest.skip("Failing on Google Cloud Build - skiping for now") + @unittest.skipIf( + on_google_cloud_build(), "Failing on Google Cloud Build - skiping for now" + ) def test_dblp_supervised(self): create_test_process_group() dblp_supervised_info = get_mocked_dataset_artifact_metadata()[ diff --git a/python/tests/unit/distributed/distributed_neighborloader_test.py b/python/tests/unit/distributed/distributed_neighborloader_test.py index 74ecd13eb..2e495bd5c 100644 --- a/python/tests/unit/distributed/distributed_neighborloader_test.py +++ b/python/tests/unit/distributed/distributed_neighborloader_test.py @@ -39,6 +39,7 @@ from tests.test_assets.distributed.utils import ( assert_tensor_equality, create_test_process_group, + on_google_cloud_build, ) _POSITIVE_EDGE_TYPE = message_passing_to_positive_label(DEFAULT_HOMOGENEOUS_EDGE_TYPE) @@ -349,7 +350,9 @@ def test_infinite_distributed_neighbor_loader(self): ) # TODO: (svij) - Figure out why this test is failing on Google Cloud Build - @unittest.skip("Failing on Google Cloud Build - skiping for now") + @unittest.skipIf( + on_google_cloud_build(), "Failing on Google Cloud Build - skiping for now" + ) def test_distributed_neighbor_loader_heterogeneous(self): expected_data_count = 4057