Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ inferencerConfig:
num_neighbors: "[10, 10]" # Fanout per hop, specified as a string representation of a list for the homogeneous use case
inferenceBatchSize: 512
command: python -m examples.link_prediction.graph_store.homogeneous_inference
graphStoreStorageConfig:
storageCommand: python -m examples.link_prediction.graph_store.storage_main
storageArgs:
is_inference: "True"
sharedConfig:
shouldSkipInference: false
# Model Evaluation is currently only supported for tabularized SGS GiGL pipelines. This will soon be added for in-mem SGS GiGL pipelines.
Expand Down
153 changes: 153 additions & 0 deletions examples/link_prediction/graph_store/storage_main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
"""Built-in GiGL Graph Store Server.

Derivved from https://github.com/alibaba/graphlearn-for-pytorch/blob/main/examples/distributed/server_client_mode/sage_supervised_server.py

# TODO(kmonte): Figure out how we should split out common utils from this file.

"""
import argparse
import os
from distutils.util import strtobool
from typing import Optional

import graphlearn_torch as glt
import torch

from gigl.common import Uri, UriFactory
from gigl.common.logger import Logger
from gigl.distributed import build_dataset_from_task_config_uri
from gigl.distributed.dist_dataset import DistDataset
from gigl.distributed.graph_store.storage_utils import register_dataset
from gigl.distributed.utils import get_graph_store_info
from gigl.distributed.utils.networking import get_free_ports_from_master_node
from gigl.env.distributed import GraphStoreInfo

logger = Logger()


def _run_storage_process(
storage_rank: int,
cluster_info: GraphStoreInfo,
dataset: DistDataset,
torch_process_port: int,
storage_world_backend: Optional[str],
) -> None:
register_dataset(dataset)
cluster_master_ip = cluster_info.storage_cluster_master_ip
logger.info(
f"Initializing GLT server for storage node process group {storage_rank} / {cluster_info.num_storage_nodes} on {cluster_master_ip}:{cluster_info.rpc_master_port}"
)
# Initialize the GLT server before starting the Torch Distributed process group.
# Otherwise, we saw intermittent hangs when initializing the server.
glt.distributed.init_server(
num_servers=cluster_info.num_storage_nodes,
server_rank=storage_rank,
dataset=dataset,
master_addr=cluster_master_ip,
master_port=cluster_info.rpc_master_port,
num_clients=cluster_info.compute_cluster_world_size,
)

init_method = f"tcp://{cluster_info.storage_cluster_master_ip}:{torch_process_port}"
logger.info(
f"Initializing storage node process group {storage_rank} / {cluster_info.num_storage_nodes} with backend {storage_world_backend} on {init_method}"
)
torch.distributed.init_process_group(
backend=storage_world_backend,
world_size=cluster_info.num_storage_nodes,
rank=storage_rank,
init_method=init_method,
)

logger.info(
f"Waiting for storage node {storage_rank} / {cluster_info.num_storage_nodes} to exit"
)
glt.distributed.wait_and_shutdown_server()
logger.info(f"Storage node {storage_rank} exited")


