diff --git a/Makefile b/Makefile index b86b1242f..910ad3933 100644 --- a/Makefile +++ b/Makefile @@ -278,6 +278,14 @@ run_het_dblp_sup_gs_e2e_test: --test_spec_uri="tests/e2e_tests/e2e_tests.yaml" \ --test_names="het_dblp_sup_gs_test" +run_hom_cora_snc_e2e_test: compiled_pipeline_path:=${GIGL_E2E_TEST_COMPILED_PIPELINE_PATH} +run_hom_cora_snc_e2e_test: compile_gigl_kubeflow_pipeline +run_hom_cora_snc_e2e_test: + uv run python tests/e2e_tests/e2e_test.py \ + --compiled_pipeline_path=$(compiled_pipeline_path) \ + --test_spec_uri="tests/e2e_tests/e2e_tests.yaml" \ + --test_names="hom_cora_snc_test" + run_all_e2e_tests: compiled_pipeline_path:=${GIGL_E2E_TEST_COMPILED_PIPELINE_PATH} run_all_e2e_tests: compile_gigl_kubeflow_pipeline run_all_e2e_tests: diff --git a/examples/node_classification/README.md b/examples/node_classification/README.md new file mode 100644 index 000000000..34dfbe759 --- /dev/null +++ b/examples/node_classification/README.md @@ -0,0 +1,18 @@ +# Examples for Supervised Node Classification on Homogeneous Graphs + +## Homogeneous (CORA) + +We use the CORA dataset as an example for supervised node classification on a homogeneous graph. + +[homogeneous_training.py](./homogeneous_training.py) and [homogeneous_inference.py](./homogeneous_inference.py) are +example training and inference loops for the CORA dataset, the MNIST of graph models, and available via the PyG +`Planetoid` +[dataset](https://pytorch-geometric.readthedocs.io/en/stable/generated/torch_geometric.datasets.Planetoid.html). + +```bash +make run_hom_cora_snc_e2e_test +``` + +The pipeline will run each component end-to-end: `config_populator` → `data_preprocessor` → `trainer` → `inferencer`, +exporting the per-anchor predicted class label (an integer in `[0, 7)` cast to `FLOAT64`) to a BigQuery table referenced +by `InferenceAssets.get_enumerated_predictions_table_path(...)`. diff --git a/examples/node_classification/__init__.py b/examples/node_classification/__init__.py new file mode 100644 index 000000000..5bf0ad1e2 --- /dev/null +++ b/examples/node_classification/__init__.py @@ -0,0 +1 @@ +"""Node Classification Examples""" diff --git a/examples/node_classification/configs/e2e_hom_cora_sup_task_config.yaml b/examples/node_classification/configs/e2e_hom_cora_sup_task_config.yaml new file mode 100644 index 000000000..e6108e8ef --- /dev/null +++ b/examples/node_classification/configs/e2e_hom_cora_sup_task_config.yaml @@ -0,0 +1,49 @@ +# This config is used to run homogeneous CORA supervised node classification training +# and inference using in-memory GiGL SGS. Run via `make run_hom_cora_snc_e2e_test`. +graphMetadata: + edgeTypes: + - dstNodeType: paper + relation: cites + srcNodeType: paper + nodeTypes: + - paper +taskMetadata: + nodeBasedTaskMetadata: + supervisionNodeTypes: + - paper +datasetConfig: + dataPreprocessorConfig: + dataPreprocessorConfigClsPath: gigl.src.mocking.mocking_assets.passthrough_preprocessor_config_for_mocked_assets.PassthroughPreprocessorConfigForMockedAssets + dataPreprocessorArgs: + mocked_dataset_name: 'cora_homogeneous_supervised_node_classification_edge_features' +trainerConfig: + trainerArgs: + log_every_n_batch: "25" + num_neighbors: "[10, 10]" + num_classes: "7" + train_batch_size: "16" + num_max_train_batches: "200" + num_val_batches: "20" + val_every_n_batch: "50" + command: python -m examples.node_classification.homogeneous_training +inferencerConfig: + inferencerArgs: + log_every_n_batch: "25" + num_neighbors: "[10, 10]" + num_classes: "7" + inferenceBatchSize: 512 + command: python -m examples.node_classification.homogeneous_inference +sharedConfig: + shouldSkipAutomaticTempAssetCleanup: false + shouldSkipInference: false + # Model Evaluation is currently only supported for tabularized SGS GiGL pipelines. + shouldSkipModelEvaluation: true +featureFlags: + should_run_glt_backend: 'True' + data_preprocessor_num_shards: '2' + # NODE_BASED_TASK tasks unconditionally populate `predictions_path` in the inference metadata + # (see gigl/src/config_populator/config_populator.py:348-358); the post-processor's + # unenumerator then expects whichever paths are populated to point at real BQ tables. + # We disable the embeddings path so the post-processor only unenumerates the predictions + # table that this example actually writes via `PredictionExporter`. + should_populate_embeddings_path: 'False' diff --git a/examples/node_classification/homogeneous_inference.py b/examples/node_classification/homogeneous_inference.py new file mode 100644 index 000000000..476793f64 --- /dev/null +++ b/examples/node_classification/homogeneous_inference.py @@ -0,0 +1,426 @@ +""" +This file contains an example for how to run homogeneous supervised node classification (SNC) +inference using GiGL's GraphLearn-for-PyTorch (GLT) bindings. The example exports the per-anchor +argmax class label to a temporary GCS folder, then loads it into a BigQuery table at the end of +the run. While `_run_example_inference` is coupled with GiGL orchestration, the +`_inference_process` function is generic and can be used as a reference for writing inference for +pipelines not dependent on GiGL orchestration. + +To run this file with GiGL orchestration, set the fields similar to below: + +inferencerConfig: + inferencerArgs: + # Example argument to inferencer + log_every_n_batch: "25" + inferenceBatchSize: 512 + command: python -m examples.node_classification.homogeneous_inference +featureFlags: + should_run_glt_backend: 'True' + # Disable embeddings-path population so the post-processor's unenumerator only expects the + # predictions table this example actually writes. NODE_BASED_TASK tasks unconditionally + # populate `predictions_path` already. + should_populate_embeddings_path: 'False' + +You can run this example in a full pipeline with `make run_hom_cora_snc_e2e_test` +from GiGL root. + +Each anchor node is exported as a single `FLOAT64` `pred` field (the predicted class label, i.e. +`argmax(logits)` cast to float) into BigQuery via GiGL's `PredictionExporter`. +""" + +import argparse +import gc +import time +from dataclasses import dataclass + +import torch +import torch.multiprocessing as mp + +import gigl.distributed +import gigl.distributed.utils +from examples.node_classification.models import ( + init_example_gigl_homogeneous_node_classification_model, +) +from gigl.common import GcsUri, Uri, UriFactory +from gigl.common.data.export import PredictionExporter, load_predictions_to_bigquery +from gigl.common.logger import Logger +from gigl.common.utils.gcs import GcsUtils +from gigl.distributed import DistDataset, build_dataset_from_task_config_uri +from gigl.src.common.types import AppliedTaskIdentifier +from gigl.src.common.types.graph_data import EdgeType, NodeType +from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper +from gigl.src.common.utils.bq import BqUtils +from gigl.src.common.utils.model import load_state_dict_from_uri +from gigl.src.inference.lib.assets import InferenceAssets +from gigl.utils.sampling import parse_fanout + +logger = Logger() + +# Default number of inference processes per machine when one isn't provided via +# `local_world_size` in inferencer args and there are no GPUs available. +DEFAULT_CPU_BASED_LOCAL_WORLD_SIZE = 4 + + +@dataclass(frozen=True) +class InferenceProcessArgs: + """ + Arguments for the homogeneous SNC inference process. + + Attributes: + local_world_size (int): Number of inference processes spawned by each machine. + machine_rank (int): Rank of the current machine in the cluster. + machine_world_size (int): Total number of machines in the cluster. + master_ip_address (str): IP address of the master node for process group initialization. + master_default_process_group_port (int): Port for the default process group. + dataset (DistDataset): Loaded Distributed Dataset for inference. + inference_node_type (NodeType): Node type that predicted class labels are generated for. + model_state_dict_uri (Uri): URI to load the trained model state dict from. + hid_dim (int): Encoder hidden dimension. + num_classes (int): Number of output classes. + node_feature_dim (int): Input node feature dimension. + prediction_gcs_path (GcsUri): GCS path to write predicted class labels to. + inference_batch_size (int): Batch size to use for inference. + num_neighbors (list[int]): Fanout for subgraph sampling. + sampling_workers_per_inference_process (int): Sampling workers per inference process. + sampling_worker_shared_channel_size (str): Shared-memory buffer size (e.g. ``"4GB"``). + log_every_n_batch (int): Frequency to log batch information during inference. + """ + + local_world_size: int + machine_rank: int + machine_world_size: int + master_ip_address: str + master_default_process_group_port: int + + dataset: DistDataset + inference_node_type: NodeType + + model_state_dict_uri: Uri + hid_dim: int + num_classes: int + node_feature_dim: int + + prediction_gcs_path: GcsUri + inference_batch_size: int + num_neighbors: list[int] | dict[EdgeType, list[int]] + sampling_workers_per_inference_process: int + sampling_worker_shared_channel_size: str + log_every_n_batch: int + + +@torch.no_grad() +def _inference_process( + local_rank: int, + args: InferenceProcessArgs, +) -> None: + """ + Spawned per local rank: initializes the dataloader, runs the inference loop, and writes + the per-anchor predicted class label to GCS. + + Args: + local_rank (int): Process number on the current machine. + args (InferenceProcessArgs): Dataclass containing all inference arguments. + """ + device = gigl.distributed.utils.get_available_device( + local_process_rank=local_rank, + ) + if torch.cuda.is_available(): + logger.info( + f"Using GPU {device} with index {device.index} on local rank: {local_rank} for inference" + ) + torch.cuda.set_device(device) + rank = args.machine_rank * args.local_world_size + local_rank + world_size = args.machine_world_size * args.local_world_size + logger.info( + f"Local rank {local_rank} in machine {args.machine_rank} has rank {rank}/{world_size} " + f"and using device {device} for inference" + ) + torch.distributed.init_process_group( + backend="gloo" if device.type == "cpu" else "nccl", + init_method=f"tcp://{args.master_ip_address}:{args.master_default_process_group_port}", + rank=rank, + world_size=world_size, + ) + + data_loader = gigl.distributed.DistNeighborLoader( + dataset=args.dataset, + num_neighbors=args.num_neighbors, + local_process_rank=local_rank, + local_process_world_size=args.local_world_size, + input_nodes=None, # Homogeneous case: `None` defaults to using all nodes for inference. + num_workers=args.sampling_workers_per_inference_process, + batch_size=args.inference_batch_size, + pin_memory_device=device, + worker_concurrency=args.sampling_workers_per_inference_process, + channel_size=args.sampling_worker_shared_channel_size, + process_start_gap_seconds=0, + ) + + model_state_dict = load_state_dict_from_uri( + load_from_uri=args.model_state_dict_uri, device=device + ) + model = init_example_gigl_homogeneous_node_classification_model( + node_feature_dim=args.node_feature_dim, + num_classes=args.num_classes, + hid_dim=args.hid_dim, + device=device, + state_dict=model_state_dict, + ) + model.eval() + + logger.info(f"Model initialized on device {device}") + + output_filename = f"machine_{args.machine_rank}_local_process_{local_rank}" + + # Clean any stale files at the destination — GiGL orchestration cleans this automatically, + # but a local retry would otherwise leave stale files. + gcs_utils = GcsUtils() + gcs_base_uri = GcsUri.join(args.prediction_gcs_path, output_filename) + num_files_at_gcs_path = gcs_utils.count_blobs_in_gcs_path(gcs_base_uri) + if num_files_at_gcs_path > 0: + logger.warning( + f"{num_files_at_gcs_path} files already detected at base gcs path. " + f"Cleaning up files at path ... " + ) + gcs_utils.delete_files_in_bucket_dir(gcs_base_uri) + + # The BigQuery predictions schema stores a single FLOAT64 `pred` per node + # (see gigl/common/data/export.py:67); for multi-class classification we write + # `argmax(logits)` cast to float. + exporter = PredictionExporter(export_dir=gcs_base_uri) + + # Barrier so all processes have initialized their dataloader before the inference loop starts; + # otherwise on-the-fly subgraph sampling can fail. + torch.distributed.barrier() + + t = time.time() + data_loading_start_time = time.time() + cumulative_data_loading_time = 0.0 + cumulative_inference_time = 0.0 + + for batch_idx, data in enumerate(data_loader): + cumulative_data_loading_time += time.time() - data_loading_start_time + + inference_start_time = time.time() + + logits = model(data=data, device=device) + anchor_logits = logits[: data.batch_size] + anchor_predictions = anchor_logits.argmax(dim=-1).float().cpu() + node_ids = data.batch.cpu() + + exporter.add_prediction( + id_batch=node_ids, + prediction_batch=anchor_predictions, + prediction_type=str(args.inference_node_type), + ) + + cumulative_inference_time += time.time() - inference_start_time + + if batch_idx > 0 and batch_idx % args.log_every_n_batch == 0: + logger.info( + f"rank {rank} processed {batch_idx} batches. " + f"{args.log_every_n_batch} batches took {time.time() - t:.2f} seconds. " + f"Among them, data loading took {cumulative_data_loading_time:.2f} seconds " + f"and model inference took {cumulative_inference_time:.2f} seconds." + ) + t = time.time() + cumulative_data_loading_time = 0 + cumulative_inference_time = 0 + + data_loading_start_time = time.time() + + logger.info(f"--- Rank {rank} finished inference.") + + write_start_time = time.time() + exporter.flush_records() + logger.info( + f"--- Rank {rank} finished writing predictions to GCS, " + f"which took {time.time() - write_start_time:.2f} seconds" + ) + + # Barrier before shutting down so all processes finish sampling first; otherwise still-active + # samplers will fail when their peer's loader exits. + torch.distributed.barrier() + + data_loader.shutdown() + gc.collect() + torch.distributed.destroy_process_group() + + logger.info( + f"--- All machines local rank {local_rank} finished inference. Deleted data loader" + ) + + +def _run_example_inference( + job_name: str, + task_config_uri: str, +) -> None: + """ + Runs an example SNC inference pipeline using GiGL Orchestration. + + Args: + job_name (str): Name of current job. + task_config_uri (str): Path to frozen GbmlConfig URI. + """ + program_start_time = time.time() + + # One main process per machine needs to coordinate partitioning + synchronization; assuming + # spawn-via-Vertex sets up env:// init for us. + torch.distributed.init_process_group(backend="gloo") + + logger.info( + f"Took {time.time() - program_start_time:.2f} seconds to connect worker pool" + ) + + dataset = build_dataset_from_task_config_uri(task_config_uri=task_config_uri) + + gbml_config_pb_wrapper = GbmlConfigPbWrapper.get_gbml_config_pb_wrapper_from_uri( + gbml_config_uri=UriFactory.create_uri(task_config_uri) + ) + # `model_state_dict_uri` is read from the same `trained_model_metadata.trained_model_uri` slot + # the trainer wrote to (see homogeneous_training.py `model_uri` derivation). + model_uri = UriFactory.create_uri( + gbml_config_pb_wrapper.gbml_config_pb.shared_config.trained_model_metadata.trained_model_uri + ) + graph_metadata = gbml_config_pb_wrapper.graph_metadata_pb_wrapper + output_bq_table_path = InferenceAssets.get_enumerated_predictions_table_path( + gbml_config_pb_wrapper, graph_metadata.homogeneous_node_type + ) + bq_project_id, bq_dataset_id, bq_table_name = BqUtils.parse_bq_table_path( + bq_table_path=output_bq_table_path + ) + # Write to a temporary GCS folder during the inference loop, then load to BigQuery at the end. + prediction_output_gcs_folder = InferenceAssets.get_gcs_asset_write_path_prefix( + applied_task_identifier=AppliedTaskIdentifier(job_name), + bq_table_path=output_bq_table_path, + ) + preprocessed_metadata_pb_wrapper = ( + gbml_config_pb_wrapper.preprocessed_metadata_pb_wrapper + ) + node_feature_dim = ( + preprocessed_metadata_pb_wrapper.condensed_node_type_to_feature_dim_map[ + graph_metadata.homogeneous_condensed_node_type + ] + ) + + inferencer_args = dict(gbml_config_pb_wrapper.inferencer_config.inferencer_args) + # `inference_batch_size` lives on the dedicated proto field, NOT inside the `inferencer_args` + # string map. + inference_batch_size = gbml_config_pb_wrapper.inferencer_config.inference_batch_size + + hid_dim = int(inferencer_args.get("hid_dim", "16")) + num_classes = int(inferencer_args.get("num_classes", "7")) + + arg_local_world_size = inferencer_args.get("local_world_size") + if arg_local_world_size is not None: + local_world_size = int(arg_local_world_size) + logger.info(f"Using local_world_size from inferencer_args: {local_world_size}") + if torch.cuda.is_available() and local_world_size != torch.cuda.device_count(): + logger.warning( + f"local_world_size {local_world_size} does not match the number of GPUs " + f"{torch.cuda.device_count()}. This may lead to unexpected failures with NCCL " + f"communication. Consider setting local_world_size to the number of GPUs." + ) + else: + if torch.cuda.is_available() and torch.cuda.device_count() > 0: + local_world_size = torch.cuda.device_count() + logger.info( + f"Detected {local_world_size} GPUs. Setting local_world_size to {local_world_size}" + ) + else: + logger.info( + f"No GPUs detected. Setting local_world_size to " + f"`{DEFAULT_CPU_BASED_LOCAL_WORLD_SIZE}`" + ) + local_world_size = DEFAULT_CPU_BASED_LOCAL_WORLD_SIZE + + master_ip_address = gigl.distributed.utils.get_internal_ip_from_master_node() + machine_rank = torch.distributed.get_rank() + machine_world_size = torch.distributed.get_world_size() + master_default_process_group_port = ( + gigl.distributed.utils.get_free_ports_from_master_node(num_ports=1)[0] + ) + torch.distributed.destroy_process_group() + + inference_start_time = time.time() + + num_neighbors = parse_fanout(inferencer_args.get("num_neighbors", "[10, 10]")) + + sampling_workers_per_inference_process = int( + inferencer_args.get("sampling_workers_per_inference_process", "4") + ) + + sampling_worker_shared_channel_size = inferencer_args.get( + "sampling_worker_shared_channel_size", "4GB" + ) + + log_every_n_batch = int(inferencer_args.get("log_every_n_batch", "25")) + + inference_args = InferenceProcessArgs( + local_world_size=local_world_size, + machine_rank=machine_rank, + machine_world_size=machine_world_size, + master_ip_address=master_ip_address, + master_default_process_group_port=master_default_process_group_port, + dataset=dataset, + inference_node_type=graph_metadata.homogeneous_node_type, + model_state_dict_uri=model_uri, + hid_dim=hid_dim, + num_classes=num_classes, + node_feature_dim=node_feature_dim, + prediction_gcs_path=prediction_output_gcs_folder, + inference_batch_size=inference_batch_size, + num_neighbors=num_neighbors, + sampling_workers_per_inference_process=sampling_workers_per_inference_process, + sampling_worker_shared_channel_size=sampling_worker_shared_channel_size, + log_every_n_batch=log_every_n_batch, + ) + + mp.spawn( + fn=_inference_process, + args=(inference_args,), + nprocs=local_world_size, + join=True, + ) + + logger.info( + f"--- Inference finished on rank {machine_rank}, which took " + f"{time.time() - inference_start_time:.2f} seconds" + ) + + # Machine 0 loads the per-rank GCS shards into BigQuery. + if machine_rank == 0: + logger.info("--- Machine 0 triggers loading predictions from GCS to BigQuery") + _ = load_predictions_to_bigquery( + gcs_folder=prediction_output_gcs_folder, + project_id=bq_project_id, + dataset_id=bq_dataset_id, + table_id=bq_table_name, + ) + + logger.info( + f"--- Program finished, which took {time.time() - program_start_time:.2f} seconds" + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Arguments for distributed SNC model inference on VertexAI" + ) + parser.add_argument( + "--job_name", + type=str, + help="Inference job name", + ) + parser.add_argument("--task_config_uri", type=str, help="Gbml config uri") + + # parse_known_args is required because Vertex AI's gigl/src/common/vertex_ai_launcher.py + # appends extra runtime flags (e.g. --resource_config_uri, --use_cuda) that this module + # does not declare. + args, unused_args = parser.parse_known_args() + logger.info(f"Unused arguments: {unused_args}") + + _run_example_inference( + job_name=args.job_name, + task_config_uri=args.task_config_uri, + ) diff --git a/examples/node_classification/homogeneous_training.py b/examples/node_classification/homogeneous_training.py new file mode 100644 index 000000000..107c4c6a1 --- /dev/null +++ b/examples/node_classification/homogeneous_training.py @@ -0,0 +1,742 @@ +""" +This file contains an example for how to run homogeneous supervised node classification (SNC) +training using live subgraph sampling powered by GraphLearn-for-PyTorch (GLT). While +`_run_example_training` is coupled with GiGL orchestration, the `_training_process` and +`_run_validation_loops` functions are generic and can be used as references for writing training +for pipelines not dependent on GiGL orchestration. + +To run this file with GiGL orchestration, set the fields similar to below: + +trainerConfig: + trainerArgs: + # Example argument to trainer + log_every_n_batch: "25" + command: python -m examples.node_classification.homogeneous_training +featureFlags: + should_run_glt_backend: 'True' + +You can run this example in a full pipeline with `make run_hom_cora_snc_e2e_test` +from GiGL root. + +Given a frozen task config with already-populated data preprocessor output, the following training +script can be run locally using: +WORLD_SIZE=1 RANK=0 MASTER_ADDR="localhost" MASTER_PORT=20000 \ + python -m examples.node_classification.homogeneous_training \ + --task_config_uri= + +A frozen task config with data preprocessor outputs can be generated by running an e2e pipeline +with `stop_after=data_preprocessor` and using the frozen config generated from the +`config_populator` component after the run has completed. +""" + +import argparse +import statistics +import time +from collections.abc import Iterator +from dataclasses import dataclass +from typing import Literal, Optional + +import torch +import torch.distributed +import torch.multiprocessing as mp +from torch_geometric.data import Data + +import gigl.distributed.utils +from examples.node_classification.models import ( + HomogeneousNodeClassificationGNN, + init_example_gigl_homogeneous_node_classification_model, +) +from gigl.common import Uri, UriFactory +from gigl.common.logger import Logger +from gigl.common.utils.torch_training import is_distributed_available_and_initialized +from gigl.distributed import DistDataset, build_dataset_from_task_config_uri +from gigl.distributed.distributed_neighborloader import DistNeighborLoader +from gigl.distributed.utils import get_available_device +from gigl.src.common.translators.model_eval_metrics_translator import ( + write_eval_metrics_to_uri, +) +from gigl.src.common.types.graph_data import EdgeType +from gigl.src.common.types.model_eval_metrics import ( + EvalMetric, + EvalMetricsCollection, + EvalMetricType, +) +from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper +from gigl.src.common.utils.model import load_state_dict_from_uri, save_state_dict +from gigl.types.graph import to_homogeneous +from gigl.utils.iterator import InfiniteIterator +from gigl.utils.sampling import parse_fanout + +logger = Logger() + + +def _sync_metric_across_processes(metric: torch.Tensor) -> float: + """ + Takes the average of a training metric across multiple processes. Note that this function + requires DDP to be initialized. + + Args: + metric (torch.Tensor): The metric, expressed as a torch Tensor, which should be synced + across multiple processes. + + Returns: + float: The average of the provided metric across all training processes. + """ + assert is_distributed_available_and_initialized(), "DDP is not initialized" + metric_tensor = metric.detach().clone() + torch.distributed.all_reduce(metric_tensor, op=torch.distributed.ReduceOp.SUM) + return metric_tensor.item() / torch.distributed.get_world_size() + + +def _setup_dataloader( + dataset: DistDataset, + split: Literal["train", "val", "test"], + num_neighbors: list[int] | dict[EdgeType, list[int]], + sampling_workers_per_process: int, + batch_size: int, + device: torch.device, + sampling_worker_shared_channel_size: str, + process_start_gap_seconds: int, +) -> DistNeighborLoader: + """ + Sets up a single ``DistNeighborLoader`` for the given split. + + Args: + dataset (DistDataset): Loaded distributed dataset. + split (Literal["train", "val", "test"]): Which split to load. + num_neighbors (list[int]): Fanout per hop. + sampling_workers_per_process (int): Sampling workers per training process. + batch_size (int): Number of anchor nodes per batch. + device (torch.device): Device to pin sampled batches to. + sampling_worker_shared_channel_size (str): Shared-memory channel size (e.g. ``"4GB"``). + process_start_gap_seconds (int): Sleep gap between dataloader inits to reduce peak memory. + + Returns: + DistNeighborLoader: Configured loader over the chosen split's anchor nodes. + """ + rank = torch.distributed.get_rank() + + if split == "train": + input_nodes = to_homogeneous(dataset.train_node_ids) + shuffle = True + elif split == "val": + input_nodes = to_homogeneous(dataset.val_node_ids) + shuffle = False + else: # split == "test" + input_nodes = to_homogeneous(dataset.test_node_ids) + shuffle = False + + loader = DistNeighborLoader( + dataset=dataset, + num_neighbors=num_neighbors, + input_nodes=input_nodes, + num_workers=sampling_workers_per_process, + batch_size=batch_size, + pin_memory_device=device, + worker_concurrency=sampling_workers_per_process, + channel_size=sampling_worker_shared_channel_size, + process_start_gap_seconds=process_start_gap_seconds, + shuffle=shuffle, + ) + + logger.info(f"---Rank {rank} finished setting up {split} loader") + + torch.distributed.barrier() + return loader + + +def _compute_loss_and_accuracy( + model: HomogeneousNodeClassificationGNN, + data: Data, + loss_fn: torch.nn.CrossEntropyLoss, + device: torch.device, +) -> tuple[torch.Tensor, torch.Tensor]: + """ + Runs a forward pass and returns per-batch loss + per-batch accuracy (mean over anchor nodes). + + Args: + model (HomogeneousNodeClassificationGNN): Model (possibly DDP-wrapped internally). + data (Data): Sampled subgraph batch. + loss_fn (torch.nn.CrossEntropyLoss): Classification loss. + device (torch.device): Compute device. + + Returns: + tuple[torch.Tensor, torch.Tensor]: Scalar ``(loss, accuracy)`` tensors on ``device``. + """ + logits = model(data=data, device=device) + anchor_logits = logits[: data.batch_size] + anchor_labels = data.y[: data.batch_size].long() + loss = loss_fn(anchor_logits, anchor_labels) + accuracy = (anchor_logits.argmax(dim=-1) == anchor_labels).float().mean() + return loss, accuracy + + +@dataclass(frozen=True) +class TrainingProcessArgs: + """ + Arguments for the homogeneous SNC training process. + + Attributes: + local_world_size (int): Number of training processes spawned by each machine. + machine_rank (int): Rank of the current machine in the cluster. + machine_world_size (int): Total number of machines in the cluster. + master_ip_address (str): IP address of the master node for process group initialization. + master_default_process_group_port (int): Port for the default process group. + dataset (DistDataset): Loaded Distributed Dataset for training and testing. + model_uri (Uri): URI to save/load the trained model state dict. + eval_metrics_uri (Optional[Uri]): Destination URI for eval metrics. If None, metrics are + not written. + hid_dim (int): Hidden dimension of the encoder. + num_classes (int): Number of output classes for the classifier head. + node_feature_dim (int): Input node feature dimension. + num_neighbors (list[int]): Fanout for subgraph sampling. + sampling_workers_per_process (int): Number of sampling workers per training/testing + process. + sampling_worker_shared_channel_size (str): Shared-memory buffer size (e.g. ``"4GB"``). + process_start_gap_seconds (int): Sleep gap between dataloader initializations. + train_batch_size (int): Number of anchor nodes per training batch. + learning_rate (float): Optimizer learning rate. + weight_decay (float): Optimizer weight decay. + num_max_train_batches (int): Maximum number of training batches across all processes. + num_val_batches (int): Number of validation batches across all processes. + val_every_n_batch (int): Frequency to run validation during training. + log_every_n_batch (int): Frequency to log batch information during training. + should_skip_training (bool): If True, skip training and only run testing on a loaded model. + """ + + local_world_size: int + machine_rank: int + machine_world_size: int + master_ip_address: str + master_default_process_group_port: int + + dataset: DistDataset + + model_uri: Uri + eval_metrics_uri: Optional[Uri] + hid_dim: int + num_classes: int + node_feature_dim: int + + num_neighbors: list[int] | dict[EdgeType, list[int]] + sampling_workers_per_process: int + sampling_worker_shared_channel_size: str + process_start_gap_seconds: int + + train_batch_size: int + learning_rate: float + weight_decay: float + num_max_train_batches: int + num_val_batches: int + val_every_n_batch: int + log_every_n_batch: int + should_skip_training: bool + + +def _training_process( + local_rank: int, + args: TrainingProcessArgs, +) -> None: + """ + Spawned per local rank to run a single training (and testing) process. + + Args: + local_rank (int): Process number on the current machine. + args (TrainingProcessArgs): Dataclass containing all training arguments. + """ + world_size = args.machine_world_size * args.local_world_size + rank = args.machine_rank * args.local_world_size + local_rank + logger.info( + f"---Current training process rank: {rank}, training process world size: {world_size}" + ) + + torch.distributed.init_process_group( + backend="nccl" if torch.cuda.is_available() else "gloo", + init_method=f"tcp://{args.master_ip_address}:{args.master_default_process_group_port}", + rank=rank, + world_size=world_size, + ) + + logger.info(f"---Rank {rank} training process started") + + device = get_available_device(local_process_rank=local_rank) + if torch.cuda.is_available(): + torch.cuda.set_device(device) + logger.info(f"---Rank {rank} training process set device {device}") + + loss_fn = torch.nn.CrossEntropyLoss(reduction="mean") + + if not args.should_skip_training: + train_loader = _setup_dataloader( + dataset=args.dataset, + split="train", + num_neighbors=args.num_neighbors, + sampling_workers_per_process=args.sampling_workers_per_process, + batch_size=args.train_batch_size, + device=device, + sampling_worker_shared_channel_size=args.sampling_worker_shared_channel_size, + process_start_gap_seconds=args.process_start_gap_seconds, + ) + train_loader_iter = InfiniteIterator(train_loader) + + val_loader = _setup_dataloader( + dataset=args.dataset, + split="val", + num_neighbors=args.num_neighbors, + sampling_workers_per_process=args.sampling_workers_per_process, + batch_size=args.train_batch_size, + device=device, + sampling_worker_shared_channel_size=args.sampling_worker_shared_channel_size, + process_start_gap_seconds=args.process_start_gap_seconds, + ) + val_loader_iter = InfiniteIterator(val_loader) + + model = init_example_gigl_homogeneous_node_classification_model( + node_feature_dim=args.node_feature_dim, + num_classes=args.num_classes, + hid_dim=args.hid_dim, + device=device, + wrap_with_ddp=True, + ) + + optimizer = torch.optim.AdamW( + params=model.parameters(), + lr=args.learning_rate, + weight_decay=args.weight_decay, + ) + logger.info( + f"Model initialized on rank {rank} training device {device}\n{model}" + ) + + # Wait for all processes to finish dataloader and model setup before starting training. + torch.distributed.barrier() + + training_start_time = time.time() + batch_idx = 0 + last_n_batch_avg_loss: list[float] = [] + last_n_batch_avg_acc: list[float] = [] + last_n_batch_time: list[float] = [] + num_max_train_batches_per_process = args.num_max_train_batches // world_size + num_val_batches_per_process = args.num_val_batches // world_size + logger.info( + f"num_max_train_batches_per_process is set to {num_max_train_batches_per_process}" + ) + + model.train() + batch_start = time.time() + for data in train_loader_iter: + if batch_idx >= num_max_train_batches_per_process: + logger.info( + f"num_max_train_batches_per_process={num_max_train_batches_per_process} reached, " + f"stopping training on machine {args.machine_rank} local rank {local_rank}" + ) + break + loss, accuracy = _compute_loss_and_accuracy( + model=model, + data=data, + loss_fn=loss_fn, + device=device, + ) + optimizer.zero_grad() + loss.backward() + optimizer.step() + avg_train_loss = _sync_metric_across_processes(metric=loss) + avg_train_acc = _sync_metric_across_processes(metric=accuracy) + last_n_batch_avg_loss.append(avg_train_loss) + last_n_batch_avg_acc.append(avg_train_acc) + last_n_batch_time.append(time.time() - batch_start) + batch_start = time.time() + batch_idx += 1 + if batch_idx % args.log_every_n_batch == 0: + logger.info( + f"rank={rank}, batch={batch_idx}, latest local train_loss={loss:.6f}, " + f"latest local train_acc={accuracy:.4f}" + ) + if torch.cuda.is_available(): + torch.cuda.synchronize() + logger.info( + f"rank={rank}, mean(batch_time)={statistics.mean(last_n_batch_time):.3f} sec, " + f"max(batch_time)={max(last_n_batch_time):.3f} sec, " + f"min(batch_time)={min(last_n_batch_time):.3f} sec" + ) + last_n_batch_time.clear() + logger.info( + f"rank={rank}, last {args.log_every_n_batch} " + f"mean(avg_train_loss)={statistics.mean(last_n_batch_avg_loss):.6f}, " + f"mean(avg_train_acc)={statistics.mean(last_n_batch_avg_acc):.4f}" + ) + last_n_batch_avg_loss.clear() + last_n_batch_avg_acc.clear() + + if batch_idx % args.val_every_n_batch == 0: + logger.info(f"rank={rank}, batch={batch_idx}, validating...") + model.eval() + _run_validation_loops( + model=model, + loader=val_loader_iter, + loss_fn=loss_fn, + device=device, + log_every_n_batch=args.log_every_n_batch, + num_batches=num_val_batches_per_process, + ) + model.train() + + logger.info(f"---Rank {rank} finished training") + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + torch.distributed.barrier() + + train_loader.shutdown() + val_loader.shutdown() + + if args.machine_rank == 0 and local_rank == 0: + logger.info( + f"Training loop finished, took {time.time() - training_start_time:.3f} seconds, " + f"saving model to {args.model_uri}" + ) + # Save the unwrapped model so checkpoint keys have no `module.` prefix on the + # internally wrapped sub-modules. See examples/node_classification/models.py + # for the matching load order. + save_state_dict( + model=model.unwrap_from_ddp(), save_to_path_uri=args.model_uri + ) + else: + # should_skip_training=True: load a previously trained model and only run testing. + state_dict = load_state_dict_from_uri( + load_from_uri=args.model_uri, device=device + ) + model = init_example_gigl_homogeneous_node_classification_model( + node_feature_dim=args.node_feature_dim, + num_classes=args.num_classes, + hid_dim=args.hid_dim, + device=device, + wrap_with_ddp=True, + state_dict=state_dict, + ) + logger.info( + f"Model initialized on rank {rank} training device {device}\n{model}" + ) + + logger.info(f"---Rank {rank} started testing") + testing_start_time = time.time() + model.eval() + + test_loader = _setup_dataloader( + dataset=args.dataset, + split="test", + num_neighbors=args.num_neighbors, + sampling_workers_per_process=args.sampling_workers_per_process, + batch_size=args.train_batch_size, + device=device, + sampling_worker_shared_channel_size=args.sampling_worker_shared_channel_size, + process_start_gap_seconds=args.process_start_gap_seconds, + ) + # One pass through the test set; do NOT wrap in InfiniteIterator. + test_loader_iter = iter(test_loader) + + global_avg_test_loss, global_avg_test_acc = _run_validation_loops( + model=model, + loader=test_loader_iter, + loss_fn=loss_fn, + device=device, + log_every_n_batch=args.log_every_n_batch, + ) + + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.synchronize() + torch.distributed.barrier() + + test_loader.shutdown() + + # Write eval metrics on the lead process only. These get logged as a metrics artifact by the + # "Log Trainer Eval Metrics" component in the KFP pipeline UI. + if args.machine_rank == 0 and local_rank == 0 and args.eval_metrics_uri is not None: + eval_metrics = EvalMetricsCollection( + metrics=[ + EvalMetric.from_eval_metric_type( + EvalMetricType.loss, global_avg_test_loss + ), + EvalMetric.from_eval_metric_type( + EvalMetricType.acc, global_avg_test_acc + ), + ] + ) + write_eval_metrics_to_uri( + eval_metrics=eval_metrics, eval_metrics_uri=args.eval_metrics_uri + ) + + logger.info( + f"---Rank {rank} finished testing in {time.time() - testing_start_time:.3f} seconds" + ) + + torch.distributed.destroy_process_group() + + +@torch.inference_mode() +def _run_validation_loops( + model: HomogeneousNodeClassificationGNN, + loader: Iterator[Data], + loss_fn: torch.nn.CrossEntropyLoss, + device: torch.device, + log_every_n_batch: int, + num_batches: Optional[int] = None, +) -> tuple[float, float]: + """ + Runs a validation or test pass. + + Used for both validation while training (when ``loader`` is wrapped with ``InfiniteIterator`` + and ``num_batches`` is provided) and for the final one-pass test loop (when ``loader`` is a + plain iterator and ``num_batches`` is ``None``). + + Args: + model (HomogeneousNodeClassificationGNN): Possibly DDP-wrapped model. + loader (Iterator[Data]): Iterator over sampled batches. + loss_fn (torch.nn.CrossEntropyLoss): Classification loss. + device (torch.device): Compute device. + log_every_n_batch (int): Logging frequency. + num_batches (Optional[int]): Cap on iteration count. Required when ``loader`` is an + ``InfiniteIterator``. + + Returns: + tuple[float, float]: Global average ``(loss, accuracy)`` across all processes. + """ + rank = torch.distributed.get_rank() + + logger.info( + f"Running validation loop on rank={rank}, log_every_n_batch={log_every_n_batch}, " + f"num_batches={num_batches}" + ) + if num_batches is None and isinstance(loader, InfiniteIterator): + raise ValueError( + "Must set `num_batches` field when the provided data loader is wrapped with InfiniteIterator" + ) + + batch_idx = 0 + batch_losses: list[float] = [] + batch_accs: list[float] = [] + last_n_batch_time: list[float] = [] + batch_start = time.time() + + while True: + if num_batches and batch_idx >= num_batches: + break + try: + data = next(loader) + except StopIteration: + break + + loss, accuracy = _compute_loss_and_accuracy( + model=model, + data=data, + loss_fn=loss_fn, + device=device, + ) + + batch_losses.append(loss.item()) + batch_accs.append(accuracy.item()) + last_n_batch_time.append(time.time() - batch_start) + batch_start = time.time() + batch_idx += 1 + if batch_idx % log_every_n_batch == 0: + logger.info( + f"rank={rank}, batch={batch_idx}, latest val_loss={loss:.6f}, " + f"latest val_acc={accuracy:.4f}" + ) + if torch.cuda.is_available(): + torch.cuda.synchronize() + logger.info( + f"rank={rank}, batch={batch_idx}, " + f"mean(batch_time)={statistics.mean(last_n_batch_time):.3f} sec, " + f"max(batch_time)={max(last_n_batch_time):.3f} sec, " + f"min(batch_time)={min(last_n_batch_time):.3f} sec" + ) + last_n_batch_time.clear() + if len(batch_losses): + local_avg_loss = statistics.mean(batch_losses) + else: + local_avg_loss = 0.0 + if len(batch_accs): + local_avg_acc = statistics.mean(batch_accs) + else: + local_avg_acc = 0.0 + logger.info( + f"rank={rank} finished validation loop, local loss={local_avg_loss:.6f}, " + f"local acc={local_avg_acc:.4f}" + ) + global_avg_loss = _sync_metric_across_processes( + metric=torch.tensor(local_avg_loss, device=device) + ) + global_avg_acc = _sync_metric_across_processes( + metric=torch.tensor(local_avg_acc, device=device) + ) + logger.info( + f"rank={rank} got global validation loss={global_avg_loss:.6f}, " + f"acc={global_avg_acc:.4f}" + ) + + return global_avg_loss, global_avg_acc + + +def _run_example_training( + task_config_uri: str, +) -> None: + """ + Runs an example SNC training + testing loop using GiGL Orchestration. + + Args: + task_config_uri (str): Path to YAML-serialized GbmlConfig proto. + """ + start_time = time.time() + mp.set_start_method("spawn") + logger.info(f"Starting sub process method: {mp.get_start_method()}") + + gbml_config_pb_wrapper = GbmlConfigPbWrapper.get_gbml_config_pb_wrapper_from_uri( + gbml_config_uri=UriFactory.create_uri(task_config_uri) + ) + + trainer_args = dict(gbml_config_pb_wrapper.trainer_config.trainer_args) + + local_world_size = int(trainer_args.get("local_world_size", "2")) + if torch.cuda.is_available(): + if local_world_size > torch.cuda.device_count(): + raise ValueError( + f"Specified a local world size of {local_world_size} which exceeds the " + f"number of devices {torch.cuda.device_count()}" + ) + + fanout = trainer_args.get("num_neighbors", "[10, 10]") + num_neighbors = parse_fanout(fanout) + + sampling_workers_per_process = int( + trainer_args.get("sampling_workers_per_process", "4") + ) + + train_batch_size = int(trainer_args.get("train_batch_size", "16")) + + hid_dim = int(trainer_args.get("hid_dim", "16")) + num_classes = int(trainer_args.get("num_classes", "7")) + + sampling_worker_shared_channel_size = trainer_args.get( + "sampling_worker_shared_channel_size", "4GB" + ) + + process_start_gap_seconds = int(trainer_args.get("process_start_gap_seconds", "0")) + log_every_n_batch = int(trainer_args.get("log_every_n_batch", "25")) + + learning_rate = float(trainer_args.get("learning_rate", "0.001")) + weight_decay = float(trainer_args.get("weight_decay", "0.0005")) + num_max_train_batches = int(trainer_args.get("num_max_train_batches", "200")) + num_val_batches = int(trainer_args.get("num_val_batches", "20")) + val_every_n_batch = int(trainer_args.get("val_every_n_batch", "50")) + + logger.info( + f"Got training args local_world_size={local_world_size}, " + f"num_neighbors={num_neighbors}, " + f"sampling_workers_per_process={sampling_workers_per_process}, " + f"train_batch_size={train_batch_size}, " + f"hid_dim={hid_dim}, num_classes={num_classes}, " + f"sampling_worker_shared_channel_size={sampling_worker_shared_channel_size}, " + f"process_start_gap_seconds={process_start_gap_seconds}, " + f"log_every_n_batch={log_every_n_batch}, " + f"learning_rate={learning_rate}, weight_decay={weight_decay}, " + f"num_max_train_batches={num_max_train_batches}, " + f"num_val_batches={num_val_batches}, val_every_n_batch={val_every_n_batch}" + ) + + # Initialize a temporary process group just to discover the master IP and a free port to use + # for the per-rank training process groups. Gloo is sufficient — no GPU comms required here. + torch.distributed.init_process_group(backend="gloo") + + master_ip_address = gigl.distributed.utils.get_internal_ip_from_master_node() + machine_rank = torch.distributed.get_rank() + machine_world_size = torch.distributed.get_world_size() + master_default_process_group_port = ( + gigl.distributed.utils.get_free_ports_from_master_node(num_ports=1) + )[0] + torch.distributed.destroy_process_group() + + logger.info("--- Launching data loading process ---") + dataset = build_dataset_from_task_config_uri( + task_config_uri=task_config_uri, + is_inference=False, + ) + logger.info( + f"--- Data loading process finished, took {time.time() - start_time:.3f} seconds" + ) + + graph_metadata = gbml_config_pb_wrapper.graph_metadata_pb_wrapper + preprocessed_metadata_pb_wrapper = ( + gbml_config_pb_wrapper.preprocessed_metadata_pb_wrapper + ) + + node_feature_dim = ( + preprocessed_metadata_pb_wrapper.condensed_node_type_to_feature_dim_map[ + graph_metadata.homogeneous_condensed_node_type + ] + ) + + model_uri = UriFactory.create_uri( + gbml_config_pb_wrapper.gbml_config_pb.shared_config.trained_model_metadata.trained_model_uri + ) + + raw_eval_metrics_uri = gbml_config_pb_wrapper.gbml_config_pb.shared_config.trained_model_metadata.eval_metrics_uri + eval_metrics_uri: Optional[Uri] = ( + UriFactory.create_uri(raw_eval_metrics_uri) if raw_eval_metrics_uri else None + ) + + should_skip_training = gbml_config_pb_wrapper.shared_config.should_skip_training + + logger.info("--- Launching training processes ...\n") + start_time = time.time() + + training_args = TrainingProcessArgs( + local_world_size=local_world_size, + machine_rank=machine_rank, + machine_world_size=machine_world_size, + master_ip_address=master_ip_address, + master_default_process_group_port=master_default_process_group_port, + dataset=dataset, + model_uri=model_uri, + eval_metrics_uri=eval_metrics_uri, + hid_dim=hid_dim, + num_classes=num_classes, + node_feature_dim=node_feature_dim, + num_neighbors=num_neighbors, + sampling_workers_per_process=sampling_workers_per_process, + sampling_worker_shared_channel_size=sampling_worker_shared_channel_size, + process_start_gap_seconds=process_start_gap_seconds, + train_batch_size=train_batch_size, + learning_rate=learning_rate, + weight_decay=weight_decay, + num_max_train_batches=num_max_train_batches, + num_val_batches=num_val_batches, + val_every_n_batch=val_every_n_batch, + log_every_n_batch=log_every_n_batch, + should_skip_training=should_skip_training, + ) + + torch.multiprocessing.spawn( + _training_process, + args=(training_args,), + nprocs=local_world_size, + join=True, + ) + logger.info(f"--- Training finished, took {time.time() - start_time} seconds") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser( + description="Arguments for distributed SNC model training on VertexAI" + ) + parser.add_argument("--task_config_uri", type=str, help="Gbml config uri") + + # parse_known_args is required because Vertex AI's gigl/src/common/vertex_ai_launcher.py + # appends extra runtime flags (e.g. --job_name, --resource_config_uri, --use_cuda) that this + # module does not declare. + args, unused_args = parser.parse_known_args() + logger.info(f"Unused arguments: {unused_args}") + + _run_example_training(task_config_uri=args.task_config_uri) diff --git a/examples/node_classification/models.py b/examples/node_classification/models.py new file mode 100644 index 000000000..b55921963 --- /dev/null +++ b/examples/node_classification/models.py @@ -0,0 +1,173 @@ +from typing import Optional + +import torch +import torch.nn as nn +from torch.nn.parallel import DistributedDataParallel +from torch_geometric.data import Data +from typing_extensions import Self + +from gigl.src.common.models.pyg.homogeneous import GraphSAGE + + +class HomogeneousNodeClassificationGNN(nn.Module): + """ + Example homogeneous node-classification model wrapping a GraphSAGE encoder with a + Linear classifier head. + + The outer object stays a plain ``nn.Module``; DDP wraps the internal ``_encoder`` and + ``_head`` in place via :py:meth:`to_ddp`. See + :py:class:`gigl.nn.models.LinkPredictionGNN` for the analogous pattern. + + Args: + encoder (nn.Module): GNN encoder producing per-node embeddings of shape ``[N, hid_dim]``. + head (nn.Module): Linear classifier mapping embeddings to ``[N, num_classes]`` logits. + """ + + def __init__(self, encoder: nn.Module, head: nn.Module) -> None: + super().__init__() + self._encoder = encoder + self._head = head + + @property + def encoder(self) -> nn.Module: + return self._encoder + + @property + def head(self) -> nn.Module: + return self._head + + def forward(self, data: Data, device: torch.device) -> torch.Tensor: + """ + Runs the encoder then the classifier head on a sampled subgraph batch. + + Args: + data (Data): Sampled subgraph batch. + device (torch.device): Compute device for the forward pass. + + Returns: + torch.Tensor: Logits of shape ``[num_sampled_nodes, num_classes]``. + """ + node_embeddings = self._encoder(data=data, device=device) + logits = self._head(node_embeddings) + return logits + + def to_ddp( + self, + device: torch.device, + find_unused_encoder_parameters: bool = False, + ) -> Self: + """ + Wraps the internal encoder and classifier head in ``DistributedDataParallel`` in place. + + Mirrors :py:meth:`gigl.nn.models.LinkPredictionGNN.to_ddp`: the outer module stays a plain + ``nn.Module`` so saved state-dict keys remain unprefixed after a subsequent + :py:meth:`unwrap_from_ddp`. + + Args: + device (torch.device): The device DDP should bind to. + find_unused_encoder_parameters (bool): Forwarded as ``find_unused_parameters`` to the + encoder's ``DistributedDataParallel`` wrapper. + + Returns: + Self: This instance with ``_encoder`` and ``_head`` replaced by DDP wrappers. + """ + self._encoder = DistributedDataParallel( + self._encoder.to(device), + device_ids=[device] if device.type != "cpu" else None, + find_unused_parameters=find_unused_encoder_parameters, + ) + self._head = DistributedDataParallel( + self._head.to(device), + device_ids=[device] if device.type != "cpu" else None, + ) + return self + + def unwrap_from_ddp(self) -> "HomogeneousNodeClassificationGNN": + """ + Returns a fresh instance with unwrapped sub-modules suitable for saving. + + Returns: + HomogeneousNodeClassificationGNN: New instance where ``_encoder`` and ``_head`` are + the original plain modules. + """ + encoder = ( + self._encoder.module + if isinstance(self._encoder, DistributedDataParallel) + else self._encoder + ) + head = ( + self._head.module + if isinstance(self._head, DistributedDataParallel) + else self._head + ) + return HomogeneousNodeClassificationGNN(encoder=encoder, head=head) + + +def init_example_gigl_homogeneous_node_classification_model( + node_feature_dim: int, + num_classes: int, + hid_dim: int = 16, + num_layers: int = 2, + device: Optional[torch.device] = None, + state_dict: Optional[dict[str, torch.Tensor]] = None, + wrap_with_ddp: bool = False, + find_unused_encoder_parameters: bool = False, +) -> HomogeneousNodeClassificationGNN: + """ + Initializes a homogeneous node-classification model: ``GraphSAGE`` encoder + ``Linear`` head. + + The factory order is deliberately: + + 1. Construct the model with plain sub-modules. + 2. Move to ``device``. + 3. If ``state_dict`` is provided, load it against the unwrapped model so that the saved keys + (e.g. ``_encoder.``, ``_head.``) match the live model's keys. + 4. If ``wrap_with_ddp`` is ``True``, wrap the sub-modules in ``DistributedDataParallel``. + + This deviates from :py:func:`examples.link_prediction.models.init_example_gigl_homogeneous_model`, + which calls ``to_ddp`` *before* ``load_state_dict``. That ordering has a latent key-mismatch + bug on the ``should_skip_training=True`` eval-only path because DDP-wrapped sub-modules expose + state-dict keys with a ``module.`` prefix. + + Args: + node_feature_dim (int): Input node feature dimension. + num_classes (int): Number of output classes for the classifier head. + hid_dim (int): Encoder hidden and output dimension (head input dim). + num_layers (int): Number of GraphSAGE convolution layers. + device (Optional[torch.device]): Target device; defaults to CPU. + state_dict (Optional[dict[str, torch.Tensor]]): Optional pretrained weights. + wrap_with_ddp (bool): If ``True``, internally wrap ``_encoder`` and ``_head`` in DDP. + find_unused_encoder_parameters (bool): Forwarded to the encoder's DDP wrapper. + + Returns: + HomogeneousNodeClassificationGNN: Ready-to-train (or ready-to-infer) model. + """ + # `GraphSAGE.supports_edge_attr` is False (see `gigl/src/common/models/pyg/homogeneous.py:173`), + # so any edge features present in the dataset are ignored by this encoder. We pass `edge_dim=None` + # to make that explicit. An edge-aware variant (e.g. `GINE`, `EdgeAttrGAT`) is a natural follow-up. + encoder = GraphSAGE( + in_dim=node_feature_dim, + hid_dim=hid_dim, + out_dim=hid_dim, + edge_dim=None, + num_layers=num_layers, + conv_kwargs={}, + should_l2_normalize_embedding_layer_output=False, + ) + head = nn.Linear(hid_dim, num_classes) + model = HomogeneousNodeClassificationGNN(encoder=encoder, head=head) + + if device is None: + device = torch.device("cpu") + model.to(device) + + if state_dict is not None: + model.load_state_dict(state_dict) + + if wrap_with_ddp: + model.to_ddp( + device=device, + find_unused_encoder_parameters=find_unused_encoder_parameters, + ) + + return model diff --git a/tests/e2e_tests/e2e_tests.yaml b/tests/e2e_tests/e2e_tests.yaml index 0f3691e80..61fc4f311 100644 --- a/tests/e2e_tests/e2e_tests.yaml +++ b/tests/e2e_tests/e2e_tests.yaml @@ -25,3 +25,6 @@ tests: het_dblp_sup_gs_test: task_config_uri: "examples/link_prediction/graph_store/configs/e2e_het_dblp_sup_gs_task_config.yaml" resource_config_uri: "${oc.env:GIGL_TEST_IN_MEMORY_DEFAULT_GRAPH_STORE_RESOURCE_CONFIG,deployment/configs/e2e_glt_gs_resource_config.yaml}" + hom_cora_snc_test: + task_config_uri: "examples/node_classification/configs/e2e_hom_cora_sup_task_config.yaml" + resource_config_uri: "${oc.env:GIGL_TEST_IN_MEMORY_DEFAULT_RESOURCE_CONFIG,deployment/configs/e2e_glt_resource_config.yaml}"