diff --git a/python/gigl/distributed/distributed_neighborloader.py b/python/gigl/distributed/distributed_neighborloader.py index 3be9662f0..1125f1f9f 100644 --- a/python/gigl/distributed/distributed_neighborloader.py +++ b/python/gigl/distributed/distributed_neighborloader.py @@ -3,7 +3,11 @@ import torch from graphlearn_torch.channel import SampleMessage -from graphlearn_torch.distributed import DistLoader, MpDistSamplingWorkerOptions +from graphlearn_torch.distributed import ( + DistLoader, + MpDistSamplingWorkerOptions, + RemoteDistSamplingWorkerOptions, +) from graphlearn_torch.sampler import NodeSamplerInput, SamplingConfig, SamplingType from torch_geometric.data import Data, HeteroData from torch_geometric.typing import EdgeType @@ -13,7 +17,10 @@ from gigl.distributed.constants import DEFAULT_MASTER_INFERENCE_PORT from gigl.distributed.dist_context import DistributedContext from gigl.distributed.dist_dataset import DistDataset +from gigl.distributed.graph_store.remote_dist_dataset import RemoteDistDataset from gigl.distributed.utils.neighborloader import ( + DatasetSchema, + SamplingClusterSetup, labeled_to_homogeneous, patch_fanout_for_sampling, set_missing_features, @@ -37,10 +44,15 @@ class DistNeighborLoader(DistLoader): def __init__( self, - dataset: DistDataset, + dataset: Union[DistDataset, RemoteDistDataset], num_neighbors: Union[list[int], dict[EdgeType, list[int]]], input_nodes: Optional[ - Union[torch.Tensor, Tuple[NodeType, torch.Tensor]] + Union[ + torch.Tensor, + Tuple[NodeType, torch.Tensor], + list[torch.Tensor], + Tuple[NodeType, list[torch.Tensor]], + ] ] = None, num_workers: int = 1, batch_size: int = 1, @@ -62,7 +74,8 @@ def __init__( https://pytorch-geometric.readthedocs.io/en/2.5.2/_modules/torch_geometric/distributed/dist_neighbor_loader.html#DistNeighborLoader Args: - dataset (DistDataset): The dataset to sample from. + dataset (DistDataset | RemoteDistDataset): The dataset to sample from. + If this is a `RemoteDistDataset`, then we assumed to be in "Graph Store" mode. num_neighbors (list[int] or dict[Tuple[str, str, str], list[int]]): The number of neighbors to sample for each node in each iteration. If an entry is set to `-1`, all neighbors will be included. @@ -71,12 +84,17 @@ def __init__( context (deprecated - will be removed soon) (DistributedContext): Distributed context information of the current process. local_process_rank (deprecated - will be removed soon) (int): Required if context provided. The local rank of the current process within a node. local_process_world_size (deprecated - will be removed soon)(int): Required if context provided. The total number of processes within a node. - input_nodes (torch.Tensor or Tuple[str, torch.Tensor]): The - indices of seed nodes to start sampling from. + input_nodes (Tensor | Tuple[NodeType, Tensor] | list[Tensor] | Tuple[NodeType, list[Tensor]]): + The nodes to start sampling from. It is of type `torch.LongTensor` for homogeneous graphs. If set to `None` for homogeneous settings, all nodes will be considered. In heterogeneous graphs, this flag must be passed in as a tuple that holds the node type and node indices. (default: `None`) + For Graph Store mode, this must be a tuple of (NodeType, list[Tenor]) or list[Tenor]. + Where each Tensor in the list is the node ids to sample from, for each server. + e.g. [[10, 20], [30, 40]] means sample from nodes 10 and 20 on server 0, and nodes 30 and 40 on server 1. + If a Graph Store input (e.g. list[Tensor]) is provided to colocated mode, or colocated input (e.g. Tensor) is provided to Graph Store mode, + then an error will be raised. num_workers (int): How many workers to use (subprocesses to spwan) for distributed neighbor sampling of the current process. (default: ``1``). batch_size (int, optional): how many samples per batch to load @@ -180,7 +198,11 @@ def __init__( local_process_rank, local_process_world_size, ) # delete deprecated vars so we don't accidentally use them. - + if isinstance(dataset, RemoteDistDataset): + self._sampling_cluster_setup = SamplingClusterSetup.GRAPH_STORE + else: + self._sampling_cluster_setup = SamplingClusterSetup.COLOCATED + logger.info(f"Sampling cluster setup: {self._sampling_cluster_setup.value}") device = ( pin_memory_device if pin_memory_device @@ -188,10 +210,256 @@ def __init__( local_process_rank=local_rank ) ) + + # Determines if the node ids passed in are heterogeneous or homogeneous. + if self._sampling_cluster_setup == SamplingClusterSetup.COLOCATED: + assert isinstance( + dataset, DistDataset + ), "When using colocated mode, dataset must be a DistDataset." + input_data, worker_options, dataset_metadata = self._setup_for_colocated( + input_nodes, + dataset, + local_rank, + local_world_size, + device, + master_ip_address, + node_rank, + node_world_size, + process_start_gap_seconds, + num_workers, + worker_concurrency, + channel_size, + num_cpu_threads, + ) + else: # Graph Store mode + assert isinstance( + dataset, RemoteDistDataset + ), "When using Graph Store mode, dataset must be a RemoteDistDataset." + input_data, worker_options, dataset_metadata = self._setup_for_graph_store( + input_nodes, + dataset, + num_workers, + ) + + self._is_labeled_heterogeneous = dataset_metadata.is_labeled_heterogeneous + self._node_feature_info = dataset_metadata.node_feature_info + self._edge_feature_info = dataset_metadata.edge_feature_info + + logger.info(f"num_neighbors before patch: {num_neighbors}") + num_neighbors = patch_fanout_for_sampling( + edge_types=dataset_metadata.edge_types, + num_neighbors=num_neighbors, + ) + logger.info( + f"num_neighbors: {num_neighbors}, edge_types: {dataset_metadata.edge_types}" + ) + sampling_config = SamplingConfig( + sampling_type=SamplingType.NODE, + num_neighbors=num_neighbors, + batch_size=batch_size, + shuffle=shuffle, + drop_last=drop_last, + with_edge=True, + collect_features=True, + with_neg=False, + with_weight=False, + edge_dir=dataset_metadata.edge_dir, + seed=None, # it's actually optional - None means random. + ) + + if should_cleanup_distributed_context and torch.distributed.is_initialized(): + logger.info( + f"Cleaning up process group as it was initialized inside {self.__class__.__name__}.__init__." + ) + torch.distributed.destroy_process_group() + + if self._sampling_cluster_setup == SamplingClusterSetup.COLOCATED: + super().__init__( + dataset, # Pass in the dataset for colocated mode. + input_data, + sampling_config, + device, + worker_options, + ) + else: + # For Graph Store mode, we need to start the communcation between compute and storage nodes sequentially, by compute node. + # E.g. intialize connections between compute node 0 and storage nodes 0, 1, 2, 3, then compute node 1 and storage nodes 0, 1, 2, 3, etc. + # Note that each compute node may have multiple connections to each storage node, once per compute process. + # It's important to distinguish "compute node" (e.g. physical compute machine) from "compute process" (e.g. process running on the compute node). + # Since in practice we have multiple compute processes per compute node, and each compute process needs to initialize the connection to the storage nodes. + # E.g. if there are 4 gpus per compute node, then there will be 4 connections from each compute node to each storage node. + # We need to this because if we don't, then there is a race condition when initalizing the samplers on the storage nodes [1] + # Where since the lock is per *server* (e.g. per storage node), if we try to start one connection from compute node 0, and compute node 1 + # Then we deadlock and fail. + # Specifically, the race condition happens in `DistLoader.__init__` when it initializes the sampling producers on the storage nodes. [2] + # [1]: https://github.com/alibaba/graphlearn-for-pytorch/blob/main/graphlearn_torch/python/distributed/dist_server.py#L129-L167 + # [2]: https://github.com/alibaba/graphlearn-for-pytorch/blob/88ff111ac0d9e45c6c9d2d18cfc5883dca07e9f9/graphlearn_torch/python/distributed/dist_loader.py#L187-L193 + + # See below for a connection setup. + # ╔═══════════════════════════════════════════════════════════════════════════════════════╗ + # ║ COMPUTE TO STORAGE NODE CONNECTIONS ║ + # ╚═══════════════════════════════════════════════════════════════════════════════════════╝ + + # COMPUTE NODES STORAGE NODES + # ═════════════ ═════════════ + + # ┌──────────────────────┐ (1) ┌───────────────┐ + # │ COMPUTE NODE 0 │ │ │ + # │ ┌────┬────┬────┬────┤ ══════════════════════════════════│ STORAGE 0 │ + # │ │GPU │GPU │GPU │GPU │ ╱ │ │ + # │ │ 0 │ 1 │ 2 │ 3 │ ════════════════════╲ ╱ └───────────────┘ + # │ └────┴────┴────┴────┤ (2) ╲ ╱ + # └──────────────────────┘ ╲ ╱ + # ╳ + # (3) ╱ ╲ (4) + # ┌──────────────────────┐ ╱ ╲ ┌───────────────┐ + # │ COMPUTE NODE 1 │ ╱ ╲ │ │ + # │ ┌────┬────┬────┬────┤ ═════════════════╱ ═│ STORAGE 1 │ + # │ │GPU │GPU │GPU │GPU │ │ │ + # │ │ 0 │ 1 │ 2 │ 3 │ ══════════════════════════════════│ │ + # │ └────┴────┴────┴────┤ └───────────────┘ + # └──────────────────────┘ + + # ┌─────────────────────────────────────────────────────────────────────────────┐ + # │ (1) Compute Node 0 → Storage 0 (4 connections, one per GPU) │ + # │ (2) Compute Node 0 → Storage 1 (4 connections, one per GPU) │ + # │ (3) Compute Node 1 → Storage 0 (4 connections, one per GPU) │ + # │ (4) Compute Node 1 → Storage 1 (4 connections, one per GPU) │ + # └─────────────────────────────────────────────────────────────────────────────┘ + node_rank = dataset.cluster_info.compute_node_rank + for target_node_rank in range(dataset.cluster_info.num_compute_nodes): + if node_rank == target_node_rank: + super().__init__( + None, # Pass in None for Graph Store mode. + input_data, + sampling_config, + device, + worker_options, + ) + print(f"node_rank {node_rank} initialized the dist loader") + torch.distributed.barrier() + torch.distributed.barrier() + + def _setup_for_graph_store( + self, + input_nodes: Optional[ + Union[ + torch.Tensor, + Tuple[NodeType, torch.Tensor], + list[torch.Tensor], + Tuple[NodeType, list[torch.Tensor]], + ] + ], + dataset: RemoteDistDataset, + num_workers: int, + ) -> tuple[NodeSamplerInput, RemoteDistSamplingWorkerOptions, DatasetSchema]: + if input_nodes is None: + raise ValueError( + f"When using Graph Store mode, input nodes must be provided, received {input_nodes}" + ) + elif isinstance(input_nodes, torch.Tensor): + raise ValueError( + f"When using Graph Store mode, input nodes must be of type (list[Tensor] | (NodeType, list[torch.Tensor]), received {type(input_nodes)}" + ) + elif isinstance(input_nodes, tuple) and isinstance( + input_nodes[1], torch.Tensor + ): + raise ValueError( + f"When using Graph Store mode, input nodes must be of type (list[torch.Tensor] | (NodeType, list[torch.Tensor])), received {type(input_nodes)} ({type(input_nodes[0])}, {type(input_nodes[1])})" + ) + + is_labeled_heterogeneous = False + node_feature_info = dataset.get_node_feature_info() + edge_feature_info = dataset.get_edge_feature_info() + edge_types = dataset.get_edge_types() + node_rank = dataset.cluster_info.compute_node_rank + + # Get sampling ports for compute-storage connections. + sampling_ports = dataset.get_free_ports_on_storage_cluster( + num_ports=dataset.cluster_info.num_processes_per_compute + ) + sampling_port = sampling_ports[node_rank] + + worker_options = RemoteDistSamplingWorkerOptions( + server_rank=list(range(dataset.cluster_info.num_storage_nodes)), + num_workers=num_workers, + worker_devices=[torch.device("cpu") for i in range(num_workers)], + master_addr=dataset.cluster_info.storage_cluster_master_ip, + master_port=sampling_port, + worker_key=f"compute_rank_{node_rank}", + ) logger.info( - f"Dataset Building started on {node_rank} of {node_world_size} nodes, using following node as main: {master_ip_address}" + f"Rank {torch.distributed.get_rank()}! init for sampling rpc: {f'tcp://{dataset.cluster_info.storage_cluster_master_ip}:{sampling_port}'}" ) + # Setup input data for the dataloader. + + # Determine nodes list and fallback input_type based on input_nodes structure + if isinstance(input_nodes, list): + nodes = input_nodes + fallback_input_type = None + require_edge_feature_info = False + elif isinstance(input_nodes, tuple) and isinstance(input_nodes[1], list): + nodes = input_nodes[1] + fallback_input_type = input_nodes[0] + require_edge_feature_info = True + else: + raise ValueError( + f"When using Graph Store mode, input nodes must be of type (list[torch.Tensor] | (NodeType, list[torch.Tensor])), received {type(input_nodes)}" + ) + + # Determine input_type based on edge_feature_info + if isinstance(edge_types, list): + if edge_types == [DEFAULT_HOMOGENEOUS_EDGE_TYPE]: + input_type: Optional[NodeType] = DEFAULT_HOMOGENEOUS_NODE_TYPE + else: + input_type = fallback_input_type + elif require_edge_feature_info: + raise ValueError( + "When using Graph Store mode, edge types must be provided for heterogeneous graphs." + ) + else: + input_type = None + + input_data = [ + NodeSamplerInput(node=node, input_type=input_type) for node in nodes + ] + + return ( + input_data, + worker_options, + DatasetSchema( + is_labeled_heterogeneous=is_labeled_heterogeneous, + edge_types=edge_types, + node_feature_info=node_feature_info, + edge_feature_info=edge_feature_info, + edge_dir=dataset.get_edge_dir(), + ), + ) + + def _setup_for_colocated( + self, + input_nodes: Optional[ + Union[ + torch.Tensor, + Tuple[NodeType, torch.Tensor], + list[torch.Tensor], + Tuple[NodeType, list[torch.Tensor]], + ] + ], + dataset: DistDataset, + local_rank: int, + local_world_size: int, + device: torch.device, + master_ip_address: str, + node_rank: int, + node_world_size: int, + process_start_gap_seconds: float, + num_workers: int, + worker_concurrency: int, + channel_size: str, + num_cpu_threads: Optional[int], + ) -> tuple[NodeSamplerInput, MpDistSamplingWorkerOptions, DatasetSchema]: if input_nodes is None: if dataset.node_ids is None: raise ValueError( @@ -202,9 +470,15 @@ def __init__( f"input_nodes must be provided for heterogeneous datasets, received node_ids of type: {dataset.node_ids.keys()}" ) input_nodes = dataset.node_ids - - # Determines if the node ids passed in are heterogeneous or homogeneous. - self._is_labeled_heterogeneous = False + if isinstance(input_nodes, list): + raise ValueError( + f"When using Colocated mode, input nodes must be of type (torch.Tensor | (NodeType, torch.Tensor)), received {type(input_nodes)}" + ) + elif isinstance(input_nodes, tuple) and isinstance(input_nodes[1], list): + raise ValueError( + f"When using Colocated mode, input nodes must be of type (torch.Tensor | (NodeType, torch.Tensor)), received {type(input_nodes)} ({type(input_nodes[0])}, {type(input_nodes[1])})" + ) + is_labeled_heterogeneous = False if isinstance(input_nodes, torch.Tensor): node_ids = input_nodes @@ -216,7 +490,7 @@ def __init__( and DEFAULT_HOMOGENEOUS_NODE_TYPE in dataset.node_ids ): node_type = DEFAULT_HOMOGENEOUS_NODE_TYPE - self._is_labeled_heterogeneous = True + is_labeled_heterogeneous = True else: raise ValueError( f"For heterogeneous datasets, input_nodes must be a tuple of (node_type, node_ids) OR if it is a labeled homogeneous dataset, input_nodes may be a torch.Tensor. Received node types: {dataset.node_ids.keys()}" @@ -229,19 +503,12 @@ def __init__( dataset.node_ids, abc.Mapping ), "Dataset must be heterogeneous if provided input nodes are a tuple." - num_neighbors = patch_fanout_for_sampling( - dataset.get_edge_types(), num_neighbors - ) - curr_process_nodes = shard_nodes_by_process( input_nodes=node_ids, local_process_rank=local_rank, local_process_world_size=local_world_size, ) - self._node_feature_info = dataset.node_feature_info - self._edge_feature_info = dataset.edge_feature_info - input_data = NodeSamplerInput(node=curr_process_nodes, input_type=node_type) # Sets up processes and torch device for initializing the GLT DistNeighborLoader, setting up RPC and worker groups to minimize @@ -305,28 +572,23 @@ def __init__( pin_memory=device.type == "cuda", ) - sampling_config = SamplingConfig( - sampling_type=SamplingType.NODE, - num_neighbors=num_neighbors, - batch_size=batch_size, - shuffle=shuffle, - drop_last=drop_last, - with_edge=True, - collect_features=True, - with_neg=False, - with_weight=False, - edge_dir=dataset.edge_dir, - seed=None, # it's actually optional - None means random. + if isinstance(dataset.graph, dict): + edge_types = list(dataset.graph.keys()) + else: + edge_types = None + + return ( + input_data, + worker_options, + DatasetSchema( + is_labeled_heterogeneous=is_labeled_heterogeneous, + edge_types=edge_types, + node_feature_info=dataset.node_feature_info, + edge_feature_info=dataset.edge_feature_info, + edge_dir=dataset.edge_dir, + ), ) - if should_cleanup_distributed_context and torch.distributed.is_initialized(): - logger.info( - f"Cleaning up process group as it was initialized inside {self.__class__.__name__}.__init__." - ) - torch.distributed.destroy_process_group() - - super().__init__(dataset, input_data, sampling_config, device, worker_options) - def _collate_fn(self, msg: SampleMessage) -> Union[Data, HeteroData]: data = super()._collate_fn(msg) data = set_missing_features( diff --git a/python/gigl/distributed/graph_store/compute.py b/python/gigl/distributed/graph_store/compute.py index 36a3b66dd..cd556d941 100644 --- a/python/gigl/distributed/graph_store/compute.py +++ b/python/gigl/distributed/graph_store/compute.py @@ -36,30 +36,34 @@ def init_compute_process( cluster_info.compute_node_rank * cluster_info.num_processes_per_compute + local_rank ) + cluster_master_ip = cluster_info.storage_cluster_master_ip logger.info( - f"Initializing compute process group {compute_cluster_rank} / {cluster_info.compute_cluster_world_size}. on {cluster_info.compute_cluster_master_ip}:{cluster_info.compute_cluster_master_port} with backend {compute_world_backend}." - f" OS rank: {os.environ['RANK']}, local client rank: {local_rank}" - ) - torch.distributed.init_process_group( - backend=compute_world_backend, - world_size=cluster_info.compute_cluster_world_size, - rank=compute_cluster_rank, - init_method=f"tcp://{cluster_info.compute_cluster_master_ip}:{cluster_info.compute_cluster_master_port}", - ) - logger.info( - f"Initializing RPC client for compute node {compute_cluster_rank} / {cluster_info.compute_cluster_world_size} on {cluster_info.cluster_master_ip}:{cluster_info.cluster_master_port}." + f"Initializing RPC client for compute node {compute_cluster_rank} / {cluster_info.compute_cluster_world_size} on {cluster_master_ip}:{cluster_info.rpc_master_port}." f" OS rank: {os.environ['RANK']}, local compute rank: {local_rank}" f" num_servers: {cluster_info.num_storage_nodes}, num_clients: {cluster_info.compute_cluster_world_size}" ) + # Initialize the GLT client before starting the Torch Distributed process group. + # Otherwise, we saw intermittent hangs when initializing the client. glt.distributed.init_client( num_servers=cluster_info.num_storage_nodes, num_clients=cluster_info.compute_cluster_world_size, client_rank=compute_cluster_rank, - master_addr=cluster_info.cluster_master_ip, - master_port=cluster_info.cluster_master_port, + master_addr=cluster_master_ip, + master_port=cluster_info.rpc_master_port, client_group_name="gigl_client_rpc", ) + logger.info( + f"Initializing compute process group {compute_cluster_rank} / {cluster_info.compute_cluster_world_size}. on {cluster_info.compute_cluster_master_ip}:{cluster_info.compute_cluster_master_port} with backend {compute_world_backend}." + f" OS rank: {os.environ['RANK']}, local client rank: {local_rank}" + ) + torch.distributed.init_process_group( + backend=compute_world_backend, + world_size=cluster_info.compute_cluster_world_size, + rank=compute_cluster_rank, + init_method=f"tcp://{cluster_info.compute_cluster_master_ip}:{cluster_info.compute_cluster_master_port}", + ) + def shutdown_compute_proccess() -> None: """ diff --git a/python/gigl/distributed/graph_store/storage_main.py b/python/gigl/distributed/graph_store/storage_main.py index 0cfdef957..3dc8040a5 100644 --- a/python/gigl/distributed/graph_store/storage_main.py +++ b/python/gigl/distributed/graph_store/storage_main.py @@ -30,24 +30,32 @@ def _run_storage_process( storage_world_backend: Optional[str], ) -> None: register_dataset(dataset) + cluster_master_ip = cluster_info.storage_cluster_master_ip logger.info( - f"Initializing storage node {storage_rank} / {cluster_info.num_storage_nodes} with backend {storage_world_backend} on {cluster_info.cluster_master_ip}:{torch_process_port}" - ) - torch.distributed.init_process_group( - backend=storage_world_backend, - world_size=cluster_info.num_storage_nodes, - rank=storage_rank, - init_method=f"tcp://{cluster_info.cluster_master_ip}:{torch_process_port}", + f"Initializing GLT server for storage node process group {storage_rank} / {cluster_info.num_storage_nodes} on {cluster_master_ip}:{cluster_info.rpc_master_port}" ) + # Initialize the GLT server before starting the Torch Distributed process group. + # Otherwise, we saw intermittent hangs when initializing the server. glt.distributed.init_server( num_servers=cluster_info.num_storage_nodes, server_rank=storage_rank, dataset=dataset, - master_addr=cluster_info.cluster_master_ip, - master_port=cluster_info.cluster_master_port, + master_addr=cluster_master_ip, + master_port=cluster_info.rpc_master_port, num_clients=cluster_info.compute_cluster_world_size, ) + init_method = f"tcp://{cluster_info.storage_cluster_master_ip}:{torch_process_port}" + logger.info( + f"Initializing storage node process group {storage_rank} / {cluster_info.num_storage_nodes} with backend {storage_world_backend} on {init_method}" + ) + torch.distributed.init_process_group( + backend=storage_world_backend, + world_size=cluster_info.num_storage_nodes, + rank=storage_rank, + init_method=init_method, + ) + logger.info( f"Waiting for storage node {storage_rank} / {cluster_info.num_storage_nodes} to exit" ) @@ -59,7 +67,7 @@ def storage_node_process( storage_rank: int, cluster_info: GraphStoreInfo, task_config_uri: Uri, - is_inference: bool, + is_inference: bool = True, tf_record_uri_pattern: str = ".*-of-.*\.tfrecord(\.gz)?$", storage_world_backend: Optional[str] = None, ) -> None: @@ -71,7 +79,7 @@ def storage_node_process( storage_rank (int): The rank of the storage node. cluster_info (GraphStoreInfo): The cluster information. task_config_uri (Uri): The task config URI. - is_inference (bool): Whether the process is an inference process. + is_inference (bool): Whether the process is an inference process. Defaults to True. tf_record_uri_pattern (str): The TF Record URI pattern. storage_world_backend (Optional[str]): The backend for the storage Torch Distributed process group. """ @@ -95,6 +103,7 @@ def storage_node_process( _tfrecord_uri_pattern=tf_record_uri_pattern, ) torch_process_port = get_free_ports_from_master_node(num_ports=1)[0] + torch.distributed.destroy_process_group() server_processes = [] mp_context = torch.multiprocessing.get_context("spawn") # TODO(kmonte): Enable more than one server process per machine @@ -120,18 +129,21 @@ def storage_node_process( parser = argparse.ArgumentParser() parser.add_argument("--task_config_uri", type=str, required=True) parser.add_argument("--resource_config_uri", type=str, required=True) - parser.add_argument("--is_inference", action="store_true") + parser.add_argument("--job_name", type=str, required=True) args = parser.parse_args() logger.info(f"Running storage node with arguments: {args}") is_inference = args.is_inference - torch.distributed.init_process_group() + torch.distributed.init_process_group(backend="gloo") cluster_info = get_graph_store_info() + logger.info(f"Cluster info: {cluster_info}") + logger.info( + f"World size: {torch.distributed.get_world_size()}, rank: {torch.distributed.get_rank()}, OS world size: {os.environ['WORLD_SIZE']}, OS rank: {os.environ['RANK']}" + ) # Tear down the """"global""" process group so we can have a server-specific process group. torch.distributed.destroy_process_group() storage_node_process( storage_rank=cluster_info.storage_node_rank, cluster_info=cluster_info, task_config_uri=UriFactory.create_uri(args.task_config_uri), - is_inference=is_inference, ) diff --git a/python/gigl/distributed/utils/neighborloader.py b/python/gigl/distributed/utils/neighborloader.py index 29a09d71a..54b687c98 100644 --- a/python/gigl/distributed/utils/neighborloader.py +++ b/python/gigl/distributed/utils/neighborloader.py @@ -1,7 +1,9 @@ """Utils for Neighbor loaders.""" from collections import abc from copy import deepcopy -from typing import Optional, TypeVar, Union +from dataclasses import dataclass +from enum import Enum +from typing import Literal, Optional, TypeVar, Union import torch from torch_geometric.data import Data, HeteroData @@ -15,6 +17,33 @@ _GraphType = TypeVar("_GraphType", Data, HeteroData) +class SamplingClusterSetup(Enum): + """ + The setup of the sampling cluster. + """ + + COLOCATED = "colocated" + GRAPH_STORE = "graph_store" + + +@dataclass(frozen=True) +class DatasetSchema: + """ + Shared metadata between the local and remote datasets. + """ + + # If the dataset is labeled heterogeneous. E.g. one node type, one edge type, and "label" edges. + is_labeled_heterogeneous: bool + # List of all edge types in the graph. + edge_types: Optional[list[EdgeType]] + # Node feature info. + node_feature_info: Optional[Union[FeatureInfo, dict[NodeType, FeatureInfo]]] + # Edge feature info. + edge_feature_info: Optional[Union[FeatureInfo, dict[EdgeType, FeatureInfo]]] + # Edge direction. + edge_dir: Union[str, Literal["in", "out"]] + + def patch_fanout_for_sampling( edge_types: Optional[list[EdgeType]], num_neighbors: Union[list[int], dict[EdgeType, list[int]]], diff --git a/python/gigl/distributed/utils/networking.py b/python/gigl/distributed/utils/networking.py index 7d2ba46b9..e2e1f320e 100644 --- a/python/gigl/distributed/utils/networking.py +++ b/python/gigl/distributed/utils/networking.py @@ -115,6 +115,13 @@ def get_internal_ip_from_master_node( ) -> str: """ Get the internal IP address of the master node in a distributed setup. + + Args: + _global_rank_override (Optional[int]): Override for the global rank, + useful for testing or if global rank is not accurately available. + + Returns: + str: The internal IP address of the master node. """ return get_internal_ip_from_node( node_rank=0, _global_rank_override=_global_rank_override @@ -131,6 +138,12 @@ def get_internal_ip_from_node( i.e. when using :py:obj:`gigl.distributed.dataset_factory` + Args: + node_rank (int): Rank of the node, to fetch the internal IP address of. + device (Optional[torch.device]): Device to use for communication. Defaults to None, which will use the default device. + _global_rank_override (Optional[int]): Override for the global rank, + useful for testing or if global rank is not accurately available. + Returns: str: The internal IP address of the node. """ @@ -158,7 +171,9 @@ def get_internal_ip_from_node( device = torch.device("cuda" if torch.cuda.is_available() else "cpu") torch.distributed.broadcast_object_list(ip_list, src=node_rank, device=device) node_ip = ip_list[0] - logger.info(f"Rank {rank} received master node's internal IP: {node_ip}") + logger.info( + f"Rank {rank} received master node's internal IP: {node_ip} on device {device}" + ) assert node_ip is not None, "Could not retrieve master node's internal IP" return node_ip @@ -230,12 +245,21 @@ def get_graph_store_info() -> GraphStoreInfo: compute_cluster_master_ip = cluster_master_ip storage_cluster_master_ip = get_internal_ip_from_node(node_rank=num_compute_nodes) - cluster_master_port, compute_cluster_master_port = get_free_ports_from_node( - num_ports=2, node_rank=0 - ) - storage_cluster_master_port = get_free_ports_from_node( - num_ports=1, node_rank=num_compute_nodes - )[0] + # Cluster master is by convention rank 0. + cluster_master_rank = 0 + ( + cluster_master_port, + compute_cluster_master_port, + ) = get_free_ports_from_node(num_ports=2, node_rank=cluster_master_rank) + + # Since we structure the cluster as [compute0, ..., computeN, storage0, ..., storageN], the storage master is the first storage node. + # And it's rank is the number of compute nodes. + storage_master_rank = num_compute_nodes + ( + storage_cluster_master_port, + storage_rpc_port, + storage_rpc_wait_port, + ) = get_free_ports_from_node(num_ports=3, node_rank=storage_master_rank) num_processes_per_compute = int( os.environ.get(COMPUTE_CLUSTER_LOCAL_WORLD_SIZE_ENV_KEY, "1") @@ -251,6 +275,8 @@ def get_graph_store_info() -> GraphStoreInfo: cluster_master_port=cluster_master_port, storage_cluster_master_port=storage_cluster_master_port, compute_cluster_master_port=compute_cluster_master_port, + rpc_master_port=storage_rpc_port, + rpc_wait_port=storage_rpc_wait_port, ) diff --git a/python/gigl/env/distributed.py b/python/gigl/env/distributed.py index 3c3bc6465..990569059 100644 --- a/python/gigl/env/distributed.py +++ b/python/gigl/env/distributed.py @@ -55,6 +55,13 @@ class GraphStoreInfo: # https://snapchat.github.io/GiGL/docs/api/snapchat/research/gbml/gigl_resource_config_pb2/index.html#snapchat.research.gbml.gigl_resource_config_pb2.VertexAiGraphStoreConfig num_processes_per_compute: int + # Port of the master node for the RPC communication. + # NOTE: This should be on the *storage* master node, not the compute master node. + rpc_master_port: int + # Port of the master node for the RPC wait communication. + # NOTE: This should be on the *storage* master node, not the compute master node. + rpc_wait_port: int + @property def num_cluster_nodes(self) -> int: return self.num_storage_nodes + self.num_compute_nodes diff --git a/python/tests/integration/distributed/graph_store/graph_store_integration_test.py b/python/tests/integration/distributed/graph_store/graph_store_integration_test.py index 487504c65..4d66c761b 100644 --- a/python/tests/integration/distributed/graph_store/graph_store_integration_test.py +++ b/python/tests/integration/distributed/graph_store/graph_store_integration_test.py @@ -1,14 +1,17 @@ import collections import os +import socket import unittest from typing import Optional from unittest import mock import torch import torch.multiprocessing as mp +from torch_geometric.data import Data from gigl.common import Uri from gigl.common.logger import Logger +from gigl.distributed.distributed_neighborloader import DistNeighborLoader from gigl.distributed.graph_store.compute import ( init_compute_process, shutdown_compute_proccess, @@ -108,10 +111,37 @@ def _run_client_process( ).get_node_ids() _assert_sampler_input(cluster_info, simple_sampler_input, expected_sampler_input) + # Check that the edge types are correct assert ( remote_dist_dataset.get_edge_types() == expected_edge_types ), f"Expected edge types {expected_edge_types}, got {remote_dist_dataset.get_edge_types()}" + torch.distributed.barrier() + + # Test the DistNeighborLoader + loader = DistNeighborLoader( + dataset=remote_dist_dataset, + num_neighbors=[2, 2], + pin_memory_device=torch.device("cpu"), + input_nodes=sampler_input, + num_workers=2, + worker_concurrency=2, + ) + count = 0 + for datum in loader: + assert isinstance(datum, Data) + count += 1 + torch.distributed.barrier() + logger.info(f"Rank {torch.distributed.get_rank()} loaded {count} batches") + # Verify that we sampled all nodes. + count_tensor = torch.tensor(count, dtype=torch.int64) + all_node_count = 0 + for rank_expected_sampler_input in expected_sampler_input.values(): + all_node_count += sum(len(nodes) for nodes in rank_expected_sampler_input) + torch.distributed.all_reduce(count_tensor, op=torch.distributed.ReduceOp.SUM) + assert ( + count_tensor.item() == all_node_count + ), f"Expected {all_node_count} total nodes, got {count_tensor.item()}" shutdown_compute_proccess() @@ -186,6 +216,7 @@ def _get_expected_input_nodes_by_rank( ], } + Args: num_nodes (int): The number of nodes in the graph. cluster_info (GraphStoreInfo): The cluster information. @@ -220,17 +251,22 @@ def test_graph_store_locally(self): storage_cluster_master_port, compute_cluster_master_port, master_port, - ) = get_free_ports(num_ports=4) + rpc_master_port, + rpc_wait_port, + ) = get_free_ports(num_ports=6) + host_ip = socket.gethostbyname(socket.gethostname()) cluster_info = GraphStoreInfo( num_storage_nodes=2, num_compute_nodes=2, num_processes_per_compute=2, - cluster_master_ip="localhost", - storage_cluster_master_ip="localhost", - compute_cluster_master_ip="localhost", + cluster_master_ip=host_ip, + storage_cluster_master_ip=host_ip, + compute_cluster_master_ip=host_ip, cluster_master_port=cluster_master_port, storage_cluster_master_port=storage_cluster_master_port, compute_cluster_master_port=compute_cluster_master_port, + rpc_master_port=rpc_master_port, + rpc_wait_port=rpc_wait_port, ) num_cora_nodes = 2708 @@ -244,7 +280,7 @@ def test_graph_store_locally(self): with mock.patch.dict( os.environ, { - "MASTER_ADDR": "localhost", + "MASTER_ADDR": host_ip, "MASTER_PORT": str(master_port), "RANK": str(i), "WORLD_SIZE": str(cluster_info.compute_cluster_world_size), @@ -271,7 +307,7 @@ def test_graph_store_locally(self): with mock.patch.dict( os.environ, { - "MASTER_ADDR": "localhost", + "MASTER_ADDR": host_ip, "MASTER_PORT": str(master_port), "RANK": str(i + cluster_info.num_compute_nodes), "WORLD_SIZE": str(cluster_info.compute_cluster_world_size), diff --git a/python/tests/unit/env/distributed_test.py b/python/tests/unit/env/distributed_test.py index 793ffdc42..d1ba90391 100644 --- a/python/tests/unit/env/distributed_test.py +++ b/python/tests/unit/env/distributed_test.py @@ -27,6 +27,8 @@ def setUp(self) -> None: cluster_master_port=1234, storage_cluster_master_port=1235, compute_cluster_master_port=1236, + rpc_master_port=1237, + rpc_wait_port=1238, ) def test_num_cluster_nodes(self):