def storage_node_process(
storage_rank: int,
cluster_info: GraphStoreInfo,
task_config_uri: Uri,
is_inference: bool = True,
tf_record_uri_pattern: str = ".*-of-.*\.tfrecord(\.gz)?$",
storage_world_backend: Optional[str] = None,
) -> None:
"""Run a storage node process

Should be called *once* per storage node (machine).

Args:
storage_rank (int): The rank of the storage node.
cluster_info (GraphStoreInfo): The cluster information.
task_config_uri (Uri): The task config URI.
is_inference (bool): Whether the process is an inference process. Defaults to True.
tf_record_uri_pattern (str): The TF Record URI pattern.
storage_world_backend (Optional[str]): The backend for the storage Torch Distributed process group.
"""
init_method = f"tcp://{cluster_info.storage_cluster_master_ip}:{cluster_info.storage_cluster_master_port}"
logger.info(
f"Initializing storage node {storage_rank} / {cluster_info.num_storage_nodes}. OS rank: {os.environ['RANK']}, OS world size: {os.environ['WORLD_SIZE']} init method: {init_method}"
)
torch.distributed.init_process_group(
backend="gloo",
world_size=cluster_info.num_storage_nodes,
rank=storage_rank,
init_method=init_method,
group_name="gigl_server_comms",
)
logger.info(
f"Storage node {storage_rank} / {cluster_info.num_storage_nodes} process group initialized"
)
dataset = build_dataset_from_task_config_uri(
task_config_uri=task_config_uri,
is_inference=is_inference,
_tfrecord_uri_pattern=tf_record_uri_pattern,
)
torch_process_port = get_free_ports_from_master_node(num_ports=1)[0]
torch.distributed.destroy_process_group()
server_processes = []
mp_context = torch.multiprocessing.get_context("spawn")
# TODO(kmonte): Enable more than one server process per machine
for i in range(1):
server_process = mp_context.Process(
target=_run_storage_process,
args=(
storage_rank + i, # storage_rank
cluster_info, # cluster_info
dataset, # dataset
torch_process_port, # torch_process_port
storage_world_backend, # storage_world_backend
),
)
server_processes.append(server_process)
for server_process in server_processes:
server_process.start()
for server_process in server_processes:
server_process.join()


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--task_config_uri", type=str, required=True)
parser.add_argument("--resource_config_uri", type=str, required=True)
parser.add_argument("--job_name", type=str, required=True)
parser.add_argument("--is_inference", type=str, required=True)
args = parser.parse_args()
logger.info(f"Running storage node with arguments: {args}")
is_inference = bool(strtobool(args.is_inference))
torch.distributed.init_process_group(backend="gloo")
cluster_info = get_graph_store_info()
logger.info(f"Cluster info: {cluster_info}")
logger.info(
f"World size: {torch.distributed.get_world_size()}, rank: {torch.distributed.get_rank()}, OS world size: {os.environ['WORLD_SIZE']}, OS rank: {os.environ['RANK']}"
)
# Tear down the """"global""" process group so we can have a server-specific process group.
torch.distributed.destroy_process_group()
storage_node_process(
storage_rank=cluster_info.storage_node_rank,
cluster_info=cluster_info,
task_config_uri=UriFactory.create_uri(args.task_config_uri),
is_inference=is_inference,
)
19 changes: 19 additions & 0 deletions proto/snapchat/research/gbml/gbml_config.proto
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,17 @@ message GbmlConfig {
}
}

// Configuration for GraphStore storage.
message GraphStoreStorageConfig {
// Command to use for launching storage job.
// e.g. "python -m gigl.distributed.graph_store.storage_main".
string storage_command = 1;
// Arguments to instantiate concrete BaseStorage instance with.
// Will be appended to the storage_command.
// e.g. {"my_dataset_uri": "gs://my_dataset"} -> "python -m gigl.distributed.graph_store.storage_main --my_dataset_uri=gs://my_dataset"
map<string, string> storage_args = 2;
}

message TrainerConfig {
// (deprecated)
// Uri pointing to user-written BaseTrainer class definition. Used for the subgraph-sampling-based training process.
Expand All @@ -188,6 +199,10 @@ message GbmlConfig {
}
// Weather to log to tensorboard or not (defaults to false)
bool should_log_to_tensorboard = 12;

// Configuration for GraphStore storage.
// If setup, then GiGLResourceConfig.trainer_resource_config.vertex_ai_graph_store_trainer_config must be set.
GraphStoreStorageConfig graph_store_storage_config = 13;
}

