diff --git a/gigl/distributed/dist_ablp_neighborloader.py b/gigl/distributed/dist_ablp_neighborloader.py index 52d4c782f..f20c2e14c 100644 --- a/gigl/distributed/dist_ablp_neighborloader.py +++ b/gigl/distributed/dist_ablp_neighborloader.py @@ -1,14 +1,17 @@ import ast +import concurrent.futures import time from collections import Counter, abc, defaultdict from typing import Optional, Union import torch -from graphlearn_torch.channel import SampleMessage, ShmChannel +from graphlearn_torch.channel import RemoteReceivingChannel, SampleMessage, ShmChannel from graphlearn_torch.distributed import ( DistLoader, MpDistSamplingWorkerOptions, + RemoteDistSamplingWorkerOptions, get_context, + request_server, ) from graphlearn_torch.sampler import SamplingConfig, SamplingType from graphlearn_torch.utils import reverse_edge_type @@ -21,7 +24,9 @@ from gigl.distributed.dist_context import DistributedContext from gigl.distributed.dist_dataset import DistDataset from gigl.distributed.dist_sampling_producer import DistSamplingProducer +from gigl.distributed.dist_server import DistServer from gigl.distributed.distributed_neighborloader import DEFAULT_NUM_CPU_THREADS +from gigl.distributed.graph_store.remote_dist_dataset import RemoteDistDataset from gigl.distributed.sampler import ( NEGATIVE_LABEL_METADATA_KEY, POSITIVE_LABEL_METADATA_KEY, @@ -29,6 +34,8 @@ metadata_key_with_prefix, ) from gigl.distributed.utils.neighborloader import ( + DatasetSchema, + SamplingClusterSetup, labeled_to_homogeneous, patch_fanout_for_sampling, set_missing_features, @@ -53,12 +60,20 @@ class DistABLPLoader(DistLoader): def __init__( self, - dataset: DistDataset, + dataset: Union[DistDataset, RemoteDistDataset], num_neighbors: Union[list[int], dict[EdgeType, list[int]]], input_nodes: Optional[ Union[ torch.Tensor, tuple[NodeType, torch.Tensor], + # Graph Store mode inputs + dict[int, tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]], + tuple[ + NodeType, + dict[ + int, tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] + ], + ], ] ] = None, supervision_edge_type: Optional[Union[EdgeType, list[EdgeType]]] = None, @@ -123,24 +138,29 @@ def __init__( - `y_negative`: {(a, to, b): {0: torch.tensor([3])}, (a, to, c): {0: torch.tensor([4])}} Args: - dataset (DistDataset): The dataset to sample from. + dataset (Union[DistDataset, RemoteDistDataset]): The dataset to sample from. + If this is a `RemoteDistDataset`, then we are in "Graph Store" mode. num_neighbors (list[int] or dict[tuple[str, str, str], list[int]]): The number of neighbors to sample for each node in each iteration. If an entry is set to `-1`, all neighbors will be included. In heterogeneous graphs, may also take in a dictionary denoting the amount of neighbors to sample for each individual edge type. - context (DistributedContext): Distributed context information of the current process. - input_nodes (Optional[torch.Tensor, tuple[NodeType, torch.Tensor]]): - Indices of seed nodes to start sampling from. - If set to `None` for homogeneous settings, all nodes will be considered. - In heterogeneous graphs, this flag must be passed in as a tuple that holds - the node type and node indices. (default: `None`) + input_nodes: Indices of seed nodes to start sampling from. + For Colocated mode: `torch.Tensor` or `tuple[NodeType, torch.Tensor]`. + If set to `None` for homogeneous settings, all nodes will be considered. + In heterogeneous graphs, this flag must be passed in as a tuple that holds + the node type and node indices. + For Graph Store mode: `dict[int, tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]]` + or `tuple[NodeType, dict[int, tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]]]`. + The dict maps server_rank to (anchor_nodes, positive_labels, negative_labels). + This is the return type of `RemoteDistDataset.get_ablp_input()`. supervision_edge_type (Optional[Union[EdgeType, list[EdgeType]]]): The edge type(s) to use for supervision. Must be None iff the dataset is labeled homogeneous. If set to a single EdgeType, the positive and negative labels will be stored in the `y_positive` and `y_negative` fields of the Data object. If set to a list of EdgeTypes, the positive and negative labels will be stored in the `y_positive` and `y_negative` fields of the Data object, with the key being the EdgeType. (default: `None`) + NOTE: Graph Store mode currently only supports a single supervision edge type. num_workers (int): How many workers to use (subprocesses to spwan) for distributed neighbor sampling of the current process. (default: ``1``). batch_size (int, optional): how many samples per batch to load @@ -187,6 +207,13 @@ def __init__( master_ip_address: str should_cleanup_distributed_context: bool = False + # Determine sampling cluster setup based on dataset type + if isinstance(dataset, RemoteDistDataset): + self._sampling_cluster_setup = SamplingClusterSetup.GRAPH_STORE + else: + self._sampling_cluster_setup = SamplingClusterSetup.COLOCATED + logger.info(f"Sampling cluster setup: {self._sampling_cluster_setup.value}") + if supervision_edge_type is None: self._supervision_edge_types: list[EdgeType] = [ DEFAULT_HOMOGENEOUS_EDGE_TYPE @@ -199,6 +226,15 @@ def __init__( self._supervision_edge_types = supervision_edge_type else: self._supervision_edge_types = [supervision_edge_type] + + # TODO(kmonte): Support multiple supervision edge types in Graph Store mode + if self._sampling_cluster_setup == SamplingClusterSetup.GRAPH_STORE: + if len(self._supervision_edge_types) > 1: + raise ValueError( + "Graph Store mode currently only supports a single supervision edge type. " + f"Received {len(self._supervision_edge_types)} edge types: {self._supervision_edge_types}" + ) + del supervision_edge_type if context: @@ -262,29 +298,295 @@ def __init__( local_process_world_size, ) # delete deprecated vars so we don't accidentally use them. + device = ( + pin_memory_device + if pin_memory_device + else gigl.distributed.utils.get_available_device( + local_process_rank=local_rank + ) + ) + self.to_device = device + + # Call appropriate setup method based on sampling cluster setup + if self._sampling_cluster_setup == SamplingClusterSetup.COLOCATED: + assert isinstance( + dataset, DistDataset + ), "When using colocated mode, dataset must be a DistDataset." + # Validate input_nodes type for colocated mode + if isinstance(input_nodes, abc.Mapping) or ( + isinstance(input_nodes, tuple) + and isinstance(input_nodes[1], abc.Mapping) + ): + raise ValueError( + f"When using Colocated mode, input_nodes must be of type " + f"(torch.Tensor | tuple[NodeType, torch.Tensor] | None), " + f"received Graph Store format: {type(input_nodes)}" + ) + ( + sampler_input, + worker_options, + dataset_metadata, + ) = self._setup_for_colocated( + input_nodes=input_nodes, + dataset=dataset, + local_rank=local_rank, + local_world_size=local_world_size, + device=device, + master_ip_address=master_ip_address, + node_rank=node_rank, + node_world_size=node_world_size, + num_workers=num_workers, + worker_concurrency=worker_concurrency, + channel_size=channel_size, + num_cpu_threads=num_cpu_threads, + ) + else: # Graph Store mode + assert isinstance( + dataset, RemoteDistDataset + ), "When using Graph Store mode, dataset must be a RemoteDistDataset." + # Validate input_nodes type for Graph Store mode + if ( + input_nodes is None + or isinstance(input_nodes, torch.Tensor) + or ( + isinstance(input_nodes, tuple) + and isinstance(input_nodes[1], torch.Tensor) + ) + ): + raise ValueError( + f"When using Graph Store mode, input_nodes must be of type " + f"(dict[int, tuple[...]] | tuple[NodeType, dict[int, tuple[...]]]), " + f"received Colocated format: {type(input_nodes)}" + ) + ( + sampler_input, + worker_options, + dataset_metadata, + ) = self._setup_for_graph_store( + input_nodes=input_nodes, + dataset=dataset, + supervision_edge_type=self._supervision_edge_types[0], + num_workers=num_workers, + ) + + self._is_input_heterogeneous = dataset_metadata.is_labeled_heterogeneous + self._node_feature_info = dataset_metadata.node_feature_info + self._edge_feature_info = dataset_metadata.edge_feature_info + + num_neighbors = patch_fanout_for_sampling( + dataset_metadata.edge_types, num_neighbors + ) + + if should_cleanup_distributed_context and torch.distributed.is_initialized(): + logger.info( + f"Cleaning up process group as it was initialized inside {self.__class__.__name__}.__init__." + ) + torch.distributed.destroy_process_group() + + sampling_config = SamplingConfig( + sampling_type=SamplingType.NODE, + num_neighbors=num_neighbors, + batch_size=batch_size, + shuffle=shuffle, + drop_last=drop_last, + with_edge=True, + collect_features=True, + with_neg=False, + with_weight=False, + edge_dir=dataset_metadata.edge_dir, + seed=None, # it's actually optional - None means random. + ) + + if self._sampling_cluster_setup == SamplingClusterSetup.COLOCATED: + # Code below this point is taken from the GLT DistNeighborLoader.__init__() function + # (graphlearn_torch/python/distributed/dist_neighbor_loader.py). + # We do this so that we may override the DistSamplingProducer that is used with the GiGL implementation. + + # Type narrowing for colocated mode + + self.data = dataset + self.input_data = sampler_input[0] + del dataset, sampler_input + assert isinstance(self.data, DistDataset) + assert isinstance(self.input_data, ABLPNodeSamplerInput) + + self.sampling_type = sampling_config.sampling_type + self.num_neighbors = sampling_config.num_neighbors + self.batch_size = sampling_config.batch_size + self.shuffle = sampling_config.shuffle + self.drop_last = sampling_config.drop_last + self.with_edge = sampling_config.with_edge + self.with_weight = sampling_config.with_weight + self.collect_features = sampling_config.collect_features + self.edge_dir = sampling_config.edge_dir + self.sampling_config = sampling_config + self.worker_options = worker_options + + # We can set shutdowned to false now + self._shutdowned = False + + self._is_mp_worker = True + self._is_collocated_worker = False + self._is_remote_worker = False + + self.num_data_partitions = self.data.num_partitions + self.data_partition_idx = self.data.partition_idx + self._set_ntypes_and_etypes( + self.data.get_node_types(), self.data.get_edge_types() + ) + + self._num_recv = 0 + self._epoch = 0 + + current_ctx = get_context() + + self._input_len = len(self.input_data) + self._input_type = self.input_data.input_type + self._num_expected = self._input_len // self.batch_size + if not self.drop_last and self._input_len % self.batch_size != 0: + self._num_expected += 1 + + if not current_ctx.is_worker(): + raise RuntimeError( + f"'{self.__class__.__name__}': only supports " + f"launching multiprocessing sampling workers with " + f"a non-server distribution mode, current role of " + f"distributed context is {current_ctx.role}." + ) + if self.data is None: + raise ValueError( + f"'{self.__class__.__name__}': missing input dataset " + f"when launching multiprocessing sampling workers." + ) + + # Launch multiprocessing sampling workers + self._with_channel = True + self.worker_options._set_worker_ranks(current_ctx) + + self._channel = ShmChannel( + self.worker_options.channel_capacity, self.worker_options.channel_size + ) + if self.worker_options.pin_memory: + self._channel.pin_memory() + + self._mp_producer = DistSamplingProducer( + self.data, + self.input_data, + self.sampling_config, + self.worker_options, + self._channel, + ) + # When initiating data loader(s), there will be a spike of memory usage lasting for ~30s. + # The current hypothesis is making connections across machines require a lot of memory. + # If we start all data loaders in all processes simultaneously, the spike of memory + # usage will add up and cause CPU memory OOM. Hence, we initiate the data loaders group by group + # to smooth the memory usage. The definition of group is discussed in init_neighbor_loader_worker. + logger.info( + f"---Machine {rank} local process number {local_rank} preparing to sleep for {process_start_gap_seconds * local_rank} seconds" + ) + time.sleep(process_start_gap_seconds * local_rank) + self._mp_producer.init() + else: + # Graph Store mode - re-implement remote worker setup + # Use sequential initialization per compute node to avoid race conditions + # when initializing the samplers on the storage nodes. + node_rank = dataset.cluster_info.compute_node_rank + for target_node_rank in range(dataset.cluster_info.num_compute_nodes): + if node_rank == target_node_rank: + self._init_remote_worker( + dataset=dataset, + sampler_input=sampler_input, + sampling_config=sampling_config, + worker_options=worker_options, + dataset_metadata=dataset_metadata, + ) + logger.info(f"node_rank {node_rank} initialized the dist loader") + torch.distributed.barrier() + torch.distributed.barrier() + + def _setup_for_colocated( + self, + input_nodes: Optional[ + Union[ + torch.Tensor, + tuple[NodeType, torch.Tensor], + dict[int, tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]], + tuple[ + NodeType, + dict[ + int, tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] + ], + ], + ] + ], + dataset: DistDataset, + local_rank: int, + local_world_size: int, + device: torch.device, + master_ip_address: str, + node_rank: int, + node_world_size: int, + num_workers: int, + worker_concurrency: int, + channel_size: str, + num_cpu_threads: Optional[int], + ) -> tuple[list[ABLPNodeSamplerInput], MpDistSamplingWorkerOptions, DatasetSchema]: + """ + Setup method for colocated (non-Graph Store) mode. + + Args: + input_nodes: Input nodes for sampling (tensor or tuple of node type and tensor). + dataset: The DistDataset to sample from. + local_rank: Local rank of the current process. + local_world_size: Total number of processes on this machine. + device: Target device for sampled data. + master_ip_address: IP address of the master node. + node_rank: Rank of the current machine. + node_world_size: Total number of machines. + num_workers: Number of sampling workers. + worker_concurrency: Max sampling concurrency per worker. + channel_size: Size of shared memory channel. + num_cpu_threads: Number of CPU threads for PyTorch. + + Returns: + Tuple of (ABLPNodeSamplerInput, MpDistSamplingWorkerOptions, DatasetSchema). + """ + # Validate input format - should not be Graph Store format + if isinstance(input_nodes, abc.Mapping): + raise ValueError( + f"When using Colocated mode, input_nodes must be of type (torch.Tensor | tuple[NodeType, torch.Tensor]), " + f"received {type(input_nodes)}" + ) + elif isinstance(input_nodes, tuple) and isinstance(input_nodes[1], abc.Mapping): + raise ValueError( + f"When using Colocated mode, input_nodes must be of type (torch.Tensor | tuple[NodeType, torch.Tensor]), " + f"received tuple with second element of type {type(input_nodes[1])}" + ) + if not isinstance(dataset.graph, abc.Mapping): raise ValueError( - f"The dataset must be heterogeneous for ABLP. Recieved dataset with graph of type: {type(dataset.graph)}" + f"The dataset must be heterogeneous for ABLP. Received dataset with graph of type: {type(dataset.graph)}" ) - self._is_input_heterogeneous: bool = False + + is_labeled_heterogeneous: bool = False if isinstance(input_nodes, tuple): if self._supervision_edge_types == [DEFAULT_HOMOGENEOUS_EDGE_TYPE]: raise ValueError( "When using heterogeneous ABLP, you must provide supervision_edge_types." ) - self._is_input_heterogeneous = True + is_labeled_heterogeneous = True anchor_node_type, anchor_node_ids = input_nodes # TODO (mkolodner-sc): We currently assume supervision edges are directed outward, revisit in future if # this assumption is no longer valid and/or is too opinionated - for supervision_edge_type in self._supervision_edge_types: + for sup_edge_type in self._supervision_edge_types: assert ( - supervision_edge_type[0] == anchor_node_type + sup_edge_type[0] == anchor_node_type ), f"Label EdgeType are currently expected to be provided in outward edge direction as tuple (`anchor_node_type`,`relation`,`supervision_node_type`), \ - got supervision edge type {supervision_edge_type} with anchor node type {anchor_node_type}" + got supervision edge type {sup_edge_type} with anchor node type {anchor_node_type}" if dataset.edge_dir == "in": self._supervision_edge_types = [ - reverse_edge_type(supervision_edge_type) - for supervision_edge_type in self._supervision_edge_types + reverse_edge_type(sup_edge_type) + for sup_edge_type in self._supervision_edge_types ] elif isinstance(input_nodes, torch.Tensor): if self._supervision_edge_types != [DEFAULT_HOMOGENEOUS_EDGE_TYPE]: @@ -306,9 +608,10 @@ def __init__( raise ValueError( f"Expected supervision edge type to be None for homogeneous input nodes, got {self._supervision_edge_types}" ) - anchor_node_ids = dataset.node_ids anchor_node_type = DEFAULT_HOMOGENEOUS_NODE_TYPE + else: + raise ValueError(f"Unexpected input_nodes type: {type(input_nodes)}") missing_edge_types = set(self._supervision_edge_types) - set( dataset.graph.keys() @@ -318,6 +621,9 @@ def __init__( f"Missing edge types in dataset: {missing_edge_types}. Edge types in dataset: {dataset.graph.keys()}" ) + # Type narrowing - anchor_node_ids is always a Tensor in colocated mode + assert isinstance(anchor_node_ids, torch.Tensor) + if len(anchor_node_ids.shape) != 1: raise ValueError( f"input_nodes must be a 1D tensor, got {anchor_node_ids.shape}." @@ -333,11 +639,11 @@ def __init__( self._negative_label_edge_types: list[EdgeType] = [] positive_labels_by_label_edge_type: dict[EdgeType, torch.Tensor] = {} negative_labels_by_label_edge_type: dict[EdgeType, torch.Tensor] = {} - for supervision_edge_type in self._supervision_edge_types: + for sup_edge_type in self._supervision_edge_types: ( positive_label_edge_type, negative_label_edge_type, - ) = select_label_edge_types(supervision_edge_type, dataset.graph.keys()) + ) = select_label_edge_types(sup_edge_type, dataset.graph.keys()) self._positive_label_edge_types.append(positive_label_edge_type) if negative_label_edge_type is not None: self._negative_label_edge_types.append(negative_label_edge_type) @@ -363,31 +669,17 @@ def __init__( negative_label_by_edge_types=negative_labels_by_label_edge_type, ) - self.to_device = ( - pin_memory_device - if pin_memory_device - else gigl.distributed.utils.get_available_device( - local_process_rank=local_rank - ) - ) - - num_neighbors = patch_fanout_for_sampling( - dataset.get_edge_types(), num_neighbors - ) - - self._node_feature_info = dataset.node_feature_info - self._edge_feature_info = dataset.edge_feature_info - - # Sets up processes and torch device for initializing the GLT DistNeighborLoader, setting up RPC and worker groups to minimize - # the memory overhead and CPU contention. + # Sets up processes and torch device for initializing the GLT DistNeighborLoader, + # setting up RPC and worker groups to minimize the memory overhead and CPU contention. neighbor_loader_ports = gigl.distributed.utils.get_free_ports_from_master_node( num_ports=local_world_size ) neighbor_loader_port_for_current_rank = neighbor_loader_ports[local_rank] logger.info( - f"Initializing neighbor loader worker in process: {local_rank}/{local_world_size} using device: {self.to_device} on port {neighbor_loader_port_for_current_rank}." + f"Initializing neighbor loader worker in process: {local_rank}/{local_world_size} " + f"using device: {device} on port {neighbor_loader_port_for_current_rank}." ) - should_use_cpu_workers = self.to_device.type == "cpu" + should_use_cpu_workers = device.type == "cpu" if should_use_cpu_workers and num_cpu_threads is None: logger.info( "Using CPU workers, but found num_cpu_threads to be None. " @@ -402,9 +694,8 @@ def __init__( rank=node_rank, world_size=node_world_size, master_worker_port=neighbor_loader_port_for_current_rank, - device=self.to_device, + device=device, should_use_cpu_workers=should_use_cpu_workers, - # Lever to explore tuning for CPU based inference num_cpu_threads=num_cpu_threads, ) logger.info( @@ -420,47 +711,218 @@ def __init__( num_workers=num_workers, worker_devices=[torch.device("cpu") for _ in range(num_workers)], worker_concurrency=worker_concurrency, - # Each worker will spawn several sampling workers, and all sampling workers spawned by workers in one group - # need to be connected. Thus, we need master ip address and master port to - # initate the connection. - # Note that different groups of workers are independent, and thus - # the sampling processes in different groups should be independent, and should - # use different master ports. master_addr=master_ip_address, master_port=dist_sampling_port_for_current_rank, - # Load testing show that when num_rpc_threads exceed 16, the performance - # will degrade. num_rpc_threads=min(dataset.num_partitions, 16), rpc_timeout=600, channel_size=channel_size, - pin_memory=self.to_device.type == "cuda", + pin_memory=device.type == "cuda", ) - if should_cleanup_distributed_context and torch.distributed.is_initialized(): + edge_types = list(dataset.graph.keys()) + + return ( + [sampler_input], + worker_options, + DatasetSchema( + is_labeled_heterogeneous=is_labeled_heterogeneous, + edge_types=edge_types, + node_feature_info=dataset.node_feature_info, + edge_feature_info=dataset.edge_feature_info, + edge_dir=dataset.edge_dir, + ), + ) + + def _setup_for_graph_store( + self, + input_nodes: Optional[ + Union[ + torch.Tensor, + tuple[NodeType, torch.Tensor], + dict[int, tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]], + tuple[ + NodeType, + dict[ + int, tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] + ], + ], + ] + ], + dataset: RemoteDistDataset, + supervision_edge_type: EdgeType, + num_workers: int, + ) -> tuple[ + list[ABLPNodeSamplerInput], RemoteDistSamplingWorkerOptions, DatasetSchema + ]: + """ + Setup method for Graph Store mode. + + Args: + input_nodes: ABLP input from RemoteDistDataset.get_ablp_input(). + Format: dict[server_rank, (anchors, positive_labels, negative_labels)] + or tuple[NodeType, dict[server_rank, (anchors, positive_labels, negative_labels)]]. + dataset: The RemoteDistDataset to sample from. + supervision_edge_type: The single supervision edge type to use. + num_workers: Number of sampling workers. + + Returns: + Tuple of (list[ABLPNodeSamplerInput], RemoteDistSamplingWorkerOptions, DatasetSchema). + """ + # Validate input format - must be Graph Store format + if input_nodes is None: + raise ValueError( + f"When using Graph Store mode, input_nodes must be provided, received {input_nodes}" + ) + elif isinstance(input_nodes, torch.Tensor): + raise ValueError( + f"When using Graph Store mode, input_nodes must be of type " + f"(dict[int, tuple[Tensor, Tensor, Optional[Tensor]]] | " + f"tuple[NodeType, dict[int, tuple[Tensor, Tensor, Optional[Tensor]]]]), " + f"received {type(input_nodes)}" + ) + elif isinstance(input_nodes, tuple) and isinstance( + input_nodes[1], torch.Tensor + ): + raise ValueError( + f"When using Graph Store mode, input_nodes must be of type " + f"(dict[int, tuple[Tensor, Tensor, Optional[Tensor]]] | " + f"tuple[NodeType, dict[int, tuple[Tensor, Tensor, Optional[Tensor]]]]), " + f"received tuple with second element of type {type(input_nodes[1])}" + ) + + is_labeled_heterogeneous = False + node_feature_info = dataset.get_node_feature_info() + edge_feature_info = dataset.get_edge_feature_info() + edge_types = dataset.get_edge_types() + node_rank = dataset.cluster_info.compute_node_rank + + # Get sampling ports for compute-storage connections. + sampling_ports = dataset.get_free_ports_on_storage_cluster( + num_ports=dataset.cluster_info.num_processes_per_compute + ) + sampling_port = sampling_ports[node_rank] + + worker_options = RemoteDistSamplingWorkerOptions( + server_rank=list(range(dataset.cluster_info.num_storage_nodes)), + num_workers=num_workers, + worker_devices=[torch.device("cpu") for _ in range(num_workers)], + master_addr=dataset.cluster_info.storage_cluster_master_ip, + master_port=sampling_port, + worker_key=f"compute_rank_{node_rank}", + ) + logger.info( + f"Rank {torch.distributed.get_rank()}! init for sampling rpc: " + f"tcp://{dataset.cluster_info.storage_cluster_master_ip}:{sampling_port}" + ) + + # Determine input type based on input_nodes structure + if isinstance(input_nodes, abc.Mapping): + # Labeled homogeneous: dict[int, tuple[...]] + nodes_dict = input_nodes + input_type: Optional[NodeType] = DEFAULT_HOMOGENEOUS_NODE_TYPE + elif isinstance(input_nodes, tuple) and isinstance(input_nodes[1], abc.Mapping): + # Heterogeneous: (NodeType, dict[int, tuple[...]]) + input_type = input_nodes[0] + nodes_dict = input_nodes[1] + is_labeled_heterogeneous = True + else: + raise ValueError( + f"When using Graph Store mode, input_nodes must be of type " + f"(dict[int, tuple[...]] | tuple[NodeType, dict[int, tuple[...]]]), " + f"received {type(input_nodes)}" + ) + + # Validate server ranks + servers = nodes_dict.keys() + if len(servers) > 0: + if ( + max(servers) >= dataset.cluster_info.num_storage_nodes + or min(servers) < 0 + ): + raise ValueError( + f"When using Graph Store mode, the server ranks must be in range " + f"[0, {dataset.cluster_info.num_storage_nodes}), " + f"received inputs for servers: {list(servers)}" + ) + + # Get label edge types for building ABLPNodeSamplerInput + # TODO(kmonte): Support multiple supervision edge types in Graph Store mode + ( + positive_label_edge_type, + negative_label_edge_type, + ) = select_label_edge_types(supervision_edge_type, edge_types or []) + logger.info(f"Positive label edge type: {positive_label_edge_type}") + logger.info(f"Negative label edge type: {negative_label_edge_type}") + self._positive_label_edge_types = [positive_label_edge_type] + self._negative_label_edge_types = ( + [negative_label_edge_type] if negative_label_edge_type else [] + ) + + # Convert from dict format to list of ABLPNodeSamplerInput + input_data: list[ABLPNodeSamplerInput] = [] + for server_rank in range(dataset.cluster_info.num_storage_nodes): + if server_rank in nodes_dict: + anchors, positive_labels, negative_labels = nodes_dict[server_rank] + else: + # Empty input for servers with no data for this rank + anchors = torch.empty(0, dtype=torch.long) + positive_labels = torch.empty(0, 0, dtype=torch.long) + negative_labels = None + + # Build label dicts keyed by label edge type + positive_label_by_edge_types = {positive_label_edge_type: positive_labels} + negative_label_by_edge_types: dict[EdgeType, torch.Tensor] = {} + if negative_labels is not None and negative_label_edge_type is not None: + negative_label_by_edge_types[negative_label_edge_type] = negative_labels + logger.info( - f"Cleaning up process group as it was initialized inside {self.__class__.__name__}.__init__." + f"Rank: {torch.distributed.get_rank()}! Building ABLPNodeSamplerInput for server rank: {server_rank} with input type: {input_type}. anchors: {anchors.shape}, positive_labels: {positive_labels.shape}, negative_labels: {negative_labels.shape if negative_labels is not None else None}" ) - torch.distributed.destroy_process_group() + ablp_input = ABLPNodeSamplerInput( + node=anchors, + input_type=input_type, + positive_label_by_edge_types=positive_label_by_edge_types, + negative_label_by_edge_types=negative_label_by_edge_types, + ) + input_data.append(ablp_input) - sampling_config = SamplingConfig( - sampling_type=SamplingType.NODE, - num_neighbors=num_neighbors, - batch_size=batch_size, - shuffle=shuffle, - drop_last=drop_last, - with_edge=True, - collect_features=True, - with_neg=False, - with_weight=False, - edge_dir=dataset.edge_dir, - seed=None, # it's actually optional - None means random. + return ( + input_data, + worker_options, + DatasetSchema( + is_labeled_heterogeneous=is_labeled_heterogeneous, + edge_types=edge_types, + node_feature_info=node_feature_info, + edge_feature_info=edge_feature_info, + edge_dir=dataset.get_edge_dir(), + ), ) - # Code below this point is taken from the GLT DistNeighborLoader.__init__() function (graphlearn_torch/python/distributed/dist_neighbor_loader.py). - # We do this so that we may override the DistSamplingProducer that is used with the GiGL implementation. + def _init_remote_worker( + self, + dataset: RemoteDistDataset, + sampler_input: list[ABLPNodeSamplerInput], + sampling_config: SamplingConfig, + worker_options: RemoteDistSamplingWorkerOptions, + dataset_metadata: DatasetSchema, + ) -> None: + """ + Initialize the remote worker code path for Graph Store mode. + + This re-implements GLT's DistLoader remote worker setup but uses GiGL's DistServer. - self.data = dataset - self.input_data = sampler_input + Args: + dataset: The RemoteDistDataset to sample from. + sampler_input: List of ABLPNodeSamplerInput, one per server. + sampling_config: Configuration for sampling. + worker_options: Options for remote sampling workers. + dataset_metadata: Metadata about the dataset schema. + """ + # Set instance variables (like DistLoader does) + # Note: We assign to self.data and self.input_data which are also set in the colocated + # branch. For Graph Store mode, data is None and input_data is a list. + object.__setattr__(self, "data", None) # No local data in Graph Store mode + object.__setattr__(self, "input_data", sampler_input) self.sampling_type = sampling_config.sampling_type self.num_neighbors = sampling_config.num_neighbors self.batch_size = sampling_config.batch_size @@ -473,70 +935,78 @@ def __init__( self.sampling_config = sampling_config self.worker_options = worker_options - # We can set shutdowned to false now self._shutdowned = False - self._is_mp_worker = True + # Set worker type flags + self._is_mp_worker = False self._is_collocated_worker = False - self._is_remote_worker = False + self._is_remote_worker = True - self.num_data_partitions = self.data.num_partitions - self.data_partition_idx = self.data.partition_idx - self._set_ntypes_and_etypes( - self.data.get_node_types(), self.data.get_edge_types() - ) + # For remote worker, end of epoch is determined by server + self._num_expected = float("inf") + self._with_channel = True self._num_recv = 0 self._epoch = 0 - current_ctx = get_context() - - self._input_len = len(self.input_data) - self._input_type = self.input_data.input_type - self._num_expected = self._input_len // self.batch_size - if not self.drop_last and self._input_len % self.batch_size != 0: - self._num_expected += 1 - - if not current_ctx.is_worker(): - raise RuntimeError( - f"'{self.__class__.__name__}': only supports " - f"launching multiprocessing sampling workers with " - f"a non-server distribution mode, current role of " - f"distributed context is {current_ctx.role}." - ) - if self.data is None: - raise ValueError( - f"'{self.__class__.__name__}': missing input dataset " - f"when launching multiprocessing sampling workers." + # Get server rank list from worker_options + self._server_rank_list = ( + worker_options.server_rank + if isinstance(worker_options.server_rank, list) + else [worker_options.server_rank] + ) + self._input_data_list = sampler_input # Already a list (one per server) + + # Get input type from first input + self._input_type = self._input_data_list[0].input_type + + # Get dataset metadata from cluster_info (not via RPC) + self.num_data_partitions = dataset.cluster_info.num_storage_nodes + self.data_partition_idx = dataset.cluster_info.compute_node_rank + + # Derive node types from edge types + # For labeled homogeneous: edge_types contains DEFAULT_HOMOGENEOUS_EDGE_TYPE + # For heterogeneous: extract unique src/dst types from edge types + edge_types = dataset_metadata.edge_types or [] + if edge_types: + node_types = list( + set([et[0] for et in edge_types] + [et[2] for et in edge_types]) ) + else: + node_types = [DEFAULT_HOMOGENEOUS_NODE_TYPE] + self._set_ntypes_and_etypes(node_types, edge_types) + + # Create sampling producers on each server (concurrently) + # Move input data to CPU before sending to server + for input_data in self._input_data_list: + input_data.to(torch.device("cpu")) + + self._producer_id_list = [] + with concurrent.futures.ThreadPoolExecutor() as executor: + futures = [ + executor.submit( + request_server, + server_rank, + DistServer.create_sampling_producer, + input_data, + self.sampling_config, + self.worker_options, + ) + for server_rank, input_data in zip( + self._server_rank_list, self._input_data_list + ) + ] - # Launch multiprocessing sampling workers - self._with_channel = True - self.worker_options._set_worker_ranks(current_ctx) + for future in futures: + producer_id = future.result() + self._producer_id_list.append(producer_id) - self._channel = ShmChannel( - self.worker_options.channel_capacity, self.worker_options.channel_size - ) - if self.worker_options.pin_memory: - self._channel.pin_memory() - - self._mp_producer = DistSamplingProducer( - self.data, - self.input_data, - self.sampling_config, - self.worker_options, - self._channel, - ) - # When initiating data loader(s), there will be a spike of memory usage lasting for ~30s. - # The current hypothesis is making connections across machines require a lot of memory. - # If we start all data loaders in all processes simultaneously, the spike of memory - # usage will add up and cause CPU memory OOM. Hence, we initiate the data loaders group by group - # to smooth the memory usage. The definition of group is discussed in init_neighbor_loader_worker. - logger.info( - f"---Machine {rank} local process number {local_rank} preparing to sleep for {process_start_gap_seconds * local_rank} seconds" + # Create remote receiving channel for cross-machine message passing + self._channel = RemoteReceivingChannel( + self._server_rank_list, + self._producer_id_list, + self.worker_options.prefetch_size, ) - time.sleep(process_start_gap_seconds * local_rank) - self._mp_producer.init() def _get_labels( self, msg: SampleMessage diff --git a/gigl/distributed/dist_neighbor_sampler.py b/gigl/distributed/dist_neighbor_sampler.py index ec9692611..c7c5b5b6a 100644 --- a/gigl/distributed/dist_neighbor_sampler.py +++ b/gigl/distributed/dist_neighbor_sampler.py @@ -15,6 +15,7 @@ from graphlearn_torch.typing import EdgeType, NodeType from graphlearn_torch.utils import count_dict, merge_dict, reverse_edge_type +from gigl.common.logger import Logger from gigl.distributed.sampler import ( NEGATIVE_LABEL_METADATA_KEY, POSITIVE_LABEL_METADATA_KEY, @@ -22,6 +23,8 @@ ) from gigl.utils.data_splitters import PADDING_NODE +logger = Logger() + # TODO (mkolodner-sc): Investigate upstreaming this change back to GLT diff --git a/gigl/distributed/dist_sampling_producer.py b/gigl/distributed/dist_sampling_producer.py index 249ae772d..e743fb933 100644 --- a/gigl/distributed/dist_sampling_producer.py +++ b/gigl/distributed/dist_sampling_producer.py @@ -32,8 +32,11 @@ from torch.utils.data.dataloader import DataLoader from torch.utils.data.dataset import Dataset +from gigl.common.logger import Logger from gigl.distributed.dist_neighbor_sampler import DistABLPNeighborSampler +logger = Logger() + def _sampling_worker_loop( rank: int, @@ -84,6 +87,7 @@ def _sampling_worker_loop( if sampling_config.seed is not None: seed_everything(sampling_config.seed) + logger.info(f"Sampling config: {sampling_config}") dist_sampler = DistABLPNeighborSampler( data, sampling_config.num_neighbors, @@ -167,6 +171,7 @@ def _sampling_worker_loop( class DistSamplingProducer(DistMpSamplingProducer): def init(self): + logger.info("Initializing GiGL DistSamplingProducer") r"""Create the subprocess pool. Init samplers and rpc server.""" if self.sampling_config.seed is not None: seed_everything(self.sampling_config.seed) diff --git a/gigl/distributed/dist_server.py b/gigl/distributed/dist_server.py new file mode 100644 index 000000000..75b00b22c --- /dev/null +++ b/gigl/distributed/dist_server.py @@ -0,0 +1,334 @@ +# Copyright 2022 Alibaba Group Holding Limited. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import logging +import threading +import time +import warnings +from typing import Dict, Optional, Union + +import graphlearn_torch.distributed.dist_server as glt_dist_server +from graphlearn_torch.channel import QueueTimeoutError, ShmChannel +from graphlearn_torch.distributed import ( + RemoteDistSamplingWorkerOptions, + barrier, + init_rpc, + shutdown_rpc, +) +from graphlearn_torch.partition import PartitionBook +from graphlearn_torch.sampler import ( + EdgeSamplerInput, + NodeSamplerInput, + RemoteSamplerInput, + SamplingConfig, +) + +from gigl.distributed.dist_dataset import DistDataset +from gigl.distributed.dist_sampling_producer import DistSamplingProducer + +SERVER_EXIT_STATUS_CHECK_INTERVAL = 5.0 +r""" Interval (in seconds) to check exit status of server. +""" + + +class DistServer(object): + r"""A server that supports launching remote sampling workers for + training clients. + + Note that this server is enabled only when the distribution mode is a + server-client framework, and the graph and feature store will be partitioned + and managed by all server nodes. + + Args: + dataset (DistDataset): The ``DistDataset`` object of a partition of graph + data and feature data, along with distributed patition books. + """ + + def __init__(self, dataset: DistDataset): + self.dataset = dataset + self._lock = threading.RLock() + self._exit = False + self._cur_producer_idx = 0 # auto incremental index (same as producer count) + # The mapping from the key in worker options (such as 'train', 'test') + # to producer id + self._worker_key2producer_id: Dict[str, int] = {} + self._producer_pool: Dict[int, DistSamplingProducer] = {} + self._msg_buffer_pool: Dict[int, ShmChannel] = {} + self._epoch: Dict[int, int] = {} # last epoch for the producer + + def shutdown(self): + for producer_id in list(self._producer_pool.keys()): + self.destroy_sampling_producer(producer_id) + assert len(self._producer_pool) == 0 + assert len(self._msg_buffer_pool) == 0 + + def wait_for_exit(self): + r"""Block until the exit flag been set to ``True``.""" + while not self._exit: + time.sleep(SERVER_EXIT_STATUS_CHECK_INTERVAL) + + def exit(self): + r"""Set the exit flag to ``True``.""" + self._exit = True + return self._exit + + def get_dataset_meta(self): + r"""Get the meta info of the distributed dataset managed by the current + server, including partition info and graph types. + """ + return ( + self.dataset.num_partitions, + self.dataset.partition_idx, + self.dataset.get_node_types(), + self.dataset.get_edge_types(), + ) + + def get_node_partition_id(self, node_type, index): + if isinstance(self.dataset.node_pb, PartitionBook): + partition_id = self.dataset.node_pb[index] + return partition_id + elif isinstance(self.dataset.node_pb, Dict): + partition_id = self.dataset.node_pb[node_type][index] + return partition_id + return None + + def get_node_feature(self, node_type, index): + feature = self.dataset.get_node_feature(node_type) + return feature[index].cpu() + + def get_tensor_size(self, node_type): + feature = self.dataset.get_node_feature(node_type) + return feature.shape + + def get_node_label(self, node_type, index): + label = self.dataset.get_node_label(node_type) + return label[index] + + def get_edge_index(self, edge_type, layout): + graph = self.dataset.get_graph(edge_type) + row = None + col = None + result = None + if layout == "coo": + row, col, _, _ = graph.topo.to_coo() + result = (row, col) + else: + raise ValueError(f"Invalid layout {layout}") + return result + + def get_edge_size(self, edge_type, layout): + graph = self.dataset.get_graph(edge_type) + if layout == "coo": + row_count = graph.row_count + col_count = graph.col_count + else: + raise ValueError(f"Invalid layout {layout}") + return (row_count, col_count) + + def create_sampling_producer( + self, + sampler_input: Union[NodeSamplerInput, EdgeSamplerInput, RemoteSamplerInput], + sampling_config: SamplingConfig, + worker_options: RemoteDistSamplingWorkerOptions, + ) -> int: + r"""Create and initialize an instance of ``DistSamplingProducer`` with + a group of subprocesses for distributed sampling. + + Args: + sampler_input (NodeSamplerInput or EdgeSamplerInput): The input data + for sampling. + sampling_config (SamplingConfig): Configuration of sampling meta info. + worker_options (RemoteDistSamplingWorkerOptions): Options for launching + remote sampling workers by this server. + + Returns: + A unique id of created sampling producer on this server. + """ + if isinstance(sampler_input, RemoteSamplerInput): + sampler_input = sampler_input.to_local_sampler_input(dataset=self.dataset) + + with self._lock: + producer_id = self._worker_key2producer_id.get(worker_options.worker_key) + if producer_id is None: + producer_id = self._cur_producer_idx + self._worker_key2producer_id[worker_options.worker_key] = producer_id + self._cur_producer_idx += 1 + buffer = ShmChannel( + worker_options.buffer_capacity, worker_options.buffer_size + ) + print( + f"Creating DistMpSamplingProducer ({DistSamplingProducer}) for worker key: {worker_options.worker_key} with producer id: {producer_id}" + ) + producer = DistSamplingProducer( + self.dataset, sampler_input, sampling_config, worker_options, buffer + ) + producer.init() + self._producer_pool[producer_id] = producer + self._msg_buffer_pool[producer_id] = buffer + self._epoch[producer_id] = -1 + return producer_id + + def destroy_sampling_producer(self, producer_id: int): + r"""Shutdown and destroy a sampling producer managed by this server with + its producer id. + """ + with self._lock: + producer = self._producer_pool.get(producer_id, None) + if producer is not None: + producer.shutdown() + self._producer_pool.pop(producer_id) + self._msg_buffer_pool.pop(producer_id) + self._epoch.pop(producer_id) + + def start_new_epoch_sampling(self, producer_id: int, epoch: int): + r"""Start a new epoch sampling tasks for a specific sampling producer + with its producer id. + """ + with self._lock: + cur_epoch = self._epoch[producer_id] + if cur_epoch < epoch: + self._epoch[producer_id] = epoch + producer = self._producer_pool.get(producer_id, None) + if producer is not None: + producer.produce_all() + + def fetch_one_sampled_message(self, producer_id: int): + r"""Fetch a sampled message from the buffer of a specific sampling + producer with its producer id. + """ + producer = self._producer_pool.get(producer_id, None) + if producer is None: + warnings.warn("invalid producer_id {producer_id}") + return None, False + if producer.is_all_sampling_completed_and_consumed(): + return None, True + buffer = self._msg_buffer_pool.get(producer_id, None) + while True: + try: + msg = buffer.recv(timeout_ms=500) + return msg, False + except QueueTimeoutError as e: + if producer.is_all_sampling_completed(): + return None, True + + +_dist_server: Optional[DistServer] = None +r""" ``DistServer`` instance of the current process. +""" + + +def get_server() -> Optional[DistServer]: + r"""Get the ``DistServer`` instance on the current process.""" + return _dist_server + + +def init_server( + num_servers: int, + server_rank: int, + dataset: DistDataset, + master_addr: str, + master_port: int, + num_clients: int = 0, + num_rpc_threads: int = 16, + request_timeout: int = 180, + server_group_name: Optional[str] = None, + is_dynamic: bool = False, +): + r"""Initialize the current process as a server and establish connections + with all other servers and clients. Note that this method should be called + only in the server-client distribution mode. + + Args: + num_servers (int): Number of processes participating in the server group. + server_rank (int): Rank of the current process withing the server group (it + should be a number between 0 and ``num_servers``-1). + dataset (DistDataset): The ``DistDataset`` object of a partition of graph + data and feature data, along with distributed patition book info. + master_addr (str): The master TCP address for RPC connection between all + servers and clients, the value of this parameter should be same for all + servers and clients. + master_port (int): The master TCP port for RPC connection between all + servers and clients, the value of this parameter should be same for all + servers and clients. + num_clients (int): Number of processes participating in the client group. + if ``is_dynamic`` is ``True``, this parameter will be ignored. + num_rpc_threads (int): The number of RPC worker threads used for the + current server to respond remote requests. (Default: ``16``). + request_timeout (int): The max timeout seconds for remote requests, + otherwise an exception will be raised. (Default: ``16``). + server_group_name (str): A unique name of the server group that current + process belongs to. If set to ``None``, a default name will be used. + (Default: ``None``). + is_dynamic (bool): Whether the world size is dynamic. (Default: ``False``). + """ + if server_group_name: + server_group_name = server_group_name.replace("-", "_") + glt_dist_server._set_server_context( + num_servers, server_rank, server_group_name, num_clients + ) + global _dist_server + _dist_server = DistServer(dataset=dataset) + # Also set GLT's _dist_server so that GLT's RPC mechanism routes to GiGL's server + glt_dist_server._dist_server = _dist_server + init_rpc( + master_addr, + master_port, + num_rpc_threads, + request_timeout, + is_dynamic=is_dynamic, + ) + + +def wait_and_shutdown_server(): + r"""Block until all client have been shutdowned, and further shutdown the + server on the current process and destroy all RPC connections. + """ + current_context = glt_dist_server.get_context() + if current_context is None: + logging.warning( + "'wait_and_shutdown_server': try to shutdown server when " + "the current process has not been initialized as a server." + ) + return + if not current_context.is_server(): + raise RuntimeError( + f"'wait_and_shutdown_server': role type of " + f"the current process context is not a server, " + f"got {current_context.role}." + ) + global _dist_server + if _dist_server is not None: + _dist_server.wait_for_exit() + _dist_server.shutdown() + _dist_server = None + # Also clear GLT's _dist_server + glt_dist_server._dist_server = None + barrier() + shutdown_rpc() + + +def _call_func_on_server(func, *args, **kwargs): + r"""A callee entry for remote requests on the server side.""" + if not callable(func): + logging.warning( + f"'_call_func_on_server': receive a non-callable " f"function target {func}" + ) + return None + + server = get_server() + if hasattr(server, func.__name__): + return func(server, *args, **kwargs) + + return func(*args, **kwargs) diff --git a/gigl/distributed/graph_store/remote_dist_dataset.py b/gigl/distributed/graph_store/remote_dist_dataset.py index 19e60afe2..c5ee47c0c 100644 --- a/gigl/distributed/graph_store/remote_dist_dataset.py +++ b/gigl/distributed/graph_store/remote_dist_dataset.py @@ -8,16 +8,21 @@ from gigl.common.logger import Logger from gigl.distributed.graph_store.storage_utils import ( + get_ablp_input, get_edge_dir, get_edge_feature_info, get_edge_types, get_node_feature_info, - get_node_ids_for_rank, + get_node_ids, ) from gigl.distributed.utils.networking import get_free_ports from gigl.env.distributed import GraphStoreInfo from gigl.src.common.types.graph_data import EdgeType, NodeType -from gigl.types.graph import FeatureInfo +from gigl.types.graph import ( + DEFAULT_HOMOGENEOUS_EDGE_TYPE, + DEFAULT_HOMOGENEOUS_NODE_TYPE, + FeatureInfo, +) logger = Logger() @@ -109,23 +114,26 @@ def get_edge_dir(self) -> Union[str, Literal["in", "out"]]: ) def _get_node_ids( - self, node_type: Optional[NodeType] = None + self, + rank: Optional[int] = None, + world_size: Optional[int] = None, + node_type: Optional[NodeType] = None, + split: Optional[Literal["train", "val", "test"]] = None, ) -> dict[int, torch.Tensor]: """Fetches node ids from the storage nodes for the current compute node (machine).""" futures: list[torch.futures.Future[torch.Tensor]] = [] - rank = self.cluster_info.compute_node_rank - world_size = self.cluster_info.num_storage_nodes logger.info( - f"Getting node ids for rank {rank} / {world_size} with node type {node_type}" + f"Getting node ids for rank {rank} / {world_size} with node type {node_type} and split {split}" ) for server_rank in range(self.cluster_info.num_storage_nodes): futures.append( async_request_server( server_rank, - get_node_ids_for_rank, + get_node_ids, rank=rank, world_size=world_size, + split=split, node_type=node_type, ) ) @@ -134,35 +142,79 @@ def _get_node_ids( def get_node_ids( self, + rank: Optional[int] = None, + world_size: Optional[int] = None, + split: Optional[Literal["train", "val", "test"]] = None, node_type: Optional[NodeType] = None, ) -> dict[int, torch.Tensor]: """ Fetches node ids from the storage nodes for the current compute node (machine). - The returned list are the node ids for the current compute node, by storage rank. + The returned dict maps storage rank to the node ids stored on that storage node, + filtered and sharded according to the provided arguments. + + Args: + rank (Optional[int]): The rank of the process requesting node ids. Must be provided if world_size is provided. + world_size (Optional[int]): The total number of processes in the distributed setup. Must be provided if rank is provided. + split (Optional[Literal["train", "val", "test"]]): + The split of the dataset to get node ids from. + If provided, the dataset must have `train_node_ids`, `val_node_ids`, and `test_node_ids` properties. + node_type (Optional[NodeType]): The type of nodes to get. + Must be provided for heterogeneous datasets. + + Returns: + dict[int, torch.Tensor]: A dict mapping storage rank to node ids. - For example, if there are two storage ranks, and two compute ranks, and 16 total nodes, - In this scenario, the node ids are sharded as follows: - Storage rank 0: [0, 1, 2, 3, 4, 5, 6, 7] - Storage rank 1: [8, 9, 10, 11, 12, 13, 14, 15] + Examples: + Suppose we have 2 storage nodes and 2 compute nodes, with 16 total nodes. + Nodes are partitioned across storage nodes, with splits defined as: - NOTE: The GLT sampling enginer expects that all processes on a given compute machine - to have the same sampling input (node ids). - As such, the input tensors may be duplicated across all processes on a given compute machine. - In order to save on cpu memory, pass in `mp_sharing_dict` to the `RemoteDistDataset` constructor. + Storage rank 0: [0, 1, 2, 3, 4, 5, 6, 7] + train=[0, 1, 2, 3], val=[4, 5], test=[6, 7] + Storage rank 1: [8, 9, 10, 11, 12, 13, 14, 15] + train=[8, 9, 10, 11], val=[12, 13], test=[14, 15] - Then, for compute rank 0 (node 0, process 0), the returned dict will be: + Get all nodes (no split filtering, no sharding): + + >>> dataset.get_node_ids() { - 0: [0, 1, 3, 4], # From storage rank 0 - 1: [8, 9, 10, 11] # From storage rank 1 + 0: tensor([0, 1, 2, 3, 4, 5, 6, 7]), # All 8 nodes from storage rank 0 + 1: tensor([8, 9, 10, 11, 12, 13, 14, 15]) # All 8 nodes from storage rank 1 } - Args: - node_type (Optional[NodeType]): The type of nodes to get. - Must be provided for heterogeneous datasets. + Shard all nodes across 2 compute nodes (compute rank 0 gets first half from each storage): - Returns: - dict[int, torch.Tensor]: A dict storage rank to node ids. + >>> dataset.get_node_ids(rank=0, world_size=2) + { + 0: tensor([0, 1, 2, 3]), # First 4 of all 8 nodes from storage rank 0 + 1: tensor([8, 9, 10, 11]) # First 4 of all 8 nodes from storage rank 1 + } + + Get only training nodes (no sharding): + + >>> dataset.get_node_ids(split="train") + { + 0: tensor([0, 1, 2, 3]), # 4 training nodes from storage rank 0 + 1: tensor([8, 9, 10, 11]) # 4 training nodes from storage rank 1 + } + + Combine split and sharding (training nodes, sharded for compute rank 0): + + >>> dataset.get_node_ids(rank=0, world_size=2, split="train") + { + 0: tensor([0, 1]), # First 2 of 4 training nodes from storage rank 0 + 1: tensor([8, 9]) # First 2 of 4 training nodes from storage rank 1 + } + + Note: + When `split=None`, all nodes are queryable. This means nodes from any split + (train, val, or test) may be returned. This is useful when you need to sample + neighbors during inference, as neighbor nodes may belong to any split. + + The GLT sampling engine expects all processes on a given compute machine to have + the same sampling input (node ids). As such, the input tensors may be duplicated + across all processes on a given compute machine. To save on CPU memory, pass + `mp_sharing_dict` to the `RemoteDistDataset` constructor. """ def server_key(server_rank: int) -> str: @@ -174,7 +226,7 @@ def server_key(server_rank: int) -> str: logger.info( f"Compute rank {torch.distributed.get_rank()} is getting node ids from storage nodes" ) - node_ids = self._get_node_ids(node_type) + node_ids = self._get_node_ids(rank, world_size, node_type, split) for server_rank, node_id in node_ids.items(): node_id.share_memory_() self._mp_sharing_dict[server_key(server_rank)] = node_id @@ -188,7 +240,7 @@ def server_key(server_rank: int) -> str: } return node_ids else: - return self._get_node_ids(node_type) + return self._get_node_ids(rank, world_size, node_type, split) def get_free_ports_on_storage_cluster(self, num_ports: int) -> list[int]: """ @@ -228,6 +280,174 @@ def get_free_ports_on_storage_cluster(self, num_ports: int) -> list[int]: logger.info(f"Compute rank {compute_cluster_rank} received free ports: {ports}") return ports + def _get_ablp_input( + self, + split: Literal["train", "val", "test"], + rank: Optional[int] = None, + world_size: Optional[int] = None, + node_type: NodeType = DEFAULT_HOMOGENEOUS_NODE_TYPE, + supervision_edge_type: EdgeType = DEFAULT_HOMOGENEOUS_EDGE_TYPE, + ) -> dict[int, tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]]: + """Fetches ABLP input from the storage nodes for the current compute node (machine).""" + futures: list[ + torch.futures.Future[ + tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] + ] + ] = [] + logger.info( + f"Getting ABLP input for rank {rank} / {world_size} with node type {node_type}, " + f"split {split}, and supervision edge type {supervision_edge_type}" + ) + + for server_rank in range(self.cluster_info.num_storage_nodes): + futures.append( + async_request_server( + server_rank, + get_ablp_input, + split=split, + rank=rank, + world_size=world_size, + node_type=node_type, + supervision_edge_type=supervision_edge_type, + ) + ) + ablp_inputs = torch.futures.wait_all(futures) + return { + server_rank: ablp_input + for server_rank, ablp_input in enumerate(ablp_inputs) + } + + def get_ablp_input( + self, + split: Literal["train", "val", "test"], + rank: Optional[int] = None, + world_size: Optional[int] = None, + node_type: NodeType = DEFAULT_HOMOGENEOUS_NODE_TYPE, + supervision_edge_type: EdgeType = DEFAULT_HOMOGENEOUS_EDGE_TYPE, + ) -> dict[int, tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]]: + """ + Fetches ABLP (Anchor Based Link Prediction) input from the storage nodes. + + The returned dict maps storage rank to a tuple of (anchor_nodes, positive_labels, negative_labels) + for that storage node, filtered and sharded according to the provided arguments. + + Args: + split (Literal["train", "val", "test"]): The split to get the input for. + rank (Optional[int]): The rank of the process requesting the input. + Must be provided if world_size is provided. + world_size (Optional[int]): The total number of processes in the distributed setup. + Must be provided if rank is provided. + node_type (NodeType): The type of nodes to retrieve. + Defaults to DEFAULT_HOMOGENEOUS_NODE_TYPE. + supervision_edge_type (EdgeType): The edge type for supervision. + Defaults to DEFAULT_HOMOGENEOUS_EDGE_TYPE. + + Returns: + dict[int, tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]]: + A dict mapping storage rank to a tuple of: + - anchor_nodes: The anchor node ids for the split + - positive_labels: Positive label node ids of shape [N, M] where N is the number + of anchor nodes and M is the number of positive labels per anchor + - negative_labels: Negative label node ids of shape [N, M], or None if unavailable + + Examples: + Suppose we have 1 storage node with users [0, 1, 2, 3, 4] where: + train=[0, 1, 2], val=[3], test=[4] + And positive/negative labels defined for link prediction. + + Get training ABLP input: + + >>> dataset.get_ablp_input(split="train", node_type=USER) + { + 0: ( + tensor([0, 1, 2]), # anchor nodes + tensor([[0, 1], [1, 2], [2, 3]]), # positive labels + tensor([[2], [3], [4]]) # negative labels + ) + } + + With sharding across 2 compute nodes (rank 0 gets first portion): + + >>> dataset.get_ablp_input(split="train", rank=0, world_size=2, node_type=USER) + { + 0: ( + tensor([0]), # first anchor node + tensor([[0, 1]]), # its positive labels + tensor([[2]]) # its negative labels + ) + } + + Note: + The GLT sampling engine expects all processes on a given compute machine to have + the same sampling input (node ids). As such, the input tensors may be duplicated + across all processes on a given compute machine. To save on CPU memory, pass + `mp_sharing_dict` to the `RemoteDistDataset` constructor. + """ + + def anchors_key(server_rank: int) -> str: + return f"ablp_server_{server_rank}_anchors" + + def positive_labels_key(server_rank: int) -> str: + return f"ablp_server_{server_rank}_positive_labels" + + def negative_labels_key(server_rank: int) -> str: + return f"ablp_server_{server_rank}_negative_labels" + + if self._mp_sharing_dict is not None: + if self._local_rank == 0: + start_time = time.time() + logger.info( + f"Compute rank {torch.distributed.get_rank()} is getting ABLP input from storage nodes" + ) + ablp_inputs = self._get_ablp_input( + split, rank, world_size, node_type, supervision_edge_type + ) + for server_rank, ( + anchors, + positive_labels, + negative_labels, + ) in ablp_inputs.items(): + anchors.share_memory_() + positive_labels.share_memory_() + self._mp_sharing_dict[anchors_key(server_rank)] = anchors + self._mp_sharing_dict[ + positive_labels_key(server_rank) + ] = positive_labels + if negative_labels is not None: + negative_labels.share_memory_() + self._mp_sharing_dict[ + negative_labels_key(server_rank) + ] = negative_labels + logger.info( + f"Compute rank {torch.distributed.get_rank()} got ABLP input from storage nodes " + f"in {time.time() - start_time:.2f} seconds" + ) + torch.distributed.barrier() + returned_ablp_inputs: dict[ + int, tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]] + ] = {} + for server_rank in range(self.cluster_info.num_storage_nodes): + anchors = self._mp_sharing_dict[anchors_key(server_rank)] + positive_labels = self._mp_sharing_dict[ + positive_labels_key(server_rank) + ] + neg_key = negative_labels_key(server_rank) + negative_labels = ( + self._mp_sharing_dict[neg_key] + if neg_key in self._mp_sharing_dict + else None + ) + returned_ablp_inputs[server_rank] = ( + anchors, + positive_labels, + negative_labels, + ) + return returned_ablp_inputs + else: + return self._get_ablp_input( + split, rank, world_size, node_type, supervision_edge_type + ) + def get_edge_types(self) -> Optional[list[EdgeType]]: """Get the edge types from the registered dataset. diff --git a/gigl/distributed/graph_store/storage_utils.py b/gigl/distributed/graph_store/storage_utils.py index e90209e70..dbe4f3a44 100644 --- a/gigl/distributed/graph_store/storage_utils.py +++ b/gigl/distributed/graph_store/storage_utils.py @@ -18,6 +18,7 @@ [1]: https://github.com/alibaba/graphlearn-for-pytorch/blob/main/graphlearn_torch/python/distributed/dist_server.py#L38 """ +from collections import abc from typing import Literal, Optional, Union import torch @@ -105,49 +106,93 @@ def get_edge_dir() -> Literal["in", "out"]: return _dataset.edge_dir -# TODO(kmonte): Migrate this to be `get_node_ids(split?, shard?)` -def get_node_ids_for_rank( - rank: int, - world_size: int, - node_type: Optional[NodeType] = DEFAULT_HOMOGENEOUS_NODE_TYPE, +def get_node_ids( + rank: Optional[int] = None, + world_size: Optional[int] = None, + split: Optional[Union[Literal["train", "val", "test"], str]] = None, + node_type: Optional[NodeType] = None, ) -> torch.Tensor: - """Get the node IDs assigned to a specific rank in distributed processing. - - Shards the node IDs across processes based on the rank and world size. + """ + Get the node ids from the registered dataset. Args: - rank: The rank of the process requesting node IDs. - world_size: The total number of processes in the distributed setup. - node_type: The type of nodes to retrieve. Defaults to the default homogeneous node type. + rank (Optional[int]): The rank of the process requesting node ids. Must be provided if world_size is provided. + world_size (Optional[int]): The total number of processes in the distributed setup. Must be provided if rank is provided. + split (Optional[Literal["train", "val", "test"]]): The split of the dataset to get node ids from. If provided, the dataset must have `train_node_ids`, `val_node_ids`, and `test_node_ids` properties. + node_type (Optional[NodeType]): The type of nodes to get node ids for. Must be provided if the dataset is heterogeneous. Returns: - A tensor containing the node IDs assigned to the specified rank. + The node ids. Raises: - ValueError: If no dataset has been registered or if node_ids format is invalid. + ValueError: + * If no dataset has been registered + * If the rank and world_size are not provided together + * If the split is invalid + * If the node ids are not a torch.Tensor or a dict[NodeType, torch.Tensor] + * If the node type is provided for a homogeneous dataset + * If the node ids are not a dict[NodeType, torch.Tensor] when no node type is provided + + Examples: + Suppose the dataset has 100 nodes total: train=[0..59], val=[60..79], test=[80..99]. + + Get all node ids (no split filtering): + + >>> get_node_ids() + tensor([0, 1, 2, ..., 99]) # All 100 nodes + + Get only training nodes: + + >>> get_node_ids(split="train") + tensor([0, 1, 2, ..., 59]) # 60 training nodes + + Shard all nodes across 4 processes (each gets ~25 nodes): + + >>> get_node_ids(rank=0, world_size=4) + tensor([0, 1, 2, ..., 24]) # First 25 of all 100 nodes + + Shard training nodes across 4 processes (each gets ~15 nodes): + + >>> get_node_ids(rank=0, world_size=4, split="train") + tensor([0, 1, 2, ..., 14]) # First 15 of the 60 training nodes + + Note: When `split=None`, all nodes are queryable. This means nodes from any + split (train, val, or test) may be returned. This is useful when you need + to sample neighbors during inference, as neighbor nodes may belong to any split. """ - logger.info( - f"Getting node ids for rank {rank} / {world_size} with node type {node_type}" - ) if _dataset is None: raise _NO_DATASET_ERROR - if isinstance(_dataset.node_ids, torch.Tensor): - if node_type is not None: - raise ValueError( - f"node_type must be None for a homogeneous dataset. Got {node_type}. In GiGL, we usually do not have a truly homogeneous dataset, this is an odd error!" - ) + if (rank is None) ^ (world_size is None): + raise ValueError( + f"rank and world_size must be provided together. Received rank: {rank}, world_size: {world_size}" + ) + if split == "train": + nodes = _dataset.train_node_ids + elif split == "val": + nodes = _dataset.val_node_ids + elif split == "test": + nodes = _dataset.test_node_ids + elif split is None: nodes = _dataset.node_ids - elif isinstance(_dataset.node_ids, dict): - if node_type is None: + else: + raise ValueError( + f"Invalid split: {split}. Must be one of 'train', 'val', 'test', or None." + ) + + if node_type is not None: + if not isinstance(nodes, abc.Mapping): raise ValueError( - f"node_type must be not None for a heterogeneous dataset. Got {node_type}." + f"node_type was provided as {node_type}, so node ids must be a dict[NodeType, torch.Tensor] (e.g. a heterogeneous dataset), got {type(nodes)}" ) - nodes = _dataset.node_ids[node_type] - else: + nodes = nodes[node_type] + elif not isinstance(nodes, torch.Tensor): raise ValueError( - f"Node ids must be a torch.Tensor or a dict[NodeType, torch.Tensor], got {type(_dataset.node_ids)}" + f"node_type was not provided, so node ids must be a torch.Tensor (e.g. a homogeneous dataset), got {type(nodes)}." ) - return shard_nodes_by_process(nodes, rank, world_size) + + if rank is not None and world_size is not None: + return shard_nodes_by_process(nodes, rank, world_size) + return nodes def get_edge_types() -> Optional[list[EdgeType]]: @@ -166,8 +211,8 @@ def get_edge_types() -> Optional[list[EdgeType]]: def get_ablp_input( split: Union[Literal["train", "val", "test"], str], - rank: int, - world_size: int, + rank: Optional[int] = None, + world_size: Optional[int] = None, node_type: NodeType = DEFAULT_HOMOGENEOUS_NODE_TYPE, supervision_edge_type: EdgeType = DEFAULT_HOMOGENEOUS_EDGE_TYPE, ) -> tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: @@ -178,8 +223,8 @@ def get_ablp_input( Args: split: The split to get the training input for. - rank: The rank of the process requesting the training input. - world_size: The total number of processes in the distributed setup. + rank: The rank of the process requesting the training input. Defaults to None, in which case all nodes are returned. Must be provided if world_size is provided. + world_size: The total number of processes in the distributed setup. Defaults to None, in which case all nodes are returned. Must be provided if rank is provided. node_type: The type of nodes to retrieve. Defaults to the default homogeneous node type. supervision_edge_type: The edge type to use for the supervision. Defaults to the default homogeneous edge type. Returns: @@ -195,31 +240,13 @@ def get_ablp_input( if _dataset is None: raise _NO_DATASET_ERROR - if split == "train": - anchors = _dataset.train_node_ids - elif split == "val": - anchors = _dataset.val_node_ids - elif split == "test": - anchors = _dataset.test_node_ids - else: - raise ValueError(f"Invalid split: {split}") - - if isinstance(anchors, torch.Tensor): - raise ValueError( - f"dataset.node_ids should be a dict[NodeType, torch.Tensor] for getting training input for datasets, got a torch.Tensor for split {split}" - ) - elif isinstance(anchors, dict): - anchor_nodes = anchors[node_type] - else: - raise ValueError( - f"Anchor nodes must be a torch.Tensor or a dict[NodeType, torch.Tensor], got {type(anchors)}" - ) - - anchors_for_rank = shard_nodes_by_process(anchor_nodes, rank, world_size) + anchors = get_node_ids( + split=split, rank=rank, world_size=world_size, node_type=node_type + ) positive_label_edge_type, negative_label_edge_type = select_label_edge_types( supervision_edge_type, _dataset.get_edge_types() ) positive_labels, negative_labels = get_labels_for_anchor_nodes( - _dataset, anchors_for_rank, positive_label_edge_type, negative_label_edge_type + _dataset, anchors, positive_label_edge_type, negative_label_edge_type ) - return anchors_for_rank, positive_labels, negative_labels + return anchors, positive_labels, negative_labels diff --git a/tests/integration/distributed/graph_store/graph_store_integration_test.py b/tests/integration/distributed/graph_store/graph_store_integration_test.py index 23ecc787b..01054d705 100644 --- a/tests/integration/distributed/graph_store/graph_store_integration_test.py +++ b/tests/integration/distributed/graph_store/graph_store_integration_test.py @@ -11,6 +11,7 @@ from gigl.common import Uri from gigl.common.logger import Logger +from gigl.distributed.dist_ablp_neighborloader import DistABLPLoader from gigl.distributed.distributed_neighborloader import DistNeighborLoader from gigl.distributed.graph_store.compute import ( init_compute_process, @@ -31,6 +32,10 @@ CORA_USER_DEFINED_NODE_ANCHOR_MOCKED_DATASET_INFO, DBLP_GRAPH_NODE_ANCHOR_MOCKED_DATASET_INFO, ) +from gigl.types.graph import ( + DEFAULT_HOMOGENEOUS_EDGE_TYPE, + DEFAULT_HOMOGENEOUS_NODE_TYPE, +) from tests.test_assets.distributed.utils import assert_tensor_equality logger = Logger() @@ -59,6 +64,240 @@ def _assert_sampler_input( torch.distributed.barrier() +def _assert_ablp_input( + cluster_info: GraphStoreInfo, + ablp_result: dict[int, tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]], +) -> None: + """Assert ABLP input structure and verify consistency across ranks on same compute node.""" + for i in range(cluster_info.compute_cluster_world_size): + if i == torch.distributed.get_rank(): + logger.info( + f"Verifying ABLP input for rank {i} / {cluster_info.compute_cluster_world_size}" + ) + logger.info(f"--------------------------------") + + # Verify structure: dict mapping server_rank to (anchors, positive_labels, negative_labels) + assert isinstance( + ablp_result, dict + ), f"Expected dict, got {type(ablp_result)}" + assert ( + len(ablp_result) == cluster_info.num_storage_nodes + ), f"Expected {cluster_info.num_storage_nodes} storage nodes in result, got {len(ablp_result)}" + + for server_rank, ( + anchors, + positive_labels, + negative_labels, + ) in ablp_result.items(): + # Verify anchors shape (1D tensor) + assert isinstance( + anchors, torch.Tensor + ), f"Anchors should be a tensor, got {type(anchors)}" + assert anchors.dim() == 1, f"Anchors should be 1D, got {anchors.dim()}D" + assert len(anchors) > 0, "Anchors should not be empty" + + # Verify positive_labels shape (2D tensor: [num_anchors, num_positive_labels]) + assert isinstance( + positive_labels, torch.Tensor + ), f"Positive labels should be a tensor, got {type(positive_labels)}" + assert ( + positive_labels.dim() == 2 + ), f"Positive labels should be 2D, got {positive_labels.dim()}D" + assert positive_labels.shape[0] == len( + anchors + ), f"Positive labels first dim should match anchors length, got {positive_labels.shape[0]} vs {len(anchors)}" + + # Verify negative_labels is None or has correct shape + if negative_labels is not None: + assert isinstance( + negative_labels, torch.Tensor + ), f"Negative labels should be a tensor, got {type(negative_labels)}" + assert ( + negative_labels.dim() == 2 + ), f"Negative labels should be 2D, got {negative_labels.dim()}D" + assert negative_labels.shape[0] == len( + anchors + ), f"Negative labels first dim should match anchors length" + + logger.info( + f"Server rank {server_rank}: anchors shape={anchors.shape}, " + f"positive_labels shape={positive_labels.shape}, " + f"negative_labels shape={negative_labels.shape if negative_labels is not None else None}" + ) + + logger.info( + f"{i} / {cluster_info.compute_cluster_world_size} compute node rank ABLP input verified" + ) + torch.distributed.barrier() + + torch.distributed.barrier() + + # Gather ABLP data from all ranks and verify processes on same compute_node_rank have identical data + local_anchors, local_positive, local_negative = ablp_result[0] + local_data = ( + cluster_info.compute_node_rank, + local_anchors.clone(), + local_positive.clone(), + local_negative.clone() if local_negative is not None else None, + ) + gathered_data: list[tuple[int, torch.Tensor, torch.Tensor, Optional[torch.Tensor]]] = [None] * cluster_info.compute_cluster_world_size # type: ignore[list-item] + torch.distributed.all_gather_object(gathered_data, local_data) + + # Group by compute_node_rank and verify all processes in same group have identical ABLP input + my_compute_node_rank = cluster_info.compute_node_rank + for ( + other_compute_node_rank, + other_anchors, + other_positive, + other_negative, + ) in gathered_data: + if other_compute_node_rank == my_compute_node_rank: + assert_tensor_equality(local_anchors, other_anchors) + assert_tensor_equality(local_positive, other_positive) + if local_negative is not None and other_negative is not None: + assert_tensor_equality(local_negative, other_negative) + else: + assert local_negative is None and other_negative is None, ( + f"Negative labels mismatch: local={local_negative is not None}, " + f"other={other_negative is not None}" + ) + + torch.distributed.barrier() + logger.info( + f"Rank {torch.distributed.get_rank()} verified processes on same compute_node_rank " + f"({my_compute_node_rank}) have identical ABLP input" + ) + + +def _run_compute_train_tests( + client_rank: int, + cluster_info: GraphStoreInfo, + mp_sharing_dict: dict[str, torch.Tensor], + node_type: Optional[NodeType], +) -> None: + """ + Compute test for training mode that verifies ABLP input and DistABLPLoader. + """ + init_compute_process(client_rank, cluster_info, compute_world_backend="gloo") + + remote_dist_dataset = RemoteDistDataset( + cluster_info=cluster_info, + local_rank=client_rank, + mp_sharing_dict=mp_sharing_dict, + ) + + # Use default types for labeled homogeneous graph + test_node_type = ( + node_type if node_type is not None else DEFAULT_HOMOGENEOUS_NODE_TYPE + ) + supervision_edge_type = DEFAULT_HOMOGENEOUS_EDGE_TYPE + + # Test 1: Verify get_ablp_input for train split + ablp_result = remote_dist_dataset.get_ablp_input( + split="train", + rank=cluster_info.compute_node_rank, + world_size=cluster_info.num_compute_nodes, + node_type=test_node_type, + supervision_edge_type=supervision_edge_type, + ) + + _assert_ablp_input(cluster_info, ablp_result) + + # Test 2: Test DistABLPLoader with Graph Store mode + print(f"ablp_result: {ablp_result}") + for rank, (anchors, positive_labels, negative_labels) in ablp_result.items(): + if negative_labels is not None: + print( + f"rank: {rank}, anchors: {anchors.shape}, positive_labels: {positive_labels.shape}, negative_labels: {negative_labels.shape}" + ) + else: + print( + f"rank: {rank}, anchors: {anchors.shape}, positive_labels: {positive_labels.shape}" + ) + torch.distributed.barrier() + + # For labeled homogeneous, pass the dict directly (not as tuple) + input_nodes = ablp_result + + loader = DistABLPLoader( + dataset=remote_dist_dataset, + num_neighbors=[2, 2], + input_nodes=input_nodes, + supervision_edge_type=supervision_edge_type, + pin_memory_device=torch.device("cpu"), + num_workers=2, + worker_concurrency=2, + ) + + count = 0 + for batch in loader: + # Verify batch structure + assert hasattr(batch, "y_positive"), "Batch should have y_positive labels" + # y_positive should be dict mapping local anchor idx -> local label indices + assert isinstance( + batch.y_positive, dict + ), f"y_positive should be dict, got {type(batch.y_positive)}" + count += 1 + + torch.distributed.barrier() + logger.info(f"Rank {torch.distributed.get_rank()} loaded {count} ABLP batches") + + # Verify total count across all ranks + count_tensor = torch.tensor(count, dtype=torch.int64) + torch.distributed.all_reduce(count_tensor, op=torch.distributed.ReduceOp.SUM) + + # Calculate expected total anchors by summing across all compute nodes + # Each process on the same compute node has the same anchor count, so we sum + # across all processes and divide by num_processes_per_compute to get the true total + local_total_anchors = sum( + ablp_result[server_rank][0].shape[0] for server_rank in ablp_result + ) + expected_anchors_tensor = torch.tensor(local_total_anchors, dtype=torch.int64) + torch.distributed.all_reduce( + expected_anchors_tensor, op=torch.distributed.ReduceOp.SUM + ) + expected_batches = ( + expected_anchors_tensor.item() // cluster_info.num_processes_per_compute + ) + assert ( + count_tensor.item() == expected_batches + ), f"Expected {expected_batches} total batches, got {count_tensor.item()}" + + shutdown_compute_proccess() + + +def _client_train_process( + client_rank: int, + cluster_info: GraphStoreInfo, + node_type: Optional[NodeType], +) -> None: + """Client process for training mode that spawns compute train tests.""" + logger.info( + f"Initializing train client node {client_rank} / {cluster_info.num_compute_nodes}. " + f"OS rank: {os.environ['RANK']}, OS world size: {os.environ['WORLD_SIZE']}" + ) + + mp_context = torch.multiprocessing.get_context("spawn") + mp_sharing_dict = torch.multiprocessing.Manager().dict() + client_processes = [] + logger.info("Starting train client processes") + for i in range(cluster_info.num_processes_per_compute): + client_process = mp_context.Process( + target=_run_compute_train_tests, + args=[ + i, # client_rank + cluster_info, # cluster_info + mp_sharing_dict, # mp_sharing_dict + node_type, # node_type + ], + ) + client_processes.append(client_process) + for client_process in client_processes: + client_process.start() + for client_process in client_processes: + client_process.join() + + def _run_compute_tests( client_rank: int, cluster_info: GraphStoreInfo, @@ -78,6 +317,8 @@ def _run_compute_tests( local_rank=client_rank, mp_sharing_dict=mp_sharing_dict, ) + rank = torch.distributed.get_rank() + world_size = torch.distributed.get_world_size() assert ( remote_dist_dataset.get_edge_dir() == "in" ), f"Edge direction must be 'in' for the test dataset. Got {remote_dist_dataset.get_edge_dir()}" @@ -89,14 +330,14 @@ def _run_compute_tests( ), "Node feature info must not be None for the test dataset" ports = remote_dist_dataset.get_free_ports_on_storage_cluster(num_ports=2) assert len(ports) == 2, "Expected 2 free ports" - if torch.distributed.get_rank() == 0: + if rank == 0: all_ports = [None] * torch.distributed.get_world_size() else: all_ports = None torch.distributed.gather_object(ports, all_ports) logger.info(f"All ports: {all_ports}") - if torch.distributed.get_rank() == 0: + if rank == 0: assert isinstance(all_ports, list) for i, received_ports in enumerate(all_ports): assert ( @@ -106,7 +347,11 @@ def _run_compute_tests( torch.distributed.barrier() logger.info("Verified that all ranks received the same free ports") - sampler_input = remote_dist_dataset.get_node_ids(node_type=node_type) + sampler_input = remote_dist_dataset.get_node_ids( + node_type=node_type, + rank=cluster_info.compute_node_rank, + world_size=cluster_info.num_compute_nodes, + ) _assert_sampler_input(cluster_info, sampler_input, expected_sampler_input) # test "simple" case where we don't have mp sharing dict too @@ -114,7 +359,11 @@ def _run_compute_tests( cluster_info=cluster_info, local_rank=client_rank, mp_sharing_dict=None, - ).get_node_ids(node_type=node_type) + ).get_node_ids( + node_type=node_type, + rank=cluster_info.compute_node_rank, + world_size=cluster_info.num_compute_nodes, + ) _assert_sampler_input(cluster_info, simple_sampler_input, expected_sampler_input) assert ( @@ -257,7 +506,7 @@ def _get_expected_input_nodes_by_rank( class GraphStoreIntegrationTest(unittest.TestCase): - def test_graph_store_homogeneous(self): + def _test_graph_store_homogeneous(self): # Simulating two server machine, two compute machines. # Each machine has one process. cora_supervised_info = get_mocked_dataset_artifact_metadata()[ @@ -352,6 +601,93 @@ def test_graph_store_homogeneous(self): for server_process in server_processes: server_process.join() + def test_homogeneous_training(self): + """Test graph store with training mode (is_inference=False) to verify ABLP input.""" + cora_supervised_info = get_mocked_dataset_artifact_metadata()[ + CORA_USER_DEFINED_NODE_ANCHOR_MOCKED_DATASET_INFO.name + ] + task_config_uri = cora_supervised_info.frozen_gbml_config_uri + ( + cluster_master_port, + storage_cluster_master_port, + compute_cluster_master_port, + master_port, + rpc_master_port, + rpc_wait_port, + ) = get_free_ports(num_ports=6) + host_ip = socket.gethostbyname(socket.gethostname()) + cluster_info = GraphStoreInfo( + num_storage_nodes=2, + num_compute_nodes=2, + num_processes_per_compute=2, + cluster_master_ip=host_ip, + storage_cluster_master_ip=host_ip, + compute_cluster_master_ip=host_ip, + cluster_master_port=cluster_master_port, + storage_cluster_master_port=storage_cluster_master_port, + compute_cluster_master_port=compute_cluster_master_port, + rpc_master_port=rpc_master_port, + rpc_wait_port=rpc_wait_port, + ) + + ctx = mp.get_context("spawn") + client_processes: list = [] + for i in range(cluster_info.num_compute_nodes): + with mock.patch.dict( + os.environ, + { + "MASTER_ADDR": host_ip, + "MASTER_PORT": str(master_port), + "RANK": str(i), + "WORLD_SIZE": str(cluster_info.num_cluster_nodes), + COMPUTE_CLUSTER_LOCAL_WORLD_SIZE_ENV_KEY: str( + cluster_info.num_processes_per_compute + ), + }, + clear=False, + ): + client_process = ctx.Process( + target=_client_train_process, + args=[ + i, # client_rank + cluster_info, # cluster_info + None, # node_type - None for homogeneous dataset + ], + ) + client_process.start() + client_processes.append(client_process) + # Start server process + server_processes = [] + for i in range(cluster_info.num_storage_nodes): + with mock.patch.dict( + os.environ, + { + "MASTER_ADDR": host_ip, + "MASTER_PORT": str(master_port), + "RANK": str(i + cluster_info.num_compute_nodes), + "WORLD_SIZE": str(cluster_info.num_cluster_nodes), + COMPUTE_CLUSTER_LOCAL_WORLD_SIZE_ENV_KEY: str( + cluster_info.num_processes_per_compute + ), + }, + clear=False, + ): + server_process = ctx.Process( + target=_run_server_processes, + args=[ + cluster_info, # cluster_info + task_config_uri, # task_config_uri + False, # is_inference - False for training mode + ], + ) + server_process.start() + server_processes.append(server_process) + + for client_process in client_processes: + client_process.join() + for server_process in server_processes: + server_process.join() + # TODO: (mkolodner-sc) - Figure out why this test is failing on Google Cloud Build @unittest.skip("Failing on Google Cloud Build - skiping for now") def test_graph_store_heterogeneous(self): diff --git a/tests/test_assets/distributed/test_dataset.py b/tests/test_assets/distributed/test_dataset.py new file mode 100644 index 000000000..a138eb374 --- /dev/null +++ b/tests/test_assets/distributed/test_dataset.py @@ -0,0 +1,425 @@ +"""Factory functions for creating test DistDataset instances. + +This module provides utility functions to create DistDataset instances for unit testing. +The functions support both homogeneous and heterogeneous graphs with configurable features, +edge indices, and label splits. + +Example usage: + from tests.test_assets.distributed.test_dataset import ( + create_homogeneous_dataset, + create_heterogeneous_dataset, + create_heterogeneous_dataset_with_labels, + DEFAULT_HOMOGENEOUS_EDGE_INDEX, + DEFAULT_HETEROGENEOUS_EDGE_INDICES, + ) + + # Create a simple homogeneous dataset with default edge index + dataset = create_homogeneous_dataset(edge_index=DEFAULT_HOMOGENEOUS_EDGE_INDEX) + + # Create with custom edge index + custom_edge_index = torch.tensor([[0, 1, 2], [1, 2, 0]]) + dataset = create_homogeneous_dataset(edge_index=custom_edge_index) +""" + +from typing import Final, Literal, Optional + +import torch + +from gigl.distributed.dist_dataset import DistDataset +from gigl.src.common.types.graph_data import EdgeType, NodeType, Relation +from gigl.types.graph import ( + FeaturePartitionData, + GraphPartitionData, + PartitionOutput, + message_passing_to_negative_label, + message_passing_to_positive_label, +) +from gigl.utils.data_splitters import DistNodeAnchorLinkSplitter + +# ============================================================================= +# Default Node and Edge Types +# ============================================================================= + +USER: Final[NodeType] = NodeType("user") +STORY: Final[NodeType] = NodeType("story") +USER_TO_STORY: Final[EdgeType] = EdgeType(USER, Relation("to"), STORY) +STORY_TO_USER: Final[EdgeType] = EdgeType(STORY, Relation("to"), USER) + +# ============================================================================= +# Default Edge Indices +# ============================================================================= + +# Homogeneous: 10-node ring graph where node i connects to node (i+1) % 10 +DEFAULT_HOMOGENEOUS_EDGE_INDEX: Final[torch.Tensor] = torch.tensor( + [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [1, 2, 3, 4, 5, 6, 7, 8, 9, 0]] +) + +# Heterogeneous: 5 users, 5 stories with identity mapping (user i <-> story i) +DEFAULT_HETEROGENEOUS_EDGE_INDICES: Final[dict[EdgeType, torch.Tensor]] = { + USER_TO_STORY: torch.tensor([[0, 1, 2, 3, 4], [0, 1, 2, 3, 4]]), + STORY_TO_USER: torch.tensor([[0, 1, 2, 3, 4], [0, 1, 2, 3, 4]]), +} + +# ============================================================================= +# Default Feature Dimensions +# ============================================================================= + +DEFAULT_HOMOGENEOUS_NODE_FEATURE_DIM: Final[int] = 3 +DEFAULT_HETEROGENEOUS_NODE_FEATURE_DIM: Final[int] = 2 + + +def create_homogeneous_dataset( + edge_index: torch.Tensor, + node_features: Optional[torch.Tensor] = None, + node_feature_dim: int = DEFAULT_HOMOGENEOUS_NODE_FEATURE_DIM, + rank: int = 0, + world_size: int = 1, + edge_dir: Literal["in", "out"] = "out", +) -> DistDataset: + """Create a homogeneous test dataset. + + Creates a single-partition DistDataset with the specified edge index and node features. + + Args: + edge_index: COO format edge index [2, num_edges]. + node_features: Node feature tensor [num_nodes, feature_dim]. If None, creates zero + features with the specified dimension. + node_feature_dim: Dimension of node features if node_features is None. + rank: Rank of the current process. Defaults to 0. + world_size: Total number of processes. Defaults to 1. + edge_dir: Edge direction ("in" or "out"). Defaults to "out". + + Returns: + A DistDataset instance with the specified configuration. + + Example: + >>> dataset = create_homogeneous_dataset(edge_index=DEFAULT_HOMOGENEOUS_EDGE_INDEX) + >>> dataset.node_ids.shape + torch.Size([10]) + + >>> custom_edge_index = torch.tensor([[0, 1], [1, 0]]) + >>> dataset = create_homogeneous_dataset(edge_index=custom_edge_index) + >>> dataset.node_ids.shape + torch.Size([2]) + """ + + # Derive counts from edge index + num_nodes = int(edge_index.max().item() + 1) + num_edges = int(edge_index.shape[1]) + + # Create default features if not provided + if node_features is None: + node_features = torch.zeros(num_nodes, node_feature_dim) + + partition_output = PartitionOutput( + # Partition books filled with zeros assign all nodes/edges to partition 0 + node_partition_book=torch.zeros(num_nodes, dtype=torch.int64), + edge_partition_book=torch.zeros(num_edges, dtype=torch.int64), + partitioned_edge_index=GraphPartitionData( + edge_index=edge_index, + edge_ids=None, + ), + partitioned_node_features=FeaturePartitionData( + feats=node_features, ids=torch.arange(num_nodes) + ), + partitioned_edge_features=None, + partitioned_positive_labels=None, + partitioned_negative_labels=None, + partitioned_node_labels=None, + ) + dataset = DistDataset(rank=rank, world_size=world_size, edge_dir=edge_dir) + dataset.build(partition_output=partition_output) + return dataset + + +def create_heterogeneous_dataset( + edge_indices: dict[EdgeType, torch.Tensor], + node_features: Optional[dict[NodeType, torch.Tensor]] = None, + node_labels: Optional[dict[NodeType, torch.Tensor]] = None, + node_feature_dim: int = DEFAULT_HETEROGENEOUS_NODE_FEATURE_DIM, + rank: int = 0, + world_size: int = 1, + edge_dir: Literal["in", "out"] = "out", +) -> DistDataset: + """Create a heterogeneous test dataset. + + Creates a single-partition DistDataset with the specified edge indices and node features. + + Args: + edge_indices: Mapping of EdgeType -> COO format edge index [2, num_edges]. + node_features: Mapping of NodeType -> feature tensor [num_nodes, feature_dim]. + If None, creates zero features with the specified dimension. + node_labels: Mapping of NodeType -> label tensor [num_nodes, 1]. + If None, creates labels equal to node indices. + node_feature_dim: Dimension of node features if node_features is None. + rank: Rank of the current process. Defaults to 0. + world_size: Total number of processes. Defaults to 1. + edge_dir: Edge direction ("in" or "out"). Defaults to "out". + + Returns: + A DistDataset instance with the specified configuration. + + Example: + >>> dataset = create_heterogeneous_dataset(edge_indices=DEFAULT_HETEROGENEOUS_EDGE_INDICES) + >>> dataset.node_ids[USER].shape + torch.Size([5]) + + >>> custom_edges = {USER_TO_STORY: torch.tensor([[0, 1], [0, 1]])} + >>> dataset = create_heterogeneous_dataset(edge_indices=custom_edges) + """ + + # Derive node counts from edge indices by collecting max node ID per node type + node_counts: dict[NodeType, int] = {} + for edge_type, edge_index in edge_indices.items(): + src_type, _, dst_type = edge_type + src_max = edge_index[0].max().item() + 1 + dst_max = edge_index[1].max().item() + 1 + node_counts[src_type] = int(max(node_counts.get(src_type, 0), src_max)) + node_counts[dst_type] = int(max(node_counts.get(dst_type, 0), dst_max)) + + # Partition books filled with zeros assign all nodes/edges to partition 0 + node_partition_book = { + node_type: torch.zeros(count, dtype=torch.int64) + for node_type, count in node_counts.items() + } + edge_partition_book = { + edge_type: torch.zeros(edge_index.shape[1], dtype=torch.int64) + for edge_type, edge_index in edge_indices.items() + } + partitioned_edge_index = { + edge_type: GraphPartitionData(edge_index=edge_index, edge_ids=None) + for edge_type, edge_index in edge_indices.items() + } + + # Create default features if not provided + if node_features is None: + node_features = { + node_type: torch.zeros(count, node_feature_dim) + for node_type, count in node_counts.items() + } + + partitioned_node_features = { + node_type: FeaturePartitionData(feats=feats, ids=torch.arange(feats.shape[0])) + for node_type, feats in node_features.items() + } + + # Create default labels if not provided + if node_labels is None: + node_labels = { + node_type: torch.arange(count).unsqueeze(1) + for node_type, count in node_counts.items() + } + + partitioned_node_labels = { + node_type: FeaturePartitionData(feats=labels, ids=torch.arange(labels.shape[0])) + for node_type, labels in node_labels.items() + } + + partition_output = PartitionOutput( + node_partition_book=node_partition_book, + edge_partition_book=edge_partition_book, + partitioned_edge_index=partitioned_edge_index, + partitioned_node_features=partitioned_node_features, + partitioned_edge_features=None, + partitioned_positive_labels=None, + partitioned_negative_labels=None, + partitioned_node_labels=partitioned_node_labels, + ) + dataset = DistDataset(rank=rank, world_size=world_size, edge_dir=edge_dir) + dataset.build(partition_output=partition_output) + return dataset + + +def create_heterogeneous_dataset_with_labels( + positive_labels: dict[int, list[int]], + train_node_ids: list[int], + val_node_ids: list[int], + test_node_ids: list[int], + edge_indices: dict[EdgeType, torch.Tensor], + negative_labels: Optional[dict[int, list[int]]] = None, + node_features: Optional[dict[NodeType, torch.Tensor]] = None, + node_feature_dim: int = DEFAULT_HETEROGENEOUS_NODE_FEATURE_DIM, + src_node_type: NodeType = USER, + dst_node_type: NodeType = STORY, + supervision_edge_type: Optional[EdgeType] = None, + rank: int = 0, + world_size: int = 1, + edge_dir: Literal["in", "out"] = "out", +) -> DistDataset: + """Create a heterogeneous test dataset with label edges and train/val/test splits. + + Creates a dataset with: + - Source and destination nodes (default: USER and STORY) + - Message passing edges from edge_indices + - Positive label edges: src_node_type -[to_gigl_positive]-> dst_node_type + - Negative label edges (optional): src_node_type -[to_gigl_negative]-> dst_node_type + - Train/val/test splits for source nodes + + The splits are achieved using DistNodeAnchorLinkSplitter with an identity-like hash + function (hash(x) = x + 1). This produces deterministic splits where: + - Nodes with lower IDs go to train + - Nodes with middle IDs go to val + - Nodes with higher IDs go to test + + Args: + positive_labels: Mapping of src_node_id -> list of positive dst_node_ids. + train_node_ids: List of source node IDs in the train split (must be the lowest IDs). + val_node_ids: List of source node IDs in the val split (must be middle IDs). + test_node_ids: List of source node IDs in the test split (must be the highest IDs). + edge_indices: Mapping of EdgeType -> COO format edge index [2, num_edges]. + negative_labels: Mapping of src_node_id -> list of negative dst_node_ids, or None. + node_features: Mapping of NodeType -> feature tensor [num_nodes, feature_dim]. + node_feature_dim: Dimension of node features if node_features is None. + src_node_type: The source node type for labels. Defaults to USER. + dst_node_type: The destination node type for labels. Defaults to STORY. + supervision_edge_type: The edge type for supervision. If None, defaults to + EdgeType(src_node_type, Relation("to"), dst_node_type). + rank: Rank of the current process. Defaults to 0. + world_size: Total number of processes. Defaults to 1. + edge_dir: Edge direction ("in" or "out"). Defaults to "out". + + Returns: + A DistDataset instance with the specified configuration and splits. + + Raises: + ValueError: If any node ID in train/val/test is not in positive_labels. + + Example: + >>> positive_labels = {0: [0, 1], 1: [1, 2], 2: [2, 3]} + >>> dataset = create_heterogeneous_dataset_with_labels( + ... positive_labels=positive_labels, + ... train_node_ids=[0, 1], + ... val_node_ids=[2], + ... test_node_ids=[], + ... edge_indices=DEFAULT_HETEROGENEOUS_EDGE_INDICES, + ... ) + """ + # Set default supervision edge type + if supervision_edge_type is None: + supervision_edge_type = EdgeType(src_node_type, Relation("to"), dst_node_type) + + # Validate that all split node IDs have positive labels + all_split_node_ids = set(train_node_ids) | set(val_node_ids) | set(test_node_ids) + missing_nodes = all_split_node_ids - set(positive_labels.keys()) + if missing_nodes: + raise ValueError( + f"Node IDs {missing_nodes} are in train/val/test splits but not in positive_labels" + ) + + positive_label_edge_type = message_passing_to_positive_label(supervision_edge_type) + negative_label_edge_type = message_passing_to_negative_label(supervision_edge_type) + + # Convert positive_labels dict to COO edge index + pos_src, pos_dst = [], [] + for node_id, dst_ids in positive_labels.items(): + for dst_id in dst_ids: + pos_src.append(node_id) + pos_dst.append(dst_id) + positive_label_edge_index = torch.tensor([pos_src, pos_dst]) + + # Derive node counts from edge indices by collecting max node ID per node type + node_counts: dict[NodeType, int] = {} + for edge_type, edge_index in edge_indices.items(): + src_type, _, dst_type = edge_type + src_max = edge_index[0].max().item() + 1 + dst_max = edge_index[1].max().item() + 1 + node_counts[src_type] = int(max(node_counts.get(src_type, 0), src_max)) + node_counts[dst_type] = int(max(node_counts.get(dst_type, 0), dst_max)) + + # Also account for nodes in positive labels + node_counts[src_node_type] = max( + node_counts.get(src_node_type, 0), max(positive_labels.keys()) + 1 + ) + node_counts[dst_node_type] = max( + node_counts.get(dst_node_type, 0), + max(max(stories) for stories in positive_labels.values()) + 1, + ) + + # Set up edge partition books and edge indices + edge_partition_book = { + edge_type: torch.zeros(edge_index.shape[1], dtype=torch.int64) + for edge_type, edge_index in edge_indices.items() + } + edge_partition_book[positive_label_edge_type] = torch.zeros( + len(pos_src), dtype=torch.int64 + ) + + partitioned_edge_index = { + edge_type: GraphPartitionData(edge_index=edge_index, edge_ids=None) + for edge_type, edge_index in edge_indices.items() + } + partitioned_edge_index[positive_label_edge_type] = GraphPartitionData( + edge_index=positive_label_edge_index, + edge_ids=None, + ) + + if negative_labels is not None: + # Convert negative_labels dict to COO edge index + neg_src, neg_dst = [], [] + for node_id, dst_ids in negative_labels.items(): + for dst_id in dst_ids: + neg_src.append(node_id) + neg_dst.append(dst_id) + negative_label_edge_index = torch.tensor([neg_src, neg_dst]) + edge_partition_book[negative_label_edge_type] = torch.zeros( + len(neg_src), dtype=torch.int64 + ) + partitioned_edge_index[negative_label_edge_type] = GraphPartitionData( + edge_index=negative_label_edge_index, + edge_ids=None, + ) + + # Partition books filled with zeros assign all nodes to partition 0 + node_partition_book = { + node_type: torch.zeros(count, dtype=torch.int64) + for node_type, count in node_counts.items() + } + + # Create default features if not provided + if node_features is None: + node_features = { + node_type: torch.zeros(count, node_feature_dim) + for node_type, count in node_counts.items() + } + + partitioned_node_features = { + node_type: FeaturePartitionData(feats=feats, ids=torch.arange(feats.shape[0])) + for node_type, feats in node_features.items() + } + + partition_output = PartitionOutput( + node_partition_book=node_partition_book, + edge_partition_book=edge_partition_book, + partitioned_edge_index=partitioned_edge_index, + partitioned_node_features=partitioned_node_features, + partitioned_edge_features=None, + partitioned_positive_labels=None, + partitioned_negative_labels=None, + partitioned_node_labels=None, + ) + + # Calculate split ratios based on provided node IDs. + # With identity hash (x + 1), nodes are split by their ID values: + # - Lower IDs -> train, middle IDs -> val, higher IDs -> test + total_nodes = len(positive_labels) + num_val = len(val_node_ids) / total_nodes + num_test = len(test_node_ids) / total_nodes + + # Identity-like hash function for deterministic splits based on node ID ordering. + # Adding 1 ensures hash(0) != 0 and creates proper normalization boundaries. + def _identity_hash(x: torch.Tensor) -> torch.Tensor: + return x.clone().to(torch.int64) + 1 + + # Create splitter that will produce splits based on node ID ordering + splitter = DistNodeAnchorLinkSplitter( + sampling_direction=edge_dir, + num_val=num_val, + num_test=num_test, + hash_function=_identity_hash, + supervision_edge_types=[supervision_edge_type], + should_convert_labels_to_edges=True, + ) + + dataset = DistDataset(rank=rank, world_size=world_size, edge_dir=edge_dir) + dataset.build(partition_output=partition_output, splitter=splitter) + return dataset diff --git a/tests/unit/distributed/graph_store/remote_dist_dataset_test.py b/tests/unit/distributed/graph_store/remote_dist_dataset_test.py new file mode 100644 index 000000000..fb4bf0b21 --- /dev/null +++ b/tests/unit/distributed/graph_store/remote_dist_dataset_test.py @@ -0,0 +1,491 @@ +import unittest +from unittest.mock import patch + +import torch +import torch.distributed as dist +import torch.multiprocessing as mp + +from gigl.distributed.graph_store import storage_utils +from gigl.distributed.graph_store.remote_dist_dataset import RemoteDistDataset +from gigl.env.distributed import GraphStoreInfo +from gigl.types.graph import FeatureInfo +from tests.test_assets.distributed.test_dataset import ( + DEFAULT_HETEROGENEOUS_EDGE_INDICES, + DEFAULT_HOMOGENEOUS_EDGE_INDEX, + STORY, + STORY_TO_USER, + USER, + USER_TO_STORY, + create_heterogeneous_dataset, + create_heterogeneous_dataset_with_labels, + create_homogeneous_dataset, +) +from tests.test_assets.distributed.utils import ( + MockGraphStoreInfo, + assert_tensor_equality, + create_test_process_group, + get_process_group_init_method, +) + + +def _mock_request_server(server_rank, func, *args, **kwargs): + """Mock request_server that directly calls the function.""" + return func(*args, **kwargs) + + +def _mock_async_request_server(server_rank, func, *args, **kwargs): + """Mock async_request_server that returns a completed future with the function result.""" + future: torch.futures.Future = torch.futures.Future() + future.set_result(func(*args, **kwargs)) + return future + + +def _create_mock_graph_store_info( + num_storage_nodes: int = 1, + num_compute_nodes: int = 1, + compute_node_rank: int = 0, + num_processes_per_compute: int = 1, +) -> GraphStoreInfo: + """Create a mock GraphStoreInfo with placeholder values.""" + real_info = GraphStoreInfo( + num_storage_nodes=num_storage_nodes, + num_compute_nodes=num_compute_nodes, + cluster_master_ip="127.0.0.1", + storage_cluster_master_ip="127.0.0.1", + compute_cluster_master_ip="127.0.0.1", + cluster_master_port=12345, + storage_cluster_master_port=12346, + compute_cluster_master_port=12347, + num_processes_per_compute=num_processes_per_compute, + rpc_master_port=12348, + rpc_wait_port=12349, + ) + return MockGraphStoreInfo(real_info, compute_node_rank) + + +class TestRemoteDistDataset(unittest.TestCase): + def setUp(self) -> None: + storage_utils._dataset = None + storage_utils.register_dataset( + create_homogeneous_dataset(edge_index=DEFAULT_HOMOGENEOUS_EDGE_INDEX) + ) + + def tearDown(self) -> None: + storage_utils._dataset = None + if dist.is_initialized(): + dist.destroy_process_group() + + @patch( + "gigl.distributed.graph_store.remote_dist_dataset.request_server", + side_effect=_mock_request_server, + ) + def test_graph_metadata_getters_homogeneous(self, mock_request): + """Test get_node_feature_info, get_edge_feature_info, get_edge_dir, get_edge_types for homogeneous graphs.""" + cluster_info = _create_mock_graph_store_info() + remote_dataset = RemoteDistDataset(cluster_info=cluster_info, local_rank=0) + + self.assertEqual( + remote_dataset.get_node_feature_info(), + FeatureInfo(dim=3, dtype=torch.float32), + ) + self.assertIsNone(remote_dataset.get_edge_feature_info()) + self.assertEqual(remote_dataset.get_edge_dir(), "out") + self.assertIsNone(remote_dataset.get_edge_types()) + + def test_init_rejects_non_dict_proxy_for_mp_sharing_dict(self): + cluster_info = _create_mock_graph_store_info() + + with self.assertRaises(ValueError): + RemoteDistDataset( + cluster_info=cluster_info, + local_rank=0, + mp_sharing_dict=dict(), # Regular dict should fail + ) + + def test_cluster_info_property(self): + cluster_info = _create_mock_graph_store_info( + num_storage_nodes=3, num_compute_nodes=2 + ) + remote_dataset = RemoteDistDataset(cluster_info=cluster_info, local_rank=0) + + self.assertEqual(remote_dataset.cluster_info.num_storage_nodes, 3) + self.assertEqual(remote_dataset.cluster_info.num_compute_nodes, 2) + + @patch( + "gigl.distributed.graph_store.remote_dist_dataset.async_request_server", + side_effect=_mock_async_request_server, + ) + def test_get_node_ids(self, mock_async_request): + """Test get_node_ids returns node ids, with optional sharding via rank/world_size.""" + cluster_info = _create_mock_graph_store_info(num_storage_nodes=1) + remote_dataset = RemoteDistDataset(cluster_info=cluster_info, local_rank=0) + + # Basic: all nodes + result = remote_dataset.get_node_ids() + self.assertIn(0, result) + assert_tensor_equality(result[0], torch.arange(10)) + + # With sharding: first half (rank 0 of 2) + result = remote_dataset.get_node_ids(rank=0, world_size=2) + assert_tensor_equality(result[0], torch.arange(5)) + + # With sharding: second half (rank 1 of 2) + result = remote_dataset.get_node_ids(rank=1, world_size=2) + assert_tensor_equality(result[0], torch.arange(5, 10)) + + +class TestRemoteDistDatasetHeterogeneous(unittest.TestCase): + def setUp(self) -> None: + storage_utils._dataset = None + storage_utils.register_dataset( + create_heterogeneous_dataset( + edge_indices=DEFAULT_HETEROGENEOUS_EDGE_INDICES + ) + ) + + def tearDown(self) -> None: + storage_utils._dataset = None + if dist.is_initialized(): + dist.destroy_process_group() + + @patch( + "gigl.distributed.graph_store.remote_dist_dataset.request_server", + side_effect=_mock_request_server, + ) + def test_graph_metadata_getters_heterogeneous(self, mock_request): + """Test get_node_feature_info, get_edge_dir, get_edge_types for heterogeneous graphs.""" + cluster_info = _create_mock_graph_store_info() + remote_dataset = RemoteDistDataset(cluster_info=cluster_info, local_rank=0) + + self.assertEqual( + remote_dataset.get_node_feature_info(), + { + USER: FeatureInfo(dim=2, dtype=torch.float32), + STORY: FeatureInfo(dim=2, dtype=torch.float32), + }, + ) + self.assertEqual(remote_dataset.get_edge_dir(), "out") + self.assertEqual( + remote_dataset.get_edge_types(), [USER_TO_STORY, STORY_TO_USER] + ) + + @patch( + "gigl.distributed.graph_store.remote_dist_dataset.async_request_server", + side_effect=_mock_async_request_server, + ) + def test_get_node_ids_with_node_type(self, mock_async_request): + """Test get_node_ids with node_type for heterogeneous graphs, with optional sharding.""" + cluster_info = _create_mock_graph_store_info(num_storage_nodes=1) + remote_dataset = RemoteDistDataset(cluster_info=cluster_info, local_rank=0) + + # Get user nodes + result = remote_dataset.get_node_ids(node_type=USER) + assert_tensor_equality(result[0], torch.arange(5)) + + # Get story nodes + result = remote_dataset.get_node_ids(node_type=STORY) + assert_tensor_equality(result[0], torch.arange(5)) + + # With sharding: first half of user nodes (rank 0 of 2) + result = remote_dataset.get_node_ids(rank=0, world_size=2, node_type=USER) + assert_tensor_equality(result[0], torch.arange(2)) + + # With sharding: second half of user nodes (rank 1 of 2) + result = remote_dataset.get_node_ids(rank=1, world_size=2, node_type=USER) + assert_tensor_equality(result[0], torch.arange(2, 5)) + + +class TestRemoteDistDatasetWithSplits(unittest.TestCase): + """Tests for get_node_ids with train/val/test splits.""" + + def setUp(self) -> None: + storage_utils._dataset = None + + def tearDown(self) -> None: + storage_utils._dataset = None + if dist.is_initialized(): + dist.destroy_process_group() + + def _create_and_register_dataset_with_splits(self) -> None: + """Create and register a dataset with train/val/test splits.""" + create_test_process_group() + + positive_labels = { + 0: [0, 1], + 1: [1, 2], + 2: [2, 3], + 3: [3, 4], + 4: [4, 0], + } + negative_labels = { + 0: [2], + 1: [3], + 2: [4], + 3: [0], + 4: [1], + } + + dataset = create_heterogeneous_dataset_with_labels( + positive_labels=positive_labels, + negative_labels=negative_labels, + train_node_ids=[0, 1, 2], + val_node_ids=[3], + test_node_ids=[4], + edge_indices=DEFAULT_HETEROGENEOUS_EDGE_INDICES, + ) + storage_utils.register_dataset(dataset) + + @patch( + "gigl.distributed.graph_store.remote_dist_dataset.async_request_server", + side_effect=_mock_async_request_server, + ) + def test_get_node_ids_with_splits(self, mock_async_request): + """Test get_node_ids with train/val/test splits and optional sharding.""" + self._create_and_register_dataset_with_splits() + + cluster_info = _create_mock_graph_store_info(num_storage_nodes=1) + remote_dataset = RemoteDistDataset(cluster_info=cluster_info, local_rank=0) + + # Test each split returns correct nodes + assert_tensor_equality( + remote_dataset.get_node_ids(node_type=USER, split="train")[0], + torch.tensor([0, 1, 2]), + ) + assert_tensor_equality( + remote_dataset.get_node_ids(node_type=USER, split="val")[0], + torch.tensor([3]), + ) + assert_tensor_equality( + remote_dataset.get_node_ids(node_type=USER, split="test")[0], + torch.tensor([4]), + ) + + # No split returns all nodes + assert_tensor_equality( + remote_dataset.get_node_ids(node_type=USER, split=None)[0], + torch.arange(5), + ) + + # With sharding: train split [0, 1, 2] across 2 ranks + assert_tensor_equality( + remote_dataset.get_node_ids( + rank=0, world_size=2, node_type=USER, split="train" + )[0], + torch.tensor([0]), + ) + assert_tensor_equality( + remote_dataset.get_node_ids( + rank=1, world_size=2, node_type=USER, split="train" + )[0], + torch.tensor([1, 2]), + ) + + @patch( + "gigl.distributed.graph_store.remote_dist_dataset.async_request_server", + side_effect=_mock_async_request_server, + ) + def test_get_ablp_input(self, mock_async_request): + """Test get_ablp_input with train/val/test splits.""" + self._create_and_register_dataset_with_splits() + + cluster_info = _create_mock_graph_store_info(num_storage_nodes=1) + remote_dataset = RemoteDistDataset(cluster_info=cluster_info, local_rank=0) + + # Train split: nodes [0, 1, 2] + result = remote_dataset.get_ablp_input( + split="train", node_type=USER, supervision_edge_type=USER_TO_STORY + ) + self.assertIn(0, result) + anchors, positive_labels, negative_labels = result[0] + assert_tensor_equality(anchors, torch.tensor([0, 1, 2])) + assert_tensor_equality(positive_labels, torch.tensor([[0, 1], [1, 2], [2, 3]])) + assert negative_labels is not None + assert_tensor_equality(negative_labels, torch.tensor([[2], [3], [4]])) + + # Val split: node [3] + result = remote_dataset.get_ablp_input( + split="val", node_type=USER, supervision_edge_type=USER_TO_STORY + ) + anchors, positive_labels, negative_labels = result[0] + assert_tensor_equality(anchors, torch.tensor([3])) + assert_tensor_equality(positive_labels, torch.tensor([[3, 4]])) + assert negative_labels is not None + assert_tensor_equality(negative_labels, torch.tensor([[0]])) + + # Test split: node [4] + # Note: Labels are stored in CSR format which sorts by destination indices, + # so [4, 0] from the input becomes [0, 4] in the stored format. + result = remote_dataset.get_ablp_input( + split="test", node_type=USER, supervision_edge_type=USER_TO_STORY + ) + anchors, positive_labels, negative_labels = result[0] + assert_tensor_equality(anchors, torch.tensor([4])) + assert_tensor_equality(positive_labels, torch.tensor([[0, 4]])) + assert negative_labels is not None + assert_tensor_equality(negative_labels, torch.tensor([[1]])) + + @patch( + "gigl.distributed.graph_store.remote_dist_dataset.async_request_server", + side_effect=_mock_async_request_server, + ) + def test_get_ablp_input_with_sharding(self, mock_async_request): + """Test get_ablp_input with sharding across compute nodes.""" + self._create_and_register_dataset_with_splits() + + cluster_info = _create_mock_graph_store_info(num_storage_nodes=1) + remote_dataset = RemoteDistDataset(cluster_info=cluster_info, local_rank=0) + + # With sharding: train split [0, 1, 2] across 2 ranks + result_rank0 = remote_dataset.get_ablp_input( + split="train", + rank=0, + world_size=2, + node_type=USER, + supervision_edge_type=USER_TO_STORY, + ) + anchors_0, positive_labels_0, negative_labels_0 = result_rank0[0] + + # Rank 0 should get node 0 + assert_tensor_equality(anchors_0, torch.tensor([0])) + assert_tensor_equality(positive_labels_0, torch.tensor([[0, 1]])) + assert negative_labels_0 is not None + assert_tensor_equality(negative_labels_0, torch.tensor([[2]])) + + result_rank1 = remote_dataset.get_ablp_input( + split="train", + rank=1, + world_size=2, + node_type=USER, + supervision_edge_type=USER_TO_STORY, + ) + anchors_1, positive_labels_1, negative_labels_1 = result_rank1[0] + + # Rank 1 should get nodes 1, 2 + assert_tensor_equality(anchors_1, torch.tensor([1, 2])) + assert_tensor_equality(positive_labels_1, torch.tensor([[1, 2], [2, 3]])) + assert negative_labels_1 is not None + assert_tensor_equality(negative_labels_1, torch.tensor([[3], [4]])) + + @patch( + "gigl.distributed.graph_store.remote_dist_dataset.async_request_server", + side_effect=_mock_async_request_server, + ) + def test_get_ablp_input_with_mp_sharing_dict(self, mock_async_request): + """Test get_ablp_input with mp_sharing_dict for shared memory across processes.""" + self._create_and_register_dataset_with_splits() + + cluster_info = _create_mock_graph_store_info(num_storage_nodes=1) + mp_sharing_dict = mp.Manager().dict() + remote_dataset = RemoteDistDataset( + cluster_info=cluster_info, local_rank=0, mp_sharing_dict=mp_sharing_dict + ) + + # First call - should fetch and store in shared dict + result = remote_dataset.get_ablp_input( + split="train", node_type=USER, supervision_edge_type=USER_TO_STORY + ) + anchors, positive_labels, negative_labels = result[0] + + # Verify results are correct + assert_tensor_equality(anchors, torch.tensor([0, 1, 2])) + assert_tensor_equality(positive_labels, torch.tensor([[0, 1], [1, 2], [2, 3]])) + assert negative_labels is not None + assert_tensor_equality(negative_labels, torch.tensor([[2], [3], [4]])) + + # Verify data was stored in shared dict + self.assertIn("ablp_server_0_anchors", mp_sharing_dict) + self.assertIn("ablp_server_0_positive_labels", mp_sharing_dict) + self.assertIn("ablp_server_0_negative_labels", mp_sharing_dict) + + # Verify stored tensors match the returned tensors + assert_tensor_equality( + mp_sharing_dict["ablp_server_0_anchors"], torch.tensor([0, 1, 2]) + ) + assert_tensor_equality( + mp_sharing_dict["ablp_server_0_positive_labels"], + torch.tensor([[0, 1], [1, 2], [2, 3]]), + ) + assert_tensor_equality( + mp_sharing_dict["ablp_server_0_negative_labels"], + torch.tensor([[2], [3], [4]]), + ) + + +def _test_get_free_ports_on_storage_cluster( + rank: int, + world_size: int, + init_process_group_init_method: str, + num_ports: int, + mock_ports: list[int], +): + """Test function to run in spawned processes.""" + dist.init_process_group( + backend="gloo", + init_method=init_process_group_init_method, + world_size=world_size, + rank=rank, + ) + try: + cluster_info = _create_mock_graph_store_info( + num_compute_nodes=world_size, + num_processes_per_compute=1, + compute_node_rank=rank, + ) + remote_dataset = RemoteDistDataset(cluster_info=cluster_info, local_rank=0) + + with patch( + "gigl.distributed.graph_store.remote_dist_dataset.request_server", + return_value=mock_ports, + ): + ports = remote_dataset.get_free_ports_on_storage_cluster(num_ports) + + assert len(ports) == num_ports, f"Expected {num_ports} ports, got {len(ports)}" + + # Verify all ranks get the same ports via all_gather + gathered_ports = [None] * world_size + dist.all_gather_object(gathered_ports, ports) + + for i, rank_ports in enumerate(gathered_ports): + assert ( + rank_ports == mock_ports + ), f"Rank {i} got {rank_ports}, expected {mock_ports}" + finally: + dist.destroy_process_group() + + +class TestGetFreePortsOnStorageCluster(unittest.TestCase): + def setUp(self) -> None: + storage_utils._dataset = None + storage_utils.register_dataset( + create_homogeneous_dataset(edge_index=DEFAULT_HOMOGENEOUS_EDGE_INDEX) + ) + + def tearDown(self) -> None: + storage_utils._dataset = None + if dist.is_initialized(): + dist.destroy_process_group() + + def test_get_free_ports_on_storage_cluster_distributed(self): + """Test that free ports are correctly broadcast across all ranks.""" + init_method = get_process_group_init_method() + world_size = 2 + num_ports = 3 + mock_ports = [10000, 10001, 10002] + + mp.spawn( + fn=_test_get_free_ports_on_storage_cluster, + args=(world_size, init_method, num_ports, mock_ports), + nprocs=world_size, + ) + + def test_get_free_ports_fails_without_process_group(self): + """Test that get_free_ports_on_storage_cluster raises when dist not initialized.""" + cluster_info = _create_mock_graph_store_info() + remote_dataset = RemoteDistDataset(cluster_info=cluster_info, local_rank=0) + + with self.assertRaises(ValueError): + remote_dataset.get_free_ports_on_storage_cluster(num_ports=1) + + +if __name__ == "__main__": + unittest.main() diff --git a/tests/unit/distributed/graph_store/storage_utils_test.py b/tests/unit/distributed/graph_store/storage_utils_test.py index bc34d6c19..9c7df462a 100644 --- a/tests/unit/distributed/graph_store/storage_utils_test.py +++ b/tests/unit/distributed/graph_store/storage_utils_test.py @@ -1,41 +1,25 @@ import unittest -from typing import Final, Optional import torch -from gigl.distributed.dist_dataset import DistDataset from gigl.distributed.graph_store import storage_utils -from gigl.src.common.types.graph_data import EdgeType, NodeType, Relation -from gigl.types.graph import ( - FeatureInfo, - FeaturePartitionData, - GraphPartitionData, - PartitionOutput, - message_passing_to_negative_label, - message_passing_to_positive_label, +from gigl.src.common.types.graph_data import Relation +from gigl.types.graph import FeatureInfo +from tests.test_assets.distributed.test_dataset import ( + DEFAULT_HETEROGENEOUS_EDGE_INDICES, + DEFAULT_HOMOGENEOUS_EDGE_INDEX, + STORY, + USER, + USER_TO_STORY, + create_heterogeneous_dataset, + create_heterogeneous_dataset_with_labels, + create_homogeneous_dataset, ) -from gigl.utils.data_splitters import DistNodeAnchorLinkSplitter from tests.test_assets.distributed.utils import ( assert_tensor_equality, create_test_process_group, ) -_USER = NodeType("user") -_STORY = NodeType("story") -_USER_TO_STORY = EdgeType(_USER, Relation("to"), _STORY) -_STORY_TO_USER = EdgeType(_STORY, Relation("to"), _USER) - -# Default edge indices for test graphs (COO format: [2, num_edges]) -# Homogeneous: 10-node ring graph where node i connects to node (i+1) % 10 -_DEFAULT_HOMOGENEOUS_EDGE_INDEX: Final[torch.Tensor] = torch.tensor( - [[0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [1, 2, 3, 4, 5, 6, 7, 8, 9, 0]] -) -# Heterogeneous: 5 users, 5 stories with identity mapping (user i <-> story i) -_DEFAULT_HETEROGENEOUS_EDGE_INDICES: Final[dict[EdgeType, torch.Tensor]] = { - _USER_TO_STORY: torch.tensor([[0, 1, 2, 3, 4], [0, 1, 2, 3, 4]]), - _STORY_TO_USER: torch.tensor([[0, 1, 2, 3, 4], [0, 1, 2, 3, 4]]), -} - class TestRemoteDataset(unittest.TestCase): def setUp(self) -> None: @@ -48,260 +32,10 @@ def tearDown(self) -> None: if torch.distributed.is_initialized(): torch.distributed.destroy_process_group() - def _create_heterogeneous_dataset( - self, - edge_indices: dict[EdgeType, torch.Tensor], - ) -> DistDataset: - """Helper method to create a heterogeneous test dataset. - - Args: - edge_indices: Mapping of EdgeType -> COO format edge index [2, num_edges]. - """ - # Derive node counts from edge indices by collecting max node ID per node type - node_counts: dict[NodeType, int] = {} - for edge_type, edge_index in edge_indices.items(): - src_type, _, dst_type = edge_type - src_max = edge_index[0].max().item() + 1 - dst_max = edge_index[1].max().item() + 1 - node_counts[src_type] = int(max(node_counts.get(src_type, 0), src_max)) - node_counts[dst_type] = int(max(node_counts.get(dst_type, 0), dst_max)) - - # Partition books filled with zeros assign all nodes/edges to partition 0 (rank 0 - we only have 1 rank in the test) - node_partition_book = { - node_type: torch.zeros(count, dtype=torch.int64) - for node_type, count in node_counts.items() - } - edge_partition_book = { - edge_type: torch.zeros(edge_index.shape[1], dtype=torch.int64) - for edge_type, edge_index in edge_indices.items() - } - partitioned_edge_index = { - edge_type: GraphPartitionData(edge_index=edge_index, edge_ids=None) - for edge_type, edge_index in edge_indices.items() - } - partitioned_node_features = { - node_type: FeaturePartitionData( - feats=torch.zeros(count, 2), ids=torch.arange(count) - ) - for node_type, count in node_counts.items() - } - partitioned_node_labels = { - node_type: FeaturePartitionData( - feats=torch.arange(count).unsqueeze(1), ids=torch.arange(count) - ) - for node_type, count in node_counts.items() - } - - partition_output = PartitionOutput( - node_partition_book=node_partition_book, - edge_partition_book=edge_partition_book, - partitioned_edge_index=partitioned_edge_index, - partitioned_node_features=partitioned_node_features, - partitioned_edge_features=None, - partitioned_positive_labels=None, - partitioned_negative_labels=None, - partitioned_node_labels=partitioned_node_labels, - ) - dataset = DistDataset(rank=0, world_size=1, edge_dir="out") - dataset.build(partition_output=partition_output) - return dataset - - def _create_homogeneous_dataset( - self, - edge_index: torch.Tensor, - ) -> DistDataset: - """Helper method to create a homogeneous test dataset. - - Args: - edge_index: COO format edge index [2, num_edges]. - """ - - # Derive counts from edge index - num_nodes = int(edge_index.max().item() + 1) - num_edges = int(edge_index.shape[1]) - - partition_output = PartitionOutput( - # Partition books filled with zeros assign all nodes/edges to partition 0 (rank 0 - we only have 1 rank in the test) - node_partition_book=torch.zeros(num_nodes, dtype=torch.int64), - edge_partition_book=torch.zeros(num_edges, dtype=torch.int64), - partitioned_edge_index=GraphPartitionData( - edge_index=edge_index, - edge_ids=None, - ), - partitioned_node_features=FeaturePartitionData( - feats=torch.zeros(num_nodes, 3), ids=torch.arange(num_nodes) - ), - partitioned_edge_features=None, - partitioned_positive_labels=None, - partitioned_negative_labels=None, - partitioned_node_labels=None, - ) - dataset = DistDataset(rank=0, world_size=1, edge_dir="out") - dataset.build(partition_output=partition_output) - return dataset - - def _create_heterogeneous_dataset_with_labels( - self, - positive_labels: dict[int, list[int]], - negative_labels: Optional[dict[int, list[int]]], - train_user_ids: list[int], - val_user_ids: list[int], - test_user_ids: list[int], - edge_indices: dict[EdgeType, torch.Tensor], - ) -> DistDataset: - """Helper method to create a heterogeneous test dataset with label edges and splits. - - Creates a dataset with: - - USER nodes (count derived from edge indices) - - STORY nodes (count derived from edge indices) - - Message passing edges from edge_indices - - Positive label edges: USER -[to_gigl_positive]-> STORY (from positive_labels) - - Negative label edges (optional): USER -[to_gigl_negative]-> STORY (from negative_labels) - - Train/val/test splits for USER nodes - - The splits are achieved using DistNodeAnchorLinkSplitter with an identity-like hash - function (hash(x) = x + 1). This produces deterministic splits where: - - Nodes with lower IDs go to train - - Nodes with middle IDs go to val - - Nodes with higher IDs go to test - - Args: - positive_labels: Mapping of user_id -> list of positive story_ids. - negative_labels: Mapping of user_id -> list of negative story_ids, or None. - train_user_ids: List of user IDs in the train split (must be the lowest IDs). - val_user_ids: List of user IDs in the val split (must be middle IDs). - test_user_ids: List of user IDs in the test split (must be the highest IDs). - edge_indices: Mapping of EdgeType -> COO format edge index [2, num_edges]. - - Raises: - ValueError: If any user ID in train/val/test is not in positive_labels. - """ - # Validate that all split user IDs have positive labels - all_split_user_ids = ( - set(train_user_ids) | set(val_user_ids) | set(test_user_ids) - ) - missing_users = all_split_user_ids - set(positive_labels.keys()) - if missing_users: - raise ValueError( - f"User IDs {missing_users} are in train/val/test splits but not in positive_labels" - ) - - positive_label_edge_type = message_passing_to_positive_label(_USER_TO_STORY) - negative_label_edge_type = message_passing_to_negative_label(_USER_TO_STORY) - - # Convert positive_labels dict to COO edge index - pos_src, pos_dst = [], [] - for user_id, story_ids in positive_labels.items(): - for story_id in story_ids: - pos_src.append(user_id) - pos_dst.append(story_id) - positive_label_edge_index = torch.tensor([pos_src, pos_dst]) - - # Derive node counts from edge indices by collecting max node ID per node type - node_counts: dict[NodeType, int] = {} - for edge_type, edge_index in edge_indices.items(): - src_type, _, dst_type = edge_type - src_max = edge_index[0].max().item() + 1 - dst_max = edge_index[1].max().item() + 1 - node_counts[src_type] = int(max(node_counts.get(src_type, 0), src_max)) - node_counts[dst_type] = int(max(node_counts.get(dst_type, 0), dst_max)) - # Also account for nodes in positive labels - node_counts[_USER] = max( - node_counts.get(_USER, 0), max(positive_labels.keys()) + 1 - ) - node_counts[_STORY] = max( - node_counts.get(_STORY, 0), - max(max(stories) for stories in positive_labels.values()) + 1, - ) - - # Set up edge partition books and edge indices - # Partition books filled with zeros assign all edges to partition 0 (single machine) - edge_partition_book = { - edge_type: torch.zeros(edge_index.shape[1], dtype=torch.int64) - for edge_type, edge_index in edge_indices.items() - } - edge_partition_book[positive_label_edge_type] = torch.zeros( - len(pos_src), dtype=torch.int64 - ) - - partitioned_edge_index = { - edge_type: GraphPartitionData(edge_index=edge_index, edge_ids=None) - for edge_type, edge_index in edge_indices.items() - } - partitioned_edge_index[positive_label_edge_type] = GraphPartitionData( - edge_index=positive_label_edge_index, - edge_ids=None, - ) - - if negative_labels is not None: - # Convert negative_labels dict to COO edge index - neg_src, neg_dst = [], [] - for user_id, story_ids in negative_labels.items(): - for story_id in story_ids: - neg_src.append(user_id) - neg_dst.append(story_id) - negative_label_edge_index = torch.tensor([neg_src, neg_dst]) - edge_partition_book[negative_label_edge_type] = torch.zeros( - len(neg_src), dtype=torch.int64 - ) - partitioned_edge_index[negative_label_edge_type] = GraphPartitionData( - edge_index=negative_label_edge_index, - edge_ids=None, - ) - - # Partition books filled with zeros assign all nodes to partition 0 (single machine) - node_partition_book = { - node_type: torch.zeros(count, dtype=torch.int64) - for node_type, count in node_counts.items() - } - partitioned_node_features = { - node_type: FeaturePartitionData( - feats=torch.zeros(count, 2), ids=torch.arange(count) - ) - for node_type, count in node_counts.items() - } - - partition_output = PartitionOutput( - node_partition_book=node_partition_book, - edge_partition_book=edge_partition_book, - partitioned_edge_index=partitioned_edge_index, - partitioned_node_features=partitioned_node_features, - partitioned_edge_features=None, - partitioned_positive_labels=None, - partitioned_negative_labels=None, - partitioned_node_labels=None, - ) - - # Calculate split ratios based on provided user IDs. - # With identity hash (x + 1), nodes are split by their ID values: - # - Lower IDs -> train, middle IDs -> val, higher IDs -> test - total_users = len(positive_labels) - num_val = len(val_user_ids) / total_users - num_test = len(test_user_ids) / total_users - - # Identity-like hash function for deterministic splits based on node ID ordering. - # Adding 1 ensures hash(0) != 0 and creates proper normalization boundaries. - def _identity_hash(x: torch.Tensor) -> torch.Tensor: - return x.clone().to(torch.int64) + 1 - - # Create splitter that will produce splits based on node ID ordering - splitter = DistNodeAnchorLinkSplitter( - sampling_direction="out", - num_val=num_val, - num_test=num_test, - hash_function=_identity_hash, - supervision_edge_types=[_USER_TO_STORY], - should_convert_labels_to_edges=True, # Derives positive/negative edge types from supervision edge type - ) - - dataset = DistDataset(rank=0, world_size=1, edge_dir="out") - dataset.build(partition_output=partition_output, splitter=splitter) - return dataset - def test_register_dataset(self) -> None: """Test that register_dataset correctly sets the global dataset.""" - dataset = self._create_heterogeneous_dataset( - edge_indices=_DEFAULT_HETEROGENEOUS_EDGE_INDICES, + dataset = create_heterogeneous_dataset( + edge_indices=DEFAULT_HETEROGENEOUS_EDGE_INDICES, ) storage_utils.register_dataset(dataset) @@ -311,8 +45,8 @@ def test_register_dataset(self) -> None: def test_reregister_dataset_raises_error(self) -> None: """Test that reregistering a dataset raises an error.""" - dataset = self._create_heterogeneous_dataset( - edge_indices=_DEFAULT_HETEROGENEOUS_EDGE_INDICES, + dataset = create_heterogeneous_dataset( + edge_indices=DEFAULT_HETEROGENEOUS_EDGE_INDICES, ) storage_utils.register_dataset(dataset) with self.assertRaises(ValueError) as context: @@ -321,8 +55,8 @@ def test_reregister_dataset_raises_error(self) -> None: def test_get_node_feature_info_with_heterogeneous_dataset(self) -> None: """Test get_node_feature_info with a registered heterogeneous dataset.""" - dataset = self._create_heterogeneous_dataset( - edge_indices=_DEFAULT_HETEROGENEOUS_EDGE_INDICES, + dataset = create_heterogeneous_dataset( + edge_indices=DEFAULT_HETEROGENEOUS_EDGE_INDICES, ) storage_utils.register_dataset(dataset) @@ -330,15 +64,15 @@ def test_get_node_feature_info_with_heterogeneous_dataset(self) -> None: # Verify it returns the correct feature info expected = { - _USER: FeatureInfo(dim=2, dtype=torch.float32), - _STORY: FeatureInfo(dim=2, dtype=torch.float32), + USER: FeatureInfo(dim=2, dtype=torch.float32), + STORY: FeatureInfo(dim=2, dtype=torch.float32), } self.assertEqual(node_feature_info, expected) def test_get_node_feature_info_with_homogeneous_dataset(self) -> None: """Test get_node_feature_info with a registered homogeneous dataset.""" - dataset = self._create_homogeneous_dataset( - edge_index=_DEFAULT_HOMOGENEOUS_EDGE_INDEX, + dataset = create_homogeneous_dataset( + edge_index=DEFAULT_HOMOGENEOUS_EDGE_INDEX, ) storage_utils.register_dataset(dataset) @@ -358,8 +92,8 @@ def test_get_node_feature_info_without_registered_dataset(self) -> None: def test_get_edge_feature_info_with_heterogeneous_dataset(self) -> None: """Test get_edge_feature_info with a registered heterogeneous dataset.""" - dataset = self._create_heterogeneous_dataset( - edge_indices=_DEFAULT_HETEROGENEOUS_EDGE_INDICES, + dataset = create_heterogeneous_dataset( + edge_indices=DEFAULT_HETEROGENEOUS_EDGE_INDICES, ) storage_utils.register_dataset(dataset) @@ -370,8 +104,8 @@ def test_get_edge_feature_info_with_heterogeneous_dataset(self) -> None: def test_get_edge_feature_info_with_homogeneous_dataset(self) -> None: """Test get_edge_feature_info with a registered homogeneous dataset.""" - dataset = self._create_homogeneous_dataset( - edge_index=_DEFAULT_HOMOGENEOUS_EDGE_INDEX, + dataset = create_homogeneous_dataset( + edge_index=DEFAULT_HOMOGENEOUS_EDGE_INDEX, ) storage_utils.register_dataset(dataset) @@ -388,39 +122,57 @@ def test_get_edge_feature_info_without_registered_dataset(self) -> None: self.assertIn("Dataset not registered", str(context.exception)) self.assertIn("register_dataset", str(context.exception)) + def get_node_ids(self) -> None: + """Test get_node_ids with a registered dataset.""" + dataset = create_homogeneous_dataset( + edge_index=DEFAULT_HOMOGENEOUS_EDGE_INDEX, + ) + storage_utils.register_dataset(dataset) + node_ids = storage_utils.get_node_ids() + self.assertIsInstance(node_ids, torch.Tensor) + self.assertEqual(node_ids.shape[0], 10) + assert_tensor_equality(node_ids, torch.arange(10)) + + def get_node_ids_heterogeneous(self) -> None: + """Test get_node_ids with a registered heterogeneous dataset.""" + dataset = create_heterogeneous_dataset( + edge_indices=DEFAULT_HETEROGENEOUS_EDGE_INDICES, + ) + storage_utils.register_dataset(dataset) + node_ids = storage_utils.get_node_ids(node_type=USER) + self.assertIsInstance(node_ids, torch.Tensor) + self.assertEqual(node_ids.shape[0], 5) + assert_tensor_equality(node_ids, torch.arange(5)) + def test_get_node_ids_for_rank_with_homogeneous_dataset(self) -> None: """Test get_node_ids_for_rank with a homogeneous dataset.""" - dataset = self._create_homogeneous_dataset( - edge_index=_DEFAULT_HOMOGENEOUS_EDGE_INDEX, + dataset = create_homogeneous_dataset( + edge_index=DEFAULT_HOMOGENEOUS_EDGE_INDEX, ) storage_utils.register_dataset(dataset) # Test with world_size=1, rank=0 (should get all nodes) - node_ids = storage_utils.get_node_ids_for_rank( - rank=0, world_size=1, node_type=None - ) + node_ids = storage_utils.get_node_ids(rank=0, world_size=1, node_type=None) self.assertIsInstance(node_ids, torch.Tensor) self.assertEqual(node_ids.shape[0], 10) assert_tensor_equality(node_ids, torch.arange(10)) def test_get_node_ids_for_rank_with_heterogeneous_dataset(self) -> None: """Test get_node_ids_for_rank with a heterogeneous dataset.""" - dataset = self._create_heterogeneous_dataset( - edge_indices=_DEFAULT_HETEROGENEOUS_EDGE_INDICES, + dataset = create_heterogeneous_dataset( + edge_indices=DEFAULT_HETEROGENEOUS_EDGE_INDICES, ) storage_utils.register_dataset(dataset) # Test with USER node type - user_node_ids = storage_utils.get_node_ids_for_rank( - rank=0, world_size=1, node_type=_USER - ) + user_node_ids = storage_utils.get_node_ids(rank=0, world_size=1, node_type=USER) self.assertIsInstance(user_node_ids, torch.Tensor) self.assertEqual(user_node_ids.shape[0], 5) assert_tensor_equality(user_node_ids, torch.arange(5)) # Test with STORY node type - story_node_ids = storage_utils.get_node_ids_for_rank( - rank=0, world_size=1, node_type=_STORY + story_node_ids = storage_utils.get_node_ids( + rank=0, world_size=1, node_type=STORY ) self.assertIsInstance(story_node_ids, torch.Tensor) self.assertEqual(story_node_ids.shape[0], 5) @@ -428,33 +180,23 @@ def test_get_node_ids_for_rank_with_heterogeneous_dataset(self) -> None: def test_get_node_ids_for_rank_with_multiple_ranks(self) -> None: """Test get_node_ids_for_rank with multiple ranks to verify sharding.""" - dataset = self._create_homogeneous_dataset( - edge_index=_DEFAULT_HOMOGENEOUS_EDGE_INDEX, + dataset = create_homogeneous_dataset( + edge_index=DEFAULT_HOMOGENEOUS_EDGE_INDEX, ) storage_utils.register_dataset(dataset) # Test with world_size=2 - rank_0_nodes = storage_utils.get_node_ids_for_rank( - rank=0, world_size=2, node_type=None - ) - rank_1_nodes = storage_utils.get_node_ids_for_rank( - rank=1, world_size=2, node_type=None - ) + rank_0_nodes = storage_utils.get_node_ids(rank=0, world_size=2, node_type=None) + rank_1_nodes = storage_utils.get_node_ids(rank=1, world_size=2, node_type=None) # Verify each rank gets different nodes assert_tensor_equality(rank_0_nodes, torch.arange(5)) assert_tensor_equality(rank_1_nodes, torch.arange(5, 10)) # Test with world_size=3 (uneven split) - rank_0_nodes = storage_utils.get_node_ids_for_rank( - rank=0, world_size=3, node_type=None - ) - rank_1_nodes = storage_utils.get_node_ids_for_rank( - rank=1, world_size=3, node_type=None - ) - rank_2_nodes = storage_utils.get_node_ids_for_rank( - rank=2, world_size=3, node_type=None - ) + rank_0_nodes = storage_utils.get_node_ids(rank=0, world_size=3, node_type=None) + rank_1_nodes = storage_utils.get_node_ids(rank=1, world_size=3, node_type=None) + rank_2_nodes = storage_utils.get_node_ids(rank=2, world_size=3, node_type=None) assert_tensor_equality(rank_0_nodes, torch.arange(3)) assert_tensor_equality(rank_1_nodes, torch.arange(3, 6)) @@ -463,43 +205,111 @@ def test_get_node_ids_for_rank_with_multiple_ranks(self) -> None: def test_get_node_ids_for_rank_without_registered_dataset(self) -> None: """Test get_node_ids_for_rank raises ValueError when no dataset is registered.""" with self.assertRaises(ValueError) as context: - storage_utils.get_node_ids_for_rank(rank=0, world_size=1) + storage_utils.get_node_ids(rank=0, world_size=1) self.assertIn("Dataset not registered", str(context.exception)) self.assertIn("register_dataset", str(context.exception)) def test_get_node_ids_for_rank_with_homogeneous_dataset_and_node_type(self) -> None: """Test get_node_ids_for_rank with a homogeneous dataset and a node type.""" - dataset = self._create_homogeneous_dataset( - edge_index=_DEFAULT_HOMOGENEOUS_EDGE_INDEX, + dataset = create_homogeneous_dataset( + edge_index=DEFAULT_HOMOGENEOUS_EDGE_INDEX, ) storage_utils.register_dataset(dataset) - with self.assertRaises(ValueError) as context: - storage_utils.get_node_ids_for_rank(rank=0, world_size=1, node_type=_USER) - self.assertIn( - "node_type must be None for a homogeneous dataset. Got user.", - str(context.exception), - ) + with self.assertRaises(ValueError): + storage_utils.get_node_ids(rank=0, world_size=1, node_type=USER) def test_get_node_ids_for_rank_with_heterogeneous_dataset_and_no_node_type( self, ) -> None: """Test get_node_ids_for_rank with a heterogeneous dataset and no node type.""" - dataset = self._create_heterogeneous_dataset( - edge_indices=_DEFAULT_HETEROGENEOUS_EDGE_INDICES, + dataset = create_heterogeneous_dataset( + edge_indices=DEFAULT_HETEROGENEOUS_EDGE_INDICES, ) storage_utils.register_dataset(dataset) - with self.assertRaises(ValueError) as context: - storage_utils.get_node_ids_for_rank(rank=0, world_size=1, node_type=None) - self.assertIn( - "node_type must be not None for a heterogeneous dataset. Got None.", - str(context.exception), + with self.assertRaises(ValueError): + storage_utils.get_node_ids(rank=0, world_size=1, node_type=None) + + def test_get_node_ids_with_train_split(self) -> None: + """Test get_node_ids returns only training nodes when split='train'.""" + create_test_process_group() + + positive_labels = {0: [0], 1: [1], 2: [2], 3: [3], 4: [4]} + dataset = create_heterogeneous_dataset_with_labels( + positive_labels=positive_labels, + train_node_ids=[0, 1, 2], + val_node_ids=[3], + test_node_ids=[4], + edge_indices=DEFAULT_HETEROGENEOUS_EDGE_INDICES, + ) + storage_utils.register_dataset(dataset) + + train_nodes = storage_utils.get_node_ids(node_type=USER, split="train") + assert_tensor_equality(train_nodes, torch.tensor([0, 1, 2])) + + def test_get_node_ids_with_val_split(self) -> None: + """Test get_node_ids returns only validation nodes when split='val'.""" + create_test_process_group() + + positive_labels = {0: [0], 1: [1], 2: [2], 3: [3], 4: [4]} + dataset = create_heterogeneous_dataset_with_labels( + positive_labels=positive_labels, + train_node_ids=[0, 1, 2], + val_node_ids=[3], + test_node_ids=[4], + edge_indices=DEFAULT_HETEROGENEOUS_EDGE_INDICES, ) + storage_utils.register_dataset(dataset) + + val_nodes = storage_utils.get_node_ids(node_type=USER, split="val") + assert_tensor_equality(val_nodes, torch.tensor([3])) + + def test_get_node_ids_with_test_split(self) -> None: + """Test get_node_ids returns only test nodes when split='test'.""" + create_test_process_group() + + positive_labels = {0: [0], 1: [1], 2: [2], 3: [3], 4: [4]} + dataset = create_heterogeneous_dataset_with_labels( + positive_labels=positive_labels, + train_node_ids=[0, 1, 2], + val_node_ids=[3], + test_node_ids=[4], + edge_indices=DEFAULT_HETEROGENEOUS_EDGE_INDICES, + ) + storage_utils.register_dataset(dataset) + + test_nodes = storage_utils.get_node_ids(node_type=USER, split="test") + assert_tensor_equality(test_nodes, torch.tensor([4])) + + def test_get_node_ids_with_split_and_sharding(self) -> None: + """Test get_node_ids with split and rank/world_size for sharding.""" + create_test_process_group() + + positive_labels = {0: [0], 1: [1], 2: [2], 3: [3], 4: [4]} + dataset = create_heterogeneous_dataset_with_labels( + positive_labels=positive_labels, + train_node_ids=[0, 1, 2], + val_node_ids=[3], + test_node_ids=[4], + edge_indices=DEFAULT_HETEROGENEOUS_EDGE_INDICES, + ) + storage_utils.register_dataset(dataset) + + # Train split has [0, 1, 2], shard across 2 ranks + rank_0_nodes = storage_utils.get_node_ids( + rank=0, world_size=2, node_type=USER, split="train" + ) + rank_1_nodes = storage_utils.get_node_ids( + rank=1, world_size=2, node_type=USER, split="train" + ) + + assert_tensor_equality(rank_0_nodes, torch.tensor([0])) + assert_tensor_equality(rank_1_nodes, torch.tensor([1, 2])) def test_get_edge_dir(self) -> None: """Test get_edge_dir with a registered dataset.""" - dataset = self._create_homogeneous_dataset( - edge_index=_DEFAULT_HOMOGENEOUS_EDGE_INDEX, + dataset = create_homogeneous_dataset( + edge_index=DEFAULT_HOMOGENEOUS_EDGE_INDEX, ) storage_utils.register_dataset(dataset) edge_dir = storage_utils.get_edge_dir() @@ -507,8 +317,8 @@ def test_get_edge_dir(self) -> None: def test_get_node_feature_info(self) -> None: """Test get_node_feature_info with a registered dataset.""" - dataset = self._create_homogeneous_dataset( - edge_index=_DEFAULT_HOMOGENEOUS_EDGE_INDEX, + dataset = create_homogeneous_dataset( + edge_index=DEFAULT_HOMOGENEOUS_EDGE_INDEX, ) storage_utils.register_dataset(dataset) node_feature_info = storage_utils.get_node_feature_info() @@ -516,8 +326,8 @@ def test_get_node_feature_info(self) -> None: def test_get_edge_feature_info(self) -> None: """Test get_edge_feature_info with a registered dataset.""" - dataset = self._create_homogeneous_dataset( - edge_index=_DEFAULT_HOMOGENEOUS_EDGE_INDEX, + dataset = create_homogeneous_dataset( + edge_index=DEFAULT_HOMOGENEOUS_EDGE_INDEX, ) storage_utils.register_dataset(dataset) edge_feature_info = storage_utils.get_edge_feature_info() @@ -525,8 +335,8 @@ def test_get_edge_feature_info(self) -> None: def test_get_edge_types_homogeneous(self) -> None: """Test get_edge_types with a homogeneous dataset.""" - dataset = self._create_homogeneous_dataset( - edge_index=_DEFAULT_HOMOGENEOUS_EDGE_INDEX, + dataset = create_homogeneous_dataset( + edge_index=DEFAULT_HOMOGENEOUS_EDGE_INDEX, ) storage_utils.register_dataset(dataset) edge_types = storage_utils.get_edge_types() @@ -534,14 +344,14 @@ def test_get_edge_types_homogeneous(self) -> None: def test_get_edge_types_heterogeneous(self) -> None: """Test get_edge_types with a heterogeneous dataset.""" - dataset = self._create_heterogeneous_dataset( - edge_indices=_DEFAULT_HETEROGENEOUS_EDGE_INDICES, + dataset = create_heterogeneous_dataset( + edge_indices=DEFAULT_HETEROGENEOUS_EDGE_INDICES, ) storage_utils.register_dataset(dataset) edge_types = storage_utils.get_edge_types() self.assertEqual( edge_types, - [(_USER, Relation("to"), _STORY), (_STORY, Relation("to"), _USER)], + [(USER, Relation("to"), STORY), (STORY, Relation("to"), USER)], ) def test_get_ablp_input(self) -> None: @@ -570,13 +380,13 @@ def test_get_ablp_input(self) -> None: "test": [4], } - dataset = self._create_heterogeneous_dataset_with_labels( + dataset = create_heterogeneous_dataset_with_labels( positive_labels=positive_labels, negative_labels=negative_labels, - train_user_ids=split_to_user_ids["train"], - val_user_ids=split_to_user_ids["val"], - test_user_ids=split_to_user_ids["test"], - edge_indices=_DEFAULT_HETEROGENEOUS_EDGE_INDICES, + train_node_ids=split_to_user_ids["train"], + val_node_ids=split_to_user_ids["val"], + test_node_ids=split_to_user_ids["test"], + edge_indices=DEFAULT_HETEROGENEOUS_EDGE_INDICES, ) storage_utils.register_dataset(dataset) @@ -586,8 +396,8 @@ def test_get_ablp_input(self) -> None: split=split, rank=0, world_size=1, - node_type=_USER, - supervision_edge_type=_USER_TO_STORY, + node_type=USER, + supervision_edge_type=USER_TO_STORY, ) # Verify anchor nodes match expected users @@ -623,13 +433,13 @@ def test_get_ablp_input_multiple_ranks(self) -> None: } train_user_ids = [0, 1, 2, 3] - dataset = self._create_heterogeneous_dataset_with_labels( + dataset = create_heterogeneous_dataset_with_labels( positive_labels=positive_labels, negative_labels=negative_labels, - train_user_ids=train_user_ids, - val_user_ids=[4], - test_user_ids=[], - edge_indices=_DEFAULT_HETEROGENEOUS_EDGE_INDICES, + train_node_ids=train_user_ids, + val_node_ids=[4], + test_node_ids=[], + edge_indices=DEFAULT_HETEROGENEOUS_EDGE_INDICES, ) storage_utils.register_dataset(dataset) @@ -641,8 +451,8 @@ def test_get_ablp_input_multiple_ranks(self) -> None: split="train", rank=0, world_size=2, - node_type=_USER, - supervision_edge_type=_USER_TO_STORY, + node_type=USER, + supervision_edge_type=USER_TO_STORY, ) # Get training input for rank 1 of 2 @@ -650,8 +460,8 @@ def test_get_ablp_input_multiple_ranks(self) -> None: split="train", rank=1, world_size=2, - node_type=_USER, - supervision_edge_type=_USER_TO_STORY, + node_type=USER, + supervision_edge_type=USER_TO_STORY, ) # Train nodes [0, 1, 2, 3] should be split across ranks @@ -681,8 +491,8 @@ def test_get_training_input_without_registered_dataset(self) -> None: split="train", rank=0, world_size=1, - node_type=_USER, - supervision_edge_type=_USER_TO_STORY, + node_type=USER, + supervision_edge_type=USER_TO_STORY, ) def test_get_ablp_input_invalid_split(self) -> None: @@ -691,13 +501,13 @@ def test_get_ablp_input_invalid_split(self) -> None: positive_labels = {0: [0], 1: [1], 2: [2], 3: [3], 4: [4]} negative_labels = {0: [1], 1: [2], 2: [3], 3: [4], 4: [0]} - dataset = self._create_heterogeneous_dataset_with_labels( + dataset = create_heterogeneous_dataset_with_labels( positive_labels=positive_labels, negative_labels=negative_labels, - train_user_ids=[0, 1, 2], - val_user_ids=[3], - test_user_ids=[4], - edge_indices=_DEFAULT_HETEROGENEOUS_EDGE_INDICES, + train_node_ids=[0, 1, 2], + val_node_ids=[3], + test_node_ids=[4], + edge_indices=DEFAULT_HETEROGENEOUS_EDGE_INDICES, ) storage_utils.register_dataset(dataset) @@ -706,8 +516,8 @@ def test_get_ablp_input_invalid_split(self) -> None: split="invalid", rank=0, world_size=1, - node_type=_USER, - supervision_edge_type=_USER_TO_STORY, + node_type=USER, + supervision_edge_type=USER_TO_STORY, ) def test_get_training_input_without_negative_labels(self) -> None: @@ -723,13 +533,13 @@ def test_get_training_input_without_negative_labels(self) -> None: } train_user_ids = [0, 1, 2] - dataset = self._create_heterogeneous_dataset_with_labels( + dataset = create_heterogeneous_dataset_with_labels( positive_labels=positive_labels, negative_labels=None, # No negative labels - train_user_ids=train_user_ids, - val_user_ids=[3], - test_user_ids=[4], - edge_indices=_DEFAULT_HETEROGENEOUS_EDGE_INDICES, + train_node_ids=train_user_ids, + val_node_ids=[3], + test_node_ids=[4], + edge_indices=DEFAULT_HETEROGENEOUS_EDGE_INDICES, ) storage_utils.register_dataset(dataset) @@ -737,8 +547,8 @@ def test_get_training_input_without_negative_labels(self) -> None: split="train", rank=0, world_size=1, - node_type=_USER, - supervision_edge_type=_USER_TO_STORY, + node_type=USER, + supervision_edge_type=USER_TO_STORY, ) # Verify train split returns the expected users diff --git a/tests/unit/test_assets/__init__.py b/tests/unit/test_assets/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/test_assets/test_dataset_test.py b/tests/unit/test_assets/test_dataset_test.py new file mode 100644 index 000000000..12028e6a9 --- /dev/null +++ b/tests/unit/test_assets/test_dataset_test.py @@ -0,0 +1,269 @@ +"""Unit tests for test_dataset factory functions.""" + +import unittest + +import torch + +from gigl.src.common.types.graph_data import EdgeType, NodeType, Relation +from gigl.types.graph import FeatureInfo +from tests.test_assets.distributed.test_dataset import ( + DEFAULT_HETEROGENEOUS_EDGE_INDICES, + DEFAULT_HETEROGENEOUS_NODE_FEATURE_DIM, + DEFAULT_HOMOGENEOUS_EDGE_INDEX, + DEFAULT_HOMOGENEOUS_NODE_FEATURE_DIM, + STORY, + STORY_TO_USER, + USER, + USER_TO_STORY, + create_heterogeneous_dataset, + create_heterogeneous_dataset_with_labels, + create_homogeneous_dataset, +) +from tests.test_assets.distributed.utils import ( + assert_tensor_equality, + create_test_process_group, +) + + +class TestCreateHomogeneousDataset(unittest.TestCase): + """Tests for create_homogeneous_dataset function.""" + + def test_with_default_edge_index(self) -> None: + """Test creating a dataset with default edge index constant.""" + dataset = create_homogeneous_dataset(edge_index=DEFAULT_HOMOGENEOUS_EDGE_INDEX) + + # Verify node count (10 nodes in default ring graph) + node_ids = dataset.node_ids + assert isinstance(node_ids, torch.Tensor) + self.assertEqual(node_ids.shape[0], 10) + + # Verify feature info uses default dimension + expected_feature_info = FeatureInfo( + dim=DEFAULT_HOMOGENEOUS_NODE_FEATURE_DIM, dtype=torch.float32 + ) + self.assertEqual(dataset.node_feature_info, expected_feature_info) + + # Verify default edge direction + self.assertEqual(dataset.edge_dir, "out") + + def test_with_custom_parameters(self) -> None: + """Test creating a dataset with custom edge index, features, and edge direction.""" + # Custom 3-node graph + custom_edge_index = torch.tensor([[0, 1, 2], [1, 2, 0]]) + custom_features = torch.ones(3, 5) + + dataset = create_homogeneous_dataset( + edge_index=custom_edge_index, + node_features=custom_features, + edge_dir="in", + ) + + # Verify node count from custom edge index + node_ids = dataset.node_ids + assert isinstance(node_ids, torch.Tensor) + self.assertEqual(node_ids.shape[0], 3) + assert_tensor_equality(node_ids, torch.arange(3)) + + # Verify feature dimension from custom features + self.assertEqual( + dataset.node_feature_info, FeatureInfo(dim=5, dtype=torch.float32) + ) + + # Verify custom edge direction + self.assertEqual(dataset.edge_dir, "in") + + def test_default_edge_index_not_modified(self) -> None: + """Test that creating a dataset doesn't modify the default edge index.""" + original = DEFAULT_HOMOGENEOUS_EDGE_INDEX.clone() + _ = create_homogeneous_dataset(edge_index=DEFAULT_HOMOGENEOUS_EDGE_INDEX) + + assert_tensor_equality(DEFAULT_HOMOGENEOUS_EDGE_INDEX, original) + + +class TestCreateHeterogeneousDataset(unittest.TestCase): + """Tests for create_heterogeneous_dataset function.""" + + def test_with_default_edge_indices(self) -> None: + """Test creating a dataset with default edge indices constant.""" + dataset = create_heterogeneous_dataset( + edge_indices=DEFAULT_HETEROGENEOUS_EDGE_INDICES + ) + + # Verify node counts (5 users, 5 stories in default graph) + node_ids = dataset.node_ids + assert isinstance(node_ids, dict) + self.assertEqual(node_ids[USER].shape[0], 5) + self.assertEqual(node_ids[STORY].shape[0], 5) + + # Verify feature info uses default dimension + expected_feature_info = { + USER: FeatureInfo( + dim=DEFAULT_HETEROGENEOUS_NODE_FEATURE_DIM, dtype=torch.float32 + ), + STORY: FeatureInfo( + dim=DEFAULT_HETEROGENEOUS_NODE_FEATURE_DIM, dtype=torch.float32 + ), + } + self.assertEqual(dataset.node_feature_info, expected_feature_info) + + # Verify default edge direction + self.assertEqual(dataset.edge_dir, "out") + + def test_with_custom_parameters(self) -> None: + """Test creating a dataset with custom edge indices, features, labels, and edge direction.""" + # Custom 3-node graph per type + custom_edges = { + USER_TO_STORY: torch.tensor([[0, 1, 2], [0, 1, 2]]), + STORY_TO_USER: torch.tensor([[0, 1, 2], [0, 1, 2]]), + } + custom_features = { + USER: torch.ones(3, 4), + STORY: torch.ones(3, 4), + } + custom_labels = { + USER: torch.tensor([[10], [20], [30]]), + STORY: torch.tensor([[100], [200], [300]]), + } + + dataset = create_heterogeneous_dataset( + edge_indices=custom_edges, + node_features=custom_features, + node_labels=custom_labels, + edge_dir="in", + ) + + # Verify node counts from custom edge indices + node_ids = dataset.node_ids + assert isinstance(node_ids, dict) + self.assertEqual(node_ids[USER].shape[0], 3) + self.assertEqual(node_ids[STORY].shape[0], 3) + + # Verify feature dimension from custom features + expected_feature_info = { + USER: FeatureInfo(dim=4, dtype=torch.float32), + STORY: FeatureInfo(dim=4, dtype=torch.float32), + } + self.assertEqual(dataset.node_feature_info, expected_feature_info) + + # Verify node labels are set + node_labels = dataset.node_labels + assert isinstance(node_labels, dict) + self.assertIn(USER, node_labels) + self.assertIn(STORY, node_labels) + + # Verify custom edge direction + self.assertEqual(dataset.edge_dir, "in") + + def test_default_edge_indices_not_modified(self) -> None: + """Test that creating a dataset doesn't modify the default edge indices.""" + original = { + edge_type: edge_index.clone() + for edge_type, edge_index in DEFAULT_HETEROGENEOUS_EDGE_INDICES.items() + } + _ = create_heterogeneous_dataset( + edge_indices=DEFAULT_HETEROGENEOUS_EDGE_INDICES + ) + + for edge_type, edge_index in DEFAULT_HETEROGENEOUS_EDGE_INDICES.items(): + assert_tensor_equality(edge_index, original[edge_type]) + + +class TestCreateHeterogeneousDatasetWithLabels(unittest.TestCase): + """Tests for create_heterogeneous_dataset_with_labels function.""" + + def tearDown(self) -> None: + """Clean up after each test.""" + if torch.distributed.is_initialized(): + torch.distributed.destroy_process_group() + + def test_basic_dataset_with_splits(self) -> None: + """Test creating a dataset with train/val/test splits and labels.""" + create_test_process_group() + + positive_labels = {0: [0, 1], 1: [1, 2], 2: [2, 3], 3: [3, 4], 4: [4, 0]} + negative_labels = {0: [2], 1: [3], 2: [4], 3: [0], 4: [1]} + + dataset = create_heterogeneous_dataset_with_labels( + positive_labels=positive_labels, + train_node_ids=[0, 1, 2], + val_node_ids=[3], + test_node_ids=[4], + edge_indices=DEFAULT_HETEROGENEOUS_EDGE_INDICES, + negative_labels=negative_labels, + ) + + # Verify train/val/test node IDs are set + train_node_ids = dataset.train_node_ids + val_node_ids = dataset.val_node_ids + test_node_ids = dataset.test_node_ids + assert isinstance(train_node_ids, dict) + assert isinstance(val_node_ids, dict) + assert isinstance(test_node_ids, dict) + + # Verify split sizes + self.assertEqual(train_node_ids[USER].shape[0], 3) + self.assertEqual(val_node_ids[USER].shape[0], 1) + self.assertEqual(test_node_ids[USER].shape[0], 1) + + def test_missing_positive_labels_raises_error(self) -> None: + """Test that missing positive labels for split nodes raises an error.""" + positive_labels = {0: [0], 1: [1], 2: [2], 3: [3], 4: [4]} + + with self.assertRaises(ValueError) as context: + create_heterogeneous_dataset_with_labels( + positive_labels=positive_labels, + train_node_ids=[0, 1, 2], + val_node_ids=[3], + test_node_ids=[5], # Node 5 not in positive_labels + edge_indices=DEFAULT_HETEROGENEOUS_EDGE_INDICES, + ) + + self.assertIn("5", str(context.exception)) + self.assertIn("positive_labels", str(context.exception)) + + def test_with_custom_node_types_and_feature_dim(self) -> None: + """Test creating a dataset with custom node types and feature dimension.""" + create_test_process_group() + + positive_labels = {0: [0, 1], 1: [1, 2], 2: [2, 3], 3: [3, 4], 4: [4, 0]} + + custom_src_type = NodeType("author") + custom_dst_type = NodeType("article") + custom_edge_type = EdgeType(custom_src_type, Relation("wrote"), custom_dst_type) + reverse_edge_type = EdgeType( + custom_dst_type, Relation("written_by"), custom_src_type + ) + + custom_edges = { + custom_edge_type: torch.tensor([[0, 1, 2, 3, 4], [0, 1, 2, 3, 4]]), + reverse_edge_type: torch.tensor([[0, 1, 2, 3, 4], [0, 1, 2, 3, 4]]), + } + + dataset = create_heterogeneous_dataset_with_labels( + positive_labels=positive_labels, + train_node_ids=[0, 1, 2], + val_node_ids=[3], + test_node_ids=[4], + edge_indices=custom_edges, + src_node_type=custom_src_type, + dst_node_type=custom_dst_type, + supervision_edge_type=custom_edge_type, + node_feature_dim=8, + ) + + # Verify custom node types + node_ids = dataset.node_ids + assert isinstance(node_ids, dict) + self.assertIn(custom_src_type, node_ids) + self.assertIn(custom_dst_type, node_ids) + + # Verify custom feature dimension + expected_feature_info = { + custom_src_type: FeatureInfo(dim=8, dtype=torch.float32), + custom_dst_type: FeatureInfo(dim=8, dtype=torch.float32), + } + self.assertEqual(dataset.node_feature_info, expected_feature_info) + + +if __name__ == "__main__": + unittest.main()