Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 161 additions & 45 deletions gigl/distributed/dist_ablp_neighborloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,8 @@
metadata_key_with_prefix,
)
from gigl.distributed.utils.neighborloader import (
DatasetSchema,
SamplingClusterSetup,
labeled_to_homogeneous,
patch_fanout_for_sampling,
set_missing_features,
Expand Down Expand Up @@ -201,6 +203,8 @@ def __init__(
self._supervision_edge_types = [supervision_edge_type]
del supervision_edge_type

self._sampling_cluster_setup = SamplingClusterSetup.COLOCATED

if context:
assert (
local_process_world_size is not None
Expand Down Expand Up @@ -262,17 +266,135 @@ def __init__(
local_process_world_size,
) # delete deprecated vars so we don't accidentally use them.

self.to_device = (
pin_memory_device
if pin_memory_device
else gigl.distributed.utils.get_available_device(
local_process_rank=local_rank
)
)

(
sampler_input,
worker_options,
dataset_metadata,
) = self._setup_for_colocated(
input_nodes=input_nodes,
dataset=dataset,
local_rank=local_rank,
local_world_size=local_world_size,
device=self.to_device,
master_ip_address=master_ip_address,
node_rank=node_rank,
node_world_size=node_world_size,
num_workers=num_workers,
worker_concurrency=worker_concurrency,
channel_size=channel_size,
num_cpu_threads=num_cpu_threads,
)

self._is_input_labeled_homogeneous = (
dataset_metadata.is_homogeneous_with_labeled_edge_type
)
self._node_feature_info = dataset_metadata.node_feature_info
self._edge_feature_info = dataset_metadata.edge_feature_info

num_neighbors = patch_fanout_for_sampling(
dataset_metadata.edge_types, num_neighbors
)

if should_cleanup_distributed_context and torch.distributed.is_initialized():
logger.info(
f"Cleaning up process group as it was initialized inside {self.__class__.__name__}.__init__."
)
torch.distributed.destroy_process_group()

sampling_config = SamplingConfig(
sampling_type=SamplingType.NODE,
num_neighbors=num_neighbors,
batch_size=batch_size,
shuffle=shuffle,
drop_last=drop_last,
with_edge=True,
collect_features=True,
with_neg=False,
with_weight=False,
edge_dir=dataset_metadata.edge_dir,
seed=None, # it's actually optional - None means random.
)

if self._sampling_cluster_setup == SamplingClusterSetup.COLOCATED:
self._start_colocated_producers(
dataset=dataset,
rank=rank,
local_rank=local_rank,
process_start_gap_seconds=process_start_gap_seconds,
sampler_input=sampler_input,
sampling_config=sampling_config,
worker_options=worker_options,
)

def _setup_for_colocated(
self,
input_nodes: Optional[Union[torch.Tensor, tuple[NodeType, torch.Tensor]]],
dataset: DistDataset,
local_rank: int,
local_world_size: int,
device: torch.device,
master_ip_address: str,
node_rank: int,
node_world_size: int,
num_workers: int,
worker_concurrency: int,
channel_size: str,
num_cpu_threads: Optional[int],
) -> tuple[list[ABLPNodeSamplerInput], MpDistSamplingWorkerOptions, DatasetSchema]:
"""
Setup method for colocated (non-Graph Store) mode.

Args:
input_nodes: Input nodes for sampling (tensor or tuple of node type and tensor).
dataset: The DistDataset to sample from.
local_rank: Local rank of the current process.
local_world_size: Total number of processes on this machine.
device: Target device for sampled data.
master_ip_address: IP address of the master node.
node_rank: Rank of the current machine.
node_world_size: Total number of machines.
num_workers: Number of sampling workers.
worker_concurrency: Max sampling concurrency per worker.
channel_size: Size of shared memory channel.
num_cpu_threads: Number of CPU threads for PyTorch.

Returns:
Tuple of (list[ABLPNodeSamplerInput], MpDistSamplingWorkerOptions, DatasetSchema).
"""
# Validate input format - should not be Graph Store format
if isinstance(input_nodes, abc.Mapping):
raise ValueError(
f"When using Colocated mode, input_nodes must be of type "
f"(torch.Tensor | tuple[NodeType, torch.Tensor] | None), "
f"received {type(input_nodes)}"
)
elif isinstance(input_nodes, tuple) and isinstance(input_nodes[1], abc.Mapping):
raise ValueError(
f"When using Colocated mode, input_nodes must be of type "
f"(torch.Tensor | tuple[NodeType, torch.Tensor] | None), "
f"received tuple with second element of type {type(input_nodes[1])}"
)

if not isinstance(dataset.graph, abc.Mapping):
raise ValueError(
f"The dataset must be heterogeneous for ABLP. Recieved dataset with graph of type: {type(dataset.graph)}"
f"The dataset must be heterogeneous for ABLP. Received dataset with graph of type: {type(dataset.graph)}"
)
self._is_input_heterogeneous: bool = False

is_homogeneous_with_labeled_edge_type: bool = False
if isinstance(input_nodes, tuple):
if self._supervision_edge_types == [DEFAULT_HOMOGENEOUS_EDGE_TYPE]:
raise ValueError(
"When using heterogeneous ABLP, you must provide supervision_edge_types."
)
self._is_input_heterogeneous = True
is_homogeneous_with_labeled_edge_type = True
anchor_node_type, anchor_node_ids = input_nodes
# TODO (mkolodner-sc): We currently assume supervision edges are directed outward, revisit in future if
# this assumption is no longer valid and/or is too opinionated
Expand Down Expand Up @@ -306,9 +428,10 @@ def __init__(
raise ValueError(
f"Expected supervision edge type to be None for homogeneous input nodes, got {self._supervision_edge_types}"
)

anchor_node_ids = dataset.node_ids
anchor_node_type = DEFAULT_HOMOGENEOUS_NODE_TYPE
else:
raise ValueError(f"Unexpected input_nodes type: {type(input_nodes)}")

missing_edge_types = set(self._supervision_edge_types) - set(
dataset.graph.keys()
Expand All @@ -318,6 +441,9 @@ def __init__(
f"Missing edge types in dataset: {missing_edge_types}. Edge types in dataset: {dataset.graph.keys()}"
)

# Type narrowing - anchor_node_ids is always a Tensor in colocated mode
assert isinstance(anchor_node_ids, torch.Tensor)

if len(anchor_node_ids.shape) != 1:
raise ValueError(
f"input_nodes must be a 1D tensor, got {anchor_node_ids.shape}."
Expand Down Expand Up @@ -363,31 +489,17 @@ def __init__(
negative_label_by_edge_types=negative_labels_by_label_edge_type,
)

self.to_device = (
pin_memory_device
if pin_memory_device
else gigl.distributed.utils.get_available_device(
local_process_rank=local_rank
)
)

num_neighbors = patch_fanout_for_sampling(
dataset.get_edge_types(), num_neighbors
)

self._node_feature_info = dataset.node_feature_info
self._edge_feature_info = dataset.edge_feature_info

# Sets up processes and torch device for initializing the GLT DistNeighborLoader, setting up RPC and worker groups to minimize
# the memory overhead and CPU contention.
# Sets up processes and torch device for initializing the GLT DistNeighborLoader,
# setting up RPC and worker groups to minimize the memory overhead and CPU contention.
neighbor_loader_ports = gigl.distributed.utils.get_free_ports_from_master_node(
num_ports=local_world_size
)
neighbor_loader_port_for_current_rank = neighbor_loader_ports[local_rank]
logger.info(
f"Initializing neighbor loader worker in process: {local_rank}/{local_world_size} using device: {self.to_device} on port {neighbor_loader_port_for_current_rank}."
f"Initializing neighbor loader worker in process: {local_rank}/{local_world_size} "
f"using device: {device} on port {neighbor_loader_port_for_current_rank}."
)
should_use_cpu_workers = self.to_device.type == "cpu"
should_use_cpu_workers = device.type == "cpu"
if should_use_cpu_workers and num_cpu_threads is None:
logger.info(
"Using CPU workers, but found num_cpu_threads to be None. "
Expand All @@ -402,7 +514,7 @@ def __init__(
rank=node_rank,
world_size=node_world_size,
master_worker_port=neighbor_loader_port_for_current_rank,
device=self.to_device,
device=device,
should_use_cpu_workers=should_use_cpu_workers,
# Lever to explore tuning for CPU based inference
num_cpu_threads=num_cpu_threads,
Expand All @@ -428,39 +540,43 @@ def __init__(
# use different master ports.
master_addr=master_ip_address,
master_port=dist_sampling_port_for_current_rank,
# Load testing show that when num_rpc_threads exceed 16, the performance
# Load testing shows that when num_rpc_threads exceed 16, the performance
# will degrade.
num_rpc_threads=min(dataset.num_partitions, 16),
rpc_timeout=600,
channel_size=channel_size,
pin_memory=self.to_device.type == "cuda",
pin_memory=device.type == "cuda",
)

if should_cleanup_distributed_context and torch.distributed.is_initialized():
logger.info(
f"Cleaning up process group as it was initialized inside {self.__class__.__name__}.__init__."
)
torch.distributed.destroy_process_group()
edge_types = list(dataset.graph.keys())

sampling_config = SamplingConfig(
sampling_type=SamplingType.NODE,
num_neighbors=num_neighbors,
batch_size=batch_size,
shuffle=shuffle,
drop_last=drop_last,
with_edge=True,
collect_features=True,
with_neg=False,
with_weight=False,
edge_dir=dataset.edge_dir,
seed=None, # it's actually optional - None means random.
return (
[sampler_input],
Comment thread
kmontemayor2-sc marked this conversation as resolved.
worker_options,
DatasetSchema(
is_homogeneous_with_labeled_edge_type=is_homogeneous_with_labeled_edge_type,
edge_types=edge_types,
node_feature_info=dataset.node_feature_info,
edge_feature_info=dataset.edge_feature_info,
edge_dir=dataset.edge_dir,
),
)

def _start_colocated_producers(
self,
dataset: DistDataset,
rank: int,
local_rank: int,
process_start_gap_seconds: float,
sampler_input: list[ABLPNodeSamplerInput],
sampling_config: SamplingConfig,
worker_options: MpDistSamplingWorkerOptions,
) -> None:
# Code below this point is taken from the GLT DistNeighborLoader.__init__() function (graphlearn_torch/python/distributed/dist_neighbor_loader.py).
# We do this so that we may override the DistSamplingProducer that is used with the GiGL implementation.

self.data = dataset
self.input_data = sampler_input
self.input_data = sampler_input[0]
Comment thread
kmontemayor2-sc marked this conversation as resolved.
self.sampling_type = sampling_config.sampling_type
self.num_neighbors = sampling_config.num_neighbors
self.batch_size = sampling_config.batch_size
Expand Down Expand Up @@ -701,7 +817,7 @@ def _collate_fn(self, msg: SampleMessage) -> Union[Data, HeteroData]:
)
if isinstance(data, HeteroData):
data = strip_label_edges(data)
if not self._is_input_heterogeneous:
if not self._is_input_labeled_homogeneous:
if len(self._supervision_edge_types) != 1:
raise ValueError(
f"Expected 1 supervision edge type, got {len(self._supervision_edge_types)}"
Expand Down
19 changes: 11 additions & 8 deletions gigl/distributed/distributed_neighborloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,9 @@ def __init__(
num_workers,
)

self._is_labeled_heterogeneous = dataset_metadata.is_labeled_heterogeneous
self._is_homogeneous_with_labeled_edge_type = (
dataset_metadata.is_homogeneous_with_labeled_edge_type
)
self._node_feature_info = dataset_metadata.node_feature_info
self._edge_feature_info = dataset_metadata.edge_feature_info

Expand Down Expand Up @@ -380,7 +382,7 @@ def _setup_for_graph_store(
f"When using Graph Store mode, input nodes must be of type (dict[int, torch.Tensor] | (NodeType, dict[int, torch.Tensor])), received {type(input_nodes)} ({type(input_nodes[0])}, {type(input_nodes[1])})"
)

is_labeled_heterogeneous = False
is_homogeneous_with_labeled_edge_type = False
node_feature_info = dataset.get_node_feature_info()
edge_feature_info = dataset.get_edge_feature_info()
edge_types = dataset.get_edge_types()
Expand Down Expand Up @@ -417,13 +419,14 @@ def _setup_for_graph_store(
require_edge_feature_info = True
else:
raise ValueError(
f"When using Graph Store mode, input nodes must be of type (list[torch.Tensor] | (NodeType, list[torch.Tensor])), received {type(input_nodes)}"
f"When using Graph Store mode, input nodes must be of type (dict[int, torch.Tensor] | (NodeType, dict[int, torch.Tensor])), received {type(input_nodes)}"
)

# Determine input_type based on edge_feature_info
if isinstance(edge_types, list):
if edge_types == [DEFAULT_HOMOGENEOUS_EDGE_TYPE]:
input_type: Optional[NodeType] = DEFAULT_HOMOGENEOUS_NODE_TYPE
is_homogeneous_with_labeled_edge_type = True
else:
input_type = fallback_input_type
elif require_edge_feature_info:
Expand Down Expand Up @@ -455,7 +458,7 @@ def _setup_for_graph_store(
input_data,
worker_options,
DatasetSchema(
is_labeled_heterogeneous=is_labeled_heterogeneous,
is_homogeneous_with_labeled_edge_type=is_homogeneous_with_labeled_edge_type,
edge_types=edge_types,
node_feature_info=node_feature_info,
edge_feature_info=edge_feature_info,
Expand Down Expand Up @@ -503,7 +506,7 @@ def _setup_for_colocated(
raise ValueError(
f"When using Colocated mode, input nodes must be of type (torch.Tensor | (NodeType, torch.Tensor)), received {type(input_nodes)} ({type(input_nodes[0])}, {type(input_nodes[1])})"
)
is_labeled_heterogeneous = False
is_homogeneous_with_labeled_edge_type = False
if isinstance(input_nodes, torch.Tensor):
node_ids = input_nodes

Expand All @@ -515,7 +518,7 @@ def _setup_for_colocated(
and DEFAULT_HOMOGENEOUS_NODE_TYPE in dataset.node_ids
):
node_type = DEFAULT_HOMOGENEOUS_NODE_TYPE
is_labeled_heterogeneous = True
is_homogeneous_with_labeled_edge_type = True
else:
raise ValueError(
f"For heterogeneous datasets, input_nodes must be a tuple of (node_type, node_ids) OR if it is a labeled homogeneous dataset, input_nodes may be a torch.Tensor. Received node types: {dataset.node_ids.keys()}"
Expand Down Expand Up @@ -605,7 +608,7 @@ def _setup_for_colocated(
input_data,
worker_options,
DatasetSchema(
is_labeled_heterogeneous=is_labeled_heterogeneous,
is_homogeneous_with_labeled_edge_type=is_homogeneous_with_labeled_edge_type,
edge_types=edge_types,
node_feature_info=dataset.node_feature_info,
edge_feature_info=dataset.edge_feature_info,
Expand All @@ -623,6 +626,6 @@ def _collate_fn(self, msg: SampleMessage) -> Union[Data, HeteroData]:
)
if isinstance(data, HeteroData):
data = strip_label_edges(data)
if self._is_labeled_heterogeneous:
if self._is_homogeneous_with_labeled_edge_type:
data = labeled_to_homogeneous(DEFAULT_HOMOGENEOUS_EDGE_TYPE, data)
return data
5 changes: 3 additions & 2 deletions gigl/distributed/utils/neighborloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,8 +32,9 @@ class DatasetSchema:
Shared metadata between the local and remote datasets.
"""

# If the dataset is labeled heterogeneous. E.g. one node type, one edge type, and "label" edges.
is_labeled_heterogeneous: bool
# If the dataset is homogeneous with labeled edge type. E.g. one node type, one edge type, and "label" edges.
# This happens in an otherwise homogeneous dataset when doing ABLP and when we split the dataset.
is_homogeneous_with_labeled_edge_type: bool
# List of all edge types in the graph.
edge_types: Optional[list[EdgeType]]
# Node feature info.
Expand Down