message InferencerConfig {
Expand All @@ -205,6 +220,10 @@ message GbmlConfig {
// Optional. If set, will be used to batch inference samples to a specific size before call for inference is made
// Defaults to setting in python/gigl/src/inference/gnn_inferencer.py
uint32 inference_batch_size = 5;

// Configuration for GraphStore storage.
// If setup, then GiGLResourceConfig.inferencer_resource_config.vertex_ai_graph_store_inferencer_config must be set.
GraphStoreStorageConfig graph_store_storage_config = 6;
}

message PostProcessorConfig {
Expand Down
2 changes: 2 additions & 0 deletions python/gigl/distributed/graph_store/storage_main.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

Derivved from https://github.com/alibaba/graphlearn-for-pytorch/blob/main/examples/distributed/server_client_mode/sage_supervised_server.py

TODO(kmonte): Remove this, and only expose utils.
We keep this around so we can use the utils in tests/integration/distributed/graph_store/graph_store_integration_test.py.
"""
import argparse
import os
Expand Down
11 changes: 6 additions & 5 deletions python/gigl/src/common/vertex_ai_launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ def launch_graph_store_enabled_job(
resource_config_uri: Uri,
process_command: str,
process_runtime_args: Mapping[str, str],
storage_command: str,
storage_args: Mapping[str, str],
resource_config_wrapper: GiglResourceConfigWrapper,
cpu_docker_uri: Optional[str],
cuda_docker_uri: Optional[str],
Expand All @@ -112,6 +114,8 @@ def launch_graph_store_enabled_job(
resource_config_uri: URI to the resource configuration
process_command: Command to run in the compute container
process_runtime_args: Runtime arguments for the process
storage_command: Command to run in the storage container
storage_args: Arguments to pass to the storage command
resource_config_wrapper: Wrapper for the resource configuration
cpu_docker_uri: Docker image URI for CPU execution
cuda_docker_uri: Docker image URI for GPU execution
Expand Down Expand Up @@ -173,8 +177,8 @@ def launch_graph_store_enabled_job(
job_name=job_name,
task_config_uri=task_config_uri,
resource_config_uri=resource_config_uri,
command_str=f"python -m gigl.distributed.graph_store.storage_main",
args={}, # No extra args for storage pool
command_str=storage_command,
args=storage_args,
use_cuda=is_cpu_execution,
container_uri=container_uri,
vertex_ai_resource_config=storage_pool_config,
Expand Down Expand Up @@ -272,9 +276,6 @@ def _build_job_config(
)
if vertex_ai_resource_config.scheduling_strategy
else None,
boot_disk_size_gb=vertex_ai_resource_config.boot_disk_size_gb
if vertex_ai_resource_config.boot_disk_size_gb
else 100, # Default to 100 GB for backward compatibility
)
return job_config

Expand Down
2 changes: 2 additions & 0 deletions python/gigl/src/inference/v2/glt_inferencer.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,8 @@ def __execute_VAI_inference(
resource_config_uri=resource_config_uri,
process_command=inference_process_command,
process_runtime_args=inference_process_runtime_args,
storage_command=gbml_config_pb_wrapper.inferencer_config.graph_store_storage_config.storage_command,
storage_args=gbml_config_pb_wrapper.inferencer_config.graph_store_storage_config.storage_args,
resource_config_wrapper=resource_config_wrapper,
cpu_docker_uri=cpu_docker_uri,
cuda_docker_uri=cuda_docker_uri,
Expand Down
2 changes: 2 additions & 0 deletions python/gigl/src/training/v2/glt_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,8 @@ def __execute_VAI_training(
resource_config_uri=resource_config_uri,
process_command=training_process_command,
process_runtime_args=training_process_runtime_args,
storage_command=gbml_config_pb_wrapper.trainer_config.graph_store_storage_config.storage_command,
storage_args=gbml_config_pb_wrapper.trainer_config.graph_store_storage_config.storage_args,
resource_config_wrapper=resource_config,
cpu_docker_uri=cpu_docker_uri,
cuda_docker_uri=cuda_docker_uri,
Expand Down
86 changes: 86 additions & 0 deletions python/gigl/src/validation_check/config_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,10 @@
assert_subgraph_sampler_output_exists,
assert_trained_model_exists,
)
from gigl.src.validation_check.libs.gbml_and_resource_config_compatibility_checks import (
check_inferencer_graph_store_compatibility,
check_trainer_graph_store_compatibility,
)
from gigl.src.validation_check.libs.name_checks import (
check_if_kfp_pipeline_job_name_valid,
)
Expand Down Expand Up @@ -191,6 +195,79 @@

logger = Logger()

# Map of start components to graph store compatibility checks to run
# Only run trainer checks when starting at or before Trainer
# Only run inferencer checks when starting at or before Inferencer
START_COMPONENT_TO_GRAPH_STORE_COMPATIBILITY_CHECKS = {
GiGLComponents.ConfigPopulator.value: [
check_trainer_graph_store_compatibility,
check_inferencer_graph_store_compatibility,
],
GiGLComponents.DataPreprocessor.value: [
check_trainer_graph_store_compatibility,
check_inferencer_graph_store_compatibility,
],
GiGLComponents.SubgraphSampler.value: [
check_trainer_graph_store_compatibility,
check_inferencer_graph_store_compatibility,
],
GiGLComponents.SplitGenerator.value: [
check_trainer_graph_store_compatibility,
check_inferencer_graph_store_compatibility,
],
GiGLComponents.Trainer.value: [
check_trainer_graph_store_compatibility,
check_inferencer_graph_store_compatibility,
],
GiGLComponents.Inferencer.value: [
check_inferencer_graph_store_compatibility,
],
# PostProcessor doesn't need graph store compatibility checks
}

# Map of (start, stop) component tuples to graph store compatibility checks

STOP_COMPONENT_TO_GRAPH_STORE_COMPATIBILITY_CHECKS_TO_SKIP = {
GiGLComponents.Trainer.value: [
check_inferencer_graph_store_compatibility,
],
}


def _run_gbml_and_resource_config_compatibility_checks(
start_at: str,
stop_after: Optional[str],
gbml_config_pb_wrapper: GbmlConfigPbWrapper,
resource_config_wrapper: GiglResourceConfigWrapper,
) -> None:
"""
Run compatibility checks between GbmlConfig and GiglResourceConfig.

These checks verify that graph store mode configurations are consistent
across both the template config (GbmlConfig) and resource config (GiglResourceConfig).

Args:
start_at: The component to start at.
stop_after: Optional component to stop after.
gbml_config_pb_wrapper: The GbmlConfig wrapper (template config).
resource_config_wrapper: The GiglResourceConfig wrapper (resource config).
"""
# Get the appropriate compatibility checks based on start/stop components
compatibility_checks = set(
START_COMPONENT_TO_GRAPH_STORE_COMPATIBILITY_CHECKS.get(start_at, [])
)
if stop_after in STOP_COMPONENT_TO_GRAPH_STORE_COMPATIBILITY_CHECKS_TO_SKIP:
for skipped_check in STOP_COMPONENT_TO_GRAPH_STORE_COMPATIBILITY_CHECKS_TO_SKIP[
stop_after
]:
compatibility_checks.discard(skipped_check)

for check in compatibility_checks:
check(
gbml_config_pb_wrapper=gbml_config_pb_wrapper,
resource_config_wrapper=resource_config_wrapper,
)


def kfp_validation_checks(
job_name: str,
Expand Down Expand Up @@ -261,6 +338,15 @@ def kfp_validation_checks(
f"Skipping resource config check {resource_config_check.__name__} because we are using live subgraph sampling backend."
)

# check compatibility between template config and resource config for graph store mode
# These checks ensure that if graph store mode is enabled in one config, it's also enabled in the other
_run_gbml_and_resource_config_compatibility_checks(
start_at=start_at,
stop_after=stop_after,
gbml_config_pb_wrapper=gbml_config_pb_wrapper,
resource_config_wrapper=resource_config_wrapper,
)

# check if trained model file exist when skipping training
if gbml_config_pb.shared_config.should_skip_training == True:
assert_trained_model_exists(gbml_config_pb=gbml_config_pb)
Expand Down
Loading