From 4233f92c13c273d6f25fca7b8e9b03d00b60d4ac Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Sat, 10 Jan 2026 00:07:40 +0000 Subject: [PATCH 01/16] Setup DistNeighborloader for graph store sampling --- .../distributed/distributed_neighborloader.py | 299 +++++++++++++++--- .../gigl/distributed/graph_store/compute.py | 32 +- .../distributed/graph_store/storage_main.py | 47 ++- python/gigl/distributed/utils/networking.py | 33 +- python/gigl/env/distributed.py | 7 + .../graph_store_integration_test.py | 46 ++- python/tests/unit/env/distributed_test.py | 2 + 7 files changed, 382 insertions(+), 84 deletions(-) diff --git a/python/gigl/distributed/distributed_neighborloader.py b/python/gigl/distributed/distributed_neighborloader.py index 3be9662f0..fd094d9ad 100644 --- a/python/gigl/distributed/distributed_neighborloader.py +++ b/python/gigl/distributed/distributed_neighborloader.py @@ -1,9 +1,14 @@ from collections import Counter, abc -from typing import Optional, Tuple, Union +from dataclasses import dataclass +from typing import Literal, Optional, Tuple, Union import torch from graphlearn_torch.channel import SampleMessage -from graphlearn_torch.distributed import DistLoader, MpDistSamplingWorkerOptions +from graphlearn_torch.distributed import ( + DistLoader, + MpDistSamplingWorkerOptions, + RemoteDistSamplingWorkerOptions, +) from graphlearn_torch.sampler import NodeSamplerInput, SamplingConfig, SamplingType from torch_geometric.data import Data, HeteroData from torch_geometric.typing import EdgeType @@ -13,6 +18,7 @@ from gigl.distributed.constants import DEFAULT_MASTER_INFERENCE_PORT from gigl.distributed.dist_context import DistributedContext from gigl.distributed.dist_dataset import DistDataset +from gigl.distributed.graph_store.remote_dist_dataset import RemoteDistDataset from gigl.distributed.utils.neighborloader import ( labeled_to_homogeneous, patch_fanout_for_sampling, @@ -26,6 +32,7 @@ from gigl.types.graph import ( DEFAULT_HOMOGENEOUS_EDGE_TYPE, DEFAULT_HOMOGENEOUS_NODE_TYPE, + FeatureInfo, ) logger = Logger() @@ -34,13 +41,27 @@ DEFAULT_NUM_CPU_THREADS = 2 +# Shared metadata between the local and remote datasets. +@dataclass(frozen=True) +class _DatasetMetadata: + is_labeled_heterogeneous: bool + node_feature_info: Optional[Union[FeatureInfo, dict[NodeType, FeatureInfo]]] + edge_feature_info: Optional[Union[FeatureInfo, dict[EdgeType, FeatureInfo]]] + edge_dir: Union[str, Literal["in", "out"]] + + class DistNeighborLoader(DistLoader): def __init__( self, - dataset: DistDataset, + dataset: Union[DistDataset, RemoteDistDataset], num_neighbors: Union[list[int], dict[EdgeType, list[int]]], input_nodes: Optional[ - Union[torch.Tensor, Tuple[NodeType, torch.Tensor]] + Union[ + torch.Tensor, + Tuple[NodeType, torch.Tensor], + list[torch.Tensor], + Tuple[NodeType, list[torch.Tensor]], + ] ] = None, num_workers: int = 1, batch_size: int = 1, @@ -62,7 +83,7 @@ def __init__( https://pytorch-geometric.readthedocs.io/en/2.5.2/_modules/torch_geometric/distributed/dist_neighbor_loader.html#DistNeighborLoader Args: - dataset (DistDataset): The dataset to sample from. + dataset (DistDataset | RemoteDistDataset): The dataset to sample from. Must be a "RemoteDistDataset" if using Graph Store mode. num_neighbors (list[int] or dict[Tuple[str, str, str], list[int]]): The number of neighbors to sample for each node in each iteration. If an entry is set to `-1`, all neighbors will be included. @@ -71,12 +92,15 @@ def __init__( context (deprecated - will be removed soon) (DistributedContext): Distributed context information of the current process. local_process_rank (deprecated - will be removed soon) (int): Required if context provided. The local rank of the current process within a node. local_process_world_size (deprecated - will be removed soon)(int): Required if context provided. The total number of processes within a node. - input_nodes (torch.Tensor or Tuple[str, torch.Tensor]): The - indices of seed nodes to start sampling from. + input_nodes (Tenor | Tuple[NodeType, Tenor] | list[Tenor] | Tuple[NodeType, list[Tenor]]): + The nodes to start sampling from. It is of type `torch.LongTensor` for homogeneous graphs. If set to `None` for homogeneous settings, all nodes will be considered. In heterogeneous graphs, this flag must be passed in as a tuple that holds the node type and node indices. (default: `None`) + For Graph Store mode, this must be a tuple of (NodeType, list[Tenor]) or list[Tenor]. + Where each Tensor in the list is the node ids to sample from, for each server. + e.g. [[10, 20], [30, 40]] means sample from nodes 10 and 20 on server 0, and nodes 30 and 40 on server 1. num_workers (int): How many workers to use (subprocesses to spwan) for distributed neighbor sampling of the current process. (default: ``1``). batch_size (int, optional): how many samples per batch to load @@ -188,10 +212,219 @@ def __init__( local_process_rank=local_rank ) ) + + # Determines if the node ids passed in are heterogeneous or homogeneous. + self._is_labeled_heterogeneous = False + if isinstance(dataset, DistDataset): + input_data, worker_options, dataset_metadata = self._setup_for_colocated( + input_nodes, + dataset, + local_rank, + local_world_size, + device, + master_ip_address, + node_rank, + node_world_size, + process_start_gap_seconds, + num_workers, + worker_concurrency, + channel_size, + num_cpu_threads, + ) + else: # RemoteDistDataset + input_data, worker_options, dataset_metadata = self._setup_for_graph_store( + input_nodes, + dataset, + num_workers, + ) + + self._is_labeled_heterogeneous = dataset_metadata.is_labeled_heterogeneous + self._node_feature_info = dataset_metadata.node_feature_info + self._edge_feature_info = dataset_metadata.edge_feature_info + + num_neighbors = patch_fanout_for_sampling( + list(dataset_metadata.edge_feature_info.keys()) + if isinstance(dataset_metadata.edge_feature_info, dict) + else None, + num_neighbors, + ) + + sampling_config = SamplingConfig( + sampling_type=SamplingType.NODE, + num_neighbors=num_neighbors, + batch_size=batch_size, + shuffle=shuffle, + drop_last=drop_last, + with_edge=True, + collect_features=True, + with_neg=False, + with_weight=False, + edge_dir=dataset_metadata.edge_dir, + seed=None, # it's actually optional - None means random. + ) + + if should_cleanup_distributed_context and torch.distributed.is_initialized(): + logger.info( + f"Cleaning up process group as it was initialized inside {self.__class__.__name__}.__init__." + ) + torch.distributed.destroy_process_group() + + if isinstance(dataset, DistDataset): + super().__init__( + dataset if isinstance(dataset, DistDataset) else None, + input_data, + sampling_config, + device, + worker_options, + ) + else: + # For Graph Store mode, we need to start the communcation between compute and storage nodes sequentially, by compute node. + # E.g. intialize connections between compute node 0 and storage nodes 0, 1, 2, 3, then compute node 1 and storage nodes 0, 1, 2, 3, etc. + # Note that each compute node may have multiple connections to each storage node, once per compute process. + # E.g. if there are 4 gpus per compute node, then there will be 4 connections to each storage node. + # We need to this because if we don't, then there is a race condition when initalizing the samplers on the storage nodes [1] + # Where since the lock is per *server* (e.g. per storage node), if we try to start one connection from compute node 0, and compute node 1 + # Then we deadlock and fail. + # [1]: https://github.com/alibaba/graphlearn-for-pytorch/blob/main/graphlearn_torch/python/distributed/dist_server.py#L129-L167 + node_rank = dataset.cluster_info.compute_node_rank + for target_node_rank in range(dataset.cluster_info.num_compute_nodes): + if node_rank == target_node_rank: + super().__init__( + dataset if isinstance(dataset, DistDataset) else None, + input_data, + sampling_config, + device, + worker_options, + ) + print(f"node_rank {node_rank} initialized the dist loader") + torch.distributed.barrier() + torch.distributed.barrier() + + def _setup_for_graph_store( + self, + input_nodes: Optional[ + Union[ + torch.Tensor, + Tuple[NodeType, torch.Tensor], + list[torch.Tensor], + Tuple[NodeType, list[torch.Tensor]], + ] + ], + dataset: RemoteDistDataset, + num_workers: int, + ) -> tuple[NodeSamplerInput, RemoteDistSamplingWorkerOptions, _DatasetMetadata]: + if input_nodes is None: + raise ValueError( + f"When using Graph Store mode, input nodes must be provided, received {input_nodes}" + ) + elif isinstance(input_nodes, torch.Tensor): + raise ValueError( + f"When using Graph Store mode, input nodes must be of type (list[Tensor] | (NodeType, list[torch.Tensor]), received {type(input_nodes)}" + ) + elif isinstance(input_nodes, tuple) and isinstance( + input_nodes[1], torch.Tensor + ): + raise ValueError( + f"When using Graph Store mode, input nodes must be of type (list[torch.Tensor] | (NodeType, list[torch.Tensor])), received {type(input_nodes)} ({type(input_nodes[0])}, {type(input_nodes[1])})" + ) + + is_labeled_heterogeneous = False + node_feature_info = dataset.get_node_feature_info() + edge_feature_info = dataset.get_edge_feature_info() + node_rank = dataset.cluster_info.compute_node_rank + + # Get sampling ports for compute-storage connections. + sampling_ports = dataset.get_free_ports_on_storage_cluster( + num_ports=dataset.cluster_info.num_processes_per_compute + ) + sampling_port = sampling_ports[node_rank] + + worker_options = RemoteDistSamplingWorkerOptions( + server_rank=list(range(dataset.cluster_info.num_storage_nodes)), + num_workers=num_workers, + worker_devices=[torch.device("cpu") for i in range(num_workers)], + master_addr=dataset.cluster_info.storage_cluster_master_ip, + master_port=sampling_port, + worker_key=f"compute_rank_{node_rank}", + ) logger.info( - f"Dataset Building started on {node_rank} of {node_world_size} nodes, using following node as main: {master_ip_address}" + f"Rank {torch.distributed.get_rank()}! init for sampling rpc: {f'tcp://{dataset.cluster_info.storage_cluster_master_ip}:{sampling_port}'}" + ) + + # Setup input data for the dataloader. + + # Determine nodes list and fallback input_type based on input_nodes structure + if isinstance(input_nodes, list): + nodes = input_nodes + fallback_input_type = None + require_edge_feature_info = False + elif isinstance(input_nodes, tuple) and isinstance(input_nodes[1], list): + nodes = input_nodes[1] + fallback_input_type = input_nodes[0] + require_edge_feature_info = True + else: + raise ValueError( + f"When using Graph Store mode, input nodes must be of type (list[torch.Tensor] | (NodeType, list[torch.Tensor])), received {type(input_nodes)}" + ) + + # Determine input_type based on edge_feature_info + if isinstance(edge_feature_info, dict): + if len(edge_feature_info) == 0: + raise ValueError( + "When using Graph Store mode, edge feature info must be provided for heterogeneous graphs." + ) + elif ( + len(edge_feature_info) == 1 + and DEFAULT_HOMOGENEOUS_EDGE_TYPE in edge_feature_info + ): + input_type: NodeType | None = DEFAULT_HOMOGENEOUS_NODE_TYPE + else: + input_type = fallback_input_type + elif require_edge_feature_info: + raise ValueError( + "When using Graph Store mode, edge feature info must be provided for heterogeneous graphs." + ) + else: + input_type = None + + input_data = [ + NodeSamplerInput(node=node, input_type=input_type) for node in nodes + ] + + return ( + input_data, + worker_options, + _DatasetMetadata( + is_labeled_heterogeneous=is_labeled_heterogeneous, + node_feature_info=node_feature_info, + edge_feature_info=edge_feature_info, + edge_dir=dataset.get_edge_dir(), + ), ) + def _setup_for_colocated( + self, + input_nodes: Optional[ + Union[ + torch.Tensor, + Tuple[NodeType, torch.Tensor], + list[torch.Tensor], + Tuple[NodeType, list[torch.Tensor]], + ] + ], + dataset: DistDataset, + local_rank: int, + local_world_size: int, + device: torch.device, + master_ip_address: str, + node_rank: int, + node_world_size: int, + process_start_gap_seconds: float, + num_workers: int, + worker_concurrency: int, + channel_size: str, + num_cpu_threads: Optional[int], + ) -> tuple[NodeSamplerInput, MpDistSamplingWorkerOptions, _DatasetMetadata]: if input_nodes is None: if dataset.node_ids is None: raise ValueError( @@ -202,9 +435,15 @@ def __init__( f"input_nodes must be provided for heterogeneous datasets, received node_ids of type: {dataset.node_ids.keys()}" ) input_nodes = dataset.node_ids - - # Determines if the node ids passed in are heterogeneous or homogeneous. - self._is_labeled_heterogeneous = False + if isinstance(input_nodes, list): + raise ValueError( + f"When using Colocated mode, input nodes must be of type (torch.Tensor | (NodeType, torch.Tensor)), received {type(input_nodes)}" + ) + elif isinstance(input_nodes, tuple) and isinstance(input_nodes[1], list): + raise ValueError( + f"When using Colocated mode, input nodes must be of type (torch.Tensor | (NodeType, torch.Tensor)), received {type(input_nodes)} ({type(input_nodes[0])}, {type(input_nodes[1])})" + ) + is_labeled_heterogeneous = False if isinstance(input_nodes, torch.Tensor): node_ids = input_nodes @@ -216,7 +455,7 @@ def __init__( and DEFAULT_HOMOGENEOUS_NODE_TYPE in dataset.node_ids ): node_type = DEFAULT_HOMOGENEOUS_NODE_TYPE - self._is_labeled_heterogeneous = True + is_labeled_heterogeneous = True else: raise ValueError( f"For heterogeneous datasets, input_nodes must be a tuple of (node_type, node_ids) OR if it is a labeled homogeneous dataset, input_nodes may be a torch.Tensor. Received node types: {dataset.node_ids.keys()}" @@ -229,19 +468,12 @@ def __init__( dataset.node_ids, abc.Mapping ), "Dataset must be heterogeneous if provided input nodes are a tuple." - num_neighbors = patch_fanout_for_sampling( - dataset.get_edge_types(), num_neighbors - ) - curr_process_nodes = shard_nodes_by_process( input_nodes=node_ids, local_process_rank=local_rank, local_process_world_size=local_world_size, ) - self._node_feature_info = dataset.node_feature_info - self._edge_feature_info = dataset.edge_feature_info - input_data = NodeSamplerInput(node=curr_process_nodes, input_type=node_type) # Sets up processes and torch device for initializing the GLT DistNeighborLoader, setting up RPC and worker groups to minimize @@ -305,28 +537,17 @@ def __init__( pin_memory=device.type == "cuda", ) - sampling_config = SamplingConfig( - sampling_type=SamplingType.NODE, - num_neighbors=num_neighbors, - batch_size=batch_size, - shuffle=shuffle, - drop_last=drop_last, - with_edge=True, - collect_features=True, - with_neg=False, - with_weight=False, - edge_dir=dataset.edge_dir, - seed=None, # it's actually optional - None means random. + return ( + input_data, + worker_options, + _DatasetMetadata( + is_labeled_heterogeneous=is_labeled_heterogeneous, + node_feature_info=dataset.node_feature_info, + edge_feature_info=dataset.edge_feature_info, + edge_dir=dataset.edge_dir, + ), ) - if should_cleanup_distributed_context and torch.distributed.is_initialized(): - logger.info( - f"Cleaning up process group as it was initialized inside {self.__class__.__name__}.__init__." - ) - torch.distributed.destroy_process_group() - - super().__init__(dataset, input_data, sampling_config, device, worker_options) - def _collate_fn(self, msg: SampleMessage) -> Union[Data, HeteroData]: data = super()._collate_fn(msg) data = set_missing_features( diff --git a/python/gigl/distributed/graph_store/compute.py b/python/gigl/distributed/graph_store/compute.py index 36a3b66dd..6039eaef9 100644 --- a/python/gigl/distributed/graph_store/compute.py +++ b/python/gigl/distributed/graph_store/compute.py @@ -1,8 +1,8 @@ import os from typing import Optional -import graphlearn_torch as glt import torch +from graphlearn_torch.distributed.dist_client import init_client, shutdown_client from gigl.common.logger import Logger from gigl.env.distributed import GraphStoreInfo @@ -36,6 +36,21 @@ def init_compute_process( cluster_info.compute_node_rank * cluster_info.num_processes_per_compute + local_rank ) + cluster_master_ip = cluster_info.storage_cluster_master_ip + logger.info( + f"Initializing RPC client for compute node {compute_cluster_rank} / {cluster_info.compute_cluster_world_size} on {cluster_master_ip}:{cluster_info.rpc_master_port}." + f" OS rank: {os.environ['RANK']}, local compute rank: {local_rank}" + f" num_servers: {cluster_info.num_storage_nodes}, num_clients: {cluster_info.compute_cluster_world_size}" + ) + init_client( + num_servers=cluster_info.num_storage_nodes, + num_clients=cluster_info.compute_cluster_world_size, + client_rank=compute_cluster_rank, + master_addr=cluster_master_ip, + master_port=cluster_info.rpc_master_port, + client_group_name="gigl_client_rpc", + ) + logger.info( f"Initializing compute process group {compute_cluster_rank} / {cluster_info.compute_cluster_world_size}. on {cluster_info.compute_cluster_master_ip}:{cluster_info.compute_cluster_master_port} with backend {compute_world_backend}." f" OS rank: {os.environ['RANK']}, local client rank: {local_rank}" @@ -46,19 +61,6 @@ def init_compute_process( rank=compute_cluster_rank, init_method=f"tcp://{cluster_info.compute_cluster_master_ip}:{cluster_info.compute_cluster_master_port}", ) - logger.info( - f"Initializing RPC client for compute node {compute_cluster_rank} / {cluster_info.compute_cluster_world_size} on {cluster_info.cluster_master_ip}:{cluster_info.cluster_master_port}." - f" OS rank: {os.environ['RANK']}, local compute rank: {local_rank}" - f" num_servers: {cluster_info.num_storage_nodes}, num_clients: {cluster_info.compute_cluster_world_size}" - ) - glt.distributed.init_client( - num_servers=cluster_info.num_storage_nodes, - num_clients=cluster_info.compute_cluster_world_size, - client_rank=compute_cluster_rank, - master_addr=cluster_info.cluster_master_ip, - master_port=cluster_info.cluster_master_port, - client_group_name="gigl_client_rpc", - ) def shutdown_compute_proccess() -> None: @@ -70,5 +72,5 @@ def shutdown_compute_proccess() -> None: Args: None """ - glt.distributed.shutdown_client() + shutdown_client() torch.distributed.destroy_process_group() diff --git a/python/gigl/distributed/graph_store/storage_main.py b/python/gigl/distributed/graph_store/storage_main.py index 0cfdef957..222613af8 100644 --- a/python/gigl/distributed/graph_store/storage_main.py +++ b/python/gigl/distributed/graph_store/storage_main.py @@ -7,8 +7,11 @@ import os from typing import Optional -import graphlearn_torch as glt import torch +from graphlearn_torch.distributed.dist_server import ( + init_server, + wait_and_shutdown_server, +) from gigl.common import Uri, UriFactory from gigl.common.logger import Logger @@ -30,28 +33,34 @@ def _run_storage_process( storage_world_backend: Optional[str], ) -> None: register_dataset(dataset) + cluster_master_ip = cluster_info.storage_cluster_master_ip logger.info( - f"Initializing storage node {storage_rank} / {cluster_info.num_storage_nodes} with backend {storage_world_backend} on {cluster_info.cluster_master_ip}:{torch_process_port}" + f"Initializing GLT server for storage node process group {storage_rank} / {cluster_info.num_storage_nodes} on {cluster_master_ip}:{cluster_info.rpc_master_port}" ) - torch.distributed.init_process_group( - backend=storage_world_backend, - world_size=cluster_info.num_storage_nodes, - rank=storage_rank, - init_method=f"tcp://{cluster_info.cluster_master_ip}:{torch_process_port}", - ) - glt.distributed.init_server( + init_server( num_servers=cluster_info.num_storage_nodes, server_rank=storage_rank, dataset=dataset, - master_addr=cluster_info.cluster_master_ip, - master_port=cluster_info.cluster_master_port, + master_addr=cluster_master_ip, + master_port=cluster_info.rpc_master_port, num_clients=cluster_info.compute_cluster_world_size, ) + init_method = f"tcp://{cluster_info.storage_cluster_master_ip}:{torch_process_port}" + logger.info( + f"Initializing storage node process group {storage_rank} / {cluster_info.num_storage_nodes} with backend {storage_world_backend} on {init_method}" + ) + torch.distributed.init_process_group( + backend=storage_world_backend, + world_size=cluster_info.num_storage_nodes, + rank=storage_rank, + init_method=init_method, + ) + logger.info( f"Waiting for storage node {storage_rank} / {cluster_info.num_storage_nodes} to exit" ) - glt.distributed.wait_and_shutdown_server() + wait_and_shutdown_server() logger.info(f"Storage node {storage_rank} exited") @@ -59,7 +68,7 @@ def storage_node_process( storage_rank: int, cluster_info: GraphStoreInfo, task_config_uri: Uri, - is_inference: bool, + is_inference: bool = True, tf_record_uri_pattern: str = ".*-of-.*\.tfrecord(\.gz)?$", storage_world_backend: Optional[str] = None, ) -> None: @@ -71,7 +80,7 @@ def storage_node_process( storage_rank (int): The rank of the storage node. cluster_info (GraphStoreInfo): The cluster information. task_config_uri (Uri): The task config URI. - is_inference (bool): Whether the process is an inference process. + is_inference (bool): Whether the process is an inference process. Defaults to True. tf_record_uri_pattern (str): The TF Record URI pattern. storage_world_backend (Optional[str]): The backend for the storage Torch Distributed process group. """ @@ -95,6 +104,7 @@ def storage_node_process( _tfrecord_uri_pattern=tf_record_uri_pattern, ) torch_process_port = get_free_ports_from_master_node(num_ports=1)[0] + torch.distributed.destroy_process_group() server_processes = [] mp_context = torch.multiprocessing.get_context("spawn") # TODO(kmonte): Enable more than one server process per machine @@ -120,18 +130,21 @@ def storage_node_process( parser = argparse.ArgumentParser() parser.add_argument("--task_config_uri", type=str, required=True) parser.add_argument("--resource_config_uri", type=str, required=True) - parser.add_argument("--is_inference", action="store_true") + parser.add_argument("--job_name", type=str, required=True) args = parser.parse_args() logger.info(f"Running storage node with arguments: {args}") is_inference = args.is_inference - torch.distributed.init_process_group() + torch.distributed.init_process_group(backend="gloo") cluster_info = get_graph_store_info() + logger.info(f"Cluster info: {cluster_info}") + logger.info( + f"World size: {torch.distributed.get_world_size()}, rank: {torch.distributed.get_rank()}, OS world size: {os.environ['WORLD_SIZE']}, OS rank: {os.environ['RANK']}" + ) # Tear down the """"global""" process group so we can have a server-specific process group. torch.distributed.destroy_process_group() storage_node_process( storage_rank=cluster_info.storage_node_rank, cluster_info=cluster_info, task_config_uri=UriFactory.create_uri(args.task_config_uri), - is_inference=is_inference, ) diff --git a/python/gigl/distributed/utils/networking.py b/python/gigl/distributed/utils/networking.py index 7d2ba46b9..e9d33921c 100644 --- a/python/gigl/distributed/utils/networking.py +++ b/python/gigl/distributed/utils/networking.py @@ -115,6 +115,13 @@ def get_internal_ip_from_master_node( ) -> str: """ Get the internal IP address of the master node in a distributed setup. + + Args: + _global_rank_override (Optional[int]): Override for the global rank, + useful for testing or if global rank is not accurately available. + + Returns: + str: The internal IP address of the master node. """ return get_internal_ip_from_node( node_rank=0, _global_rank_override=_global_rank_override @@ -131,6 +138,12 @@ def get_internal_ip_from_node( i.e. when using :py:obj:`gigl.distributed.dataset_factory` + Args: + node_rank (int): Rank of the node, to fetch the internal IP address of. + device (Optional[torch.device]): Device to use for communication. Defaults to None, which will use the default device. + _global_rank_override (Optional[int]): Override for the global rank, + useful for testing or if global rank is not accurately available. + Returns: str: The internal IP address of the node. """ @@ -155,7 +168,8 @@ def get_internal_ip_from_node( # Other nodes will receive the master's IP via broadcast ip_list = [None] - device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + device = torch.device("cpu" if not torch.cuda.is_available() else "cuda") + torch.distributed.broadcast_object_list(ip_list, src=node_rank, device=device) node_ip = ip_list[0] logger.info(f"Rank {rank} received master node's internal IP: {node_ip}") @@ -230,12 +244,15 @@ def get_graph_store_info() -> GraphStoreInfo: compute_cluster_master_ip = cluster_master_ip storage_cluster_master_ip = get_internal_ip_from_node(node_rank=num_compute_nodes) - cluster_master_port, compute_cluster_master_port = get_free_ports_from_node( - num_ports=2, node_rank=0 - ) - storage_cluster_master_port = get_free_ports_from_node( - num_ports=1, node_rank=num_compute_nodes - )[0] + ( + cluster_master_port, + compute_cluster_master_port, + ) = get_free_ports_from_node(num_ports=2, node_rank=0) + ( + storage_cluster_master_port, + storage_rpc_port, + storage_rpc_wait_port, + ) = get_free_ports_from_node(num_ports=3, node_rank=num_compute_nodes) num_processes_per_compute = int( os.environ.get(COMPUTE_CLUSTER_LOCAL_WORLD_SIZE_ENV_KEY, "1") @@ -251,6 +268,8 @@ def get_graph_store_info() -> GraphStoreInfo: cluster_master_port=cluster_master_port, storage_cluster_master_port=storage_cluster_master_port, compute_cluster_master_port=compute_cluster_master_port, + rpc_master_port=storage_rpc_port, + rpc_wait_port=storage_rpc_wait_port, ) diff --git a/python/gigl/env/distributed.py b/python/gigl/env/distributed.py index 3c3bc6465..990569059 100644 --- a/python/gigl/env/distributed.py +++ b/python/gigl/env/distributed.py @@ -55,6 +55,13 @@ class GraphStoreInfo: # https://snapchat.github.io/GiGL/docs/api/snapchat/research/gbml/gigl_resource_config_pb2/index.html#snapchat.research.gbml.gigl_resource_config_pb2.VertexAiGraphStoreConfig num_processes_per_compute: int + # Port of the master node for the RPC communication. + # NOTE: This should be on the *storage* master node, not the compute master node. + rpc_master_port: int + # Port of the master node for the RPC wait communication. + # NOTE: This should be on the *storage* master node, not the compute master node. + rpc_wait_port: int + @property def num_cluster_nodes(self) -> int: return self.num_storage_nodes + self.num_compute_nodes diff --git a/python/tests/integration/distributed/graph_store/graph_store_integration_test.py b/python/tests/integration/distributed/graph_store/graph_store_integration_test.py index c267b7fed..5a1cc8369 100644 --- a/python/tests/integration/distributed/graph_store/graph_store_integration_test.py +++ b/python/tests/integration/distributed/graph_store/graph_store_integration_test.py @@ -1,13 +1,16 @@ import collections import os +import socket import unittest from unittest import mock import torch import torch.multiprocessing as mp +from torch_geometric.data import Data from gigl.common import Uri from gigl.common.logger import Logger +from gigl.distributed.distributed_neighborloader import DistNeighborLoader from gigl.distributed.graph_store.compute import ( init_compute_process, shutdown_compute_proccess, @@ -103,7 +106,32 @@ def _run_client_process( mp_sharing_dict=None, ).get_node_ids() _assert_sampler_input(cluster_info, simple_sampler_input, expected_sampler_input) + torch.distributed.barrier() + # Test the DistNeighborLoader + loader = DistNeighborLoader( + dataset=remote_dist_dataset, + num_neighbors=[2, 2], + pin_memory_device=torch.device("cpu"), + input_nodes=sampler_input, + num_workers=2, + worker_concurrency=2, + ) + count = 0 + for datum in loader: + assert isinstance(datum, Data) + count += 1 + torch.distributed.barrier() + logger.info(f"Rank {torch.distributed.get_rank()} loaded {count} batches") + # Verify that we sampled all nodes. + count_tensor = torch.tensor(count, dtype=torch.int64) + all_node_count = 0 + for rank_expected_sampler_input in expected_sampler_input.values(): + all_node_count += sum(len(nodes) for nodes in rank_expected_sampler_input) + torch.distributed.all_reduce(count_tensor, op=torch.distributed.ReduceOp.SUM) + assert ( + count_tensor.item() == all_node_count + ), f"Expected {all_node_count} total nodes, got {count_tensor.item()}" shutdown_compute_proccess() @@ -176,6 +204,7 @@ def _get_expected_input_nodes_by_rank( ], } + Args: num_nodes (int): The number of nodes in the graph. cluster_info (GraphStoreInfo): The cluster information. @@ -212,17 +241,22 @@ def test_graph_store_locally(self): storage_cluster_master_port, compute_cluster_master_port, master_port, - ) = get_free_ports(num_ports=4) + rpc_master_port, + rpc_wait_port, + ) = get_free_ports(num_ports=6) + host_ip = socket.gethostbyname(socket.gethostname()) cluster_info = GraphStoreInfo( num_storage_nodes=2, num_compute_nodes=2, num_processes_per_compute=2, - cluster_master_ip="localhost", - storage_cluster_master_ip="localhost", - compute_cluster_master_ip="localhost", + cluster_master_ip=host_ip, + storage_cluster_master_ip=host_ip, + compute_cluster_master_ip=host_ip, cluster_master_port=cluster_master_port, storage_cluster_master_port=storage_cluster_master_port, compute_cluster_master_port=compute_cluster_master_port, + rpc_master_port=rpc_master_port, + rpc_wait_port=rpc_wait_port, ) num_cora_nodes = 2708 @@ -236,7 +270,7 @@ def test_graph_store_locally(self): with mock.patch.dict( os.environ, { - "MASTER_ADDR": "localhost", + "MASTER_ADDR": host_ip, "MASTER_PORT": str(master_port), "RANK": str(i), "WORLD_SIZE": str(cluster_info.compute_cluster_world_size), @@ -262,7 +296,7 @@ def test_graph_store_locally(self): with mock.patch.dict( os.environ, { - "MASTER_ADDR": "localhost", + "MASTER_ADDR": host_ip, "MASTER_PORT": str(master_port), "RANK": str(i + cluster_info.num_compute_nodes), "WORLD_SIZE": str(cluster_info.compute_cluster_world_size), diff --git a/python/tests/unit/env/distributed_test.py b/python/tests/unit/env/distributed_test.py index 793ffdc42..d1ba90391 100644 --- a/python/tests/unit/env/distributed_test.py +++ b/python/tests/unit/env/distributed_test.py @@ -27,6 +27,8 @@ def setUp(self) -> None: cluster_master_port=1234, storage_cluster_master_port=1235, compute_cluster_master_port=1236, + rpc_master_port=1237, + rpc_wait_port=1238, ) def test_num_cluster_nodes(self): From fa0bd9adb269abb4ca483533882f33207d0d55ff Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Sat, 10 Jan 2026 00:19:32 +0000 Subject: [PATCH 02/16] cleanup --- .../distributed/distributed_neighborloader.py | 22 +++++-------------- .../gigl/distributed/graph_store/compute.py | 8 ++++--- .../distributed/graph_store/storage_main.py | 11 +++++----- .../gigl/distributed/utils/neighborloader.py | 19 +++++++++++++++- 4 files changed, 34 insertions(+), 26 deletions(-) diff --git a/python/gigl/distributed/distributed_neighborloader.py b/python/gigl/distributed/distributed_neighborloader.py index fd094d9ad..bcf4c2d5e 100644 --- a/python/gigl/distributed/distributed_neighborloader.py +++ b/python/gigl/distributed/distributed_neighborloader.py @@ -1,6 +1,5 @@ from collections import Counter, abc -from dataclasses import dataclass -from typing import Literal, Optional, Tuple, Union +from typing import Optional, Tuple, Union import torch from graphlearn_torch.channel import SampleMessage @@ -20,6 +19,7 @@ from gigl.distributed.dist_dataset import DistDataset from gigl.distributed.graph_store.remote_dist_dataset import RemoteDistDataset from gigl.distributed.utils.neighborloader import ( + DatasetMetadata, labeled_to_homogeneous, patch_fanout_for_sampling, set_missing_features, @@ -32,7 +32,6 @@ from gigl.types.graph import ( DEFAULT_HOMOGENEOUS_EDGE_TYPE, DEFAULT_HOMOGENEOUS_NODE_TYPE, - FeatureInfo, ) logger = Logger() @@ -41,15 +40,6 @@ DEFAULT_NUM_CPU_THREADS = 2 -# Shared metadata between the local and remote datasets. -@dataclass(frozen=True) -class _DatasetMetadata: - is_labeled_heterogeneous: bool - node_feature_info: Optional[Union[FeatureInfo, dict[NodeType, FeatureInfo]]] - edge_feature_info: Optional[Union[FeatureInfo, dict[EdgeType, FeatureInfo]]] - edge_dir: Union[str, Literal["in", "out"]] - - class DistNeighborLoader(DistLoader): def __init__( self, @@ -312,7 +302,7 @@ def _setup_for_graph_store( ], dataset: RemoteDistDataset, num_workers: int, - ) -> tuple[NodeSamplerInput, RemoteDistSamplingWorkerOptions, _DatasetMetadata]: + ) -> tuple[NodeSamplerInput, RemoteDistSamplingWorkerOptions, DatasetMetadata]: if input_nodes is None: raise ValueError( f"When using Graph Store mode, input nodes must be provided, received {input_nodes}" @@ -394,7 +384,7 @@ def _setup_for_graph_store( return ( input_data, worker_options, - _DatasetMetadata( + DatasetMetadata( is_labeled_heterogeneous=is_labeled_heterogeneous, node_feature_info=node_feature_info, edge_feature_info=edge_feature_info, @@ -424,7 +414,7 @@ def _setup_for_colocated( worker_concurrency: int, channel_size: str, num_cpu_threads: Optional[int], - ) -> tuple[NodeSamplerInput, MpDistSamplingWorkerOptions, _DatasetMetadata]: + ) -> tuple[NodeSamplerInput, MpDistSamplingWorkerOptions, DatasetMetadata]: if input_nodes is None: if dataset.node_ids is None: raise ValueError( @@ -540,7 +530,7 @@ def _setup_for_colocated( return ( input_data, worker_options, - _DatasetMetadata( + DatasetMetadata( is_labeled_heterogeneous=is_labeled_heterogeneous, node_feature_info=dataset.node_feature_info, edge_feature_info=dataset.edge_feature_info, diff --git a/python/gigl/distributed/graph_store/compute.py b/python/gigl/distributed/graph_store/compute.py index 6039eaef9..cd556d941 100644 --- a/python/gigl/distributed/graph_store/compute.py +++ b/python/gigl/distributed/graph_store/compute.py @@ -1,8 +1,8 @@ import os from typing import Optional +import graphlearn_torch as glt import torch -from graphlearn_torch.distributed.dist_client import init_client, shutdown_client from gigl.common.logger import Logger from gigl.env.distributed import GraphStoreInfo @@ -42,7 +42,9 @@ def init_compute_process( f" OS rank: {os.environ['RANK']}, local compute rank: {local_rank}" f" num_servers: {cluster_info.num_storage_nodes}, num_clients: {cluster_info.compute_cluster_world_size}" ) - init_client( + # Initialize the GLT client before starting the Torch Distributed process group. + # Otherwise, we saw intermittent hangs when initializing the client. + glt.distributed.init_client( num_servers=cluster_info.num_storage_nodes, num_clients=cluster_info.compute_cluster_world_size, client_rank=compute_cluster_rank, @@ -72,5 +74,5 @@ def shutdown_compute_proccess() -> None: Args: None """ - shutdown_client() + glt.distributed.shutdown_client() torch.distributed.destroy_process_group() diff --git a/python/gigl/distributed/graph_store/storage_main.py b/python/gigl/distributed/graph_store/storage_main.py index 222613af8..3dc8040a5 100644 --- a/python/gigl/distributed/graph_store/storage_main.py +++ b/python/gigl/distributed/graph_store/storage_main.py @@ -7,11 +7,8 @@ import os from typing import Optional +import graphlearn_torch as glt import torch -from graphlearn_torch.distributed.dist_server import ( - init_server, - wait_and_shutdown_server, -) from gigl.common import Uri, UriFactory from gigl.common.logger import Logger @@ -37,7 +34,9 @@ def _run_storage_process( logger.info( f"Initializing GLT server for storage node process group {storage_rank} / {cluster_info.num_storage_nodes} on {cluster_master_ip}:{cluster_info.rpc_master_port}" ) - init_server( + # Initialize the GLT server before starting the Torch Distributed process group. + # Otherwise, we saw intermittent hangs when initializing the server. + glt.distributed.init_server( num_servers=cluster_info.num_storage_nodes, server_rank=storage_rank, dataset=dataset, @@ -60,7 +59,7 @@ def _run_storage_process( logger.info( f"Waiting for storage node {storage_rank} / {cluster_info.num_storage_nodes} to exit" ) - wait_and_shutdown_server() + glt.distributed.wait_and_shutdown_server() logger.info(f"Storage node {storage_rank} exited") diff --git a/python/gigl/distributed/utils/neighborloader.py b/python/gigl/distributed/utils/neighborloader.py index 29a09d71a..9aa587cfd 100644 --- a/python/gigl/distributed/utils/neighborloader.py +++ b/python/gigl/distributed/utils/neighborloader.py @@ -1,7 +1,8 @@ """Utils for Neighbor loaders.""" from collections import abc from copy import deepcopy -from typing import Optional, TypeVar, Union +from dataclasses import dataclass +from typing import Literal, Optional, TypeVar, Union import torch from torch_geometric.data import Data, HeteroData @@ -15,6 +16,22 @@ _GraphType = TypeVar("_GraphType", Data, HeteroData) +@dataclass(frozen=True) +class DatasetMetadata: + """ + Shared metadata between the local and remote datasets. + """ + + # If the dataset is labeled heterogeneous. E.g. one node type, one edge type, and "label" edges. + is_labeled_heterogeneous: bool + # Node feature info. + node_feature_info: Optional[Union[FeatureInfo, dict[NodeType, FeatureInfo]]] + # Edge feature info. + edge_feature_info: Optional[Union[FeatureInfo, dict[EdgeType, FeatureInfo]]] + # Edge direction. + edge_dir: Union[str, Literal["in", "out"]] + + def patch_fanout_for_sampling( edge_types: Optional[list[EdgeType]], num_neighbors: Union[list[int], dict[EdgeType, list[int]]], From 1853bc1bc9bd1b1a439dce4f8aec7a75e012d4eb Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Tue, 13 Jan 2026 18:03:14 +0000 Subject: [PATCH 03/16] setup e2e gs hom inference --- .../configs/e2e_glt_gs_resource_config.yaml | 43 ++ .../link_prediction/graph_store/__init__.py | 0 .../e2e_hom_cora_sup_gs_task_config.yaml | 42 ++ .../graph_store/homogeneous_inference.py | 446 ++++++++++++++++++ testing/e2e_tests/e2e_tests.yaml | 3 + 5 files changed, 534 insertions(+) create mode 100644 deployment/configs/e2e_glt_gs_resource_config.yaml create mode 100644 examples/link_prediction/graph_store/__init__.py create mode 100644 examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_task_config.yaml create mode 100644 examples/link_prediction/graph_store/homogeneous_inference.py diff --git a/deployment/configs/e2e_glt_gs_resource_config.yaml b/deployment/configs/e2e_glt_gs_resource_config.yaml new file mode 100644 index 000000000..0bab9f54f --- /dev/null +++ b/deployment/configs/e2e_glt_gs_resource_config.yaml @@ -0,0 +1,43 @@ +shared_resource_config: + resource_labels: + cost_resource_group_tag: dev_experiments_COMPONENT + cost_resource_group: gigl_platform + common_compute_config: + project: "external-snap-ci-github-gigl" + region: "us-central1" + temp_assets_bucket: "gs://gigl-cicd-temp" + temp_regional_assets_bucket: "gs://gigl-cicd-temp" + perm_assets_bucket: "gs://gigl-cicd-perm" + temp_assets_bq_dataset_name: "gigl_temp_assets" + embedding_bq_dataset_name: "gigl_embeddings" + gcp_service_account_email: "untrusted-external-github-gigl@external-snap-ci-github-gigl.iam.gserviceaccount.com" + dataflow_runner: "DataflowRunner" +preprocessor_config: + edge_preprocessor_config: + num_workers: 1 + max_num_workers: 128 + machine_type: "n2d-highmem-32" + disk_size_gb: 300 + node_preprocessor_config: + num_workers: 1 + max_num_workers: 128 + machine_type: "n2d-highmem-64" + disk_size_gb: 300 +trainer_resource_config: + vertex_ai_trainer_config: + machine_type: n1-highmem-32 + gpu_type: NVIDIA_TESLA_T4 + gpu_limit: 2 + num_replicas: 2 +inferencer_resource_config: + vertex_ai_graph_store_inferencer_config: + graph_store_pool: + machine_type: n2-highmem-32 + gpu_type: ACCELERATOR_TYPE_UNSPECIFIED + gpu_limit: 0 + num_replicas: 2 + compute_pool: + machine_type: n1-standard-16 + gpu_type: NVIDIA_TESLA_T4 + gpu_limit: 2 + num_replicas: 2 diff --git a/examples/link_prediction/graph_store/__init__.py b/examples/link_prediction/graph_store/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_task_config.yaml b/examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_task_config.yaml new file mode 100644 index 000000000..812d9e408 --- /dev/null +++ b/examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_task_config.yaml @@ -0,0 +1,42 @@ +# This config is used to run homogeneous CORA supervised training and inference using in memory GiGL SGS. This can be run with `make run_hom_cora_sup_test`. +graphMetadata: + edgeTypes: + - dstNodeType: paper + relation: cites + srcNodeType: paper + nodeTypes: + - paper +datasetConfig: + dataPreprocessorConfig: + dataPreprocessorConfigClsPath: gigl.src.mocking.mocking_assets.passthrough_preprocessor_config_for_mocked_assets.PassthroughPreprocessorConfigForMockedAssets + dataPreprocessorArgs: + # This argument is specific for the `PassthroughPreprocessorConfigForMockedAssets` preprocessor to indicate which dataset we should be using + mocked_dataset_name: 'cora_homogeneous_node_anchor_edge_features_user_defined_labels' +# TODO(kmonte): Add GS trainer +trainerConfig: + trainerArgs: + # Example argument to trainer + log_every_n_batch: "50" # Frequency in which we log batch information + num_neighbors: "[10, 10]" # Fanout per hop, specified as a string representation of a list for the homogeneous use case + command: python -m examples.link_prediction.homogeneous_training +# TODO(kmonte): Move to user-defined server code +inferencerConfig: + inferencerArgs: + # Example argument to inferencer + log_every_n_batch: "50" # Frequency in which we log batch information + num_neighbors: "[10, 10]" # Fanout per hop, specified as a string representation of a list for the homogeneous use case + inferenceBatchSize: 512 + command: python -m examples.link_prediction.graph_store.homogeneous_inference +sharedConfig: + shouldSkipInference: false + # Model Evaluation is currently only supported for tabularized SGS GiGL pipelines. This will soon be added for in-mem SGS GiGL pipelines. + shouldSkipModelEvaluation: true +taskMetadata: + nodeAnchorBasedLinkPredictionTaskMetadata: + supervisionEdgeTypes: + - dstNodeType: paper + relation: cites + srcNodeType: paper +featureFlags: + should_run_glt_backend: 'True' + data_preprocessor_num_shards: '2' diff --git a/examples/link_prediction/graph_store/homogeneous_inference.py b/examples/link_prediction/graph_store/homogeneous_inference.py new file mode 100644 index 000000000..b6c8dbc97 --- /dev/null +++ b/examples/link_prediction/graph_store/homogeneous_inference.py @@ -0,0 +1,446 @@ +""" +This file contains an example for how to run homogeneous inference on pretrained torch.nn.Module in GiGL (or elsewhere) using new +GLT (GraphLearn-for-PyTorch) bindings that GiGL has. Note that example should be applied to use cases which already have +some pretrained `nn.Module` and are looking to utilize cost-savings with distributed inference. While `run_example_inference` is coupled with +GiGL orchestration, the `_inference_process` function is generic and can be used as references +for writing inference for pipelines not dependent on GiGL orchestration. + +To run this file with GiGL orchestration, set the fields similar to below: + +inferencerConfig: + inferencerArgs: + # Example argument to inferencer + log_every_n_batch: "50" + inferenceBatchSize: 512 + command: python -m examples.link_prediction.homogeneous_inference +featureFlags: + should_run_glt_backend: 'True' + +You can run this example in a full pipeline with `make run_hom_cora_sup_test` from GiGL root. +""" + +import argparse +import gc +import os +import sys +import time + +import torch +import torch.multiprocessing as mp +from examples.link_prediction.models import init_example_gigl_homogeneous_model + +import gigl.distributed +import gigl.distributed.utils +from gigl.common import GcsUri, UriFactory +from gigl.common.data.export import EmbeddingExporter, load_embeddings_to_bigquery +from gigl.common.logger import Logger +from gigl.common.utils.gcs import GcsUtils +from gigl.distributed.graph_store.compute import init_compute_process +from gigl.distributed.graph_store.remote_dist_dataset import RemoteDistDataset +from gigl.distributed.utils import get_graph_store_info +from gigl.env.distributed import GraphStoreInfo +from gigl.nn import LinkPredictionGNN +from gigl.src.common.types import AppliedTaskIdentifier +from gigl.src.common.types.graph_data import NodeType +from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper +from gigl.src.common.utils.bq import BqUtils +from gigl.src.common.utils.model import load_state_dict_from_uri +from gigl.src.inference.lib.assets import InferenceAssets +from gigl.utils.sampling import parse_fanout + +logger = Logger() + +# Default number of inference processes per machine incase one isnt provided in inference args +# i.e. `local_world_size` is not provided, and we can't infer automatically. +# If there are GPUs attached to the machine, we automatically infer to setting +# LOCAL_WORLD_SIZE == # of gpus on the machine. +DEFAULT_CPU_BASED_LOCAL_WORLD_SIZE = 4 + + +@torch.no_grad() +def _inference_process( + # When spawning processes, each process will be assigned a rank ranging + # from [0, num_processes). + local_rank: int, + local_world_size: int, + embedding_gcs_path: GcsUri, + model_state_dict_uri: GcsUri, + inference_batch_size: int, + hid_dim: int, + out_dim: int, + cluster_info: GraphStoreInfo, + inferencer_args: dict[str, str], + inference_node_type: NodeType, + node_feature_dim: int, + edge_feature_dim: int, + gbml_config_pb_wrapper: GbmlConfigPbWrapper, + mp_sharing_dict: dict[str, torch.Tensor], +): + """ + This function is spawned by multiple processes per machine and is responsible for: + 1. Initializing the dataLoader + 2. Running the inference loop to get the embeddings for each anchor node + 3. Writing embeddings to GCS + + Args: + local_rank (int): Process number on the current machine + local_world_size (int): Number of inference processes spawned by each machine + distributed_context (DistributedContext): Distributed context containing information for master_ip_address, rank, and world size + embedding_gcs_path (GcsUri): GCS path to load embeddings from + model_state_dict_uri (GcsUri): GCS path to load model from + inference_batch_size (int): Batch size to use for inference + hid_dim (int): Hidden dimension of the model + out_dim (int): Output dimension of the model + dataset (DistDataset): Loaded Distributed Dataset for inference + inferencer_args (dict[str, str]): Additional arguments for inferencer + inference_node_type (NodeType): Node Type that embeddings should be generated for. This is used to + tag the embeddings written to GCS. + node_feature_dim (int): Input node feature dimension for the model + edge_feature_dim (int): Input edge feature dimension for the model + """ + + device = gigl.distributed.utils.get_available_device( + local_process_rank=local_rank, + ) # The device is automatically inferred based off the local process rank and the available devices + if torch.cuda.is_available(): + # If using GPU, we set the device to the local process rank's GPU + logger.info( + f"Using GPU {device} with index {device.index} on local rank: {local_rank} for inference" + ) + torch.cuda.set_device(device) + # Parses the fanout as a string. For the homogeneous case, the fanouts should be specified as a string of a list of integers, such as "[10, 10]". + fanout = inferencer_args.get("num_neighbors", "[10, 10]") + num_neighbors = parse_fanout(fanout) + + # While the ideal value for `sampling_workers_per_inference_process` has been identified to be between `2` and `4`, this may need some tuning depending on the + # pipeline. We default this value to `4` here for simplicity. A `sampling_workers_per_process` which is too small may not have enough parallelization for + # sampling, which would slow down inference, while a value which is too large may slow down each sampling process due to competing resources, which would also + # then slow down inference. + sampling_workers_per_inference_process: int = int( + inferencer_args.get("sampling_workers_per_inference_process", "4") + ) + + # This value represents the the shared-memory buffer size (bytes) allocated for the channel during sampling, and + # is the place to store pre-fetched data, so if it is too small then prefetching is limited, causing sampling slowdown. This parameter is a string + # with `{numeric_value}{storage_size}`, where storage size could be `MB`, `GB`, etc. We default this value to 4GB, + # but in production may need some tuning. + sampling_worker_shared_channel_size: str = inferencer_args.get( + "sampling_worker_shared_channel_size", "4GB" + ) + + log_every_n_batch = int(inferencer_args.get("log_every_n_batch", "50")) + + init_compute_process(local_rank, cluster_info) + dataset = RemoteDistDataset( + cluster_info, local_rank, mp_sharing_dict=mp_sharing_dict + ) + rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() + logger.info( + f"Local rank {local_rank} in machine {cluster_info.compute_node_rank} has rank {rank}/{world_size} and using device {device} for inference" + ) + input_nodes = dataset.get_node_ids() + print( + f"input_nodes: {[(node.shape, f'{node[0],node[-1]}') for node in input_nodes]}" + ) + logger.info( + f"Rank {rank} got input nodes of shapes: {[node.shape for node in input_nodes]}" + ) + sys.stdout.flush() + + data_loader = gigl.distributed.DistNeighborLoader( + dataset=dataset, + num_neighbors=num_neighbors, + local_process_rank=local_rank, + local_process_world_size=local_world_size, + input_nodes=input_nodes, # Since homogeneous, `None` defaults to using all nodes for inference loop + num_workers=sampling_workers_per_inference_process, + batch_size=inference_batch_size, + pin_memory_device=device, + worker_concurrency=sampling_workers_per_inference_process, + channel_size=sampling_worker_shared_channel_size, + # For large-scale settings, consider setting this field to 30-60 seconds to ensure dataloaders + # don't compete for memory during initialization, causing OOM + process_start_gap_seconds=0, + ) + # Initialize a LinkPredictionGNN model and load parameters from + # the saved model. + model_state_dict = load_state_dict_from_uri( + load_from_uri=model_state_dict_uri, device=device + ) + model: LinkPredictionGNN = init_example_gigl_homogeneous_model( + node_feature_dim=node_feature_dim, + edge_feature_dim=edge_feature_dim, + hid_dim=hid_dim, + out_dim=out_dim, + device=device, + state_dict=model_state_dict, + ) + + # Set the model to evaluation mode for inference. + model.eval() + + logger.info(f"Model initialized on device {device}") + + embedding_filename = ( + f"machine_{cluster_info.compute_node_rank}_local_process_{local_rank}" + ) + + # Get temporary GCS folder to write outputs of inference to. GiGL orchestration automatic cleans this, but + # if running manually, you will need to clean this directory so that retries don't end up with stale files. + gcs_utils = GcsUtils() + gcs_base_uri = GcsUri.join(embedding_gcs_path, embedding_filename) + num_files_at_gcs_path = gcs_utils.count_blobs_in_gcs_path(gcs_base_uri) + if num_files_at_gcs_path > 0: + logger.warning( + f"{num_files_at_gcs_path} files already detected at base gcs path. Cleaning up files at path ... " + ) + gcs_utils.delete_files_in_bucket_dir(gcs_base_uri) + + # GiGL class for exporting embeddings to GCS. This is achieved by writing ids and embeddings to an in-memory buffer which gets + # flushed to GCS. Setting the min_shard_size_threshold_bytes field of this class sets the frequency of flushing to GCS, and defaults + # to only flushing when flush_embeddings() is called explicitly or after exiting via a context manager. + exporter = EmbeddingExporter(export_dir=gcs_base_uri) + + # We add a barrier here so that all machines and processes have initialized their dataloader at the start of the inference loop. Otherwise, on-the-fly subgraph + # sampling may fail. + sys.stdout.flush() + torch.distributed.barrier() + + t = time.time() + data_loading_start_time = time.time() + inference_start_time = time.time() + cumulative_data_loading_time = 0.0 + cumulative_inference_time = 0.0 + + # Begin inference loop + + # Iterating through the dataloader yields a `torch_geometric.data.Data` type + for batch_idx, data in enumerate(data_loader): + cumulative_data_loading_time += time.time() - data_loading_start_time + + inference_start_time = time.time() + + # These arguments to forward are specific to the GiGL LinkPredictionGNN model. + # If just using a nn.Module, you can just use output = model(data) + output = model(data=data, device=device) + + # The anchor node IDs are contained inside of the .batch field of the data + node_ids = data.batch.cpu() + + # Only the first `batch_size` rows of the node embeddings contain the embeddings of the anchor nodes + node_embeddings = output[: data.batch_size].cpu() + + # We add ids and embeddings to the in-memory buffer + exporter.add_embedding( + id_batch=node_ids, + embedding_batch=node_embeddings, + embedding_type=str(inference_node_type), + ) + + cumulative_inference_time += time.time() - inference_start_time + + if batch_idx > 0 and batch_idx % log_every_n_batch == 0: + logger.info( + f"rank {rank} processed {batch_idx} batches. " + f"{log_every_n_batch} batches took {time.time() - t:.2f} seconds. " + f"Among them, data loading took {cumulative_data_loading_time:.2f} seconds " + f"and model inference took {cumulative_inference_time:.2f} seconds." + ) + sys.stdout.flush() + t = time.time() + cumulative_data_loading_time = 0 + cumulative_inference_time = 0 + + data_loading_start_time = time.time() + + logger.info(f"--- Rank {rank} finished inference.") + + write_embedding_start_time = time.time() + # Flushes all remaining embeddings to GCS + exporter.flush_records() + + logger.info( + f"--- Rank {rank} finished writing embeddings to GCS, which took {time.time()-write_embedding_start_time:.2f} seconds" + ) + logger.info( + f"--- Rank {rank} wrote embeddings to GCS at {gcs_base_uri} over batches" + ) + sys.stdout.flush() + # We first call barrier to ensure that all machines and processes have finished inference. + # Only once all machines have finished inference is it safe to shutdown the data loader. + # Otherwise, processes which are still sampling *will* fail as the loaders they are trying to communicatate with will be shutdown. + # We then call `gc.collect()` to cleanup the memory used by the data_loader on the current machine. + + torch.distributed.barrier() + + data_loader.shutdown() + gc.collect() + + logger.info( + f"--- All machines local rank {local_rank} finished inference. Deleted data loader" + ) + output_bq_table_path = InferenceAssets.get_enumerated_embedding_table_path( + gbml_config_pb_wrapper, inference_node_type + ) + bq_project_id, bq_dataset_id, bq_table_name = BqUtils.parse_bq_table_path( + bq_table_path=output_bq_table_path + ) + # After inference is finished, we use the process on the Machine 0 to load embeddings from GCS to BQ. + if rank == 0: + logger.info("--- Machine 0 triggers loading embeddings from GCS to BigQuery") + + # The `load_embeddings_to_bigquery` API returns a BigQuery LoadJob object + # representing the load operation, which allows user to monitor and retrieve + # details about the job status and result. + _ = load_embeddings_to_bigquery( + gcs_folder=embedding_gcs_path, + project_id=bq_project_id, + dataset_id=bq_dataset_id, + table_id=bq_table_name, + ) + + sys.stdout.flush() + + +def _run_example_inference( + job_name: str, + task_config_uri: str, +) -> None: + """ + Runs an example inference pipeline using GiGL Orchestration. + Args: + job_name (str): Name of current job + task_config_uri (str): Path to frozen GBMLConfigPbWrapper + """ + program_start_time = time.time() + + # The main process per machine needs to be able to talk with each other to partition and synchronize the graph data. + # Thus, the user is responsible here for 1. spinning up a single process per machine, + # and 2. init_process_group amongst these processes. + # Assuming this is spinning up inside VAI; it already sets up the env:// init method for us; thus we don't need anything + # special here. + torch.distributed.init_process_group(backend="gloo") + + logger.info( + f"Took {time.time() - program_start_time:.2f} seconds to connect worker pool" + ) + logger.info( + f"World size: {torch.distributed.get_world_size()}, rank: {torch.distributed.get_rank()}, OS world size: {os.environ['WORLD_SIZE']}, OS rank: {os.environ['RANK']}" + ) + + cluster_info = get_graph_store_info() + logger.info(f"Cluster info: {cluster_info}") + torch.distributed.destroy_process_group() + # Read from GbmlConfig for preprocessed data metadata, GNN model uri, and bigquery embedding table path, and additional inference args + gbml_config_pb_wrapper = GbmlConfigPbWrapper.get_gbml_config_pb_wrapper_from_uri( + gbml_config_uri=UriFactory.create_uri(task_config_uri) + ) + model_uri = UriFactory.create_uri( + gbml_config_pb_wrapper.gbml_config_pb.shared_config.trained_model_metadata.trained_model_uri + ) + graph_metadata = gbml_config_pb_wrapper.graph_metadata_pb_wrapper + output_bq_table_path = InferenceAssets.get_enumerated_embedding_table_path( + gbml_config_pb_wrapper, graph_metadata.homogeneous_node_type + ) + bq_project_id, bq_dataset_id, bq_table_name = BqUtils.parse_bq_table_path( + bq_table_path=output_bq_table_path + ) + # We write embeddings to a temporary GCS path during the inference loop, since writing directly to bigquery for each embedding is slow. + # After inference has finished, we then load all embeddings to bigquery from GCS. + embedding_output_gcs_folder = InferenceAssets.get_gcs_asset_write_path_prefix( + applied_task_identifier=AppliedTaskIdentifier(job_name), + bq_table_path=output_bq_table_path, + ) + node_feature_dim = gbml_config_pb_wrapper.preprocessed_metadata_pb_wrapper.condensed_node_type_to_feature_dim_map[ + graph_metadata.homogeneous_condensed_node_type + ] + edge_feature_dim = gbml_config_pb_wrapper.preprocessed_metadata_pb_wrapper.condensed_edge_type_to_feature_dim_map[ + graph_metadata.homogeneous_condensed_edge_type + ] + + inferencer_args = dict(gbml_config_pb_wrapper.inferencer_config.inferencer_args) + inference_batch_size = gbml_config_pb_wrapper.inferencer_config.inference_batch_size + + hid_dim = int(inferencer_args.get("hid_dim", "16")) + out_dim = int(inferencer_args.get("out_dim", "16")) + + local_world_size: int + arg_local_world_size = inferencer_args.get("local_world_size") + if arg_local_world_size is not None: + local_world_size = int(arg_local_world_size) + logger.info(f"Using local_world_size from inferencer_args: {local_world_size}") + if torch.cuda.is_available() and local_world_size != torch.cuda.device_count(): + logger.warning( + f"local_world_size {local_world_size} does not match the number of GPUs {torch.cuda.device_count()}. " + "This may lead to unexpected failures with NCCL communication incase GPUs are being used for " + + "training/inference. Consider setting local_world_size to the number of GPUs." + ) + else: + if torch.cuda.is_available() and torch.cuda.device_count() > 0: + # If GPUs are available, we set the local_world_size to the number of GPUs + local_world_size = torch.cuda.device_count() + logger.info( + f"Detected {local_world_size} GPUs. Thus, setting local_world_size to {local_world_size}" + ) + else: + # If no GPUs are available, we set the local_world_size to the number of inference processes per machine + logger.info( + f"No GPUs detected. Thus, setting local_world_size to `{DEFAULT_CPU_BASED_LOCAL_WORLD_SIZE}`" + ) + local_world_size = DEFAULT_CPU_BASED_LOCAL_WORLD_SIZE + + mp_sharing_dict = mp.Manager().dict() + + inference_start_time = time.time() + sys.stdout.flush() + # When using mp.spawn with `nprocs`, the first argument is implicitly set to be the process number on the current machine. + mp.spawn( + fn=_inference_process, + args=( + local_world_size, # local_world_size + embedding_output_gcs_folder, # embedding_gcs_path + model_uri, # model_state_dict_uri + inference_batch_size, # inference_batch_size + hid_dim, # hid_dim + out_dim, # out_dim + cluster_info, # cluster_info + inferencer_args, # inferencer_args + graph_metadata.homogeneous_node_type, # inference_node_type + node_feature_dim, # node_feature_dim + edge_feature_dim, # edge_feature_dim + gbml_config_pb_wrapper, # gbml_config_pb_wrapper + mp_sharing_dict, # mp_sharing_dict + ), + nprocs=local_world_size, + join=True, + ) + + logger.info( + f"--- Inference finished, which took {time.time()-inference_start_time:.2f} seconds" + ) + + logger.info( + f"--- Program finished, which took {time.time()-program_start_time:.2f} seconds" + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Arguments for distributed model inference on VertexAI" + ) + parser.add_argument( + "--job_name", + type=str, + help="Inference job name", + ) + parser.add_argument("--task_config_uri", type=str, help="Gbml config uri") + + # We use parse_known_args instead of parse_args since we only need job_name and task_config_uri for distributed inference + args, unused_args = parser.parse_known_args() + logger.info(f"Unused arguments: {unused_args}") + # We only need `job_name` and `task_config_uri` for running inference + _run_example_inference( + job_name=args.job_name, + task_config_uri=args.task_config_uri, + ) diff --git a/testing/e2e_tests/e2e_tests.yaml b/testing/e2e_tests/e2e_tests.yaml index 44b4445f0..3303b5855 100644 --- a/testing/e2e_tests/e2e_tests.yaml +++ b/testing/e2e_tests/e2e_tests.yaml @@ -19,3 +19,6 @@ tests: het_dblp_sup_test: task_config_uri: "examples/link_prediction/configs/e2e_het_dblp_sup_task_config.yaml" resource_config_uri: "${oc.env:GIGL_TEST_IN_MEMORY_DEFAULT_RESOURCE_CONFIG,deployment/configs/e2e_glt_resource_config.yaml}" + hom_cora_sup_gs_test: + task_config_uri: "examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_task_config.yaml" + resource_config_uri: "${oc.env:GIGL_TEST_IN_MEMORY_DEFAULT_GRAPH_STORE_RESOURCE_CONFIG,deployment/configs/e2e_glt_gs_resource_config.yaml}" From 1bccd1a3a382f92d7f72d60d2f4a445b626c51b8 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Tue, 13 Jan 2026 18:03:40 +0000 Subject: [PATCH 04/16] only gs test --- testing/e2e_tests/e2e_tests.yaml | 36 ++++++++++++++++---------------- 1 file changed, 18 insertions(+), 18 deletions(-) diff --git a/testing/e2e_tests/e2e_tests.yaml b/testing/e2e_tests/e2e_tests.yaml index 3303b5855..18084b743 100644 --- a/testing/e2e_tests/e2e_tests.yaml +++ b/testing/e2e_tests/e2e_tests.yaml @@ -1,24 +1,24 @@ # Combined e2e test configurations for GiGL # This file contains all the test specifications that can be run via the e2e test script tests: - cora_nalp_test: - task_config_uri: "python/gigl/src/mocking/configs/e2e_node_anchor_based_link_prediction_template_gbml_config.yaml" - resource_config_uri: "${oc.env:GIGL_TEST_DEFAULT_RESOURCE_CONFIG,deployment/configs/e2e_cicd_resource_config.yaml}" - cora_snc_test: - task_config_uri: "python/gigl/src/mocking/configs/e2e_supervised_node_classification_template_gbml_config.yaml" - resource_config_uri: "${oc.env:GIGL_TEST_DEFAULT_RESOURCE_CONFIG,deployment/configs/e2e_cicd_resource_config.yaml}" - cora_udl_test: - task_config_uri: "python/gigl/src/mocking/configs/e2e_udl_node_anchor_based_link_prediction_template_gbml_config.yaml" - resource_config_uri: "${oc.env:GIGL_TEST_DEFAULT_RESOURCE_CONFIG,deployment/configs/e2e_cicd_resource_config.yaml}" - dblp_nalp_test: - task_config_uri: "python/gigl/src/mocking/configs/dblp_node_anchor_based_link_prediction_template_gbml_config.yaml" - resource_config_uri: "${oc.env:GIGL_TEST_DEFAULT_RESOURCE_CONFIG,deployment/configs/e2e_cicd_resource_config.yaml}" - hom_cora_sup_test: - task_config_uri: "examples/link_prediction/configs/e2e_hom_cora_sup_task_config.yaml" - resource_config_uri: "${oc.env:GIGL_TEST_IN_MEMORY_DEFAULT_RESOURCE_CONFIG,deployment/configs/e2e_glt_resource_config.yaml}" - het_dblp_sup_test: - task_config_uri: "examples/link_prediction/configs/e2e_het_dblp_sup_task_config.yaml" - resource_config_uri: "${oc.env:GIGL_TEST_IN_MEMORY_DEFAULT_RESOURCE_CONFIG,deployment/configs/e2e_glt_resource_config.yaml}" + # cora_nalp_test: + # task_config_uri: "python/gigl/src/mocking/configs/e2e_node_anchor_based_link_prediction_template_gbml_config.yaml" + # resource_config_uri: "${oc.env:GIGL_TEST_DEFAULT_RESOURCE_CONFIG,deployment/configs/e2e_cicd_resource_config.yaml}" + # cora_snc_test: + # task_config_uri: "python/gigl/src/mocking/configs/e2e_supervised_node_classification_template_gbml_config.yaml" + # resource_config_uri: "${oc.env:GIGL_TEST_DEFAULT_RESOURCE_CONFIG,deployment/configs/e2e_cicd_resource_config.yaml}" + # cora_udl_test: + # task_config_uri: "python/gigl/src/mocking/configs/e2e_udl_node_anchor_based_link_prediction_template_gbml_config.yaml" + # resource_config_uri: "${oc.env:GIGL_TEST_DEFAULT_RESOURCE_CONFIG,deployment/configs/e2e_cicd_resource_config.yaml}" + # dblp_nalp_test: + # task_config_uri: "python/gigl/src/mocking/configs/dblp_node_anchor_based_link_prediction_template_gbml_config.yaml" + # resource_config_uri: "${oc.env:GIGL_TEST_DEFAULT_RESOURCE_CONFIG,deployment/configs/e2e_cicd_resource_config.yaml}" + # hom_cora_sup_test: + # task_config_uri: "examples/link_prediction/configs/e2e_hom_cora_sup_task_config.yaml" + # resource_config_uri: "${oc.env:GIGL_TEST_IN_MEMORY_DEFAULT_RESOURCE_CONFIG,deployment/configs/e2e_glt_resource_config.yaml}" + # het_dblp_sup_test: + # task_config_uri: "examples/link_prediction/configs/e2e_het_dblp_sup_task_config.yaml" + # resource_config_uri: "${oc.env:GIGL_TEST_IN_MEMORY_DEFAULT_RESOURCE_CONFIG,deployment/configs/e2e_glt_resource_config.yaml}" hom_cora_sup_gs_test: task_config_uri: "examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_task_config.yaml" resource_config_uri: "${oc.env:GIGL_TEST_IN_MEMORY_DEFAULT_GRAPH_STORE_RESOURCE_CONFIG,deployment/configs/e2e_glt_gs_resource_config.yaml}" From 3678d705bf59bc27afead38455706e8b4a199382 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Tue, 13 Jan 2026 18:24:48 +0000 Subject: [PATCH 05/16] fix path --- testing/e2e_tests/e2e_tests.yaml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/testing/e2e_tests/e2e_tests.yaml b/testing/e2e_tests/e2e_tests.yaml index 18084b743..4fb8a2c69 100644 --- a/testing/e2e_tests/e2e_tests.yaml +++ b/testing/e2e_tests/e2e_tests.yaml @@ -20,5 +20,5 @@ tests: # task_config_uri: "examples/link_prediction/configs/e2e_het_dblp_sup_task_config.yaml" # resource_config_uri: "${oc.env:GIGL_TEST_IN_MEMORY_DEFAULT_RESOURCE_CONFIG,deployment/configs/e2e_glt_resource_config.yaml}" hom_cora_sup_gs_test: - task_config_uri: "examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_task_config.yaml" + task_config_uri: "examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_task_config.yaml" resource_config_uri: "${oc.env:GIGL_TEST_IN_MEMORY_DEFAULT_GRAPH_STORE_RESOURCE_CONFIG,deployment/configs/e2e_glt_gs_resource_config.yaml}" From bcb93e623eb64b022b3f54d32bad1baed68b116b Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Tue, 13 Jan 2026 21:14:59 +0000 Subject: [PATCH 06/16] logs --- .../link_prediction/homogeneous_training.py | 24 ++++++++++++++++++- python/gigl/common/logger.py | 2 ++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/examples/link_prediction/homogeneous_training.py b/examples/link_prediction/homogeneous_training.py index 85c8c4e48..6c7887080 100644 --- a/examples/link_prediction/homogeneous_training.py +++ b/examples/link_prediction/homogeneous_training.py @@ -26,6 +26,7 @@ import statistics import time from collections.abc import Iterator +import sys from typing import Literal, Optional import torch @@ -319,7 +320,10 @@ def _training_process( logger.info(f"---Rank {rank} training process set device {device}") logger.info(f"---Rank {rank} training process group initialized") - + logger.info(f"graph: {dataset.graph=}") + logger.info(f"Node IDs: {dataset.node_ids=}") + logger.info(f"Num neighbors: {num_neighbors=}") + sys.stdout.flush() loss_fn = RetrievalLoss( loss=torch.nn.CrossEntropyLoss(reduction="mean"), temperature=0.07, @@ -709,6 +713,24 @@ def _run_example_training( num_val_batches={num_val_batches}, \ val_every_n_batch={val_every_n_batch}" ) + print( + f"Got training args local_world_size={local_world_size}, \ + num_neighbors={num_neighbors}, \ + sampling_workers_per_process={sampling_workers_per_process}, \ + main_batch_size={main_batch_size}, \ + random_batch_size={random_batch_size}, \ + hid_dim={hid_dim}, \ + out_dim={out_dim}, \ + sampling_worker_shared_channel_size={sampling_worker_shared_channel_size}, \ + process_start_gap_seconds={process_start_gap_seconds}, \ + log_every_n_batch={log_every_n_batch}, \ + learning_rate={learning_rate}, \ + weight_decay={weight_decay}, \ + num_max_train_batches={num_max_train_batches}, \ + num_val_batches={num_val_batches}, \ + val_every_n_batch={val_every_n_batch}" + ) + # This `init_process_group` is only called to get the master_ip_address, master port, and rank/world_size fields which help with partitioning, sampling, # and distributed training/testing. We can use `gloo` here since these fields we are extracting don't require GPU capabilities provided by `nccl`. diff --git a/python/gigl/common/logger.py b/python/gigl/common/logger.py index 52102f3e9..175e0f2f6 100644 --- a/python/gigl/common/logger.py +++ b/python/gigl/common/logger.py @@ -1,6 +1,7 @@ import logging import os import pathlib +import sys from datetime import datetime from typing import Any, MutableMapping, Optional @@ -71,4 +72,5 @@ def process(self, msg: str, kwargs: MutableMapping[str, Any]) -> Any: return msg, kwargs def __getattr__(self, name: str): + sys.stdout.flush() return getattr(self._logger, name) From 359398f9315133c3e5971ad84db8b9fda9c71203 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Tue, 13 Jan 2026 21:16:52 +0000 Subject: [PATCH 07/16] other test --- testing/e2e_tests/e2e_tests.yaml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/testing/e2e_tests/e2e_tests.yaml b/testing/e2e_tests/e2e_tests.yaml index 4fb8a2c69..f65f8944d 100644 --- a/testing/e2e_tests/e2e_tests.yaml +++ b/testing/e2e_tests/e2e_tests.yaml @@ -13,9 +13,9 @@ tests: # dblp_nalp_test: # task_config_uri: "python/gigl/src/mocking/configs/dblp_node_anchor_based_link_prediction_template_gbml_config.yaml" # resource_config_uri: "${oc.env:GIGL_TEST_DEFAULT_RESOURCE_CONFIG,deployment/configs/e2e_cicd_resource_config.yaml}" - # hom_cora_sup_test: - # task_config_uri: "examples/link_prediction/configs/e2e_hom_cora_sup_task_config.yaml" - # resource_config_uri: "${oc.env:GIGL_TEST_IN_MEMORY_DEFAULT_RESOURCE_CONFIG,deployment/configs/e2e_glt_resource_config.yaml}" + hom_cora_sup_test: + task_config_uri: "examples/link_prediction/configs/e2e_hom_cora_sup_task_config.yaml" + resource_config_uri: "${oc.env:GIGL_TEST_IN_MEMORY_DEFAULT_RESOURCE_CONFIG,deployment/configs/e2e_glt_resource_config.yaml}" # het_dblp_sup_test: # task_config_uri: "examples/link_prediction/configs/e2e_het_dblp_sup_task_config.yaml" # resource_config_uri: "${oc.env:GIGL_TEST_IN_MEMORY_DEFAULT_RESOURCE_CONFIG,deployment/configs/e2e_glt_resource_config.yaml}" From c804750fb75a41b740002f250e30c23d65b64d96 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Wed, 14 Jan 2026 21:26:33 +0000 Subject: [PATCH 08/16] fixes --- python/gigl/common/logger.py | 2 - .../distributed/distributed_neighborloader.py | 104 ++++++++++++++---- .../graph_store/remote_dist_dataset.py | 12 ++ .../distributed/graph_store/storage_utils.py | 14 +++ .../gigl/distributed/utils/neighborloader.py | 14 ++- python/gigl/distributed/utils/networking.py | 17 ++- 6 files changed, 136 insertions(+), 27 deletions(-) diff --git a/python/gigl/common/logger.py b/python/gigl/common/logger.py index 175e0f2f6..52102f3e9 100644 --- a/python/gigl/common/logger.py +++ b/python/gigl/common/logger.py @@ -1,7 +1,6 @@ import logging import os import pathlib -import sys from datetime import datetime from typing import Any, MutableMapping, Optional @@ -72,5 +71,4 @@ def process(self, msg: str, kwargs: MutableMapping[str, Any]) -> Any: return msg, kwargs def __getattr__(self, name: str): - sys.stdout.flush() return getattr(self._logger, name) diff --git a/python/gigl/distributed/distributed_neighborloader.py b/python/gigl/distributed/distributed_neighborloader.py index bcf4c2d5e..1706628ae 100644 --- a/python/gigl/distributed/distributed_neighborloader.py +++ b/python/gigl/distributed/distributed_neighborloader.py @@ -19,7 +19,8 @@ from gigl.distributed.dist_dataset import DistDataset from gigl.distributed.graph_store.remote_dist_dataset import RemoteDistDataset from gigl.distributed.utils.neighborloader import ( - DatasetMetadata, + DatasetSchema, + SamplingClusterSetup, labeled_to_homogeneous, patch_fanout_for_sampling, set_missing_features, @@ -73,7 +74,8 @@ def __init__( https://pytorch-geometric.readthedocs.io/en/2.5.2/_modules/torch_geometric/distributed/dist_neighbor_loader.html#DistNeighborLoader Args: - dataset (DistDataset | RemoteDistDataset): The dataset to sample from. Must be a "RemoteDistDataset" if using Graph Store mode. + dataset (DistDataset | RemoteDistDataset): The dataset to sample from. + If this is a `RemoteDistDataset`, then we assumed to be in "Graph Store" mode. num_neighbors (list[int] or dict[Tuple[str, str, str], list[int]]): The number of neighbors to sample for each node in each iteration. If an entry is set to `-1`, all neighbors will be included. @@ -82,7 +84,7 @@ def __init__( context (deprecated - will be removed soon) (DistributedContext): Distributed context information of the current process. local_process_rank (deprecated - will be removed soon) (int): Required if context provided. The local rank of the current process within a node. local_process_world_size (deprecated - will be removed soon)(int): Required if context provided. The total number of processes within a node. - input_nodes (Tenor | Tuple[NodeType, Tenor] | list[Tenor] | Tuple[NodeType, list[Tenor]]): + input_nodes (Tensor | Tuple[NodeType, Tensor] | list[Tensor] | Tuple[NodeType, list[Tensor]]): The nodes to start sampling from. It is of type `torch.LongTensor` for homogeneous graphs. If set to `None` for homogeneous settings, all nodes will be considered. @@ -91,6 +93,8 @@ def __init__( For Graph Store mode, this must be a tuple of (NodeType, list[Tenor]) or list[Tenor]. Where each Tensor in the list is the node ids to sample from, for each server. e.g. [[10, 20], [30, 40]] means sample from nodes 10 and 20 on server 0, and nodes 30 and 40 on server 1. + If a Graph Store input (e.g. list[Tensor]) is provided to colocated mode, or colocated input (e.g. Tensor) is provided to Graph Store mode, + then an error will be raised. num_workers (int): How many workers to use (subprocesses to spwan) for distributed neighbor sampling of the current process. (default: ``1``). batch_size (int, optional): how many samples per batch to load @@ -194,7 +198,11 @@ def __init__( local_process_rank, local_process_world_size, ) # delete deprecated vars so we don't accidentally use them. - + if isinstance(dataset, RemoteDistDataset): + self._sampling_cluster_setup = SamplingClusterSetup.GRAPH_STORE + else: + self._sampling_cluster_setup = SamplingClusterSetup.COLOCATED + logger.info(f"Sampling cluster setup: {self._sampling_cluster_setup.value}") device = ( pin_memory_device if pin_memory_device @@ -204,8 +212,10 @@ def __init__( ) # Determines if the node ids passed in are heterogeneous or homogeneous. - self._is_labeled_heterogeneous = False - if isinstance(dataset, DistDataset): + if self._sampling_cluster_setup == SamplingClusterSetup.COLOCATED: + assert isinstance( + dataset, DistDataset + ), "When using colocated mode, dataset must be a DistDataset." input_data, worker_options, dataset_metadata = self._setup_for_colocated( input_nodes, dataset, @@ -221,7 +231,10 @@ def __init__( channel_size, num_cpu_threads, ) - else: # RemoteDistDataset + else: # Graph Store mode + assert isinstance( + dataset, RemoteDistDataset + ), "When using Graph Store mode, dataset must be a RemoteDistDataset." input_data, worker_options, dataset_metadata = self._setup_for_graph_store( input_nodes, dataset, @@ -232,13 +245,14 @@ def __init__( self._node_feature_info = dataset_metadata.node_feature_info self._edge_feature_info = dataset_metadata.edge_feature_info + logger.info(f"num_neighbors before patch: {num_neighbors}") num_neighbors = patch_fanout_for_sampling( - list(dataset_metadata.edge_feature_info.keys()) - if isinstance(dataset_metadata.edge_feature_info, dict) - else None, - num_neighbors, + edge_types=dataset_metadata.edge_types, + num_neighbors=num_neighbors, + ) + logger.info( + f"num_neighbors: {num_neighbors}, edge_types: {dataset_metadata.edge_types}" ) - sampling_config = SamplingConfig( sampling_type=SamplingType.NODE, num_neighbors=num_neighbors, @@ -259,7 +273,7 @@ def __init__( ) torch.distributed.destroy_process_group() - if isinstance(dataset, DistDataset): + if self._sampling_cluster_setup == SamplingClusterSetup.COLOCATED: super().__init__( dataset if isinstance(dataset, DistDataset) else None, input_data, @@ -271,11 +285,47 @@ def __init__( # For Graph Store mode, we need to start the communcation between compute and storage nodes sequentially, by compute node. # E.g. intialize connections between compute node 0 and storage nodes 0, 1, 2, 3, then compute node 1 and storage nodes 0, 1, 2, 3, etc. # Note that each compute node may have multiple connections to each storage node, once per compute process. - # E.g. if there are 4 gpus per compute node, then there will be 4 connections to each storage node. + # It's important to distinguish "compute node" (e.g. physical compute machine) from "compute process" (e.g. process running on the compute node). + # Since in practice we have multiple compute processes per compute node, and each compute process needs to initialize the connection to the storage nodes. + # E.g. if there are 4 gpus per compute node, then there will be 4 connections from each compute node to each storage node. # We need to this because if we don't, then there is a race condition when initalizing the samplers on the storage nodes [1] # Where since the lock is per *server* (e.g. per storage node), if we try to start one connection from compute node 0, and compute node 1 # Then we deadlock and fail. + # Specifically, the race condition happens in `DistLoader.__init__` when it initializes the sampling producers on the storage nodes. [2] # [1]: https://github.com/alibaba/graphlearn-for-pytorch/blob/main/graphlearn_torch/python/distributed/dist_server.py#L129-L167 + # [2]: https://github.com/alibaba/graphlearn-for-pytorch/blob/88ff111ac0d9e45c6c9d2d18cfc5883dca07e9f9/graphlearn_torch/python/distributed/dist_loader.py#L187-L193 + + # See below for a connection setup. + # ╔═══════════════════════════════════════════════════════════════════════════════════════╗ + # ║ COMPUTE TO STORAGE NODE CONNECTIONS ║ + # ╚═══════════════════════════════════════════════════════════════════════════════════════╝ + + # COMPUTE NODES STORAGE NODES + # ═════════════ ═════════════ + + # ┌──────────────────────┐ (1) ┌───────────────┐ + # │ COMPUTE NODE 0 │ │ │ + # │ ┌────┬────┬────┬────┤ ══════════════════════════════════│ STORAGE 0 │ + # │ │GPU │GPU │GPU │GPU │ ╱ │ │ + # │ │ 0 │ 1 │ 2 │ 3 │ ════════════════════╲ ╱ └───────────────┘ + # │ └────┴────┴────┴────┤ (2) ╲ ╱ + # └──────────────────────┘ ╲ ╱ + # ╳ + # (3) ╱ ╲ (4) + # ┌──────────────────────┐ ╱ ╲ ┌───────────────┐ + # │ COMPUTE NODE 1 │ ╱ ╲ │ │ + # │ ┌────┬────┬────┬────┤ ═════════════════╱ ═│ STORAGE 1 │ + # │ │GPU │GPU │GPU │GPU │ │ │ + # │ │ 0 │ 1 │ 2 │ 3 │ ══════════════════════════════════│ │ + # │ └────┴────┴────┴────┤ └───────────────┘ + # └──────────────────────┘ + + # ┌─────────────────────────────────────────────────────────────────────────────┐ + # │ (1) Compute Node 0 → Storage 0 (4 connections, one per GPU) │ + # │ (2) Compute Node 0 → Storage 1 (4 connections, one per GPU) │ + # │ (3) Compute Node 1 → Storage 0 (4 connections, one per GPU) │ + # │ (4) Compute Node 1 → Storage 1 (4 connections, one per GPU) │ + # └─────────────────────────────────────────────────────────────────────────────┘ node_rank = dataset.cluster_info.compute_node_rank for target_node_rank in range(dataset.cluster_info.num_compute_nodes): if node_rank == target_node_rank: @@ -302,7 +352,7 @@ def _setup_for_graph_store( ], dataset: RemoteDistDataset, num_workers: int, - ) -> tuple[NodeSamplerInput, RemoteDistSamplingWorkerOptions, DatasetMetadata]: + ) -> tuple[NodeSamplerInput, RemoteDistSamplingWorkerOptions, DatasetSchema]: if input_nodes is None: raise ValueError( f"When using Graph Store mode, input nodes must be provided, received {input_nodes}" @@ -367,7 +417,7 @@ def _setup_for_graph_store( len(edge_feature_info) == 1 and DEFAULT_HOMOGENEOUS_EDGE_TYPE in edge_feature_info ): - input_type: NodeType | None = DEFAULT_HOMOGENEOUS_NODE_TYPE + input_type: Optional[NodeType] = DEFAULT_HOMOGENEOUS_NODE_TYPE else: input_type = fallback_input_type elif require_edge_feature_info: @@ -377,6 +427,15 @@ def _setup_for_graph_store( else: input_type = None + if ( + input_type is not None + and isinstance(node_feature_info, dict) + and input_type not in node_feature_info.keys() + ): + raise ValueError( + f"Input type {input_type} is not in node node types: {node_feature_info.keys()}" + ) + input_data = [ NodeSamplerInput(node=node, input_type=input_type) for node in nodes ] @@ -384,8 +443,9 @@ def _setup_for_graph_store( return ( input_data, worker_options, - DatasetMetadata( + DatasetSchema( is_labeled_heterogeneous=is_labeled_heterogeneous, + edge_types=dataset.get_edge_types(), node_feature_info=node_feature_info, edge_feature_info=edge_feature_info, edge_dir=dataset.get_edge_dir(), @@ -414,7 +474,7 @@ def _setup_for_colocated( worker_concurrency: int, channel_size: str, num_cpu_threads: Optional[int], - ) -> tuple[NodeSamplerInput, MpDistSamplingWorkerOptions, DatasetMetadata]: + ) -> tuple[NodeSamplerInput, MpDistSamplingWorkerOptions, DatasetSchema]: if input_nodes is None: if dataset.node_ids is None: raise ValueError( @@ -527,11 +587,17 @@ def _setup_for_colocated( pin_memory=device.type == "cuda", ) + if isinstance(dataset.graph, dict): + edge_types = list(dataset.graph.keys()) + else: + edge_types = None + return ( input_data, worker_options, - DatasetMetadata( + DatasetSchema( is_labeled_heterogeneous=is_labeled_heterogeneous, + edge_types=edge_types, node_feature_info=dataset.node_feature_info, edge_feature_info=dataset.edge_feature_info, edge_dir=dataset.edge_dir, diff --git a/python/gigl/distributed/graph_store/remote_dist_dataset.py b/python/gigl/distributed/graph_store/remote_dist_dataset.py index 18f4177bc..4a534a864 100644 --- a/python/gigl/distributed/graph_store/remote_dist_dataset.py +++ b/python/gigl/distributed/graph_store/remote_dist_dataset.py @@ -9,6 +9,7 @@ from gigl.distributed.graph_store.storage_utils import ( get_edge_dir, get_edge_feature_info, + get_edge_types, get_node_feature_info, get_node_ids_for_rank, ) @@ -213,3 +214,14 @@ def get_free_ports_on_storage_cluster(self, num_ports: int) -> list[int]: torch.distributed.broadcast_object_list(ports, src=0) logger.info(f"Compute rank {compute_cluster_rank} received free ports: {ports}") return ports + + def get_edge_types(self) -> Optional[list[EdgeType]]: + """Get the edge types from the registered dataset. + + Returns: + The edge types. + """ + return request_server( + 0, + get_edge_types, + ) diff --git a/python/gigl/distributed/graph_store/storage_utils.py b/python/gigl/distributed/graph_store/storage_utils.py index ef0b055ec..331eebd7c 100644 --- a/python/gigl/distributed/graph_store/storage_utils.py +++ b/python/gigl/distributed/graph_store/storage_utils.py @@ -141,3 +141,17 @@ def get_node_ids_for_rank( f"Node ids must be a torch.Tensor or a dict[NodeType, torch.Tensor], got {type(_dataset.node_ids)}" ) return shard_nodes_by_process(nodes, rank, world_size) + + +def get_edge_types() -> Optional[list[EdgeType]]: + """Get the edge types from the registered dataset. + + Returns: + The edge types. + """ + if _dataset is None: + raise _NO_DATASET_ERROR + if isinstance(_dataset.graph, dict): + return list(_dataset.graph.keys()) + else: + return None diff --git a/python/gigl/distributed/utils/neighborloader.py b/python/gigl/distributed/utils/neighborloader.py index 9aa587cfd..54b687c98 100644 --- a/python/gigl/distributed/utils/neighborloader.py +++ b/python/gigl/distributed/utils/neighborloader.py @@ -2,6 +2,7 @@ from collections import abc from copy import deepcopy from dataclasses import dataclass +from enum import Enum from typing import Literal, Optional, TypeVar, Union import torch @@ -16,14 +17,25 @@ _GraphType = TypeVar("_GraphType", Data, HeteroData) +class SamplingClusterSetup(Enum): + """ + The setup of the sampling cluster. + """ + + COLOCATED = "colocated" + GRAPH_STORE = "graph_store" + + @dataclass(frozen=True) -class DatasetMetadata: +class DatasetSchema: """ Shared metadata between the local and remote datasets. """ # If the dataset is labeled heterogeneous. E.g. one node type, one edge type, and "label" edges. is_labeled_heterogeneous: bool + # List of all edge types in the graph. + edge_types: Optional[list[EdgeType]] # Node feature info. node_feature_info: Optional[Union[FeatureInfo, dict[NodeType, FeatureInfo]]] # Edge feature info. diff --git a/python/gigl/distributed/utils/networking.py b/python/gigl/distributed/utils/networking.py index e9d33921c..e2e1f320e 100644 --- a/python/gigl/distributed/utils/networking.py +++ b/python/gigl/distributed/utils/networking.py @@ -168,11 +168,12 @@ def get_internal_ip_from_node( # Other nodes will receive the master's IP via broadcast ip_list = [None] - device = torch.device("cpu" if not torch.cuda.is_available() else "cuda") - + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") torch.distributed.broadcast_object_list(ip_list, src=node_rank, device=device) node_ip = ip_list[0] - logger.info(f"Rank {rank} received master node's internal IP: {node_ip}") + logger.info( + f"Rank {rank} received master node's internal IP: {node_ip} on device {device}" + ) assert node_ip is not None, "Could not retrieve master node's internal IP" return node_ip @@ -244,15 +245,21 @@ def get_graph_store_info() -> GraphStoreInfo: compute_cluster_master_ip = cluster_master_ip storage_cluster_master_ip = get_internal_ip_from_node(node_rank=num_compute_nodes) + # Cluster master is by convention rank 0. + cluster_master_rank = 0 ( cluster_master_port, compute_cluster_master_port, - ) = get_free_ports_from_node(num_ports=2, node_rank=0) + ) = get_free_ports_from_node(num_ports=2, node_rank=cluster_master_rank) + + # Since we structure the cluster as [compute0, ..., computeN, storage0, ..., storageN], the storage master is the first storage node. + # And it's rank is the number of compute nodes. + storage_master_rank = num_compute_nodes ( storage_cluster_master_port, storage_rpc_port, storage_rpc_wait_port, - ) = get_free_ports_from_node(num_ports=3, node_rank=num_compute_nodes) + ) = get_free_ports_from_node(num_ports=3, node_rank=storage_master_rank) num_processes_per_compute = int( os.environ.get(COMPUTE_CLUSTER_LOCAL_WORLD_SIZE_ENV_KEY, "1") From ce30f6be6bb4556e92fba6ca6ab89fc56a6d851e Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Wed, 14 Jan 2026 23:47:08 +0000 Subject: [PATCH 09/16] fix --- python/gigl/distributed/graph_store/storage_main.py | 1 - 1 file changed, 1 deletion(-) diff --git a/python/gigl/distributed/graph_store/storage_main.py b/python/gigl/distributed/graph_store/storage_main.py index 3dc8040a5..09a0d79cd 100644 --- a/python/gigl/distributed/graph_store/storage_main.py +++ b/python/gigl/distributed/graph_store/storage_main.py @@ -133,7 +133,6 @@ def storage_node_process( args = parser.parse_args() logger.info(f"Running storage node with arguments: {args}") - is_inference = args.is_inference torch.distributed.init_process_group(backend="gloo") cluster_info = get_graph_store_info() logger.info(f"Cluster info: {cluster_info}") From 76c486510287b6b5dc5b15e8145dc68d8f91a512 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Thu, 15 Jan 2026 18:33:08 +0000 Subject: [PATCH 10/16] update comments --- .../configs/example_resource_config.yaml | 58 +++++++++++++++ .../graph_store/homogeneous_inference.py | 70 +++++++++++++++++-- 2 files changed, 122 insertions(+), 6 deletions(-) create mode 100644 examples/link_prediction/graph_store/configs/example_resource_config.yaml diff --git a/examples/link_prediction/graph_store/configs/example_resource_config.yaml b/examples/link_prediction/graph_store/configs/example_resource_config.yaml new file mode 100644 index 000000000..4b9c3e629 --- /dev/null +++ b/examples/link_prediction/graph_store/configs/example_resource_config.yaml @@ -0,0 +1,58 @@ +shared_resource_config: + resource_labels: + # These are compute labels that we will try to attach to the resources created by GiGL components. + # More information: https://cloud.google.com/compute/docs/labeling-resources. + # These can be mostly used to get finer grained cost reporting through GCP billing on individual component + # and pipeline costs. + + # If COMPONENT is provided in cost_resource_group_tag, it will be automatically be replaced with one of + # {pre|sgs|spl|tra|inf|pos} standing for: {Preprocessor | Subgraph Sampler | Split Generator | Trainer + # | Inference | Post Processor} so we can get more accurate cost measurements of each component. + # See implementation: + # `python/gigl/src/common/types/pb_wrappers/gigl_resource_config.py#GiglResourceConfigWrapper.get_resource_labels` + + cost_resource_group_tag: dev_experiments_COMPONENT + cost_resource_group: gigl_platform + common_compute_config: + project: "USER_PROVIDED_PROJECT" + region: "us-central1" + # We recommend using the same bucket for temp_assets_bucket and temp_regional_assets_bucket + # These fields will get combined into one in the future. Note: Usually storage for regional buckets is cheaper, + # thus that is recommended. + temp_assets_bucket: "gs://USER_PROVIDED_TEMP_ASSETS_BUCKET" + temp_regional_assets_bucket: "gs://USER_PROVIDED_TEMP_ASSETS_BUCKET" + perm_assets_bucket: "gs://USER_PROVIDED_PERM_ASSETS_BUCKET" + temp_assets_bq_dataset_name: "gigl_temp_assets" + embedding_bq_dataset_name: "gigl_embeddings" + gcp_service_account_email: "USER_PROVIDED_SA@USER_PROVIDED_PROJECT.iam.gserviceaccount.com" + dataflow_runner: "DataflowRunner" +preprocessor_config: + edge_preprocessor_config: + num_workers: 1 + max_num_workers: 4 + machine_type: "n2-standard-16" + disk_size_gb: 300 + node_preprocessor_config: + num_workers: 1 + max_num_workers: 4 + machine_type: "n2-standard-16" + disk_size_gb: 300 +# TODO(kmonte): Update +trainer_resource_config: + vertex_ai_trainer_config: + machine_type: n1-standard-16 + gpu_type: NVIDIA_TESLA_T4 + gpu_limit: 2 + num_replicas: 2 +inferencer_resource_config: + vertex_ai_graph_store_inferencer_config: + graph_store_pool: + machine_type: n2-highmem-32 + gpu_type: ACCELERATOR_TYPE_UNSPECIFIED + gpu_limit: 0 + num_replicas: 2 + compute_pool: + machine_type: n1-standard-16 + gpu_type: NVIDIA_TESLA_T4 + gpu_limit: 2 + num_replicas: 2 diff --git a/examples/link_prediction/graph_store/homogeneous_inference.py b/examples/link_prediction/graph_store/homogeneous_inference.py index b6c8dbc97..91cde8fc7 100644 --- a/examples/link_prediction/graph_store/homogeneous_inference.py +++ b/examples/link_prediction/graph_store/homogeneous_inference.py @@ -1,9 +1,64 @@ """ -This file contains an example for how to run homogeneous inference on pretrained torch.nn.Module in GiGL (or elsewhere) using new -GLT (GraphLearn-for-PyTorch) bindings that GiGL has. Note that example should be applied to use cases which already have -some pretrained `nn.Module` and are looking to utilize cost-savings with distributed inference. While `run_example_inference` is coupled with -GiGL orchestration, the `_inference_process` function is generic and can be used as references -for writing inference for pipelines not dependent on GiGL orchestration. +This file contains an example for how to run homogeneous inference in **graph store mode** using GiGL. + +Graph Store Mode vs Standard Mode: +---------------------------------- +Graph store mode uses a heterogeneous cluster architecture with two distinct sub-clusters: + 1. **Storage Cluster (graph_store_pool)**: Dedicated machines for storing and serving the graph + data. These are typically high-memory machines without GPUs (e.g., n2-highmem-32). + 2. **Compute Cluster (compute_pool)**: Dedicated machines for running model inference/training. + These typically have GPUs attached (e.g., n1-standard-16 with NVIDIA_TESLA_T4). + +This separation allows for: + - Independent scaling of storage and compute resources + - Better memory utilization (graph data stays on storage nodes) + - Cost optimization by using appropriate hardware for each role + +In contrast, the standard inference mode (see `examples/link_prediction/homogeneous_inference.py`) +uses a homogeneous cluster where each machine handles both graph storage and computation. + +Key Implementation Differences: +------------------------------- +This file (graph store mode): + - Uses `RemoteDistDataset` to connect to a remote graph store cluster + - Uses `init_compute_process` to initialize the compute node connection to storage + - Obtains cluster topology via `get_graph_store_info()` which returns `GraphStoreInfo` + - Uses `mp_sharing_dict` for efficient tensor sharing between local processes + +Standard mode (`homogeneous_inference.py`): + - Uses `DistDataset` with `build_dataset_from_task_config_uri` where each node loads its partition + - Manually manages distributed process groups with master IP and port + - Each machine stores its own partition of the graph data + +Resource Configuration: +----------------------- +Graph store mode requires a different resource config structure. Compare: + +**Graph Store Mode** (e2e_glt_gs_resource_config.yaml): +```yaml +inferencer_resource_config: + vertex_ai_graph_store_inferencer_config: + graph_store_pool: + machine_type: n2-highmem-32 # High memory for graph storage + gpu_type: ACCELERATOR_TYPE_UNSPECIFIED + gpu_limit: 0 + num_replicas: 2 + compute_pool: + machine_type: n1-standard-16 # Standard machines with GPUs + gpu_type: NVIDIA_TESLA_T4 + gpu_limit: 2 + num_replicas: 2 +``` + +**Standard Mode** (e2e_glt_resource_config.yaml): +```yaml +inferencer_resource_config: + vertex_ai_inferencer_config: + machine_type: n1-highmem-32 + gpu_type: NVIDIA_TESLA_T4 + gpu_limit: 2 + num_replicas: 2 +``` To run this file with GiGL orchestration, set the fields similar to below: @@ -12,10 +67,13 @@ # Example argument to inferencer log_every_n_batch: "50" inferenceBatchSize: 512 - command: python -m examples.link_prediction.homogeneous_inference + command: python -m examples.link_prediction.graph_store.homogeneous_inference featureFlags: should_run_glt_backend: 'True' +Note: Ensure you use a resource config with `vertex_ai_graph_store_inferencer_config` when +running in graph store mode. + You can run this example in a full pipeline with `make run_hom_cora_sup_test` from GiGL root. """ From fdd38a425b00af76ff59be9da3308d6824a14d56 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Thu, 15 Jan 2026 18:41:07 +0000 Subject: [PATCH 11/16] fixes --- .../link_prediction/homogeneous_training.py | 24 +-------------- testing/e2e_tests/e2e_tests.yaml | 30 +++++++++---------- 2 files changed, 16 insertions(+), 38 deletions(-) diff --git a/examples/link_prediction/homogeneous_training.py b/examples/link_prediction/homogeneous_training.py index 6c7887080..85c8c4e48 100644 --- a/examples/link_prediction/homogeneous_training.py +++ b/examples/link_prediction/homogeneous_training.py @@ -26,7 +26,6 @@ import statistics import time from collections.abc import Iterator -import sys from typing import Literal, Optional import torch @@ -320,10 +319,7 @@ def _training_process( logger.info(f"---Rank {rank} training process set device {device}") logger.info(f"---Rank {rank} training process group initialized") - logger.info(f"graph: {dataset.graph=}") - logger.info(f"Node IDs: {dataset.node_ids=}") - logger.info(f"Num neighbors: {num_neighbors=}") - sys.stdout.flush() + loss_fn = RetrievalLoss( loss=torch.nn.CrossEntropyLoss(reduction="mean"), temperature=0.07, @@ -713,24 +709,6 @@ def _run_example_training( num_val_batches={num_val_batches}, \ val_every_n_batch={val_every_n_batch}" ) - print( - f"Got training args local_world_size={local_world_size}, \ - num_neighbors={num_neighbors}, \ - sampling_workers_per_process={sampling_workers_per_process}, \ - main_batch_size={main_batch_size}, \ - random_batch_size={random_batch_size}, \ - hid_dim={hid_dim}, \ - out_dim={out_dim}, \ - sampling_worker_shared_channel_size={sampling_worker_shared_channel_size}, \ - process_start_gap_seconds={process_start_gap_seconds}, \ - log_every_n_batch={log_every_n_batch}, \ - learning_rate={learning_rate}, \ - weight_decay={weight_decay}, \ - num_max_train_batches={num_max_train_batches}, \ - num_val_batches={num_val_batches}, \ - val_every_n_batch={val_every_n_batch}" - ) - # This `init_process_group` is only called to get the master_ip_address, master port, and rank/world_size fields which help with partitioning, sampling, # and distributed training/testing. We can use `gloo` here since these fields we are extracting don't require GPU capabilities provided by `nccl`. diff --git a/testing/e2e_tests/e2e_tests.yaml b/testing/e2e_tests/e2e_tests.yaml index f65f8944d..b2d8517bf 100644 --- a/testing/e2e_tests/e2e_tests.yaml +++ b/testing/e2e_tests/e2e_tests.yaml @@ -1,24 +1,24 @@ # Combined e2e test configurations for GiGL # This file contains all the test specifications that can be run via the e2e test script tests: - # cora_nalp_test: - # task_config_uri: "python/gigl/src/mocking/configs/e2e_node_anchor_based_link_prediction_template_gbml_config.yaml" - # resource_config_uri: "${oc.env:GIGL_TEST_DEFAULT_RESOURCE_CONFIG,deployment/configs/e2e_cicd_resource_config.yaml}" - # cora_snc_test: - # task_config_uri: "python/gigl/src/mocking/configs/e2e_supervised_node_classification_template_gbml_config.yaml" - # resource_config_uri: "${oc.env:GIGL_TEST_DEFAULT_RESOURCE_CONFIG,deployment/configs/e2e_cicd_resource_config.yaml}" - # cora_udl_test: - # task_config_uri: "python/gigl/src/mocking/configs/e2e_udl_node_anchor_based_link_prediction_template_gbml_config.yaml" - # resource_config_uri: "${oc.env:GIGL_TEST_DEFAULT_RESOURCE_CONFIG,deployment/configs/e2e_cicd_resource_config.yaml}" - # dblp_nalp_test: - # task_config_uri: "python/gigl/src/mocking/configs/dblp_node_anchor_based_link_prediction_template_gbml_config.yaml" - # resource_config_uri: "${oc.env:GIGL_TEST_DEFAULT_RESOURCE_CONFIG,deployment/configs/e2e_cicd_resource_config.yaml}" + cora_nalp_test: + task_config_uri: "python/gigl/src/mocking/configs/e2e_node_anchor_based_link_prediction_template_gbml_config.yaml" + resource_config_uri: "${oc.env:GIGL_TEST_DEFAULT_RESOURCE_CONFIG,deployment/configs/e2e_cicd_resource_config.yaml}" + cora_snc_test: + task_config_uri: "python/gigl/src/mocking/configs/e2e_supervised_node_classification_template_gbml_config.yaml" + resource_config_uri: "${oc.env:GIGL_TEST_DEFAULT_RESOURCE_CONFIG,deployment/configs/e2e_cicd_resource_config.yaml}" + cora_udl_test: + task_config_uri: "python/gigl/src/mocking/configs/e2e_udl_node_anchor_based_link_prediction_template_gbml_config.yaml" + resource_config_uri: "${oc.env:GIGL_TEST_DEFAULT_RESOURCE_CONFIG,deployment/configs/e2e_cicd_resource_config.yaml}" + dblp_nalp_test: + task_config_uri: "python/gigl/src/mocking/configs/dblp_node_anchor_based_link_prediction_template_gbml_config.yaml" + resource_config_uri: "${oc.env:GIGL_TEST_DEFAULT_RESOURCE_CONFIG,deployment/configs/e2e_cicd_resource_config.yaml}" hom_cora_sup_test: task_config_uri: "examples/link_prediction/configs/e2e_hom_cora_sup_task_config.yaml" resource_config_uri: "${oc.env:GIGL_TEST_IN_MEMORY_DEFAULT_RESOURCE_CONFIG,deployment/configs/e2e_glt_resource_config.yaml}" - # het_dblp_sup_test: - # task_config_uri: "examples/link_prediction/configs/e2e_het_dblp_sup_task_config.yaml" - # resource_config_uri: "${oc.env:GIGL_TEST_IN_MEMORY_DEFAULT_RESOURCE_CONFIG,deployment/configs/e2e_glt_resource_config.yaml}" + het_dblp_sup_test: + task_config_uri: "examples/link_prediction/configs/e2e_het_dblp_sup_task_config.yaml" + resource_config_uri: "${oc.env:GIGL_TEST_IN_MEMORY_DEFAULT_RESOURCE_CONFIG,deployment/configs/e2e_glt_resource_config.yaml}" hom_cora_sup_gs_test: task_config_uri: "examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_task_config.yaml" resource_config_uri: "${oc.env:GIGL_TEST_IN_MEMORY_DEFAULT_GRAPH_STORE_RESOURCE_CONFIG,deployment/configs/e2e_glt_gs_resource_config.yaml}" From d5edc61f7ef1acf81599c3710713d5fee30f9c73 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Tue, 20 Jan 2026 18:08:14 +0000 Subject: [PATCH 12/16] address comments --- Makefile | 8 +++++ .../configs/e2e_glt_gs_resource_config.yaml | 2 ++ .../e2e_hom_cora_sup_gs_task_config.yaml | 5 ++- .../configs/example_resource_config.yaml | 9 +++++ .../graph_store/homogeneous_inference.py | 33 +++++++++++++++---- 5 files changed, 50 insertions(+), 7 deletions(-) diff --git a/Makefile b/Makefile index 423f3302a..e8914b4f8 100644 --- a/Makefile +++ b/Makefile @@ -237,6 +237,14 @@ run_het_dblp_sup_e2e_test: --test_spec_uri="testing/e2e_tests/e2e_tests.yaml" \ --test_names="het_dblp_sup_test" +run_hom_cora_sup_gs_e2e_test: compiled_pipeline_path:=${GIGL_E2E_TEST_COMPILED_PIPELINE_PATH} +run_hom_cora_sup_gs_e2e_test: compile_gigl_kubeflow_pipeline +run_hom_cora_sup_gs_e2e_test: + uv run python testing/e2e_tests/e2e_test.py \ + --compiled_pipeline_path=$(compiled_pipeline_path) \ + --test_spec_uri="testing/e2e_tests/e2e_tests.yaml" \ + --test_names="hom_cora_sup_gs_test" + run_all_e2e_tests: compiled_pipeline_path:=${GIGL_E2E_TEST_COMPILED_PIPELINE_PATH} run_all_e2e_tests: compile_gigl_kubeflow_pipeline run_all_e2e_tests: diff --git a/deployment/configs/e2e_glt_gs_resource_config.yaml b/deployment/configs/e2e_glt_gs_resource_config.yaml index 0bab9f54f..097cf60d9 100644 --- a/deployment/configs/e2e_glt_gs_resource_config.yaml +++ b/deployment/configs/e2e_glt_gs_resource_config.yaml @@ -1,3 +1,5 @@ +# Diffs from e2e_glt_resource_config.yaml +# - Swap vertex_ai_inferencer_config for vertex_ai_graph_store_inferencer_config shared_resource_config: resource_labels: cost_resource_group_tag: dev_experiments_COMPONENT diff --git a/examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_task_config.yaml b/examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_task_config.yaml index 812d9e408..6cf4bdeea 100644 --- a/examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_task_config.yaml +++ b/examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_task_config.yaml @@ -1,4 +1,7 @@ -# This config is used to run homogeneous CORA supervised training and inference using in memory GiGL SGS. This can be run with `make run_hom_cora_sup_test`. +# This config is used to run homogeneous CORA supervised training and inference using in memory GiGL SGS using the Graph Store mode. +# This can be run with `make run_hom_cora_sup_gs_test`. +# Diffs from ../configs/e2e_hom_cora_sup_task_config.yaml: +# - None (currently) - we detect that "Graph Store" mode should be employed from the resource config graphMetadata: edgeTypes: - dstNodeType: paper diff --git a/examples/link_prediction/graph_store/configs/example_resource_config.yaml b/examples/link_prediction/graph_store/configs/example_resource_config.yaml index 4b9c3e629..c08b379df 100644 --- a/examples/link_prediction/graph_store/configs/example_resource_config.yaml +++ b/examples/link_prediction/graph_store/configs/example_resource_config.yaml @@ -1,3 +1,12 @@ +# Example resource config for graph store mode +# Diffs from ../configs/example_resource_config.yaml: +# - Swap vertex_ai_inferencer_config for vertex_ai_graph_store_inferencer_config +# You should swap out the following fields with your own resources: +# - project: "USER_PROVIDED_PROJECT" +# - temp_assets_bucket: "gs://USER_PROVIDED_TEMP_ASSETS_BUCKET" +# - temp_regional_assets_bucket: "gs://USER_PROVIDED_TEMP_ASSETS_BUCKET" +# - perm_assets_bucket: "gs://USER_PROVIDED_PERM_ASSETS_BUCKET" +# - gcp_service_account_email: "USER_PROVIDED_SA@USER_PROVIDED_PROJECT.iam.gserviceaccount.com" shared_resource_config: resource_labels: # These are compute labels that we will try to attach to the resources created by GiGL components. diff --git a/examples/link_prediction/graph_store/homogeneous_inference.py b/examples/link_prediction/graph_store/homogeneous_inference.py index 91cde8fc7..f5ba6e846 100644 --- a/examples/link_prediction/graph_store/homogeneous_inference.py +++ b/examples/link_prediction/graph_store/homogeneous_inference.py @@ -144,7 +144,7 @@ def _inference_process( local_rank (int): Process number on the current machine local_world_size (int): Number of inference processes spawned by each machine distributed_context (DistributedContext): Distributed context containing information for master_ip_address, rank, and world size - embedding_gcs_path (GcsUri): GCS path to load embeddings from + embedding_gcs_path (GcsUri): GCS path to write embeddings to model_state_dict_uri (GcsUri): GCS path to load model from inference_batch_size (int): Batch size to use for inference hid_dim (int): Hidden dimension of the model @@ -188,6 +188,8 @@ def _inference_process( log_every_n_batch = int(inferencer_args.get("log_every_n_batch", "50")) + # Note: This is a *critical* step in Graph Store mode. It initializes the connection to the storage cluster. + # If this is not done, the dataloader will not be able to sample from the graph store and will crash. init_compute_process(local_rank, cluster_info) dataset = RemoteDistDataset( cluster_info, local_rank, mp_sharing_dict=mp_sharing_dict @@ -204,6 +206,8 @@ def _inference_process( logger.info( f"Rank {rank} got input nodes of shapes: {[node.shape for node in input_nodes]}" ) + # We don't see logs for graph store mode for whatever reason. + # TOOD(#442): Revert this once the GCP issues are resolved. sys.stdout.flush() data_loader = gigl.distributed.DistNeighborLoader( @@ -260,9 +264,11 @@ def _inference_process( # to only flushing when flush_embeddings() is called explicitly or after exiting via a context manager. exporter = EmbeddingExporter(export_dir=gcs_base_uri) + # We don't see logs for graph store mode for whatever reason. + # TOOD(#442): Revert this once the GCP issues are resolved. + sys.stdout.flush() # We add a barrier here so that all machines and processes have initialized their dataloader at the start of the inference loop. Otherwise, on-the-fly subgraph # sampling may fail. - sys.stdout.flush() torch.distributed.barrier() t = time.time() @@ -305,6 +311,8 @@ def _inference_process( f"Among them, data loading took {cumulative_data_loading_time:.2f} seconds " f"and model inference took {cumulative_inference_time:.2f} seconds." ) + # We don't see logs for graph store mode for whatever reason. + # TOOD(#442): Revert this once the GCP issues are resolved. sys.stdout.flush() t = time.time() cumulative_data_loading_time = 0 @@ -324,6 +332,8 @@ def _inference_process( logger.info( f"--- Rank {rank} wrote embeddings to GCS at {gcs_base_uri} over batches" ) + # We don't see logs for graph store mode for whatever reason. + # TOOD(#442): Revert this once the GCP issues are resolved. sys.stdout.flush() # We first call barrier to ensure that all machines and processes have finished inference. # Only once all machines have finished inference is it safe to shutdown the data loader. @@ -358,6 +368,8 @@ def _inference_process( table_id=bq_table_name, ) + # We don't see logs for graph store mode for whatever reason. + # TOOD(#442): Revert this once the GCP issues are resolved. sys.stdout.flush() @@ -401,9 +413,6 @@ def _run_example_inference( output_bq_table_path = InferenceAssets.get_enumerated_embedding_table_path( gbml_config_pb_wrapper, graph_metadata.homogeneous_node_type ) - bq_project_id, bq_dataset_id, bq_table_name = BqUtils.parse_bq_table_path( - bq_table_path=output_bq_table_path - ) # We write embeddings to a temporary GCS path during the inference loop, since writing directly to bigquery for each embedding is slow. # After inference has finished, we then load all embeddings to bigquery from GCS. embedding_output_gcs_folder = InferenceAssets.get_gcs_asset_write_path_prefix( @@ -423,7 +432,6 @@ def _run_example_inference( hid_dim = int(inferencer_args.get("hid_dim", "16")) out_dim = int(inferencer_args.get("out_dim", "16")) - local_world_size: int arg_local_world_size = inferencer_args.get("local_world_size") if arg_local_world_size is not None: local_world_size = int(arg_local_world_size) @@ -449,8 +457,21 @@ def _run_example_inference( local_world_size = DEFAULT_CPU_BASED_LOCAL_WORLD_SIZE mp_sharing_dict = mp.Manager().dict() + if torch.distributed.get_rank() == 0: + gcs_utils = GcsUtils() + num_files_at_gcs_path = gcs_utils.count_blobs_in_gcs_path( + embedding_output_gcs_folder + ) + if num_files_at_gcs_path > 0: + logger.warning( + f"{num_files_at_gcs_path} files already detected at base gcs path {embedding_output_gcs_folder}. Cleaning up files at path ... " + ) + gcs_utils.delete_files_in_bucket_dir(embedding_output_gcs_folder) + torch.distributed.barrier() inference_start_time = time.time() + # We don't see logs for graph store mode for whatever reason. + # TOOD(#442): Revert this once the GCP issues are resolved. sys.stdout.flush() # When using mp.spawn with `nprocs`, the first argument is implicitly set to be the process number on the current machine. mp.spawn( From e9f0cfaea6ef3de0077f68ebee94a76b0e5cbeb9 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Tue, 20 Jan 2026 19:40:29 +0000 Subject: [PATCH 13/16] use cluster info --- examples/link_prediction/graph_store/homogeneous_inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/link_prediction/graph_store/homogeneous_inference.py b/examples/link_prediction/graph_store/homogeneous_inference.py index f5ba6e846..8f7ee58f1 100644 --- a/examples/link_prediction/graph_store/homogeneous_inference.py +++ b/examples/link_prediction/graph_store/homogeneous_inference.py @@ -457,7 +457,7 @@ def _run_example_inference( local_world_size = DEFAULT_CPU_BASED_LOCAL_WORLD_SIZE mp_sharing_dict = mp.Manager().dict() - if torch.distributed.get_rank() == 0: + if cluster_info.compute_node_rank == 0: gcs_utils = GcsUtils() num_files_at_gcs_path = gcs_utils.count_blobs_in_gcs_path( embedding_output_gcs_folder From db29e748dd094cc7486244eeb5229bca59b6c479 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Tue, 20 Jan 2026 21:13:35 +0000 Subject: [PATCH 14/16] more pg --- examples/link_prediction/graph_store/homogeneous_inference.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/examples/link_prediction/graph_store/homogeneous_inference.py b/examples/link_prediction/graph_store/homogeneous_inference.py index 8f7ee58f1..8542dd2a8 100644 --- a/examples/link_prediction/graph_store/homogeneous_inference.py +++ b/examples/link_prediction/graph_store/homogeneous_inference.py @@ -401,7 +401,6 @@ def _run_example_inference( cluster_info = get_graph_store_info() logger.info(f"Cluster info: {cluster_info}") - torch.distributed.destroy_process_group() # Read from GbmlConfig for preprocessed data metadata, GNN model uri, and bigquery embedding table path, and additional inference args gbml_config_pb_wrapper = GbmlConfigPbWrapper.get_gbml_config_pb_wrapper_from_uri( gbml_config_uri=UriFactory.create_uri(task_config_uri) @@ -468,6 +467,7 @@ def _run_example_inference( ) gcs_utils.delete_files_in_bucket_dir(embedding_output_gcs_folder) torch.distributed.barrier() + torch.distributed.destroy_process_group() inference_start_time = time.time() # We don't see logs for graph store mode for whatever reason. From 56d416fa82df9e7af883be2c798402369b9a61bf Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Tue, 20 Jan 2026 22:49:26 +0000 Subject: [PATCH 15/16] comments --- .../link_prediction/graph_store/homogeneous_inference.py | 7 ++----- 1 file changed, 2 insertions(+), 5 deletions(-) diff --git a/examples/link_prediction/graph_store/homogeneous_inference.py b/examples/link_prediction/graph_store/homogeneous_inference.py index 8542dd2a8..2ca9d06c6 100644 --- a/examples/link_prediction/graph_store/homogeneous_inference.py +++ b/examples/link_prediction/graph_store/homogeneous_inference.py @@ -74,7 +74,7 @@ Note: Ensure you use a resource config with `vertex_ai_graph_store_inferencer_config` when running in graph store mode. -You can run this example in a full pipeline with `make run_hom_cora_sup_test` from GiGL root. +You can run this example in a full pipeline with `make run_hom_cora_sup_gs_test` from GiGL root. """ import argparse @@ -200,9 +200,6 @@ def _inference_process( f"Local rank {local_rank} in machine {cluster_info.compute_node_rank} has rank {rank}/{world_size} and using device {device} for inference" ) input_nodes = dataset.get_node_ids() - print( - f"input_nodes: {[(node.shape, f'{node[0],node[-1]}') for node in input_nodes]}" - ) logger.info( f"Rank {rank} got input nodes of shapes: {[node.shape for node in input_nodes]}" ) @@ -215,7 +212,7 @@ def _inference_process( num_neighbors=num_neighbors, local_process_rank=local_rank, local_process_world_size=local_world_size, - input_nodes=input_nodes, # Since homogeneous, `None` defaults to using all nodes for inference loop + input_nodes=input_nodes, # Since homogeneous, num_workers=sampling_workers_per_inference_process, batch_size=inference_batch_size, pin_memory_device=device, From 7173b470f27b4cf3a7e1cdfe712c5afb4f2c29b9 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Wed, 21 Jan 2026 00:30:57 +0000 Subject: [PATCH 16/16] don't need barrier --- examples/link_prediction/graph_store/homogeneous_inference.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/link_prediction/graph_store/homogeneous_inference.py b/examples/link_prediction/graph_store/homogeneous_inference.py index 2ca9d06c6..0e7c8a1e2 100644 --- a/examples/link_prediction/graph_store/homogeneous_inference.py +++ b/examples/link_prediction/graph_store/homogeneous_inference.py @@ -398,6 +398,8 @@ def _run_example_inference( cluster_info = get_graph_store_info() logger.info(f"Cluster info: {cluster_info}") + torch.distributed.destroy_process_group() + # Read from GbmlConfig for preprocessed data metadata, GNN model uri, and bigquery embedding table path, and additional inference args gbml_config_pb_wrapper = GbmlConfigPbWrapper.get_gbml_config_pb_wrapper_from_uri( gbml_config_uri=UriFactory.create_uri(task_config_uri) @@ -463,8 +465,6 @@ def _run_example_inference( f"{num_files_at_gcs_path} files already detected at base gcs path {embedding_output_gcs_folder}. Cleaning up files at path ... " ) gcs_utils.delete_files_in_bucket_dir(embedding_output_gcs_folder) - torch.distributed.barrier() - torch.distributed.destroy_process_group() inference_start_time = time.time() # We don't see logs for graph store mode for whatever reason.