From e5656bbdfa41024c03cef7d9d877971ef47d04d2 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Thu, 22 Jan 2026 18:19:41 +0000 Subject: [PATCH 1/4] Add defs for custom storage main --- .../snapchat/research/gbml/gbml_config.proto | 19 + .../snapchat/research/gbml/gbml_config_pb2.py | 72 ++-- .../research/gbml/gbml_config_pb2.pyi | 62 ++- .../gbml/gbml_config/GbmlConfig.scala | 398 +++++++++++++++++- .../gbml/gbml_config/GbmlConfigProto.scala | 70 +-- .../gbml/gbml_config/GbmlConfig.scala | 398 +++++++++++++++++- .../gbml/gbml_config/GbmlConfigProto.scala | 70 +-- 7 files changed, 967 insertions(+), 122 deletions(-) diff --git a/proto/snapchat/research/gbml/gbml_config.proto b/proto/snapchat/research/gbml/gbml_config.proto index f06bd3d45..fc92e1afe 100644 --- a/proto/snapchat/research/gbml/gbml_config.proto +++ b/proto/snapchat/research/gbml/gbml_config.proto @@ -173,6 +173,17 @@ message GbmlConfig { } } + // Configuration for GraphStore storage. + message GraphStoreStorageConfig { + // Command to use for launching storage job. + // e.g. "python -m gigl.distributed.graph_store.storage_main". + string storage_command = 1; + // Arguments to instantiate concrete BaseStorage instance with. + // Will be appended to the storage_command. + // e.g. {"my_dataset_uri": "gs://my_dataset"} -> "python -m gigl.distributed.graph_store.storage_main --my_dataset_uri=gs://my_dataset" + map storage_args = 2; + } + message TrainerConfig { // (deprecated) // Uri pointing to user-written BaseTrainer class definition. Used for the subgraph-sampling-based training process. @@ -188,6 +199,10 @@ message GbmlConfig { } // Weather to log to tensorboard or not (defaults to false) bool should_log_to_tensorboard = 12; + + // Configuration for GraphStore storage. + // If setup, then GiGLResourceConfig.trainer_resource_config.vertex_ai_graph_store_trainer_config must be set. + GraphStoreStorageConfig graph_store_storage_config = 13; } message InferencerConfig { @@ -205,6 +220,10 @@ message GbmlConfig { // Optional. If set, will be used to batch inference samples to a specific size before call for inference is made // Defaults to setting in python/gigl/src/inference/gnn_inferencer.py uint32 inference_batch_size = 5; + + // Configuration for GraphStore storage. + // If setup, then GiGLResourceConfig.inferencer_resource_config.vertex_ai_graph_store_inferencer_config must be set. + GraphStoreStorageConfig graph_store_storage_config = 6; } message PostProcessorConfig { diff --git a/python/snapchat/research/gbml/gbml_config_pb2.py b/python/snapchat/research/gbml/gbml_config_pb2.py index 5188f1f4d..6f983bfa4 100644 --- a/python/snapchat/research/gbml/gbml_config_pb2.py +++ b/python/snapchat/research/gbml/gbml_config_pb2.py @@ -21,7 +21,7 @@ from snapchat.research.gbml import subgraph_sampling_strategy_pb2 as snapchat_dot_research_dot_gbml_dot_subgraph__sampling__strategy__pb2 -DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n(snapchat/research/gbml/gbml_config.proto\x12\x16snapchat.research.gbml\x1a)snapchat/research/gbml/graph_schema.proto\x1a\x35snapchat/research/gbml/flattened_graph_metadata.proto\x1a-snapchat/research/gbml/dataset_metadata.proto\x1a\x33snapchat/research/gbml/trained_model_metadata.proto\x1a/snapchat/research/gbml/inference_metadata.proto\x1a\x33snapchat/research/gbml/postprocessed_metadata.proto\x1a\x37snapchat/research/gbml/subgraph_sampling_strategy.proto\"\xe7+\n\nGbmlConfig\x12\x46\n\rtask_metadata\x18\x01 \x01(\x0b\x32/.snapchat.research.gbml.GbmlConfig.TaskMetadata\x12=\n\x0egraph_metadata\x18\x02 \x01(\x0b\x32%.snapchat.research.gbml.GraphMetadata\x12\x46\n\rshared_config\x18\x03 \x01(\x0b\x32/.snapchat.research.gbml.GbmlConfig.SharedConfig\x12H\n\x0e\x64\x61taset_config\x18\x04 \x01(\x0b\x32\x30.snapchat.research.gbml.GbmlConfig.DatasetConfig\x12H\n\x0etrainer_config\x18\x05 \x01(\x0b\x32\x30.snapchat.research.gbml.GbmlConfig.TrainerConfig\x12N\n\x11inferencer_config\x18\x06 \x01(\x0b\x32\x33.snapchat.research.gbml.GbmlConfig.InferencerConfig\x12U\n\x15post_processor_config\x18\t \x01(\x0b\x32\x36.snapchat.research.gbml.GbmlConfig.PostProcessorConfig\x12H\n\x0emetrics_config\x18\x07 \x01(\x0b\x32\x30.snapchat.research.gbml.GbmlConfig.MetricsConfig\x12J\n\x0fprofiler_config\x18\x08 \x01(\x0b\x32\x31.snapchat.research.gbml.GbmlConfig.ProfilerConfig\x12K\n\rfeature_flags\x18\n \x03(\x0b\x32\x34.snapchat.research.gbml.GbmlConfig.FeatureFlagsEntry\x1a\x8f\x05\n\x0cTaskMetadata\x12i\n\x18node_based_task_metadata\x18\x01 \x01(\x0b\x32\x45.snapchat.research.gbml.GbmlConfig.TaskMetadata.NodeBasedTaskMetadataH\x00\x12\x94\x01\n/node_anchor_based_link_prediction_task_metadata\x18\x02 \x01(\x0b\x32Y.snapchat.research.gbml.GbmlConfig.TaskMetadata.NodeAnchorBasedLinkPredictionTaskMetadataH\x00\x12i\n\x18link_based_task_metadata\x18\x03 \x01(\x0b\x32\x45.snapchat.research.gbml.GbmlConfig.TaskMetadata.LinkBasedTaskMetadataH\x00\x1a\x37\n\x15NodeBasedTaskMetadata\x12\x1e\n\x16supervision_node_types\x18\x01 \x03(\t\x1am\n)NodeAnchorBasedLinkPredictionTaskMetadata\x12@\n\x16supervision_edge_types\x18\x01 \x03(\x0b\x32 .snapchat.research.gbml.EdgeType\x1aY\n\x15LinkBasedTaskMetadata\x12@\n\x16supervision_edge_types\x18\x01 \x03(\x0b\x32 .snapchat.research.gbml.EdgeTypeB\x0f\n\rtask_metadata\x1a\x96\x06\n\x0cSharedConfig\x12!\n\x19preprocessed_metadata_uri\x18\x01 \x01(\t\x12P\n\x18\x66lattened_graph_metadata\x18\x02 \x01(\x0b\x32..snapchat.research.gbml.FlattenedGraphMetadata\x12\x41\n\x10\x64\x61taset_metadata\x18\x03 \x01(\x0b\x32\'.snapchat.research.gbml.DatasetMetadata\x12L\n\x16trained_model_metadata\x18\x04 \x01(\x0b\x32,.snapchat.research.gbml.TrainedModelMetadata\x12\x45\n\x12inference_metadata\x18\x05 \x01(\x0b\x32).snapchat.research.gbml.InferenceMetadata\x12M\n\x16postprocessed_metadata\x18\x0c \x01(\x0b\x32-.snapchat.research.gbml.PostProcessedMetadata\x12T\n\x0bshared_args\x18\x06 \x03(\x0b\x32?.snapchat.research.gbml.GbmlConfig.SharedConfig.SharedArgsEntry\x12\x19\n\x11is_graph_directed\x18\x07 \x01(\x08\x12\x1c\n\x14should_skip_training\x18\x08 \x01(\x08\x12\x30\n(should_skip_automatic_temp_asset_cleanup\x18\t \x01(\x08\x12\x1d\n\x15should_skip_inference\x18\n \x01(\x08\x12$\n\x1cshould_skip_model_evaluation\x18\x0b \x01(\x08\x12\x31\n)should_include_isolated_nodes_in_training\x18\r \x01(\x08\x1a\x31\n\x0fSharedArgsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x1a\xd3\x0c\n\rDatasetConfig\x12i\n\x18\x64\x61ta_preprocessor_config\x18\x01 \x01(\x0b\x32G.snapchat.research.gbml.GbmlConfig.DatasetConfig.DataPreprocessorConfig\x12g\n\x17subgraph_sampler_config\x18\x02 \x01(\x0b\x32\x46.snapchat.research.gbml.GbmlConfig.DatasetConfig.SubgraphSamplerConfig\x12\x65\n\x16split_generator_config\x18\x03 \x01(\x0b\x32\x45.snapchat.research.gbml.GbmlConfig.DatasetConfig.SplitGeneratorConfig\x1a\x84\x02\n\x16\x44\x61taPreprocessorConfig\x12)\n!data_preprocessor_config_cls_path\x18\x01 \x01(\t\x12\x81\x01\n\x16\x64\x61ta_preprocessor_args\x18\x02 \x03(\x0b\x32\x61.snapchat.research.gbml.GbmlConfig.DatasetConfig.DataPreprocessorConfig.DataPreprocessorArgsEntry\x1a;\n\x19\x44\x61taPreprocessorArgsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x1a\xd0\x04\n\x15SubgraphSamplerConfig\x12\x14\n\x08num_hops\x18\x01 \x01(\rB\x02\x18\x01\x12#\n\x17num_neighbors_to_sample\x18\x02 \x01(\x05\x42\x02\x18\x01\x12T\n\x1asubgraph_sampling_strategy\x18\n \x01(\x0b\x32\x30.snapchat.research.gbml.SubgraphSamplingStrategy\x12\x1c\n\x14num_positive_samples\x18\x03 \x01(\r\x12y\n\x12\x65xperimental_flags\x18\x05 \x03(\x0b\x32].snapchat.research.gbml.GbmlConfig.DatasetConfig.SubgraphSamplerConfig.ExperimentalFlagsEntry\x12*\n\"num_max_training_samples_to_output\x18\x06 \x01(\r\x12-\n!num_user_defined_positive_samples\x18\x07 \x01(\rB\x02\x18\x01\x12-\n!num_user_defined_negative_samples\x18\x08 \x01(\rB\x02\x18\x01\x12I\n\x0fgraph_db_config\x18\t \x01(\x0b\x32\x30.snapchat.research.gbml.GbmlConfig.GraphDBConfig\x1a\x38\n\x16\x45xperimentalFlagsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x1a\xac\x03\n\x14SplitGeneratorConfig\x12\x1f\n\x17split_strategy_cls_path\x18\x01 \x01(\t\x12y\n\x13split_strategy_args\x18\x02 \x03(\x0b\x32\\.snapchat.research.gbml.GbmlConfig.DatasetConfig.SplitGeneratorConfig.SplitStrategyArgsEntry\x12\x19\n\x11\x61ssigner_cls_path\x18\x03 \x01(\t\x12n\n\rassigner_args\x18\x04 \x03(\x0b\x32W.snapchat.research.gbml.GbmlConfig.DatasetConfig.SplitGeneratorConfig.AssignerArgsEntry\x1a\x38\n\x16SplitStrategyArgsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x1a\x33\n\x11\x41ssignerArgsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x1a\x90\x04\n\rGraphDBConfig\x12#\n\x1bgraph_db_ingestion_cls_path\x18\x01 \x01(\t\x12k\n\x17graph_db_ingestion_args\x18\x02 \x03(\x0b\x32J.snapchat.research.gbml.GbmlConfig.GraphDBConfig.GraphDbIngestionArgsEntry\x12X\n\rgraph_db_args\x18\x03 \x03(\x0b\x32\x41.snapchat.research.gbml.GbmlConfig.GraphDBConfig.GraphDbArgsEntry\x12\x66\n\x17graph_db_sampler_config\x18\x04 \x01(\x0b\x32\x45.snapchat.research.gbml.GbmlConfig.GraphDBConfig.GraphDBServiceConfig\x1a;\n\x19GraphDbIngestionArgsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x1a\x32\n\x10GraphDbArgsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x1a:\n\x14GraphDBServiceConfig\x12\"\n\x1agraph_db_client_class_path\x18\x01 \x01(\t\x1a\x8e\x02\n\rTrainerConfig\x12\x18\n\x10trainer_cls_path\x18\x01 \x01(\t\x12W\n\x0ctrainer_args\x18\x02 \x03(\x0b\x32\x41.snapchat.research.gbml.GbmlConfig.TrainerConfig.TrainerArgsEntry\x12\x12\n\x08\x63ls_path\x18\x64 \x01(\tH\x00\x12\x11\n\x07\x63ommand\x18\x65 \x01(\tH\x00\x12!\n\x19should_log_to_tensorboard\x18\x0c \x01(\x08\x1a\x32\n\x10TrainerArgsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x42\x0c\n\nexecutable\x1a\x9b\x02\n\x10InferencerConfig\x12`\n\x0finferencer_args\x18\x01 \x03(\x0b\x32G.snapchat.research.gbml.GbmlConfig.InferencerConfig.InferencerArgsEntry\x12\x1b\n\x13inferencer_cls_path\x18\x02 \x01(\t\x12\x12\n\x08\x63ls_path\x18\x64 \x01(\tH\x00\x12\x11\n\x07\x63ommand\x18\x65 \x01(\tH\x00\x12\x1c\n\x14inference_batch_size\x18\x05 \x01(\r\x1a\x35\n\x13InferencerArgsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x42\x0c\n\nexecutable\x1a\xdc\x01\n\x13PostProcessorConfig\x12j\n\x13post_processor_args\x18\x01 \x03(\x0b\x32M.snapchat.research.gbml.GbmlConfig.PostProcessorConfig.PostProcessorArgsEntry\x12\x1f\n\x17post_processor_cls_path\x18\x02 \x01(\t\x1a\x38\n\x16PostProcessorArgsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x1a\xb6\x01\n\rMetricsConfig\x12\x18\n\x10metrics_cls_path\x18\x01 \x01(\t\x12W\n\x0cmetrics_args\x18\x02 \x03(\x0b\x32\x41.snapchat.research.gbml.GbmlConfig.MetricsConfig.MetricsArgsEntry\x1a\x32\n\x10MetricsArgsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x1a\xdb\x01\n\x0eProfilerConfig\x12\x1e\n\x16should_enable_profiler\x18\x01 \x01(\x08\x12\x18\n\x10profiler_log_dir\x18\x02 \x01(\t\x12Z\n\rprofiler_args\x18\x03 \x03(\x0b\x32\x43.snapchat.research.gbml.GbmlConfig.ProfilerConfig.ProfilerArgsEntry\x1a\x33\n\x11ProfilerArgsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x1a\x33\n\x11\x46\x65\x61tureFlagsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x62\x06proto3') +DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n(snapchat/research/gbml/gbml_config.proto\x12\x16snapchat.research.gbml\x1a)snapchat/research/gbml/graph_schema.proto\x1a\x35snapchat/research/gbml/flattened_graph_metadata.proto\x1a-snapchat/research/gbml/dataset_metadata.proto\x1a\x33snapchat/research/gbml/trained_model_metadata.proto\x1a/snapchat/research/gbml/inference_metadata.proto\x1a\x33snapchat/research/gbml/postprocessed_metadata.proto\x1a\x37snapchat/research/gbml/subgraph_sampling_strategy.proto\"\xf3.\n\nGbmlConfig\x12\x46\n\rtask_metadata\x18\x01 \x01(\x0b\x32/.snapchat.research.gbml.GbmlConfig.TaskMetadata\x12=\n\x0egraph_metadata\x18\x02 \x01(\x0b\x32%.snapchat.research.gbml.GraphMetadata\x12\x46\n\rshared_config\x18\x03 \x01(\x0b\x32/.snapchat.research.gbml.GbmlConfig.SharedConfig\x12H\n\x0e\x64\x61taset_config\x18\x04 \x01(\x0b\x32\x30.snapchat.research.gbml.GbmlConfig.DatasetConfig\x12H\n\x0etrainer_config\x18\x05 \x01(\x0b\x32\x30.snapchat.research.gbml.GbmlConfig.TrainerConfig\x12N\n\x11inferencer_config\x18\x06 \x01(\x0b\x32\x33.snapchat.research.gbml.GbmlConfig.InferencerConfig\x12U\n\x15post_processor_config\x18\t \x01(\x0b\x32\x36.snapchat.research.gbml.GbmlConfig.PostProcessorConfig\x12H\n\x0emetrics_config\x18\x07 \x01(\x0b\x32\x30.snapchat.research.gbml.GbmlConfig.MetricsConfig\x12J\n\x0fprofiler_config\x18\x08 \x01(\x0b\x32\x31.snapchat.research.gbml.GbmlConfig.ProfilerConfig\x12K\n\rfeature_flags\x18\n \x03(\x0b\x32\x34.snapchat.research.gbml.GbmlConfig.FeatureFlagsEntry\x1a\x8f\x05\n\x0cTaskMetadata\x12i\n\x18node_based_task_metadata\x18\x01 \x01(\x0b\x32\x45.snapchat.research.gbml.GbmlConfig.TaskMetadata.NodeBasedTaskMetadataH\x00\x12\x94\x01\n/node_anchor_based_link_prediction_task_metadata\x18\x02 \x01(\x0b\x32Y.snapchat.research.gbml.GbmlConfig.TaskMetadata.NodeAnchorBasedLinkPredictionTaskMetadataH\x00\x12i\n\x18link_based_task_metadata\x18\x03 \x01(\x0b\x32\x45.snapchat.research.gbml.GbmlConfig.TaskMetadata.LinkBasedTaskMetadataH\x00\x1a\x37\n\x15NodeBasedTaskMetadata\x12\x1e\n\x16supervision_node_types\x18\x01 \x03(\t\x1am\n)NodeAnchorBasedLinkPredictionTaskMetadata\x12@\n\x16supervision_edge_types\x18\x01 \x03(\x0b\x32 .snapchat.research.gbml.EdgeType\x1aY\n\x15LinkBasedTaskMetadata\x12@\n\x16supervision_edge_types\x18\x01 \x03(\x0b\x32 .snapchat.research.gbml.EdgeTypeB\x0f\n\rtask_metadata\x1a\x96\x06\n\x0cSharedConfig\x12!\n\x19preprocessed_metadata_uri\x18\x01 \x01(\t\x12P\n\x18\x66lattened_graph_metadata\x18\x02 \x01(\x0b\x32..snapchat.research.gbml.FlattenedGraphMetadata\x12\x41\n\x10\x64\x61taset_metadata\x18\x03 \x01(\x0b\x32\'.snapchat.research.gbml.DatasetMetadata\x12L\n\x16trained_model_metadata\x18\x04 \x01(\x0b\x32,.snapchat.research.gbml.TrainedModelMetadata\x12\x45\n\x12inference_metadata\x18\x05 \x01(\x0b\x32).snapchat.research.gbml.InferenceMetadata\x12M\n\x16postprocessed_metadata\x18\x0c \x01(\x0b\x32-.snapchat.research.gbml.PostProcessedMetadata\x12T\n\x0bshared_args\x18\x06 \x03(\x0b\x32?.snapchat.research.gbml.GbmlConfig.SharedConfig.SharedArgsEntry\x12\x19\n\x11is_graph_directed\x18\x07 \x01(\x08\x12\x1c\n\x14should_skip_training\x18\x08 \x01(\x08\x12\x30\n(should_skip_automatic_temp_asset_cleanup\x18\t \x01(\x08\x12\x1d\n\x15should_skip_inference\x18\n \x01(\x08\x12$\n\x1cshould_skip_model_evaluation\x18\x0b \x01(\x08\x12\x31\n)should_include_isolated_nodes_in_training\x18\r \x01(\x08\x1a\x31\n\x0fSharedArgsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x1a\xd3\x0c\n\rDatasetConfig\x12i\n\x18\x64\x61ta_preprocessor_config\x18\x01 \x01(\x0b\x32G.snapchat.research.gbml.GbmlConfig.DatasetConfig.DataPreprocessorConfig\x12g\n\x17subgraph_sampler_config\x18\x02 \x01(\x0b\x32\x46.snapchat.research.gbml.GbmlConfig.DatasetConfig.SubgraphSamplerConfig\x12\x65\n\x16split_generator_config\x18\x03 \x01(\x0b\x32\x45.snapchat.research.gbml.GbmlConfig.DatasetConfig.SplitGeneratorConfig\x1a\x84\x02\n\x16\x44\x61taPreprocessorConfig\x12)\n!data_preprocessor_config_cls_path\x18\x01 \x01(\t\x12\x81\x01\n\x16\x64\x61ta_preprocessor_args\x18\x02 \x03(\x0b\x32\x61.snapchat.research.gbml.GbmlConfig.DatasetConfig.DataPreprocessorConfig.DataPreprocessorArgsEntry\x1a;\n\x19\x44\x61taPreprocessorArgsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x1a\xd0\x04\n\x15SubgraphSamplerConfig\x12\x14\n\x08num_hops\x18\x01 \x01(\rB\x02\x18\x01\x12#\n\x17num_neighbors_to_sample\x18\x02 \x01(\x05\x42\x02\x18\x01\x12T\n\x1asubgraph_sampling_strategy\x18\n \x01(\x0b\x32\x30.snapchat.research.gbml.SubgraphSamplingStrategy\x12\x1c\n\x14num_positive_samples\x18\x03 \x01(\r\x12y\n\x12\x65xperimental_flags\x18\x05 \x03(\x0b\x32].snapchat.research.gbml.GbmlConfig.DatasetConfig.SubgraphSamplerConfig.ExperimentalFlagsEntry\x12*\n\"num_max_training_samples_to_output\x18\x06 \x01(\r\x12-\n!num_user_defined_positive_samples\x18\x07 \x01(\rB\x02\x18\x01\x12-\n!num_user_defined_negative_samples\x18\x08 \x01(\rB\x02\x18\x01\x12I\n\x0fgraph_db_config\x18\t \x01(\x0b\x32\x30.snapchat.research.gbml.GbmlConfig.GraphDBConfig\x1a\x38\n\x16\x45xperimentalFlagsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x1a\xac\x03\n\x14SplitGeneratorConfig\x12\x1f\n\x17split_strategy_cls_path\x18\x01 \x01(\t\x12y\n\x13split_strategy_args\x18\x02 \x03(\x0b\x32\\.snapchat.research.gbml.GbmlConfig.DatasetConfig.SplitGeneratorConfig.SplitStrategyArgsEntry\x12\x19\n\x11\x61ssigner_cls_path\x18\x03 \x01(\t\x12n\n\rassigner_args\x18\x04 \x03(\x0b\x32W.snapchat.research.gbml.GbmlConfig.DatasetConfig.SplitGeneratorConfig.AssignerArgsEntry\x1a\x38\n\x16SplitStrategyArgsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x1a\x33\n\x11\x41ssignerArgsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x1a\x90\x04\n\rGraphDBConfig\x12#\n\x1bgraph_db_ingestion_cls_path\x18\x01 \x01(\t\x12k\n\x17graph_db_ingestion_args\x18\x02 \x03(\x0b\x32J.snapchat.research.gbml.GbmlConfig.GraphDBConfig.GraphDbIngestionArgsEntry\x12X\n\rgraph_db_args\x18\x03 \x03(\x0b\x32\x41.snapchat.research.gbml.GbmlConfig.GraphDBConfig.GraphDbArgsEntry\x12\x66\n\x17graph_db_sampler_config\x18\x04 \x01(\x0b\x32\x45.snapchat.research.gbml.GbmlConfig.GraphDBConfig.GraphDBServiceConfig\x1a;\n\x19GraphDbIngestionArgsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x1a\x32\n\x10GraphDbArgsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x1a:\n\x14GraphDBServiceConfig\x12\"\n\x1agraph_db_client_class_path\x18\x01 \x01(\t\x1a\xc9\x01\n\x17GraphStoreStorageConfig\x12\x17\n\x0fstorage_command\x18\x01 \x01(\t\x12\x61\n\x0cstorage_args\x18\x02 \x03(\x0b\x32K.snapchat.research.gbml.GbmlConfig.GraphStoreStorageConfig.StorageArgsEntry\x1a\x32\n\x10StorageArgsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x1a\xee\x02\n\rTrainerConfig\x12\x18\n\x10trainer_cls_path\x18\x01 \x01(\t\x12W\n\x0ctrainer_args\x18\x02 \x03(\x0b\x32\x41.snapchat.research.gbml.GbmlConfig.TrainerConfig.TrainerArgsEntry\x12\x12\n\x08\x63ls_path\x18\x64 \x01(\tH\x00\x12\x11\n\x07\x63ommand\x18\x65 \x01(\tH\x00\x12!\n\x19should_log_to_tensorboard\x18\x0c \x01(\x08\x12^\n\x1agraph_store_storage_config\x18\r \x01(\x0b\x32:.snapchat.research.gbml.GbmlConfig.GraphStoreStorageConfig\x1a\x32\n\x10TrainerArgsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x42\x0c\n\nexecutable\x1a\xfb\x02\n\x10InferencerConfig\x12`\n\x0finferencer_args\x18\x01 \x03(\x0b\x32G.snapchat.research.gbml.GbmlConfig.InferencerConfig.InferencerArgsEntry\x12\x1b\n\x13inferencer_cls_path\x18\x02 \x01(\t\x12\x12\n\x08\x63ls_path\x18\x64 \x01(\tH\x00\x12\x11\n\x07\x63ommand\x18\x65 \x01(\tH\x00\x12\x1c\n\x14inference_batch_size\x18\x05 \x01(\r\x12^\n\x1agraph_store_storage_config\x18\x06 \x01(\x0b\x32:.snapchat.research.gbml.GbmlConfig.GraphStoreStorageConfig\x1a\x35\n\x13InferencerArgsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x42\x0c\n\nexecutable\x1a\xdc\x01\n\x13PostProcessorConfig\x12j\n\x13post_processor_args\x18\x01 \x03(\x0b\x32M.snapchat.research.gbml.GbmlConfig.PostProcessorConfig.PostProcessorArgsEntry\x12\x1f\n\x17post_processor_cls_path\x18\x02 \x01(\t\x1a\x38\n\x16PostProcessorArgsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x1a\xb6\x01\n\rMetricsConfig\x12\x18\n\x10metrics_cls_path\x18\x01 \x01(\t\x12W\n\x0cmetrics_args\x18\x02 \x03(\x0b\x32\x41.snapchat.research.gbml.GbmlConfig.MetricsConfig.MetricsArgsEntry\x1a\x32\n\x10MetricsArgsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x1a\xdb\x01\n\x0eProfilerConfig\x12\x1e\n\x16should_enable_profiler\x18\x01 \x01(\x08\x12\x18\n\x10profiler_log_dir\x18\x02 \x01(\t\x12Z\n\rprofiler_args\x18\x03 \x03(\x0b\x32\x43.snapchat.research.gbml.GbmlConfig.ProfilerConfig.ProfilerArgsEntry\x1a\x33\n\x11ProfilerArgsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x1a\x33\n\x11\x46\x65\x61tureFlagsEntry\x12\x0b\n\x03key\x18\x01 \x01(\t\x12\r\n\x05value\x18\x02 \x01(\t:\x02\x38\x01\x62\x06proto3') @@ -44,6 +44,8 @@ _GBMLCONFIG_GRAPHDBCONFIG_GRAPHDBINGESTIONARGSENTRY = _GBMLCONFIG_GRAPHDBCONFIG.nested_types_by_name['GraphDbIngestionArgsEntry'] _GBMLCONFIG_GRAPHDBCONFIG_GRAPHDBARGSENTRY = _GBMLCONFIG_GRAPHDBCONFIG.nested_types_by_name['GraphDbArgsEntry'] _GBMLCONFIG_GRAPHDBCONFIG_GRAPHDBSERVICECONFIG = _GBMLCONFIG_GRAPHDBCONFIG.nested_types_by_name['GraphDBServiceConfig'] +_GBMLCONFIG_GRAPHSTORESTORAGECONFIG = _GBMLCONFIG.nested_types_by_name['GraphStoreStorageConfig'] +_GBMLCONFIG_GRAPHSTORESTORAGECONFIG_STORAGEARGSENTRY = _GBMLCONFIG_GRAPHSTORESTORAGECONFIG.nested_types_by_name['StorageArgsEntry'] _GBMLCONFIG_TRAINERCONFIG = _GBMLCONFIG.nested_types_by_name['TrainerConfig'] _GBMLCONFIG_TRAINERCONFIG_TRAINERARGSENTRY = _GBMLCONFIG_TRAINERCONFIG.nested_types_by_name['TrainerArgsEntry'] _GBMLCONFIG_INFERENCERCONFIG = _GBMLCONFIG.nested_types_by_name['InferencerConfig'] @@ -183,6 +185,20 @@ }) , + 'GraphStoreStorageConfig' : _reflection.GeneratedProtocolMessageType('GraphStoreStorageConfig', (_message.Message,), { + + 'StorageArgsEntry' : _reflection.GeneratedProtocolMessageType('StorageArgsEntry', (_message.Message,), { + 'DESCRIPTOR' : _GBMLCONFIG_GRAPHSTORESTORAGECONFIG_STORAGEARGSENTRY, + '__module__' : 'snapchat.research.gbml.gbml_config_pb2' + # @@protoc_insertion_point(class_scope:snapchat.research.gbml.GbmlConfig.GraphStoreStorageConfig.StorageArgsEntry) + }) + , + 'DESCRIPTOR' : _GBMLCONFIG_GRAPHSTORESTORAGECONFIG, + '__module__' : 'snapchat.research.gbml.gbml_config_pb2' + # @@protoc_insertion_point(class_scope:snapchat.research.gbml.GbmlConfig.GraphStoreStorageConfig) + }) + , + 'TrainerConfig' : _reflection.GeneratedProtocolMessageType('TrainerConfig', (_message.Message,), { 'TrainerArgsEntry' : _reflection.GeneratedProtocolMessageType('TrainerArgsEntry', (_message.Message,), { @@ -282,6 +298,8 @@ _sym_db.RegisterMessage(GbmlConfig.GraphDBConfig.GraphDbIngestionArgsEntry) _sym_db.RegisterMessage(GbmlConfig.GraphDBConfig.GraphDbArgsEntry) _sym_db.RegisterMessage(GbmlConfig.GraphDBConfig.GraphDBServiceConfig) +_sym_db.RegisterMessage(GbmlConfig.GraphStoreStorageConfig) +_sym_db.RegisterMessage(GbmlConfig.GraphStoreStorageConfig.StorageArgsEntry) _sym_db.RegisterMessage(GbmlConfig.TrainerConfig) _sym_db.RegisterMessage(GbmlConfig.TrainerConfig.TrainerArgsEntry) _sym_db.RegisterMessage(GbmlConfig.InferencerConfig) @@ -319,6 +337,8 @@ _GBMLCONFIG_GRAPHDBCONFIG_GRAPHDBINGESTIONARGSENTRY._serialized_options = b'8\001' _GBMLCONFIG_GRAPHDBCONFIG_GRAPHDBARGSENTRY._options = None _GBMLCONFIG_GRAPHDBCONFIG_GRAPHDBARGSENTRY._serialized_options = b'8\001' + _GBMLCONFIG_GRAPHSTORESTORAGECONFIG_STORAGEARGSENTRY._options = None + _GBMLCONFIG_GRAPHSTORESTORAGECONFIG_STORAGEARGSENTRY._serialized_options = b'8\001' _GBMLCONFIG_TRAINERCONFIG_TRAINERARGSENTRY._options = None _GBMLCONFIG_TRAINERCONFIG_TRAINERARGSENTRY._serialized_options = b'8\001' _GBMLCONFIG_INFERENCERCONFIG_INFERENCERARGSENTRY._options = None @@ -332,7 +352,7 @@ _GBMLCONFIG_FEATUREFLAGSENTRY._options = None _GBMLCONFIG_FEATUREFLAGSENTRY._serialized_options = b'8\001' _GBMLCONFIG._serialized_start=426 - _GBMLCONFIG._serialized_end=6033 + _GBMLCONFIG._serialized_end=6429 _GBMLCONFIG_TASKMETADATA._serialized_start=1190 _GBMLCONFIG_TASKMETADATA._serialized_end=1845 _GBMLCONFIG_TASKMETADATA_NODEBASEDTASKMETADATA._serialized_start=1571 @@ -369,26 +389,30 @@ _GBMLCONFIG_GRAPHDBCONFIG_GRAPHDBARGSENTRY._serialized_end=4731 _GBMLCONFIG_GRAPHDBCONFIG_GRAPHDBSERVICECONFIG._serialized_start=4733 _GBMLCONFIG_GRAPHDBCONFIG_GRAPHDBSERVICECONFIG._serialized_end=4791 - _GBMLCONFIG_TRAINERCONFIG._serialized_start=4794 - _GBMLCONFIG_TRAINERCONFIG._serialized_end=5064 - _GBMLCONFIG_TRAINERCONFIG_TRAINERARGSENTRY._serialized_start=5000 - _GBMLCONFIG_TRAINERCONFIG_TRAINERARGSENTRY._serialized_end=5050 - _GBMLCONFIG_INFERENCERCONFIG._serialized_start=5067 - _GBMLCONFIG_INFERENCERCONFIG._serialized_end=5350 - _GBMLCONFIG_INFERENCERCONFIG_INFERENCERARGSENTRY._serialized_start=5283 - _GBMLCONFIG_INFERENCERCONFIG_INFERENCERARGSENTRY._serialized_end=5336 - _GBMLCONFIG_POSTPROCESSORCONFIG._serialized_start=5353 - _GBMLCONFIG_POSTPROCESSORCONFIG._serialized_end=5573 - _GBMLCONFIG_POSTPROCESSORCONFIG_POSTPROCESSORARGSENTRY._serialized_start=5517 - _GBMLCONFIG_POSTPROCESSORCONFIG_POSTPROCESSORARGSENTRY._serialized_end=5573 - _GBMLCONFIG_METRICSCONFIG._serialized_start=5576 - _GBMLCONFIG_METRICSCONFIG._serialized_end=5758 - _GBMLCONFIG_METRICSCONFIG_METRICSARGSENTRY._serialized_start=5708 - _GBMLCONFIG_METRICSCONFIG_METRICSARGSENTRY._serialized_end=5758 - _GBMLCONFIG_PROFILERCONFIG._serialized_start=5761 - _GBMLCONFIG_PROFILERCONFIG._serialized_end=5980 - _GBMLCONFIG_PROFILERCONFIG_PROFILERARGSENTRY._serialized_start=5929 - _GBMLCONFIG_PROFILERCONFIG_PROFILERARGSENTRY._serialized_end=5980 - _GBMLCONFIG_FEATUREFLAGSENTRY._serialized_start=5982 - _GBMLCONFIG_FEATUREFLAGSENTRY._serialized_end=6033 + _GBMLCONFIG_GRAPHSTORESTORAGECONFIG._serialized_start=4794 + _GBMLCONFIG_GRAPHSTORESTORAGECONFIG._serialized_end=4995 + _GBMLCONFIG_GRAPHSTORESTORAGECONFIG_STORAGEARGSENTRY._serialized_start=4945 + _GBMLCONFIG_GRAPHSTORESTORAGECONFIG_STORAGEARGSENTRY._serialized_end=4995 + _GBMLCONFIG_TRAINERCONFIG._serialized_start=4998 + _GBMLCONFIG_TRAINERCONFIG._serialized_end=5364 + _GBMLCONFIG_TRAINERCONFIG_TRAINERARGSENTRY._serialized_start=5300 + _GBMLCONFIG_TRAINERCONFIG_TRAINERARGSENTRY._serialized_end=5350 + _GBMLCONFIG_INFERENCERCONFIG._serialized_start=5367 + _GBMLCONFIG_INFERENCERCONFIG._serialized_end=5746 + _GBMLCONFIG_INFERENCERCONFIG_INFERENCERARGSENTRY._serialized_start=5679 + _GBMLCONFIG_INFERENCERCONFIG_INFERENCERARGSENTRY._serialized_end=5732 + _GBMLCONFIG_POSTPROCESSORCONFIG._serialized_start=5749 + _GBMLCONFIG_POSTPROCESSORCONFIG._serialized_end=5969 + _GBMLCONFIG_POSTPROCESSORCONFIG_POSTPROCESSORARGSENTRY._serialized_start=5913 + _GBMLCONFIG_POSTPROCESSORCONFIG_POSTPROCESSORARGSENTRY._serialized_end=5969 + _GBMLCONFIG_METRICSCONFIG._serialized_start=5972 + _GBMLCONFIG_METRICSCONFIG._serialized_end=6154 + _GBMLCONFIG_METRICSCONFIG_METRICSARGSENTRY._serialized_start=6104 + _GBMLCONFIG_METRICSCONFIG_METRICSARGSENTRY._serialized_end=6154 + _GBMLCONFIG_PROFILERCONFIG._serialized_start=6157 + _GBMLCONFIG_PROFILERCONFIG._serialized_end=6376 + _GBMLCONFIG_PROFILERCONFIG_PROFILERARGSENTRY._serialized_start=6325 + _GBMLCONFIG_PROFILERCONFIG_PROFILERARGSENTRY._serialized_end=6376 + _GBMLCONFIG_FEATUREFLAGSENTRY._serialized_start=6378 + _GBMLCONFIG_FEATUREFLAGSENTRY._serialized_end=6429 # @@protoc_insertion_point(module_scope) diff --git a/python/snapchat/research/gbml/gbml_config_pb2.pyi b/python/snapchat/research/gbml/gbml_config_pb2.pyi index 3ce465beb..fecb00bf2 100644 --- a/python/snapchat/research/gbml/gbml_config_pb2.pyi +++ b/python/snapchat/research/gbml/gbml_config_pb2.pyi @@ -477,6 +477,46 @@ class GbmlConfig(google.protobuf.message.Message): def HasField(self, field_name: typing_extensions.Literal["graph_db_sampler_config", b"graph_db_sampler_config"]) -> builtins.bool: ... def ClearField(self, field_name: typing_extensions.Literal["graph_db_args", b"graph_db_args", "graph_db_ingestion_args", b"graph_db_ingestion_args", "graph_db_ingestion_cls_path", b"graph_db_ingestion_cls_path", "graph_db_sampler_config", b"graph_db_sampler_config"]) -> None: ... + class GraphStoreStorageConfig(google.protobuf.message.Message): + """Configuration for GraphStore storage.""" + + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + class StorageArgsEntry(google.protobuf.message.Message): + DESCRIPTOR: google.protobuf.descriptor.Descriptor + + KEY_FIELD_NUMBER: builtins.int + VALUE_FIELD_NUMBER: builtins.int + key: builtins.str + value: builtins.str + def __init__( + self, + *, + key: builtins.str = ..., + value: builtins.str = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["key", b"key", "value", b"value"]) -> None: ... + + STORAGE_COMMAND_FIELD_NUMBER: builtins.int + STORAGE_ARGS_FIELD_NUMBER: builtins.int + storage_command: builtins.str + """Command to use for launching storage job. + e.g. "python -m gigl.distributed.graph_store.storage_main". + """ + @property + def storage_args(self) -> google.protobuf.internal.containers.ScalarMap[builtins.str, builtins.str]: + """Arguments to instantiate concrete BaseStorage instance with. + Will be appended to the storage_command. + e.g. {"my_dataset_uri": "gs://my_dataset"} -> "python -m gigl.distributed.graph_store.storage_main --my_dataset_uri=gs://my_dataset" + """ + def __init__( + self, + *, + storage_command: builtins.str = ..., + storage_args: collections.abc.Mapping[builtins.str, builtins.str] | None = ..., + ) -> None: ... + def ClearField(self, field_name: typing_extensions.Literal["storage_args", b"storage_args", "storage_command", b"storage_command"]) -> None: ... + class TrainerConfig(google.protobuf.message.Message): DESCRIPTOR: google.protobuf.descriptor.Descriptor @@ -500,6 +540,7 @@ class GbmlConfig(google.protobuf.message.Message): CLS_PATH_FIELD_NUMBER: builtins.int COMMAND_FIELD_NUMBER: builtins.int SHOULD_LOG_TO_TENSORBOARD_FIELD_NUMBER: builtins.int + GRAPH_STORE_STORAGE_CONFIG_FIELD_NUMBER: builtins.int trainer_cls_path: builtins.str """(deprecated) Uri pointing to user-written BaseTrainer class definition. Used for the subgraph-sampling-based training process. @@ -513,6 +554,11 @@ class GbmlConfig(google.protobuf.message.Message): """Command to use for launching trainer job""" should_log_to_tensorboard: builtins.bool """Weather to log to tensorboard or not (defaults to false)""" + @property + def graph_store_storage_config(self) -> global___GbmlConfig.GraphStoreStorageConfig: + """Configuration for GraphStore storage. + If setup, then GiGLResourceConfig.trainer_resource_config.vertex_ai_graph_store_trainer_config must be set. + """ def __init__( self, *, @@ -521,9 +567,10 @@ class GbmlConfig(google.protobuf.message.Message): cls_path: builtins.str = ..., command: builtins.str = ..., should_log_to_tensorboard: builtins.bool = ..., + graph_store_storage_config: global___GbmlConfig.GraphStoreStorageConfig | None = ..., ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal["cls_path", b"cls_path", "command", b"command", "executable", b"executable"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["cls_path", b"cls_path", "command", b"command", "executable", b"executable", "should_log_to_tensorboard", b"should_log_to_tensorboard", "trainer_args", b"trainer_args", "trainer_cls_path", b"trainer_cls_path"]) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["cls_path", b"cls_path", "command", b"command", "executable", b"executable", "graph_store_storage_config", b"graph_store_storage_config"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["cls_path", b"cls_path", "command", b"command", "executable", b"executable", "graph_store_storage_config", b"graph_store_storage_config", "should_log_to_tensorboard", b"should_log_to_tensorboard", "trainer_args", b"trainer_args", "trainer_cls_path", b"trainer_cls_path"]) -> None: ... def WhichOneof(self, oneof_group: typing_extensions.Literal["executable", b"executable"]) -> typing_extensions.Literal["cls_path", "command"] | None: ... class InferencerConfig(google.protobuf.message.Message): @@ -549,6 +596,7 @@ class GbmlConfig(google.protobuf.message.Message): CLS_PATH_FIELD_NUMBER: builtins.int COMMAND_FIELD_NUMBER: builtins.int INFERENCE_BATCH_SIZE_FIELD_NUMBER: builtins.int + GRAPH_STORE_STORAGE_CONFIG_FIELD_NUMBER: builtins.int @property def inferencer_args(self) -> google.protobuf.internal.containers.ScalarMap[builtins.str, builtins.str]: ... inferencer_cls_path: builtins.str @@ -563,6 +611,11 @@ class GbmlConfig(google.protobuf.message.Message): """Optional. If set, will be used to batch inference samples to a specific size before call for inference is made Defaults to setting in python/gigl/src/inference/gnn_inferencer.py """ + @property + def graph_store_storage_config(self) -> global___GbmlConfig.GraphStoreStorageConfig: + """Configuration for GraphStore storage. + If setup, then GiGLResourceConfig.inferencer_resource_config.vertex_ai_graph_store_inferencer_config must be set. + """ def __init__( self, *, @@ -571,9 +624,10 @@ class GbmlConfig(google.protobuf.message.Message): cls_path: builtins.str = ..., command: builtins.str = ..., inference_batch_size: builtins.int = ..., + graph_store_storage_config: global___GbmlConfig.GraphStoreStorageConfig | None = ..., ) -> None: ... - def HasField(self, field_name: typing_extensions.Literal["cls_path", b"cls_path", "command", b"command", "executable", b"executable"]) -> builtins.bool: ... - def ClearField(self, field_name: typing_extensions.Literal["cls_path", b"cls_path", "command", b"command", "executable", b"executable", "inference_batch_size", b"inference_batch_size", "inferencer_args", b"inferencer_args", "inferencer_cls_path", b"inferencer_cls_path"]) -> None: ... + def HasField(self, field_name: typing_extensions.Literal["cls_path", b"cls_path", "command", b"command", "executable", b"executable", "graph_store_storage_config", b"graph_store_storage_config"]) -> builtins.bool: ... + def ClearField(self, field_name: typing_extensions.Literal["cls_path", b"cls_path", "command", b"command", "executable", b"executable", "graph_store_storage_config", b"graph_store_storage_config", "inference_batch_size", b"inference_batch_size", "inferencer_args", b"inferencer_args", "inferencer_cls_path", b"inferencer_cls_path"]) -> None: ... def WhichOneof(self, oneof_group: typing_extensions.Literal["executable", b"executable"]) -> typing_extensions.Literal["cls_path", "command"] | None: ... class PostProcessorConfig(google.protobuf.message.Message): diff --git a/scala/common/src/main/scala/snapchat/research/gbml/gbml_config/GbmlConfig.scala b/scala/common/src/main/scala/snapchat/research/gbml/gbml_config/GbmlConfig.scala index b3f110643..c4c2ca892 100644 --- a/scala/common/src/main/scala/snapchat/research/gbml/gbml_config/GbmlConfig.scala +++ b/scala/common/src/main/scala/snapchat/research/gbml/gbml_config/GbmlConfig.scala @@ -306,6 +306,7 @@ object GbmlConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gb _root_.snapchat.research.gbml.gbml_config.GbmlConfig.SharedConfig, _root_.snapchat.research.gbml.gbml_config.GbmlConfig.DatasetConfig, _root_.snapchat.research.gbml.gbml_config.GbmlConfig.GraphDBConfig, + _root_.snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig, _root_.snapchat.research.gbml.gbml_config.GbmlConfig.TrainerConfig, _root_.snapchat.research.gbml.gbml_config.GbmlConfig.InferencerConfig, _root_.snapchat.research.gbml.gbml_config.GbmlConfig.PostProcessorConfig, @@ -3654,6 +3655,309 @@ object GbmlConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gb // @@protoc_insertion_point(GeneratedMessageCompanion[snapchat.research.gbml.GbmlConfig.GraphDBConfig]) } + /** Configuration for GraphStore storage. + * + * @param storageCommand + * Command to use for launching storage job. + * e.g. "python -m gigl.distributed.graph_store.storage_main". + * @param storageArgs + * Arguments to instantiate concrete BaseStorage instance with. + * Will be appended to the storage_command. + * e.g. {"my_dataset_uri": "gs://my_dataset"} -> "python -m gigl.distributed.graph_store.storage_main --my_dataset_uri=gs://my_dataset" + */ + @SerialVersionUID(0L) + final case class GraphStoreStorageConfig( + storageCommand: _root_.scala.Predef.String = "", + storageArgs: _root_.scala.collection.immutable.Map[_root_.scala.Predef.String, _root_.scala.Predef.String] = _root_.scala.collection.immutable.Map.empty, + unknownFields: _root_.scalapb.UnknownFieldSet = _root_.scalapb.UnknownFieldSet.empty + ) extends scalapb.GeneratedMessage with scalapb.lenses.Updatable[GraphStoreStorageConfig] { + @transient + private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 + private[this] def __computeSerializedSize(): _root_.scala.Int = { + var __size = 0 + + { + val __value = storageCommand + if (!__value.isEmpty) { + __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(1, __value) + } + }; + storageArgs.foreach { __item => + val __value = snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig._typemapper_storageArgs.toBase(__item) + __size += 1 + _root_.com.google.protobuf.CodedOutputStream.computeUInt32SizeNoTag(__value.serializedSize) + __value.serializedSize + } + __size += unknownFields.serializedSize + __size + } + override def serializedSize: _root_.scala.Int = { + var __size = __serializedSizeMemoized + if (__size == 0) { + __size = __computeSerializedSize() + 1 + __serializedSizeMemoized = __size + } + __size - 1 + + } + def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { + { + val __v = storageCommand + if (!__v.isEmpty) { + _output__.writeString(1, __v) + } + }; + storageArgs.foreach { __v => + val __m = snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig._typemapper_storageArgs.toBase(__v) + _output__.writeTag(2, 2) + _output__.writeUInt32NoTag(__m.serializedSize) + __m.writeTo(_output__) + }; + unknownFields.writeTo(_output__) + } + def withStorageCommand(__v: _root_.scala.Predef.String): GraphStoreStorageConfig = copy(storageCommand = __v) + def clearStorageArgs = copy(storageArgs = _root_.scala.collection.immutable.Map.empty) + def addStorageArgs(__vs: (_root_.scala.Predef.String, _root_.scala.Predef.String) *): GraphStoreStorageConfig = addAllStorageArgs(__vs) + def addAllStorageArgs(__vs: Iterable[(_root_.scala.Predef.String, _root_.scala.Predef.String)]): GraphStoreStorageConfig = copy(storageArgs = storageArgs ++ __vs) + def withStorageArgs(__v: _root_.scala.collection.immutable.Map[_root_.scala.Predef.String, _root_.scala.Predef.String]): GraphStoreStorageConfig = copy(storageArgs = __v) + def withUnknownFields(__v: _root_.scalapb.UnknownFieldSet) = copy(unknownFields = __v) + def discardUnknownFields = copy(unknownFields = _root_.scalapb.UnknownFieldSet.empty) + def getFieldByNumber(__fieldNumber: _root_.scala.Int): _root_.scala.Any = { + (__fieldNumber: @_root_.scala.unchecked) match { + case 1 => { + val __t = storageCommand + if (__t != "") __t else null + } + case 2 => storageArgs.iterator.map(snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig._typemapper_storageArgs.toBase(_)).toSeq + } + } + def getField(__field: _root_.scalapb.descriptors.FieldDescriptor): _root_.scalapb.descriptors.PValue = { + _root_.scala.Predef.require(__field.containingMessage eq companion.scalaDescriptor) + (__field.number: @_root_.scala.unchecked) match { + case 1 => _root_.scalapb.descriptors.PString(storageCommand) + case 2 => _root_.scalapb.descriptors.PRepeated(storageArgs.iterator.map(snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig._typemapper_storageArgs.toBase(_).toPMessage).toVector) + } + } + def toProtoString: _root_.scala.Predef.String = _root_.scalapb.TextFormat.printToUnicodeString(this) + def companion: snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig.type = snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig + // @@protoc_insertion_point(GeneratedMessage[snapchat.research.gbml.GbmlConfig.GraphStoreStorageConfig]) + } + + object GraphStoreStorageConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig] { + implicit def messageCompanion: scalapb.GeneratedMessageCompanion[snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig] = this + def parseFrom(`_input__`: _root_.com.google.protobuf.CodedInputStream): snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig = { + var __storageCommand: _root_.scala.Predef.String = "" + val __storageArgs: _root_.scala.collection.mutable.Builder[(_root_.scala.Predef.String, _root_.scala.Predef.String), _root_.scala.collection.immutable.Map[_root_.scala.Predef.String, _root_.scala.Predef.String]] = _root_.scala.collection.immutable.Map.newBuilder[_root_.scala.Predef.String, _root_.scala.Predef.String] + var `_unknownFields__`: _root_.scalapb.UnknownFieldSet.Builder = null + var _done__ = false + while (!_done__) { + val _tag__ = _input__.readTag() + _tag__ match { + case 0 => _done__ = true + case 10 => + __storageCommand = _input__.readStringRequireUtf8() + case 18 => + __storageArgs += snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig._typemapper_storageArgs.toCustom(_root_.scalapb.LiteParser.readMessage[snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig.StorageArgsEntry](_input__)) + case tag => + if (_unknownFields__ == null) { + _unknownFields__ = new _root_.scalapb.UnknownFieldSet.Builder() + } + _unknownFields__.parseField(tag, _input__) + } + } + snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig( + storageCommand = __storageCommand, + storageArgs = __storageArgs.result(), + unknownFields = if (_unknownFields__ == null) _root_.scalapb.UnknownFieldSet.empty else _unknownFields__.result() + ) + } + implicit def messageReads: _root_.scalapb.descriptors.Reads[snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig] = _root_.scalapb.descriptors.Reads{ + case _root_.scalapb.descriptors.PMessage(__fieldsMap) => + _root_.scala.Predef.require(__fieldsMap.keys.forall(_.containingMessage eq scalaDescriptor), "FieldDescriptor does not match message type.") + snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig( + storageCommand = __fieldsMap.get(scalaDescriptor.findFieldByNumber(1).get).map(_.as[_root_.scala.Predef.String]).getOrElse(""), + storageArgs = __fieldsMap.get(scalaDescriptor.findFieldByNumber(2).get).map(_.as[_root_.scala.Seq[snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig.StorageArgsEntry]]).getOrElse(_root_.scala.Seq.empty).iterator.map(snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig._typemapper_storageArgs.toCustom(_)).toMap + ) + case _ => throw new RuntimeException("Expected PMessage") + } + def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = snapchat.research.gbml.gbml_config.GbmlConfig.javaDescriptor.getNestedTypes().get(4) + def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = snapchat.research.gbml.gbml_config.GbmlConfig.scalaDescriptor.nestedMessages(4) + def messageCompanionForFieldNumber(__number: _root_.scala.Int): _root_.scalapb.GeneratedMessageCompanion[_] = { + var __out: _root_.scalapb.GeneratedMessageCompanion[_] = null + (__number: @_root_.scala.unchecked) match { + case 2 => __out = snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig.StorageArgsEntry + } + __out + } + lazy val nestedMessagesCompanions: Seq[_root_.scalapb.GeneratedMessageCompanion[_ <: _root_.scalapb.GeneratedMessage]] = + Seq[_root_.scalapb.GeneratedMessageCompanion[_ <: _root_.scalapb.GeneratedMessage]]( + _root_.snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig.StorageArgsEntry + ) + def enumCompanionForFieldNumber(__fieldNumber: _root_.scala.Int): _root_.scalapb.GeneratedEnumCompanion[_] = throw new MatchError(__fieldNumber) + lazy val defaultInstance = snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig( + storageCommand = "", + storageArgs = _root_.scala.collection.immutable.Map.empty + ) + @SerialVersionUID(0L) + final case class StorageArgsEntry( + key: _root_.scala.Predef.String = "", + value: _root_.scala.Predef.String = "", + unknownFields: _root_.scalapb.UnknownFieldSet = _root_.scalapb.UnknownFieldSet.empty + ) extends scalapb.GeneratedMessage with scalapb.lenses.Updatable[StorageArgsEntry] { + @transient + private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 + private[this] def __computeSerializedSize(): _root_.scala.Int = { + var __size = 0 + + { + val __value = key + if (!__value.isEmpty) { + __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(1, __value) + } + }; + + { + val __value = value + if (!__value.isEmpty) { + __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(2, __value) + } + }; + __size += unknownFields.serializedSize + __size + } + override def serializedSize: _root_.scala.Int = { + var __size = __serializedSizeMemoized + if (__size == 0) { + __size = __computeSerializedSize() + 1 + __serializedSizeMemoized = __size + } + __size - 1 + + } + def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { + { + val __v = key + if (!__v.isEmpty) { + _output__.writeString(1, __v) + } + }; + { + val __v = value + if (!__v.isEmpty) { + _output__.writeString(2, __v) + } + }; + unknownFields.writeTo(_output__) + } + def withKey(__v: _root_.scala.Predef.String): StorageArgsEntry = copy(key = __v) + def withValue(__v: _root_.scala.Predef.String): StorageArgsEntry = copy(value = __v) + def withUnknownFields(__v: _root_.scalapb.UnknownFieldSet) = copy(unknownFields = __v) + def discardUnknownFields = copy(unknownFields = _root_.scalapb.UnknownFieldSet.empty) + def getFieldByNumber(__fieldNumber: _root_.scala.Int): _root_.scala.Any = { + (__fieldNumber: @_root_.scala.unchecked) match { + case 1 => { + val __t = key + if (__t != "") __t else null + } + case 2 => { + val __t = value + if (__t != "") __t else null + } + } + } + def getField(__field: _root_.scalapb.descriptors.FieldDescriptor): _root_.scalapb.descriptors.PValue = { + _root_.scala.Predef.require(__field.containingMessage eq companion.scalaDescriptor) + (__field.number: @_root_.scala.unchecked) match { + case 1 => _root_.scalapb.descriptors.PString(key) + case 2 => _root_.scalapb.descriptors.PString(value) + } + } + def toProtoString: _root_.scala.Predef.String = _root_.scalapb.TextFormat.printToUnicodeString(this) + def companion: snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig.StorageArgsEntry.type = snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig.StorageArgsEntry + // @@protoc_insertion_point(GeneratedMessage[snapchat.research.gbml.GbmlConfig.GraphStoreStorageConfig.StorageArgsEntry]) + } + + object StorageArgsEntry extends scalapb.GeneratedMessageCompanion[snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig.StorageArgsEntry] { + implicit def messageCompanion: scalapb.GeneratedMessageCompanion[snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig.StorageArgsEntry] = this + def parseFrom(`_input__`: _root_.com.google.protobuf.CodedInputStream): snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig.StorageArgsEntry = { + var __key: _root_.scala.Predef.String = "" + var __value: _root_.scala.Predef.String = "" + var `_unknownFields__`: _root_.scalapb.UnknownFieldSet.Builder = null + var _done__ = false + while (!_done__) { + val _tag__ = _input__.readTag() + _tag__ match { + case 0 => _done__ = true + case 10 => + __key = _input__.readStringRequireUtf8() + case 18 => + __value = _input__.readStringRequireUtf8() + case tag => + if (_unknownFields__ == null) { + _unknownFields__ = new _root_.scalapb.UnknownFieldSet.Builder() + } + _unknownFields__.parseField(tag, _input__) + } + } + snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig.StorageArgsEntry( + key = __key, + value = __value, + unknownFields = if (_unknownFields__ == null) _root_.scalapb.UnknownFieldSet.empty else _unknownFields__.result() + ) + } + implicit def messageReads: _root_.scalapb.descriptors.Reads[snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig.StorageArgsEntry] = _root_.scalapb.descriptors.Reads{ + case _root_.scalapb.descriptors.PMessage(__fieldsMap) => + _root_.scala.Predef.require(__fieldsMap.keys.forall(_.containingMessage eq scalaDescriptor), "FieldDescriptor does not match message type.") + snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig.StorageArgsEntry( + key = __fieldsMap.get(scalaDescriptor.findFieldByNumber(1).get).map(_.as[_root_.scala.Predef.String]).getOrElse(""), + value = __fieldsMap.get(scalaDescriptor.findFieldByNumber(2).get).map(_.as[_root_.scala.Predef.String]).getOrElse("") + ) + case _ => throw new RuntimeException("Expected PMessage") + } + def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig.javaDescriptor.getNestedTypes().get(0) + def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig.scalaDescriptor.nestedMessages(0) + def messageCompanionForFieldNumber(__number: _root_.scala.Int): _root_.scalapb.GeneratedMessageCompanion[_] = throw new MatchError(__number) + lazy val nestedMessagesCompanions: Seq[_root_.scalapb.GeneratedMessageCompanion[_ <: _root_.scalapb.GeneratedMessage]] = Seq.empty + def enumCompanionForFieldNumber(__fieldNumber: _root_.scala.Int): _root_.scalapb.GeneratedEnumCompanion[_] = throw new MatchError(__fieldNumber) + lazy val defaultInstance = snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig.StorageArgsEntry( + key = "", + value = "" + ) + implicit class StorageArgsEntryLens[UpperPB](_l: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig.StorageArgsEntry]) extends _root_.scalapb.lenses.ObjectLens[UpperPB, snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig.StorageArgsEntry](_l) { + def key: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.Predef.String] = field(_.key)((c_, f_) => c_.copy(key = f_)) + def value: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.Predef.String] = field(_.value)((c_, f_) => c_.copy(value = f_)) + } + final val KEY_FIELD_NUMBER = 1 + final val VALUE_FIELD_NUMBER = 2 + @transient + implicit val keyValueMapper: _root_.scalapb.TypeMapper[snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig.StorageArgsEntry, (_root_.scala.Predef.String, _root_.scala.Predef.String)] = + _root_.scalapb.TypeMapper[snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig.StorageArgsEntry, (_root_.scala.Predef.String, _root_.scala.Predef.String)](__m => (__m.key, __m.value))(__p => snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig.StorageArgsEntry(__p._1, __p._2)) + def of( + key: _root_.scala.Predef.String, + value: _root_.scala.Predef.String + ): _root_.snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig.StorageArgsEntry = _root_.snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig.StorageArgsEntry( + key, + value + ) + // @@protoc_insertion_point(GeneratedMessageCompanion[snapchat.research.gbml.GbmlConfig.GraphStoreStorageConfig.StorageArgsEntry]) + } + + implicit class GraphStoreStorageConfigLens[UpperPB](_l: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig]) extends _root_.scalapb.lenses.ObjectLens[UpperPB, snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig](_l) { + def storageCommand: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.Predef.String] = field(_.storageCommand)((c_, f_) => c_.copy(storageCommand = f_)) + def storageArgs: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.collection.immutable.Map[_root_.scala.Predef.String, _root_.scala.Predef.String]] = field(_.storageArgs)((c_, f_) => c_.copy(storageArgs = f_)) + } + final val STORAGE_COMMAND_FIELD_NUMBER = 1 + final val STORAGE_ARGS_FIELD_NUMBER = 2 + @transient + private[gbml_config] val _typemapper_storageArgs: _root_.scalapb.TypeMapper[snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig.StorageArgsEntry, (_root_.scala.Predef.String, _root_.scala.Predef.String)] = implicitly[_root_.scalapb.TypeMapper[snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig.StorageArgsEntry, (_root_.scala.Predef.String, _root_.scala.Predef.String)]] + def of( + storageCommand: _root_.scala.Predef.String, + storageArgs: _root_.scala.collection.immutable.Map[_root_.scala.Predef.String, _root_.scala.Predef.String] + ): _root_.snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig = _root_.snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig( + storageCommand, + storageArgs + ) + // @@protoc_insertion_point(GeneratedMessageCompanion[snapchat.research.gbml.GbmlConfig.GraphStoreStorageConfig]) + } + /** @param trainerClsPath * (deprecated) * Uri pointing to user-written BaseTrainer class definition. Used for the subgraph-sampling-based training process. @@ -3661,6 +3965,9 @@ object GbmlConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gb * Arguments to parameterize training process with. * @param shouldLogToTensorboard * Weather to log to tensorboard or not (defaults to false) + * @param graphStoreStorageConfig + * Configuration for GraphStore storage. + * If setup, then GiGLResourceConfig.trainer_resource_config.vertex_ai_graph_store_trainer_config must be set. */ @SerialVersionUID(0L) final case class TrainerConfig( @@ -3668,6 +3975,7 @@ object GbmlConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gb trainerArgs: _root_.scala.collection.immutable.Map[_root_.scala.Predef.String, _root_.scala.Predef.String] = _root_.scala.collection.immutable.Map.empty, executable: snapchat.research.gbml.gbml_config.GbmlConfig.TrainerConfig.Executable = snapchat.research.gbml.gbml_config.GbmlConfig.TrainerConfig.Executable.Empty, shouldLogToTensorboard: _root_.scala.Boolean = false, + graphStoreStorageConfig: _root_.scala.Option[snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig] = _root_.scala.None, unknownFields: _root_.scalapb.UnknownFieldSet = _root_.scalapb.UnknownFieldSet.empty ) extends scalapb.GeneratedMessage with scalapb.lenses.Updatable[TrainerConfig] { @transient @@ -3700,6 +4008,10 @@ object GbmlConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gb __size += _root_.com.google.protobuf.CodedOutputStream.computeBoolSize(12, __value) } }; + if (graphStoreStorageConfig.isDefined) { + val __value = graphStoreStorageConfig.get + __size += 1 + _root_.com.google.protobuf.CodedOutputStream.computeUInt32SizeNoTag(__value.serializedSize) + __value.serializedSize + }; __size += unknownFields.serializedSize __size } @@ -3731,6 +4043,12 @@ object GbmlConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gb _output__.writeBool(12, __v) } }; + graphStoreStorageConfig.foreach { __v => + val __m = __v + _output__.writeTag(13, 2) + _output__.writeUInt32NoTag(__m.serializedSize) + __m.writeTo(_output__) + }; executable.clsPath.foreach { __v => val __m = __v _output__.writeString(100, __m) @@ -3751,6 +4069,9 @@ object GbmlConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gb def getCommand: _root_.scala.Predef.String = executable.command.getOrElse("") def withCommand(__v: _root_.scala.Predef.String): TrainerConfig = copy(executable = snapchat.research.gbml.gbml_config.GbmlConfig.TrainerConfig.Executable.Command(__v)) def withShouldLogToTensorboard(__v: _root_.scala.Boolean): TrainerConfig = copy(shouldLogToTensorboard = __v) + def getGraphStoreStorageConfig: snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig = graphStoreStorageConfig.getOrElse(snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig.defaultInstance) + def clearGraphStoreStorageConfig: TrainerConfig = copy(graphStoreStorageConfig = _root_.scala.None) + def withGraphStoreStorageConfig(__v: snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig): TrainerConfig = copy(graphStoreStorageConfig = Option(__v)) def clearExecutable: TrainerConfig = copy(executable = snapchat.research.gbml.gbml_config.GbmlConfig.TrainerConfig.Executable.Empty) def withExecutable(__v: snapchat.research.gbml.gbml_config.GbmlConfig.TrainerConfig.Executable): TrainerConfig = copy(executable = __v) def withUnknownFields(__v: _root_.scalapb.UnknownFieldSet) = copy(unknownFields = __v) @@ -3768,6 +4089,7 @@ object GbmlConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gb val __t = shouldLogToTensorboard if (__t != false) __t else null } + case 13 => graphStoreStorageConfig.orNull } } def getField(__field: _root_.scalapb.descriptors.FieldDescriptor): _root_.scalapb.descriptors.PValue = { @@ -3778,6 +4100,7 @@ object GbmlConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gb case 100 => executable.clsPath.map(_root_.scalapb.descriptors.PString(_)).getOrElse(_root_.scalapb.descriptors.PEmpty) case 101 => executable.command.map(_root_.scalapb.descriptors.PString(_)).getOrElse(_root_.scalapb.descriptors.PEmpty) case 12 => _root_.scalapb.descriptors.PBoolean(shouldLogToTensorboard) + case 13 => graphStoreStorageConfig.map(_.toPMessage).getOrElse(_root_.scalapb.descriptors.PEmpty) } } def toProtoString: _root_.scala.Predef.String = _root_.scalapb.TextFormat.printToUnicodeString(this) @@ -3791,6 +4114,7 @@ object GbmlConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gb var __trainerClsPath: _root_.scala.Predef.String = "" val __trainerArgs: _root_.scala.collection.mutable.Builder[(_root_.scala.Predef.String, _root_.scala.Predef.String), _root_.scala.collection.immutable.Map[_root_.scala.Predef.String, _root_.scala.Predef.String]] = _root_.scala.collection.immutable.Map.newBuilder[_root_.scala.Predef.String, _root_.scala.Predef.String] var __shouldLogToTensorboard: _root_.scala.Boolean = false + var __graphStoreStorageConfig: _root_.scala.Option[snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig] = _root_.scala.None var __executable: snapchat.research.gbml.gbml_config.GbmlConfig.TrainerConfig.Executable = snapchat.research.gbml.gbml_config.GbmlConfig.TrainerConfig.Executable.Empty var `_unknownFields__`: _root_.scalapb.UnknownFieldSet.Builder = null var _done__ = false @@ -3808,6 +4132,8 @@ object GbmlConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gb __executable = snapchat.research.gbml.gbml_config.GbmlConfig.TrainerConfig.Executable.Command(_input__.readStringRequireUtf8()) case 96 => __shouldLogToTensorboard = _input__.readBool() + case 106 => + __graphStoreStorageConfig = Option(__graphStoreStorageConfig.fold(_root_.scalapb.LiteParser.readMessage[snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig](_input__))(_root_.scalapb.LiteParser.readMessage(_input__, _))) case tag => if (_unknownFields__ == null) { _unknownFields__ = new _root_.scalapb.UnknownFieldSet.Builder() @@ -3819,6 +4145,7 @@ object GbmlConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gb trainerClsPath = __trainerClsPath, trainerArgs = __trainerArgs.result(), shouldLogToTensorboard = __shouldLogToTensorboard, + graphStoreStorageConfig = __graphStoreStorageConfig, executable = __executable, unknownFields = if (_unknownFields__ == null) _root_.scalapb.UnknownFieldSet.empty else _unknownFields__.result() ) @@ -3830,18 +4157,20 @@ object GbmlConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gb trainerClsPath = __fieldsMap.get(scalaDescriptor.findFieldByNumber(1).get).map(_.as[_root_.scala.Predef.String]).getOrElse(""), trainerArgs = __fieldsMap.get(scalaDescriptor.findFieldByNumber(2).get).map(_.as[_root_.scala.Seq[snapchat.research.gbml.gbml_config.GbmlConfig.TrainerConfig.TrainerArgsEntry]]).getOrElse(_root_.scala.Seq.empty).iterator.map(snapchat.research.gbml.gbml_config.GbmlConfig.TrainerConfig._typemapper_trainerArgs.toCustom(_)).toMap, shouldLogToTensorboard = __fieldsMap.get(scalaDescriptor.findFieldByNumber(12).get).map(_.as[_root_.scala.Boolean]).getOrElse(false), + graphStoreStorageConfig = __fieldsMap.get(scalaDescriptor.findFieldByNumber(13).get).flatMap(_.as[_root_.scala.Option[snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig]]), executable = __fieldsMap.get(scalaDescriptor.findFieldByNumber(100).get).flatMap(_.as[_root_.scala.Option[_root_.scala.Predef.String]]).map(snapchat.research.gbml.gbml_config.GbmlConfig.TrainerConfig.Executable.ClsPath(_)) .orElse[snapchat.research.gbml.gbml_config.GbmlConfig.TrainerConfig.Executable](__fieldsMap.get(scalaDescriptor.findFieldByNumber(101).get).flatMap(_.as[_root_.scala.Option[_root_.scala.Predef.String]]).map(snapchat.research.gbml.gbml_config.GbmlConfig.TrainerConfig.Executable.Command(_))) .getOrElse(snapchat.research.gbml.gbml_config.GbmlConfig.TrainerConfig.Executable.Empty) ) case _ => throw new RuntimeException("Expected PMessage") } - def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = snapchat.research.gbml.gbml_config.GbmlConfig.javaDescriptor.getNestedTypes().get(4) - def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = snapchat.research.gbml.gbml_config.GbmlConfig.scalaDescriptor.nestedMessages(4) + def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = snapchat.research.gbml.gbml_config.GbmlConfig.javaDescriptor.getNestedTypes().get(5) + def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = snapchat.research.gbml.gbml_config.GbmlConfig.scalaDescriptor.nestedMessages(5) def messageCompanionForFieldNumber(__number: _root_.scala.Int): _root_.scalapb.GeneratedMessageCompanion[_] = { var __out: _root_.scalapb.GeneratedMessageCompanion[_] = null (__number: @_root_.scala.unchecked) match { case 2 => __out = snapchat.research.gbml.gbml_config.GbmlConfig.TrainerConfig.TrainerArgsEntry + case 13 => __out = snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig } __out } @@ -3854,6 +4183,7 @@ object GbmlConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gb trainerClsPath = "", trainerArgs = _root_.scala.collection.immutable.Map.empty, shouldLogToTensorboard = false, + graphStoreStorageConfig = _root_.scala.None, executable = snapchat.research.gbml.gbml_config.GbmlConfig.TrainerConfig.Executable.Empty ) sealed trait Executable extends _root_.scalapb.GeneratedOneof { @@ -4039,6 +4369,8 @@ object GbmlConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gb def clsPath: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.Predef.String] = field(_.getClsPath)((c_, f_) => c_.copy(executable = snapchat.research.gbml.gbml_config.GbmlConfig.TrainerConfig.Executable.ClsPath(f_))) def command: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.Predef.String] = field(_.getCommand)((c_, f_) => c_.copy(executable = snapchat.research.gbml.gbml_config.GbmlConfig.TrainerConfig.Executable.Command(f_))) def shouldLogToTensorboard: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.Boolean] = field(_.shouldLogToTensorboard)((c_, f_) => c_.copy(shouldLogToTensorboard = f_)) + def graphStoreStorageConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig] = field(_.getGraphStoreStorageConfig)((c_, f_) => c_.copy(graphStoreStorageConfig = Option(f_))) + def optionalGraphStoreStorageConfig: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.Option[snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig]] = field(_.graphStoreStorageConfig)((c_, f_) => c_.copy(graphStoreStorageConfig = f_)) def executable: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gbml_config.GbmlConfig.TrainerConfig.Executable] = field(_.executable)((c_, f_) => c_.copy(executable = f_)) } final val TRAINER_CLS_PATH_FIELD_NUMBER = 1 @@ -4046,18 +4378,21 @@ object GbmlConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gb final val CLS_PATH_FIELD_NUMBER = 100 final val COMMAND_FIELD_NUMBER = 101 final val SHOULD_LOG_TO_TENSORBOARD_FIELD_NUMBER = 12 + final val GRAPH_STORE_STORAGE_CONFIG_FIELD_NUMBER = 13 @transient private[gbml_config] val _typemapper_trainerArgs: _root_.scalapb.TypeMapper[snapchat.research.gbml.gbml_config.GbmlConfig.TrainerConfig.TrainerArgsEntry, (_root_.scala.Predef.String, _root_.scala.Predef.String)] = implicitly[_root_.scalapb.TypeMapper[snapchat.research.gbml.gbml_config.GbmlConfig.TrainerConfig.TrainerArgsEntry, (_root_.scala.Predef.String, _root_.scala.Predef.String)]] def of( trainerClsPath: _root_.scala.Predef.String, trainerArgs: _root_.scala.collection.immutable.Map[_root_.scala.Predef.String, _root_.scala.Predef.String], executable: snapchat.research.gbml.gbml_config.GbmlConfig.TrainerConfig.Executable, - shouldLogToTensorboard: _root_.scala.Boolean + shouldLogToTensorboard: _root_.scala.Boolean, + graphStoreStorageConfig: _root_.scala.Option[snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig] ): _root_.snapchat.research.gbml.gbml_config.GbmlConfig.TrainerConfig = _root_.snapchat.research.gbml.gbml_config.GbmlConfig.TrainerConfig( trainerClsPath, trainerArgs, executable, - shouldLogToTensorboard + shouldLogToTensorboard, + graphStoreStorageConfig ) // @@protoc_insertion_point(GeneratedMessageCompanion[snapchat.research.gbml.GbmlConfig.TrainerConfig]) } @@ -4068,6 +4403,9 @@ object GbmlConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gb * @param inferenceBatchSize * Optional. If set, will be used to batch inference samples to a specific size before call for inference is made * Defaults to setting in python/gigl/src/inference/gnn_inferencer.py + * @param graphStoreStorageConfig + * Configuration for GraphStore storage. + * If setup, then GiGLResourceConfig.inferencer_resource_config.vertex_ai_graph_store_inferencer_config must be set. */ @SerialVersionUID(0L) final case class InferencerConfig( @@ -4075,6 +4413,7 @@ object GbmlConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gb inferencerClsPath: _root_.scala.Predef.String = "", executable: snapchat.research.gbml.gbml_config.GbmlConfig.InferencerConfig.Executable = snapchat.research.gbml.gbml_config.GbmlConfig.InferencerConfig.Executable.Empty, inferenceBatchSize: _root_.scala.Int = 0, + graphStoreStorageConfig: _root_.scala.Option[snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig] = _root_.scala.None, unknownFields: _root_.scalapb.UnknownFieldSet = _root_.scalapb.UnknownFieldSet.empty ) extends scalapb.GeneratedMessage with scalapb.lenses.Updatable[InferencerConfig] { @transient @@ -4107,6 +4446,10 @@ object GbmlConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gb __size += _root_.com.google.protobuf.CodedOutputStream.computeUInt32Size(5, __value) } }; + if (graphStoreStorageConfig.isDefined) { + val __value = graphStoreStorageConfig.get + __size += 1 + _root_.com.google.protobuf.CodedOutputStream.computeUInt32SizeNoTag(__value.serializedSize) + __value.serializedSize + }; __size += unknownFields.serializedSize __size } @@ -4138,6 +4481,12 @@ object GbmlConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gb _output__.writeUInt32(5, __v) } }; + graphStoreStorageConfig.foreach { __v => + val __m = __v + _output__.writeTag(6, 2) + _output__.writeUInt32NoTag(__m.serializedSize) + __m.writeTo(_output__) + }; executable.clsPath.foreach { __v => val __m = __v _output__.writeString(100, __m) @@ -4158,6 +4507,9 @@ object GbmlConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gb def getCommand: _root_.scala.Predef.String = executable.command.getOrElse("") def withCommand(__v: _root_.scala.Predef.String): InferencerConfig = copy(executable = snapchat.research.gbml.gbml_config.GbmlConfig.InferencerConfig.Executable.Command(__v)) def withInferenceBatchSize(__v: _root_.scala.Int): InferencerConfig = copy(inferenceBatchSize = __v) + def getGraphStoreStorageConfig: snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig = graphStoreStorageConfig.getOrElse(snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig.defaultInstance) + def clearGraphStoreStorageConfig: InferencerConfig = copy(graphStoreStorageConfig = _root_.scala.None) + def withGraphStoreStorageConfig(__v: snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig): InferencerConfig = copy(graphStoreStorageConfig = Option(__v)) def clearExecutable: InferencerConfig = copy(executable = snapchat.research.gbml.gbml_config.GbmlConfig.InferencerConfig.Executable.Empty) def withExecutable(__v: snapchat.research.gbml.gbml_config.GbmlConfig.InferencerConfig.Executable): InferencerConfig = copy(executable = __v) def withUnknownFields(__v: _root_.scalapb.UnknownFieldSet) = copy(unknownFields = __v) @@ -4175,6 +4527,7 @@ object GbmlConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gb val __t = inferenceBatchSize if (__t != 0) __t else null } + case 6 => graphStoreStorageConfig.orNull } } def getField(__field: _root_.scalapb.descriptors.FieldDescriptor): _root_.scalapb.descriptors.PValue = { @@ -4185,6 +4538,7 @@ object GbmlConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gb case 100 => executable.clsPath.map(_root_.scalapb.descriptors.PString(_)).getOrElse(_root_.scalapb.descriptors.PEmpty) case 101 => executable.command.map(_root_.scalapb.descriptors.PString(_)).getOrElse(_root_.scalapb.descriptors.PEmpty) case 5 => _root_.scalapb.descriptors.PInt(inferenceBatchSize) + case 6 => graphStoreStorageConfig.map(_.toPMessage).getOrElse(_root_.scalapb.descriptors.PEmpty) } } def toProtoString: _root_.scala.Predef.String = _root_.scalapb.TextFormat.printToUnicodeString(this) @@ -4198,6 +4552,7 @@ object GbmlConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gb val __inferencerArgs: _root_.scala.collection.mutable.Builder[(_root_.scala.Predef.String, _root_.scala.Predef.String), _root_.scala.collection.immutable.Map[_root_.scala.Predef.String, _root_.scala.Predef.String]] = _root_.scala.collection.immutable.Map.newBuilder[_root_.scala.Predef.String, _root_.scala.Predef.String] var __inferencerClsPath: _root_.scala.Predef.String = "" var __inferenceBatchSize: _root_.scala.Int = 0 + var __graphStoreStorageConfig: _root_.scala.Option[snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig] = _root_.scala.None var __executable: snapchat.research.gbml.gbml_config.GbmlConfig.InferencerConfig.Executable = snapchat.research.gbml.gbml_config.GbmlConfig.InferencerConfig.Executable.Empty var `_unknownFields__`: _root_.scalapb.UnknownFieldSet.Builder = null var _done__ = false @@ -4215,6 +4570,8 @@ object GbmlConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gb __executable = snapchat.research.gbml.gbml_config.GbmlConfig.InferencerConfig.Executable.Command(_input__.readStringRequireUtf8()) case 40 => __inferenceBatchSize = _input__.readUInt32() + case 50 => + __graphStoreStorageConfig = Option(__graphStoreStorageConfig.fold(_root_.scalapb.LiteParser.readMessage[snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig](_input__))(_root_.scalapb.LiteParser.readMessage(_input__, _))) case tag => if (_unknownFields__ == null) { _unknownFields__ = new _root_.scalapb.UnknownFieldSet.Builder() @@ -4226,6 +4583,7 @@ object GbmlConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gb inferencerArgs = __inferencerArgs.result(), inferencerClsPath = __inferencerClsPath, inferenceBatchSize = __inferenceBatchSize, + graphStoreStorageConfig = __graphStoreStorageConfig, executable = __executable, unknownFields = if (_unknownFields__ == null) _root_.scalapb.UnknownFieldSet.empty else _unknownFields__.result() ) @@ -4237,18 +4595,20 @@ object GbmlConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gb inferencerArgs = __fieldsMap.get(scalaDescriptor.findFieldByNumber(1).get).map(_.as[_root_.scala.Seq[snapchat.research.gbml.gbml_config.GbmlConfig.InferencerConfig.InferencerArgsEntry]]).getOrElse(_root_.scala.Seq.empty).iterator.map(snapchat.research.gbml.gbml_config.GbmlConfig.InferencerConfig._typemapper_inferencerArgs.toCustom(_)).toMap, inferencerClsPath = __fieldsMap.get(scalaDescriptor.findFieldByNumber(2).get).map(_.as[_root_.scala.Predef.String]).getOrElse(""), inferenceBatchSize = __fieldsMap.get(scalaDescriptor.findFieldByNumber(5).get).map(_.as[_root_.scala.Int]).getOrElse(0), + graphStoreStorageConfig = __fieldsMap.get(scalaDescriptor.findFieldByNumber(6).get).flatMap(_.as[_root_.scala.Option[snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig]]), executable = __fieldsMap.get(scalaDescriptor.findFieldByNumber(100).get).flatMap(_.as[_root_.scala.Option[_root_.scala.Predef.String]]).map(snapchat.research.gbml.gbml_config.GbmlConfig.InferencerConfig.Executable.ClsPath(_)) .orElse[snapchat.research.gbml.gbml_config.GbmlConfig.InferencerConfig.Executable](__fieldsMap.get(scalaDescriptor.findFieldByNumber(101).get).flatMap(_.as[_root_.scala.Option[_root_.scala.Predef.String]]).map(snapchat.research.gbml.gbml_config.GbmlConfig.InferencerConfig.Executable.Command(_))) .getOrElse(snapchat.research.gbml.gbml_config.GbmlConfig.InferencerConfig.Executable.Empty) ) case _ => throw new RuntimeException("Expected PMessage") } - def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = snapchat.research.gbml.gbml_config.GbmlConfig.javaDescriptor.getNestedTypes().get(5) - def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = snapchat.research.gbml.gbml_config.GbmlConfig.scalaDescriptor.nestedMessages(5) + def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = snapchat.research.gbml.gbml_config.GbmlConfig.javaDescriptor.getNestedTypes().get(6) + def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = snapchat.research.gbml.gbml_config.GbmlConfig.scalaDescriptor.nestedMessages(6) def messageCompanionForFieldNumber(__number: _root_.scala.Int): _root_.scalapb.GeneratedMessageCompanion[_] = { var __out: _root_.scalapb.GeneratedMessageCompanion[_] = null (__number: @_root_.scala.unchecked) match { case 1 => __out = snapchat.research.gbml.gbml_config.GbmlConfig.InferencerConfig.InferencerArgsEntry + case 6 => __out = snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig } __out } @@ -4261,6 +4621,7 @@ object GbmlConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gb inferencerArgs = _root_.scala.collection.immutable.Map.empty, inferencerClsPath = "", inferenceBatchSize = 0, + graphStoreStorageConfig = _root_.scala.None, executable = snapchat.research.gbml.gbml_config.GbmlConfig.InferencerConfig.Executable.Empty ) sealed trait Executable extends _root_.scalapb.GeneratedOneof { @@ -4446,6 +4807,8 @@ object GbmlConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gb def clsPath: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.Predef.String] = field(_.getClsPath)((c_, f_) => c_.copy(executable = snapchat.research.gbml.gbml_config.GbmlConfig.InferencerConfig.Executable.ClsPath(f_))) def command: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.Predef.String] = field(_.getCommand)((c_, f_) => c_.copy(executable = snapchat.research.gbml.gbml_config.GbmlConfig.InferencerConfig.Executable.Command(f_))) def inferenceBatchSize: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.Int] = field(_.inferenceBatchSize)((c_, f_) => c_.copy(inferenceBatchSize = f_)) + def graphStoreStorageConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig] = field(_.getGraphStoreStorageConfig)((c_, f_) => c_.copy(graphStoreStorageConfig = Option(f_))) + def optionalGraphStoreStorageConfig: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.Option[snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig]] = field(_.graphStoreStorageConfig)((c_, f_) => c_.copy(graphStoreStorageConfig = f_)) def executable: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gbml_config.GbmlConfig.InferencerConfig.Executable] = field(_.executable)((c_, f_) => c_.copy(executable = f_)) } final val INFERENCER_ARGS_FIELD_NUMBER = 1 @@ -4453,18 +4816,21 @@ object GbmlConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gb final val CLS_PATH_FIELD_NUMBER = 100 final val COMMAND_FIELD_NUMBER = 101 final val INFERENCE_BATCH_SIZE_FIELD_NUMBER = 5 + final val GRAPH_STORE_STORAGE_CONFIG_FIELD_NUMBER = 6 @transient private[gbml_config] val _typemapper_inferencerArgs: _root_.scalapb.TypeMapper[snapchat.research.gbml.gbml_config.GbmlConfig.InferencerConfig.InferencerArgsEntry, (_root_.scala.Predef.String, _root_.scala.Predef.String)] = implicitly[_root_.scalapb.TypeMapper[snapchat.research.gbml.gbml_config.GbmlConfig.InferencerConfig.InferencerArgsEntry, (_root_.scala.Predef.String, _root_.scala.Predef.String)]] def of( inferencerArgs: _root_.scala.collection.immutable.Map[_root_.scala.Predef.String, _root_.scala.Predef.String], inferencerClsPath: _root_.scala.Predef.String, executable: snapchat.research.gbml.gbml_config.GbmlConfig.InferencerConfig.Executable, - inferenceBatchSize: _root_.scala.Int + inferenceBatchSize: _root_.scala.Int, + graphStoreStorageConfig: _root_.scala.Option[snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig] ): _root_.snapchat.research.gbml.gbml_config.GbmlConfig.InferencerConfig = _root_.snapchat.research.gbml.gbml_config.GbmlConfig.InferencerConfig( inferencerArgs, inferencerClsPath, executable, - inferenceBatchSize + inferenceBatchSize, + graphStoreStorageConfig ) // @@protoc_insertion_point(GeneratedMessageCompanion[snapchat.research.gbml.GbmlConfig.InferencerConfig]) } @@ -4582,8 +4948,8 @@ object GbmlConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gb ) case _ => throw new RuntimeException("Expected PMessage") } - def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = snapchat.research.gbml.gbml_config.GbmlConfig.javaDescriptor.getNestedTypes().get(6) - def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = snapchat.research.gbml.gbml_config.GbmlConfig.scalaDescriptor.nestedMessages(6) + def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = snapchat.research.gbml.gbml_config.GbmlConfig.javaDescriptor.getNestedTypes().get(7) + def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = snapchat.research.gbml.gbml_config.GbmlConfig.scalaDescriptor.nestedMessages(7) def messageCompanionForFieldNumber(__number: _root_.scala.Int): _root_.scalapb.GeneratedMessageCompanion[_] = { var __out: _root_.scalapb.GeneratedMessageCompanion[_] = null (__number: @_root_.scala.unchecked) match { @@ -4875,8 +5241,8 @@ object GbmlConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gb ) case _ => throw new RuntimeException("Expected PMessage") } - def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = snapchat.research.gbml.gbml_config.GbmlConfig.javaDescriptor.getNestedTypes().get(7) - def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = snapchat.research.gbml.gbml_config.GbmlConfig.scalaDescriptor.nestedMessages(7) + def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = snapchat.research.gbml.gbml_config.GbmlConfig.javaDescriptor.getNestedTypes().get(8) + def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = snapchat.research.gbml.gbml_config.GbmlConfig.scalaDescriptor.nestedMessages(8) def messageCompanionForFieldNumber(__number: _root_.scala.Int): _root_.scalapb.GeneratedMessageCompanion[_] = { var __out: _root_.scalapb.GeneratedMessageCompanion[_] = null (__number: @_root_.scala.unchecked) match { @@ -5193,8 +5559,8 @@ object GbmlConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gb ) case _ => throw new RuntimeException("Expected PMessage") } - def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = snapchat.research.gbml.gbml_config.GbmlConfig.javaDescriptor.getNestedTypes().get(8) - def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = snapchat.research.gbml.gbml_config.GbmlConfig.scalaDescriptor.nestedMessages(8) + def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = snapchat.research.gbml.gbml_config.GbmlConfig.javaDescriptor.getNestedTypes().get(9) + def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = snapchat.research.gbml.gbml_config.GbmlConfig.scalaDescriptor.nestedMessages(9) def messageCompanionForFieldNumber(__number: _root_.scala.Int): _root_.scalapb.GeneratedMessageCompanion[_] = { var __out: _root_.scalapb.GeneratedMessageCompanion[_] = null (__number: @_root_.scala.unchecked) match { @@ -5494,8 +5860,8 @@ object GbmlConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gb ) case _ => throw new RuntimeException("Expected PMessage") } - def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = snapchat.research.gbml.gbml_config.GbmlConfig.javaDescriptor.getNestedTypes().get(9) - def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = snapchat.research.gbml.gbml_config.GbmlConfig.scalaDescriptor.nestedMessages(9) + def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = snapchat.research.gbml.gbml_config.GbmlConfig.javaDescriptor.getNestedTypes().get(10) + def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = snapchat.research.gbml.gbml_config.GbmlConfig.scalaDescriptor.nestedMessages(10) def messageCompanionForFieldNumber(__number: _root_.scala.Int): _root_.scalapb.GeneratedMessageCompanion[_] = throw new MatchError(__number) lazy val nestedMessagesCompanions: Seq[_root_.scalapb.GeneratedMessageCompanion[_ <: _root_.scalapb.GeneratedMessage]] = Seq.empty def enumCompanionForFieldNumber(__fieldNumber: _root_.scala.Int): _root_.scalapb.GeneratedEnumCompanion[_] = throw new MatchError(__fieldNumber) diff --git a/scala/common/src/main/scala/snapchat/research/gbml/gbml_config/GbmlConfigProto.scala b/scala/common/src/main/scala/snapchat/research/gbml/gbml_config/GbmlConfigProto.scala index 1109981b0..1b11b1f79 100644 --- a/scala/common/src/main/scala/snapchat/research/gbml/gbml_config/GbmlConfigProto.scala +++ b/scala/common/src/main/scala/snapchat/research/gbml/gbml_config/GbmlConfigProto.scala @@ -26,7 +26,7 @@ object GbmlConfigProto extends _root_.scalapb.GeneratedFileObject { GhfbWV0YWRhdGEucHJvdG8aLXNuYXBjaGF0L3Jlc2VhcmNoL2dibWwvZGF0YXNldF9tZXRhZGF0YS5wcm90bxozc25hcGNoYXQvc mVzZWFyY2gvZ2JtbC90cmFpbmVkX21vZGVsX21ldGFkYXRhLnByb3RvGi9zbmFwY2hhdC9yZXNlYXJjaC9nYm1sL2luZmVyZW5jZ V9tZXRhZGF0YS5wcm90bxozc25hcGNoYXQvcmVzZWFyY2gvZ2JtbC9wb3N0cHJvY2Vzc2VkX21ldGFkYXRhLnByb3RvGjdzbmFwY - 2hhdC9yZXNlYXJjaC9nYm1sL3N1YmdyYXBoX3NhbXBsaW5nX3N0cmF0ZWd5LnByb3RvItVHCgpHYm1sQ29uZmlnEmcKDXRhc2tfb + 2hhdC9yZXNlYXJjaC9nYm1sL3N1YmdyYXBoX3NhbXBsaW5nX3N0cmF0ZWd5LnByb3RvIrhMCgpHYm1sQ29uZmlnEmcKDXRhc2tfb WV0YWRhdGEYASABKAsyLy5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sLkdibWxDb25maWcuVGFza01ldGFkYXRhQhHiPw4SDHRhc2tNZ XRhZGF0YVIMdGFza01ldGFkYXRhEmAKDmdyYXBoX21ldGFkYXRhGAIgASgLMiUuc25hcGNoYXQucmVzZWFyY2guZ2JtbC5HcmFwa E1ldGFkYXRhQhLiPw8SDWdyYXBoTWV0YWRhdGFSDWdyYXBoTWV0YWRhdGESZwoNc2hhcmVkX2NvbmZpZxgDIAEoCzIvLnNuYXBja @@ -121,35 +121,43 @@ object GbmlConfigProto extends _root_.scalapb.GeneratedFileObject { 2tleRgBIAEoCUII4j8FEgNrZXlSA2tleRIgCgV2YWx1ZRgCIAEoCUIK4j8HEgV2YWx1ZVIFdmFsdWU6AjgBGlQKEEdyYXBoRGJBc mdzRW50cnkSGgoDa2V5GAEgASgJQgjiPwUSA2tleVIDa2V5EiAKBXZhbHVlGAIgASgJQgriPwcSBXZhbHVlUgV2YWx1ZToCOAEab woUR3JhcGhEQlNlcnZpY2VDb25maWcSVwoaZ3JhcGhfZGJfY2xpZW50X2NsYXNzX3BhdGgYASABKAlCG+I/GBIWZ3JhcGhEYkNsa - WVudENsYXNzUGF0aFIWZ3JhcGhEYkNsaWVudENsYXNzUGF0aBrXAwoNVHJhaW5lckNvbmZpZxI9ChB0cmFpbmVyX2Nsc19wYXRoG - AEgASgJQhPiPxASDnRyYWluZXJDbHNQYXRoUg50cmFpbmVyQ2xzUGF0aBJ2Cgx0cmFpbmVyX2FyZ3MYAiADKAsyQS5zbmFwY2hhd - C5yZXNlYXJjaC5nYm1sLkdibWxDb25maWcuVHJhaW5lckNvbmZpZy5UcmFpbmVyQXJnc0VudHJ5QhDiPw0SC3RyYWluZXJBcmdzU - gt0cmFpbmVyQXJncxIpCghjbHNfcGF0aBhkIAEoCUIM4j8JEgdjbHNQYXRoSABSB2Nsc1BhdGgSKAoHY29tbWFuZBhlIAEoCUIM4 - j8JEgdjb21tYW5kSABSB2NvbW1hbmQSVgoZc2hvdWxkX2xvZ190b190ZW5zb3Jib2FyZBgMIAEoCEIb4j8YEhZzaG91bGRMb2dUb - 1RlbnNvcmJvYXJkUhZzaG91bGRMb2dUb1RlbnNvcmJvYXJkGlQKEFRyYWluZXJBcmdzRW50cnkSGgoDa2V5GAEgASgJQgjiPwUSA - 2tleVIDa2V5EiAKBXZhbHVlGAIgASgJQgriPwcSBXZhbHVlUgV2YWx1ZToCOAFCDAoKZXhlY3V0YWJsZRrpAwoQSW5mZXJlbmNlc - kNvbmZpZxKFAQoPaW5mZXJlbmNlcl9hcmdzGAEgAygLMkcuc25hcGNoYXQucmVzZWFyY2guZ2JtbC5HYm1sQ29uZmlnLkluZmVyZ - W5jZXJDb25maWcuSW5mZXJlbmNlckFyZ3NFbnRyeUIT4j8QEg5pbmZlcmVuY2VyQXJnc1IOaW5mZXJlbmNlckFyZ3MSRgoTaW5mZ - XJlbmNlcl9jbHNfcGF0aBgCIAEoCUIW4j8TEhFpbmZlcmVuY2VyQ2xzUGF0aFIRaW5mZXJlbmNlckNsc1BhdGgSKQoIY2xzX3Bhd - GgYZCABKAlCDOI/CRIHY2xzUGF0aEgAUgdjbHNQYXRoEigKB2NvbW1hbmQYZSABKAlCDOI/CRIHY29tbWFuZEgAUgdjb21tYW5kE - kkKFGluZmVyZW5jZV9iYXRjaF9zaXplGAUgASgNQhfiPxQSEmluZmVyZW5jZUJhdGNoU2l6ZVISaW5mZXJlbmNlQmF0Y2hTaXplG - lcKE0luZmVyZW5jZXJBcmdzRW50cnkSGgoDa2V5GAEgASgJQgjiPwUSA2tleVIDa2V5EiAKBXZhbHVlGAIgASgJQgriPwcSBXZhb - HVlUgV2YWx1ZToCOAFCDAoKZXhlY3V0YWJsZRrbAgoTUG9zdFByb2Nlc3NvckNvbmZpZxKVAQoTcG9zdF9wcm9jZXNzb3JfYXJnc - xgBIAMoCzJNLnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuR2JtbENvbmZpZy5Qb3N0UHJvY2Vzc29yQ29uZmlnLlBvc3RQcm9jZXNzb - 3JBcmdzRW50cnlCFuI/ExIRcG9zdFByb2Nlc3NvckFyZ3NSEXBvc3RQcm9jZXNzb3JBcmdzElAKF3Bvc3RfcHJvY2Vzc29yX2Nsc - 19wYXRoGAIgASgJQhniPxYSFHBvc3RQcm9jZXNzb3JDbHNQYXRoUhRwb3N0UHJvY2Vzc29yQ2xzUGF0aBpaChZQb3N0UHJvY2Vzc - 29yQXJnc0VudHJ5EhoKA2tleRgBIAEoCUII4j8FEgNrZXlSA2tleRIgCgV2YWx1ZRgCIAEoCUIK4j8HEgV2YWx1ZVIFdmFsdWU6A - jgBGpwCCg1NZXRyaWNzQ29uZmlnEj0KEG1ldHJpY3NfY2xzX3BhdGgYASABKAlCE+I/EBIObWV0cmljc0Nsc1BhdGhSDm1ldHJpY - 3NDbHNQYXRoEnYKDG1ldHJpY3NfYXJncxgCIAMoCzJBLnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuR2JtbENvbmZpZy5NZXRyaWNzQ - 29uZmlnLk1ldHJpY3NBcmdzRW50cnlCEOI/DRILbWV0cmljc0FyZ3NSC21ldHJpY3NBcmdzGlQKEE1ldHJpY3NBcmdzRW50cnkSG - goDa2V5GAEgASgJQgjiPwUSA2tleVIDa2V5EiAKBXZhbHVlGAIgASgJQgriPwcSBXZhbHVlUgV2YWx1ZToCOAEa9AIKDlByb2Zpb - GVyQ29uZmlnEk8KFnNob3VsZF9lbmFibGVfcHJvZmlsZXIYASABKAhCGeI/FhIUc2hvdWxkRW5hYmxlUHJvZmlsZXJSFHNob3VsZ - EVuYWJsZVByb2ZpbGVyEj0KEHByb2ZpbGVyX2xvZ19kaXIYAiABKAlCE+I/EBIOcHJvZmlsZXJMb2dEaXJSDnByb2ZpbGVyTG9nR - GlyEnsKDXByb2ZpbGVyX2FyZ3MYAyADKAsyQy5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sLkdibWxDb25maWcuUHJvZmlsZXJDb25ma - WcuUHJvZmlsZXJBcmdzRW50cnlCEeI/DhIMcHJvZmlsZXJBcmdzUgxwcm9maWxlckFyZ3MaVQoRUHJvZmlsZXJBcmdzRW50cnkSG - goDa2V5GAEgASgJQgjiPwUSA2tleVIDa2V5EiAKBXZhbHVlGAIgASgJQgriPwcSBXZhbHVlUgV2YWx1ZToCOAEaVQoRRmVhdHVyZ - UZsYWdzRW50cnkSGgoDa2V5GAEgASgJQgjiPwUSA2tleVIDa2V5EiAKBXZhbHVlGAIgASgJQgriPwcSBXZhbHVlUgV2YWx1ZToCO - AFiBnByb3RvMw==""" + WVudENsYXNzUGF0aFIWZ3JhcGhEYkNsaWVudENsYXNzUGF0aBqwAgoXR3JhcGhTdG9yZVN0b3JhZ2VDb25maWcSPAoPc3RvcmFnZ + V9jb21tYW5kGAEgASgJQhPiPxASDnN0b3JhZ2VDb21tYW5kUg5zdG9yYWdlQ29tbWFuZBKAAQoMc3RvcmFnZV9hcmdzGAIgAygLM + ksuc25hcGNoYXQucmVzZWFyY2guZ2JtbC5HYm1sQ29uZmlnLkdyYXBoU3RvcmVTdG9yYWdlQ29uZmlnLlN0b3JhZ2VBcmdzRW50c + nlCEOI/DRILc3RvcmFnZUFyZ3NSC3N0b3JhZ2VBcmdzGlQKEFN0b3JhZ2VBcmdzRW50cnkSGgoDa2V5GAEgASgJQgjiPwUSA2tle + VIDa2V5EiAKBXZhbHVlGAIgASgJQgriPwcSBXZhbHVlUgV2YWx1ZToCOAEa7wQKDVRyYWluZXJDb25maWcSPQoQdHJhaW5lcl9jb + HNfcGF0aBgBIAEoCUIT4j8QEg50cmFpbmVyQ2xzUGF0aFIOdHJhaW5lckNsc1BhdGgSdgoMdHJhaW5lcl9hcmdzGAIgAygLMkEuc + 25hcGNoYXQucmVzZWFyY2guZ2JtbC5HYm1sQ29uZmlnLlRyYWluZXJDb25maWcuVHJhaW5lckFyZ3NFbnRyeUIQ4j8NEgt0cmFpb + mVyQXJnc1ILdHJhaW5lckFyZ3MSKQoIY2xzX3BhdGgYZCABKAlCDOI/CRIHY2xzUGF0aEgAUgdjbHNQYXRoEigKB2NvbW1hbmQYZ + SABKAlCDOI/CRIHY29tbWFuZEgAUgdjb21tYW5kElYKGXNob3VsZF9sb2dfdG9fdGVuc29yYm9hcmQYDCABKAhCG+I/GBIWc2hvd + WxkTG9nVG9UZW5zb3Jib2FyZFIWc2hvdWxkTG9nVG9UZW5zb3Jib2FyZBKVAQoaZ3JhcGhfc3RvcmVfc3RvcmFnZV9jb25maWcYD + SABKAsyOi5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sLkdibWxDb25maWcuR3JhcGhTdG9yZVN0b3JhZ2VDb25maWdCHOI/GRIXZ3Jhc + GhTdG9yZVN0b3JhZ2VDb25maWdSF2dyYXBoU3RvcmVTdG9yYWdlQ29uZmlnGlQKEFRyYWluZXJBcmdzRW50cnkSGgoDa2V5GAEgA + SgJQgjiPwUSA2tleVIDa2V5EiAKBXZhbHVlGAIgASgJQgriPwcSBXZhbHVlUgV2YWx1ZToCOAFCDAoKZXhlY3V0YWJsZRqBBQoQS + W5mZXJlbmNlckNvbmZpZxKFAQoPaW5mZXJlbmNlcl9hcmdzGAEgAygLMkcuc25hcGNoYXQucmVzZWFyY2guZ2JtbC5HYm1sQ29uZ + mlnLkluZmVyZW5jZXJDb25maWcuSW5mZXJlbmNlckFyZ3NFbnRyeUIT4j8QEg5pbmZlcmVuY2VyQXJnc1IOaW5mZXJlbmNlckFyZ + 3MSRgoTaW5mZXJlbmNlcl9jbHNfcGF0aBgCIAEoCUIW4j8TEhFpbmZlcmVuY2VyQ2xzUGF0aFIRaW5mZXJlbmNlckNsc1BhdGgSK + QoIY2xzX3BhdGgYZCABKAlCDOI/CRIHY2xzUGF0aEgAUgdjbHNQYXRoEigKB2NvbW1hbmQYZSABKAlCDOI/CRIHY29tbWFuZEgAU + gdjb21tYW5kEkkKFGluZmVyZW5jZV9iYXRjaF9zaXplGAUgASgNQhfiPxQSEmluZmVyZW5jZUJhdGNoU2l6ZVISaW5mZXJlbmNlQ + mF0Y2hTaXplEpUBChpncmFwaF9zdG9yZV9zdG9yYWdlX2NvbmZpZxgGIAEoCzI6LnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuR2Jtb + ENvbmZpZy5HcmFwaFN0b3JlU3RvcmFnZUNvbmZpZ0Ic4j8ZEhdncmFwaFN0b3JlU3RvcmFnZUNvbmZpZ1IXZ3JhcGhTdG9yZVN0b + 3JhZ2VDb25maWcaVwoTSW5mZXJlbmNlckFyZ3NFbnRyeRIaCgNrZXkYASABKAlCCOI/BRIDa2V5UgNrZXkSIAoFdmFsdWUYAiABK + AlCCuI/BxIFdmFsdWVSBXZhbHVlOgI4AUIMCgpleGVjdXRhYmxlGtsCChNQb3N0UHJvY2Vzc29yQ29uZmlnEpUBChNwb3N0X3Byb + 2Nlc3Nvcl9hcmdzGAEgAygLMk0uc25hcGNoYXQucmVzZWFyY2guZ2JtbC5HYm1sQ29uZmlnLlBvc3RQcm9jZXNzb3JDb25maWcuU + G9zdFByb2Nlc3NvckFyZ3NFbnRyeUIW4j8TEhFwb3N0UHJvY2Vzc29yQXJnc1IRcG9zdFByb2Nlc3NvckFyZ3MSUAoXcG9zdF9wc + m9jZXNzb3JfY2xzX3BhdGgYAiABKAlCGeI/FhIUcG9zdFByb2Nlc3NvckNsc1BhdGhSFHBvc3RQcm9jZXNzb3JDbHNQYXRoGloKF + lBvc3RQcm9jZXNzb3JBcmdzRW50cnkSGgoDa2V5GAEgASgJQgjiPwUSA2tleVIDa2V5EiAKBXZhbHVlGAIgASgJQgriPwcSBXZhb + HVlUgV2YWx1ZToCOAEanAIKDU1ldHJpY3NDb25maWcSPQoQbWV0cmljc19jbHNfcGF0aBgBIAEoCUIT4j8QEg5tZXRyaWNzQ2xzU + GF0aFIObWV0cmljc0Nsc1BhdGgSdgoMbWV0cmljc19hcmdzGAIgAygLMkEuc25hcGNoYXQucmVzZWFyY2guZ2JtbC5HYm1sQ29uZ + mlnLk1ldHJpY3NDb25maWcuTWV0cmljc0FyZ3NFbnRyeUIQ4j8NEgttZXRyaWNzQXJnc1ILbWV0cmljc0FyZ3MaVAoQTWV0cmljc + 0FyZ3NFbnRyeRIaCgNrZXkYASABKAlCCOI/BRIDa2V5UgNrZXkSIAoFdmFsdWUYAiABKAlCCuI/BxIFdmFsdWVSBXZhbHVlOgI4A + Rr0AgoOUHJvZmlsZXJDb25maWcSTwoWc2hvdWxkX2VuYWJsZV9wcm9maWxlchgBIAEoCEIZ4j8WEhRzaG91bGRFbmFibGVQcm9ma + WxlclIUc2hvdWxkRW5hYmxlUHJvZmlsZXISPQoQcHJvZmlsZXJfbG9nX2RpchgCIAEoCUIT4j8QEg5wcm9maWxlckxvZ0RpclIOc + HJvZmlsZXJMb2dEaXISewoNcHJvZmlsZXJfYXJncxgDIAMoCzJDLnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuR2JtbENvbmZpZy5Qc + m9maWxlckNvbmZpZy5Qcm9maWxlckFyZ3NFbnRyeUIR4j8OEgxwcm9maWxlckFyZ3NSDHByb2ZpbGVyQXJncxpVChFQcm9maWxlc + kFyZ3NFbnRyeRIaCgNrZXkYASABKAlCCOI/BRIDa2V5UgNrZXkSIAoFdmFsdWUYAiABKAlCCuI/BxIFdmFsdWVSBXZhbHVlOgI4A + RpVChFGZWF0dXJlRmxhZ3NFbnRyeRIaCgNrZXkYASABKAlCCOI/BRIDa2V5UgNrZXkSIAoFdmFsdWUYAiABKAlCCuI/BxIFdmFsd + WVSBXZhbHVlOgI4AWIGcHJvdG8z""" ).mkString) lazy val scalaDescriptor: _root_.scalapb.descriptors.FileDescriptor = { val scalaProto = com.google.protobuf.descriptor.FileDescriptorProto.parseFrom(ProtoBytes) @@ -169,4 +177,4 @@ object GbmlConfigProto extends _root_.scalapb.GeneratedFileObject { } @deprecated("Use javaDescriptor instead. In a future version this will refer to scalaDescriptor.", "ScalaPB 0.5.47") def descriptor: com.google.protobuf.Descriptors.FileDescriptor = javaDescriptor -} \ No newline at end of file +} diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/gbml_config/GbmlConfig.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/gbml_config/GbmlConfig.scala index b3f110643..c4c2ca892 100644 --- a/scala_spark35/common/src/main/scala/snapchat/research/gbml/gbml_config/GbmlConfig.scala +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/gbml_config/GbmlConfig.scala @@ -306,6 +306,7 @@ object GbmlConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gb _root_.snapchat.research.gbml.gbml_config.GbmlConfig.SharedConfig, _root_.snapchat.research.gbml.gbml_config.GbmlConfig.DatasetConfig, _root_.snapchat.research.gbml.gbml_config.GbmlConfig.GraphDBConfig, + _root_.snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig, _root_.snapchat.research.gbml.gbml_config.GbmlConfig.TrainerConfig, _root_.snapchat.research.gbml.gbml_config.GbmlConfig.InferencerConfig, _root_.snapchat.research.gbml.gbml_config.GbmlConfig.PostProcessorConfig, @@ -3654,6 +3655,309 @@ object GbmlConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gb // @@protoc_insertion_point(GeneratedMessageCompanion[snapchat.research.gbml.GbmlConfig.GraphDBConfig]) } + /** Configuration for GraphStore storage. + * + * @param storageCommand + * Command to use for launching storage job. + * e.g. "python -m gigl.distributed.graph_store.storage_main". + * @param storageArgs + * Arguments to instantiate concrete BaseStorage instance with. + * Will be appended to the storage_command. + * e.g. {"my_dataset_uri": "gs://my_dataset"} -> "python -m gigl.distributed.graph_store.storage_main --my_dataset_uri=gs://my_dataset" + */ + @SerialVersionUID(0L) + final case class GraphStoreStorageConfig( + storageCommand: _root_.scala.Predef.String = "", + storageArgs: _root_.scala.collection.immutable.Map[_root_.scala.Predef.String, _root_.scala.Predef.String] = _root_.scala.collection.immutable.Map.empty, + unknownFields: _root_.scalapb.UnknownFieldSet = _root_.scalapb.UnknownFieldSet.empty + ) extends scalapb.GeneratedMessage with scalapb.lenses.Updatable[GraphStoreStorageConfig] { + @transient + private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 + private[this] def __computeSerializedSize(): _root_.scala.Int = { + var __size = 0 + + { + val __value = storageCommand + if (!__value.isEmpty) { + __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(1, __value) + } + }; + storageArgs.foreach { __item => + val __value = snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig._typemapper_storageArgs.toBase(__item) + __size += 1 + _root_.com.google.protobuf.CodedOutputStream.computeUInt32SizeNoTag(__value.serializedSize) + __value.serializedSize + } + __size += unknownFields.serializedSize + __size + } + override def serializedSize: _root_.scala.Int = { + var __size = __serializedSizeMemoized + if (__size == 0) { + __size = __computeSerializedSize() + 1 + __serializedSizeMemoized = __size + } + __size - 1 + + } + def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { + { + val __v = storageCommand + if (!__v.isEmpty) { + _output__.writeString(1, __v) + } + }; + storageArgs.foreach { __v => + val __m = snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig._typemapper_storageArgs.toBase(__v) + _output__.writeTag(2, 2) + _output__.writeUInt32NoTag(__m.serializedSize) + __m.writeTo(_output__) + }; + unknownFields.writeTo(_output__) + } + def withStorageCommand(__v: _root_.scala.Predef.String): GraphStoreStorageConfig = copy(storageCommand = __v) + def clearStorageArgs = copy(storageArgs = _root_.scala.collection.immutable.Map.empty) + def addStorageArgs(__vs: (_root_.scala.Predef.String, _root_.scala.Predef.String) *): GraphStoreStorageConfig = addAllStorageArgs(__vs) + def addAllStorageArgs(__vs: Iterable[(_root_.scala.Predef.String, _root_.scala.Predef.String)]): GraphStoreStorageConfig = copy(storageArgs = storageArgs ++ __vs) + def withStorageArgs(__v: _root_.scala.collection.immutable.Map[_root_.scala.Predef.String, _root_.scala.Predef.String]): GraphStoreStorageConfig = copy(storageArgs = __v) + def withUnknownFields(__v: _root_.scalapb.UnknownFieldSet) = copy(unknownFields = __v) + def discardUnknownFields = copy(unknownFields = _root_.scalapb.UnknownFieldSet.empty) + def getFieldByNumber(__fieldNumber: _root_.scala.Int): _root_.scala.Any = { + (__fieldNumber: @_root_.scala.unchecked) match { + case 1 => { + val __t = storageCommand + if (__t != "") __t else null + } + case 2 => storageArgs.iterator.map(snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig._typemapper_storageArgs.toBase(_)).toSeq + } + } + def getField(__field: _root_.scalapb.descriptors.FieldDescriptor): _root_.scalapb.descriptors.PValue = { + _root_.scala.Predef.require(__field.containingMessage eq companion.scalaDescriptor) + (__field.number: @_root_.scala.unchecked) match { + case 1 => _root_.scalapb.descriptors.PString(storageCommand) + case 2 => _root_.scalapb.descriptors.PRepeated(storageArgs.iterator.map(snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig._typemapper_storageArgs.toBase(_).toPMessage).toVector) + } + } + def toProtoString: _root_.scala.Predef.String = _root_.scalapb.TextFormat.printToUnicodeString(this) + def companion: snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig.type = snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig + // @@protoc_insertion_point(GeneratedMessage[snapchat.research.gbml.GbmlConfig.GraphStoreStorageConfig]) + } + + object GraphStoreStorageConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig] { + implicit def messageCompanion: scalapb.GeneratedMessageCompanion[snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig] = this + def parseFrom(`_input__`: _root_.com.google.protobuf.CodedInputStream): snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig = { + var __storageCommand: _root_.scala.Predef.String = "" + val __storageArgs: _root_.scala.collection.mutable.Builder[(_root_.scala.Predef.String, _root_.scala.Predef.String), _root_.scala.collection.immutable.Map[_root_.scala.Predef.String, _root_.scala.Predef.String]] = _root_.scala.collection.immutable.Map.newBuilder[_root_.scala.Predef.String, _root_.scala.Predef.String] + var `_unknownFields__`: _root_.scalapb.UnknownFieldSet.Builder = null + var _done__ = false + while (!_done__) { + val _tag__ = _input__.readTag() + _tag__ match { + case 0 => _done__ = true + case 10 => + __storageCommand = _input__.readStringRequireUtf8() + case 18 => + __storageArgs += snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig._typemapper_storageArgs.toCustom(_root_.scalapb.LiteParser.readMessage[snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig.StorageArgsEntry](_input__)) + case tag => + if (_unknownFields__ == null) { + _unknownFields__ = new _root_.scalapb.UnknownFieldSet.Builder() + } + _unknownFields__.parseField(tag, _input__) + } + } + snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig( + storageCommand = __storageCommand, + storageArgs = __storageArgs.result(), + unknownFields = if (_unknownFields__ == null) _root_.scalapb.UnknownFieldSet.empty else _unknownFields__.result() + ) + } + implicit def messageReads: _root_.scalapb.descriptors.Reads[snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig] = _root_.scalapb.descriptors.Reads{ + case _root_.scalapb.descriptors.PMessage(__fieldsMap) => + _root_.scala.Predef.require(__fieldsMap.keys.forall(_.containingMessage eq scalaDescriptor), "FieldDescriptor does not match message type.") + snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig( + storageCommand = __fieldsMap.get(scalaDescriptor.findFieldByNumber(1).get).map(_.as[_root_.scala.Predef.String]).getOrElse(""), + storageArgs = __fieldsMap.get(scalaDescriptor.findFieldByNumber(2).get).map(_.as[_root_.scala.Seq[snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig.StorageArgsEntry]]).getOrElse(_root_.scala.Seq.empty).iterator.map(snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig._typemapper_storageArgs.toCustom(_)).toMap + ) + case _ => throw new RuntimeException("Expected PMessage") + } + def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = snapchat.research.gbml.gbml_config.GbmlConfig.javaDescriptor.getNestedTypes().get(4) + def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = snapchat.research.gbml.gbml_config.GbmlConfig.scalaDescriptor.nestedMessages(4) + def messageCompanionForFieldNumber(__number: _root_.scala.Int): _root_.scalapb.GeneratedMessageCompanion[_] = { + var __out: _root_.scalapb.GeneratedMessageCompanion[_] = null + (__number: @_root_.scala.unchecked) match { + case 2 => __out = snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig.StorageArgsEntry + } + __out + } + lazy val nestedMessagesCompanions: Seq[_root_.scalapb.GeneratedMessageCompanion[_ <: _root_.scalapb.GeneratedMessage]] = + Seq[_root_.scalapb.GeneratedMessageCompanion[_ <: _root_.scalapb.GeneratedMessage]]( + _root_.snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig.StorageArgsEntry + ) + def enumCompanionForFieldNumber(__fieldNumber: _root_.scala.Int): _root_.scalapb.GeneratedEnumCompanion[_] = throw new MatchError(__fieldNumber) + lazy val defaultInstance = snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig( + storageCommand = "", + storageArgs = _root_.scala.collection.immutable.Map.empty + ) + @SerialVersionUID(0L) + final case class StorageArgsEntry( + key: _root_.scala.Predef.String = "", + value: _root_.scala.Predef.String = "", + unknownFields: _root_.scalapb.UnknownFieldSet = _root_.scalapb.UnknownFieldSet.empty + ) extends scalapb.GeneratedMessage with scalapb.lenses.Updatable[StorageArgsEntry] { + @transient + private[this] var __serializedSizeMemoized: _root_.scala.Int = 0 + private[this] def __computeSerializedSize(): _root_.scala.Int = { + var __size = 0 + + { + val __value = key + if (!__value.isEmpty) { + __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(1, __value) + } + }; + + { + val __value = value + if (!__value.isEmpty) { + __size += _root_.com.google.protobuf.CodedOutputStream.computeStringSize(2, __value) + } + }; + __size += unknownFields.serializedSize + __size + } + override def serializedSize: _root_.scala.Int = { + var __size = __serializedSizeMemoized + if (__size == 0) { + __size = __computeSerializedSize() + 1 + __serializedSizeMemoized = __size + } + __size - 1 + + } + def writeTo(`_output__`: _root_.com.google.protobuf.CodedOutputStream): _root_.scala.Unit = { + { + val __v = key + if (!__v.isEmpty) { + _output__.writeString(1, __v) + } + }; + { + val __v = value + if (!__v.isEmpty) { + _output__.writeString(2, __v) + } + }; + unknownFields.writeTo(_output__) + } + def withKey(__v: _root_.scala.Predef.String): StorageArgsEntry = copy(key = __v) + def withValue(__v: _root_.scala.Predef.String): StorageArgsEntry = copy(value = __v) + def withUnknownFields(__v: _root_.scalapb.UnknownFieldSet) = copy(unknownFields = __v) + def discardUnknownFields = copy(unknownFields = _root_.scalapb.UnknownFieldSet.empty) + def getFieldByNumber(__fieldNumber: _root_.scala.Int): _root_.scala.Any = { + (__fieldNumber: @_root_.scala.unchecked) match { + case 1 => { + val __t = key + if (__t != "") __t else null + } + case 2 => { + val __t = value + if (__t != "") __t else null + } + } + } + def getField(__field: _root_.scalapb.descriptors.FieldDescriptor): _root_.scalapb.descriptors.PValue = { + _root_.scala.Predef.require(__field.containingMessage eq companion.scalaDescriptor) + (__field.number: @_root_.scala.unchecked) match { + case 1 => _root_.scalapb.descriptors.PString(key) + case 2 => _root_.scalapb.descriptors.PString(value) + } + } + def toProtoString: _root_.scala.Predef.String = _root_.scalapb.TextFormat.printToUnicodeString(this) + def companion: snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig.StorageArgsEntry.type = snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig.StorageArgsEntry + // @@protoc_insertion_point(GeneratedMessage[snapchat.research.gbml.GbmlConfig.GraphStoreStorageConfig.StorageArgsEntry]) + } + + object StorageArgsEntry extends scalapb.GeneratedMessageCompanion[snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig.StorageArgsEntry] { + implicit def messageCompanion: scalapb.GeneratedMessageCompanion[snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig.StorageArgsEntry] = this + def parseFrom(`_input__`: _root_.com.google.protobuf.CodedInputStream): snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig.StorageArgsEntry = { + var __key: _root_.scala.Predef.String = "" + var __value: _root_.scala.Predef.String = "" + var `_unknownFields__`: _root_.scalapb.UnknownFieldSet.Builder = null + var _done__ = false + while (!_done__) { + val _tag__ = _input__.readTag() + _tag__ match { + case 0 => _done__ = true + case 10 => + __key = _input__.readStringRequireUtf8() + case 18 => + __value = _input__.readStringRequireUtf8() + case tag => + if (_unknownFields__ == null) { + _unknownFields__ = new _root_.scalapb.UnknownFieldSet.Builder() + } + _unknownFields__.parseField(tag, _input__) + } + } + snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig.StorageArgsEntry( + key = __key, + value = __value, + unknownFields = if (_unknownFields__ == null) _root_.scalapb.UnknownFieldSet.empty else _unknownFields__.result() + ) + } + implicit def messageReads: _root_.scalapb.descriptors.Reads[snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig.StorageArgsEntry] = _root_.scalapb.descriptors.Reads{ + case _root_.scalapb.descriptors.PMessage(__fieldsMap) => + _root_.scala.Predef.require(__fieldsMap.keys.forall(_.containingMessage eq scalaDescriptor), "FieldDescriptor does not match message type.") + snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig.StorageArgsEntry( + key = __fieldsMap.get(scalaDescriptor.findFieldByNumber(1).get).map(_.as[_root_.scala.Predef.String]).getOrElse(""), + value = __fieldsMap.get(scalaDescriptor.findFieldByNumber(2).get).map(_.as[_root_.scala.Predef.String]).getOrElse("") + ) + case _ => throw new RuntimeException("Expected PMessage") + } + def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig.javaDescriptor.getNestedTypes().get(0) + def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig.scalaDescriptor.nestedMessages(0) + def messageCompanionForFieldNumber(__number: _root_.scala.Int): _root_.scalapb.GeneratedMessageCompanion[_] = throw new MatchError(__number) + lazy val nestedMessagesCompanions: Seq[_root_.scalapb.GeneratedMessageCompanion[_ <: _root_.scalapb.GeneratedMessage]] = Seq.empty + def enumCompanionForFieldNumber(__fieldNumber: _root_.scala.Int): _root_.scalapb.GeneratedEnumCompanion[_] = throw new MatchError(__fieldNumber) + lazy val defaultInstance = snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig.StorageArgsEntry( + key = "", + value = "" + ) + implicit class StorageArgsEntryLens[UpperPB](_l: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig.StorageArgsEntry]) extends _root_.scalapb.lenses.ObjectLens[UpperPB, snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig.StorageArgsEntry](_l) { + def key: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.Predef.String] = field(_.key)((c_, f_) => c_.copy(key = f_)) + def value: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.Predef.String] = field(_.value)((c_, f_) => c_.copy(value = f_)) + } + final val KEY_FIELD_NUMBER = 1 + final val VALUE_FIELD_NUMBER = 2 + @transient + implicit val keyValueMapper: _root_.scalapb.TypeMapper[snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig.StorageArgsEntry, (_root_.scala.Predef.String, _root_.scala.Predef.String)] = + _root_.scalapb.TypeMapper[snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig.StorageArgsEntry, (_root_.scala.Predef.String, _root_.scala.Predef.String)](__m => (__m.key, __m.value))(__p => snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig.StorageArgsEntry(__p._1, __p._2)) + def of( + key: _root_.scala.Predef.String, + value: _root_.scala.Predef.String + ): _root_.snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig.StorageArgsEntry = _root_.snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig.StorageArgsEntry( + key, + value + ) + // @@protoc_insertion_point(GeneratedMessageCompanion[snapchat.research.gbml.GbmlConfig.GraphStoreStorageConfig.StorageArgsEntry]) + } + + implicit class GraphStoreStorageConfigLens[UpperPB](_l: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig]) extends _root_.scalapb.lenses.ObjectLens[UpperPB, snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig](_l) { + def storageCommand: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.Predef.String] = field(_.storageCommand)((c_, f_) => c_.copy(storageCommand = f_)) + def storageArgs: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.collection.immutable.Map[_root_.scala.Predef.String, _root_.scala.Predef.String]] = field(_.storageArgs)((c_, f_) => c_.copy(storageArgs = f_)) + } + final val STORAGE_COMMAND_FIELD_NUMBER = 1 + final val STORAGE_ARGS_FIELD_NUMBER = 2 + @transient + private[gbml_config] val _typemapper_storageArgs: _root_.scalapb.TypeMapper[snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig.StorageArgsEntry, (_root_.scala.Predef.String, _root_.scala.Predef.String)] = implicitly[_root_.scalapb.TypeMapper[snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig.StorageArgsEntry, (_root_.scala.Predef.String, _root_.scala.Predef.String)]] + def of( + storageCommand: _root_.scala.Predef.String, + storageArgs: _root_.scala.collection.immutable.Map[_root_.scala.Predef.String, _root_.scala.Predef.String] + ): _root_.snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig = _root_.snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig( + storageCommand, + storageArgs + ) + // @@protoc_insertion_point(GeneratedMessageCompanion[snapchat.research.gbml.GbmlConfig.GraphStoreStorageConfig]) + } + /** @param trainerClsPath * (deprecated) * Uri pointing to user-written BaseTrainer class definition. Used for the subgraph-sampling-based training process. @@ -3661,6 +3965,9 @@ object GbmlConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gb * Arguments to parameterize training process with. * @param shouldLogToTensorboard * Weather to log to tensorboard or not (defaults to false) + * @param graphStoreStorageConfig + * Configuration for GraphStore storage. + * If setup, then GiGLResourceConfig.trainer_resource_config.vertex_ai_graph_store_trainer_config must be set. */ @SerialVersionUID(0L) final case class TrainerConfig( @@ -3668,6 +3975,7 @@ object GbmlConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gb trainerArgs: _root_.scala.collection.immutable.Map[_root_.scala.Predef.String, _root_.scala.Predef.String] = _root_.scala.collection.immutable.Map.empty, executable: snapchat.research.gbml.gbml_config.GbmlConfig.TrainerConfig.Executable = snapchat.research.gbml.gbml_config.GbmlConfig.TrainerConfig.Executable.Empty, shouldLogToTensorboard: _root_.scala.Boolean = false, + graphStoreStorageConfig: _root_.scala.Option[snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig] = _root_.scala.None, unknownFields: _root_.scalapb.UnknownFieldSet = _root_.scalapb.UnknownFieldSet.empty ) extends scalapb.GeneratedMessage with scalapb.lenses.Updatable[TrainerConfig] { @transient @@ -3700,6 +4008,10 @@ object GbmlConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gb __size += _root_.com.google.protobuf.CodedOutputStream.computeBoolSize(12, __value) } }; + if (graphStoreStorageConfig.isDefined) { + val __value = graphStoreStorageConfig.get + __size += 1 + _root_.com.google.protobuf.CodedOutputStream.computeUInt32SizeNoTag(__value.serializedSize) + __value.serializedSize + }; __size += unknownFields.serializedSize __size } @@ -3731,6 +4043,12 @@ object GbmlConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gb _output__.writeBool(12, __v) } }; + graphStoreStorageConfig.foreach { __v => + val __m = __v + _output__.writeTag(13, 2) + _output__.writeUInt32NoTag(__m.serializedSize) + __m.writeTo(_output__) + }; executable.clsPath.foreach { __v => val __m = __v _output__.writeString(100, __m) @@ -3751,6 +4069,9 @@ object GbmlConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gb def getCommand: _root_.scala.Predef.String = executable.command.getOrElse("") def withCommand(__v: _root_.scala.Predef.String): TrainerConfig = copy(executable = snapchat.research.gbml.gbml_config.GbmlConfig.TrainerConfig.Executable.Command(__v)) def withShouldLogToTensorboard(__v: _root_.scala.Boolean): TrainerConfig = copy(shouldLogToTensorboard = __v) + def getGraphStoreStorageConfig: snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig = graphStoreStorageConfig.getOrElse(snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig.defaultInstance) + def clearGraphStoreStorageConfig: TrainerConfig = copy(graphStoreStorageConfig = _root_.scala.None) + def withGraphStoreStorageConfig(__v: snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig): TrainerConfig = copy(graphStoreStorageConfig = Option(__v)) def clearExecutable: TrainerConfig = copy(executable = snapchat.research.gbml.gbml_config.GbmlConfig.TrainerConfig.Executable.Empty) def withExecutable(__v: snapchat.research.gbml.gbml_config.GbmlConfig.TrainerConfig.Executable): TrainerConfig = copy(executable = __v) def withUnknownFields(__v: _root_.scalapb.UnknownFieldSet) = copy(unknownFields = __v) @@ -3768,6 +4089,7 @@ object GbmlConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gb val __t = shouldLogToTensorboard if (__t != false) __t else null } + case 13 => graphStoreStorageConfig.orNull } } def getField(__field: _root_.scalapb.descriptors.FieldDescriptor): _root_.scalapb.descriptors.PValue = { @@ -3778,6 +4100,7 @@ object GbmlConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gb case 100 => executable.clsPath.map(_root_.scalapb.descriptors.PString(_)).getOrElse(_root_.scalapb.descriptors.PEmpty) case 101 => executable.command.map(_root_.scalapb.descriptors.PString(_)).getOrElse(_root_.scalapb.descriptors.PEmpty) case 12 => _root_.scalapb.descriptors.PBoolean(shouldLogToTensorboard) + case 13 => graphStoreStorageConfig.map(_.toPMessage).getOrElse(_root_.scalapb.descriptors.PEmpty) } } def toProtoString: _root_.scala.Predef.String = _root_.scalapb.TextFormat.printToUnicodeString(this) @@ -3791,6 +4114,7 @@ object GbmlConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gb var __trainerClsPath: _root_.scala.Predef.String = "" val __trainerArgs: _root_.scala.collection.mutable.Builder[(_root_.scala.Predef.String, _root_.scala.Predef.String), _root_.scala.collection.immutable.Map[_root_.scala.Predef.String, _root_.scala.Predef.String]] = _root_.scala.collection.immutable.Map.newBuilder[_root_.scala.Predef.String, _root_.scala.Predef.String] var __shouldLogToTensorboard: _root_.scala.Boolean = false + var __graphStoreStorageConfig: _root_.scala.Option[snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig] = _root_.scala.None var __executable: snapchat.research.gbml.gbml_config.GbmlConfig.TrainerConfig.Executable = snapchat.research.gbml.gbml_config.GbmlConfig.TrainerConfig.Executable.Empty var `_unknownFields__`: _root_.scalapb.UnknownFieldSet.Builder = null var _done__ = false @@ -3808,6 +4132,8 @@ object GbmlConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gb __executable = snapchat.research.gbml.gbml_config.GbmlConfig.TrainerConfig.Executable.Command(_input__.readStringRequireUtf8()) case 96 => __shouldLogToTensorboard = _input__.readBool() + case 106 => + __graphStoreStorageConfig = Option(__graphStoreStorageConfig.fold(_root_.scalapb.LiteParser.readMessage[snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig](_input__))(_root_.scalapb.LiteParser.readMessage(_input__, _))) case tag => if (_unknownFields__ == null) { _unknownFields__ = new _root_.scalapb.UnknownFieldSet.Builder() @@ -3819,6 +4145,7 @@ object GbmlConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gb trainerClsPath = __trainerClsPath, trainerArgs = __trainerArgs.result(), shouldLogToTensorboard = __shouldLogToTensorboard, + graphStoreStorageConfig = __graphStoreStorageConfig, executable = __executable, unknownFields = if (_unknownFields__ == null) _root_.scalapb.UnknownFieldSet.empty else _unknownFields__.result() ) @@ -3830,18 +4157,20 @@ object GbmlConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gb trainerClsPath = __fieldsMap.get(scalaDescriptor.findFieldByNumber(1).get).map(_.as[_root_.scala.Predef.String]).getOrElse(""), trainerArgs = __fieldsMap.get(scalaDescriptor.findFieldByNumber(2).get).map(_.as[_root_.scala.Seq[snapchat.research.gbml.gbml_config.GbmlConfig.TrainerConfig.TrainerArgsEntry]]).getOrElse(_root_.scala.Seq.empty).iterator.map(snapchat.research.gbml.gbml_config.GbmlConfig.TrainerConfig._typemapper_trainerArgs.toCustom(_)).toMap, shouldLogToTensorboard = __fieldsMap.get(scalaDescriptor.findFieldByNumber(12).get).map(_.as[_root_.scala.Boolean]).getOrElse(false), + graphStoreStorageConfig = __fieldsMap.get(scalaDescriptor.findFieldByNumber(13).get).flatMap(_.as[_root_.scala.Option[snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig]]), executable = __fieldsMap.get(scalaDescriptor.findFieldByNumber(100).get).flatMap(_.as[_root_.scala.Option[_root_.scala.Predef.String]]).map(snapchat.research.gbml.gbml_config.GbmlConfig.TrainerConfig.Executable.ClsPath(_)) .orElse[snapchat.research.gbml.gbml_config.GbmlConfig.TrainerConfig.Executable](__fieldsMap.get(scalaDescriptor.findFieldByNumber(101).get).flatMap(_.as[_root_.scala.Option[_root_.scala.Predef.String]]).map(snapchat.research.gbml.gbml_config.GbmlConfig.TrainerConfig.Executable.Command(_))) .getOrElse(snapchat.research.gbml.gbml_config.GbmlConfig.TrainerConfig.Executable.Empty) ) case _ => throw new RuntimeException("Expected PMessage") } - def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = snapchat.research.gbml.gbml_config.GbmlConfig.javaDescriptor.getNestedTypes().get(4) - def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = snapchat.research.gbml.gbml_config.GbmlConfig.scalaDescriptor.nestedMessages(4) + def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = snapchat.research.gbml.gbml_config.GbmlConfig.javaDescriptor.getNestedTypes().get(5) + def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = snapchat.research.gbml.gbml_config.GbmlConfig.scalaDescriptor.nestedMessages(5) def messageCompanionForFieldNumber(__number: _root_.scala.Int): _root_.scalapb.GeneratedMessageCompanion[_] = { var __out: _root_.scalapb.GeneratedMessageCompanion[_] = null (__number: @_root_.scala.unchecked) match { case 2 => __out = snapchat.research.gbml.gbml_config.GbmlConfig.TrainerConfig.TrainerArgsEntry + case 13 => __out = snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig } __out } @@ -3854,6 +4183,7 @@ object GbmlConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gb trainerClsPath = "", trainerArgs = _root_.scala.collection.immutable.Map.empty, shouldLogToTensorboard = false, + graphStoreStorageConfig = _root_.scala.None, executable = snapchat.research.gbml.gbml_config.GbmlConfig.TrainerConfig.Executable.Empty ) sealed trait Executable extends _root_.scalapb.GeneratedOneof { @@ -4039,6 +4369,8 @@ object GbmlConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gb def clsPath: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.Predef.String] = field(_.getClsPath)((c_, f_) => c_.copy(executable = snapchat.research.gbml.gbml_config.GbmlConfig.TrainerConfig.Executable.ClsPath(f_))) def command: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.Predef.String] = field(_.getCommand)((c_, f_) => c_.copy(executable = snapchat.research.gbml.gbml_config.GbmlConfig.TrainerConfig.Executable.Command(f_))) def shouldLogToTensorboard: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.Boolean] = field(_.shouldLogToTensorboard)((c_, f_) => c_.copy(shouldLogToTensorboard = f_)) + def graphStoreStorageConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig] = field(_.getGraphStoreStorageConfig)((c_, f_) => c_.copy(graphStoreStorageConfig = Option(f_))) + def optionalGraphStoreStorageConfig: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.Option[snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig]] = field(_.graphStoreStorageConfig)((c_, f_) => c_.copy(graphStoreStorageConfig = f_)) def executable: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gbml_config.GbmlConfig.TrainerConfig.Executable] = field(_.executable)((c_, f_) => c_.copy(executable = f_)) } final val TRAINER_CLS_PATH_FIELD_NUMBER = 1 @@ -4046,18 +4378,21 @@ object GbmlConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gb final val CLS_PATH_FIELD_NUMBER = 100 final val COMMAND_FIELD_NUMBER = 101 final val SHOULD_LOG_TO_TENSORBOARD_FIELD_NUMBER = 12 + final val GRAPH_STORE_STORAGE_CONFIG_FIELD_NUMBER = 13 @transient private[gbml_config] val _typemapper_trainerArgs: _root_.scalapb.TypeMapper[snapchat.research.gbml.gbml_config.GbmlConfig.TrainerConfig.TrainerArgsEntry, (_root_.scala.Predef.String, _root_.scala.Predef.String)] = implicitly[_root_.scalapb.TypeMapper[snapchat.research.gbml.gbml_config.GbmlConfig.TrainerConfig.TrainerArgsEntry, (_root_.scala.Predef.String, _root_.scala.Predef.String)]] def of( trainerClsPath: _root_.scala.Predef.String, trainerArgs: _root_.scala.collection.immutable.Map[_root_.scala.Predef.String, _root_.scala.Predef.String], executable: snapchat.research.gbml.gbml_config.GbmlConfig.TrainerConfig.Executable, - shouldLogToTensorboard: _root_.scala.Boolean + shouldLogToTensorboard: _root_.scala.Boolean, + graphStoreStorageConfig: _root_.scala.Option[snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig] ): _root_.snapchat.research.gbml.gbml_config.GbmlConfig.TrainerConfig = _root_.snapchat.research.gbml.gbml_config.GbmlConfig.TrainerConfig( trainerClsPath, trainerArgs, executable, - shouldLogToTensorboard + shouldLogToTensorboard, + graphStoreStorageConfig ) // @@protoc_insertion_point(GeneratedMessageCompanion[snapchat.research.gbml.GbmlConfig.TrainerConfig]) } @@ -4068,6 +4403,9 @@ object GbmlConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gb * @param inferenceBatchSize * Optional. If set, will be used to batch inference samples to a specific size before call for inference is made * Defaults to setting in python/gigl/src/inference/gnn_inferencer.py + * @param graphStoreStorageConfig + * Configuration for GraphStore storage. + * If setup, then GiGLResourceConfig.inferencer_resource_config.vertex_ai_graph_store_inferencer_config must be set. */ @SerialVersionUID(0L) final case class InferencerConfig( @@ -4075,6 +4413,7 @@ object GbmlConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gb inferencerClsPath: _root_.scala.Predef.String = "", executable: snapchat.research.gbml.gbml_config.GbmlConfig.InferencerConfig.Executable = snapchat.research.gbml.gbml_config.GbmlConfig.InferencerConfig.Executable.Empty, inferenceBatchSize: _root_.scala.Int = 0, + graphStoreStorageConfig: _root_.scala.Option[snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig] = _root_.scala.None, unknownFields: _root_.scalapb.UnknownFieldSet = _root_.scalapb.UnknownFieldSet.empty ) extends scalapb.GeneratedMessage with scalapb.lenses.Updatable[InferencerConfig] { @transient @@ -4107,6 +4446,10 @@ object GbmlConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gb __size += _root_.com.google.protobuf.CodedOutputStream.computeUInt32Size(5, __value) } }; + if (graphStoreStorageConfig.isDefined) { + val __value = graphStoreStorageConfig.get + __size += 1 + _root_.com.google.protobuf.CodedOutputStream.computeUInt32SizeNoTag(__value.serializedSize) + __value.serializedSize + }; __size += unknownFields.serializedSize __size } @@ -4138,6 +4481,12 @@ object GbmlConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gb _output__.writeUInt32(5, __v) } }; + graphStoreStorageConfig.foreach { __v => + val __m = __v + _output__.writeTag(6, 2) + _output__.writeUInt32NoTag(__m.serializedSize) + __m.writeTo(_output__) + }; executable.clsPath.foreach { __v => val __m = __v _output__.writeString(100, __m) @@ -4158,6 +4507,9 @@ object GbmlConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gb def getCommand: _root_.scala.Predef.String = executable.command.getOrElse("") def withCommand(__v: _root_.scala.Predef.String): InferencerConfig = copy(executable = snapchat.research.gbml.gbml_config.GbmlConfig.InferencerConfig.Executable.Command(__v)) def withInferenceBatchSize(__v: _root_.scala.Int): InferencerConfig = copy(inferenceBatchSize = __v) + def getGraphStoreStorageConfig: snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig = graphStoreStorageConfig.getOrElse(snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig.defaultInstance) + def clearGraphStoreStorageConfig: InferencerConfig = copy(graphStoreStorageConfig = _root_.scala.None) + def withGraphStoreStorageConfig(__v: snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig): InferencerConfig = copy(graphStoreStorageConfig = Option(__v)) def clearExecutable: InferencerConfig = copy(executable = snapchat.research.gbml.gbml_config.GbmlConfig.InferencerConfig.Executable.Empty) def withExecutable(__v: snapchat.research.gbml.gbml_config.GbmlConfig.InferencerConfig.Executable): InferencerConfig = copy(executable = __v) def withUnknownFields(__v: _root_.scalapb.UnknownFieldSet) = copy(unknownFields = __v) @@ -4175,6 +4527,7 @@ object GbmlConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gb val __t = inferenceBatchSize if (__t != 0) __t else null } + case 6 => graphStoreStorageConfig.orNull } } def getField(__field: _root_.scalapb.descriptors.FieldDescriptor): _root_.scalapb.descriptors.PValue = { @@ -4185,6 +4538,7 @@ object GbmlConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gb case 100 => executable.clsPath.map(_root_.scalapb.descriptors.PString(_)).getOrElse(_root_.scalapb.descriptors.PEmpty) case 101 => executable.command.map(_root_.scalapb.descriptors.PString(_)).getOrElse(_root_.scalapb.descriptors.PEmpty) case 5 => _root_.scalapb.descriptors.PInt(inferenceBatchSize) + case 6 => graphStoreStorageConfig.map(_.toPMessage).getOrElse(_root_.scalapb.descriptors.PEmpty) } } def toProtoString: _root_.scala.Predef.String = _root_.scalapb.TextFormat.printToUnicodeString(this) @@ -4198,6 +4552,7 @@ object GbmlConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gb val __inferencerArgs: _root_.scala.collection.mutable.Builder[(_root_.scala.Predef.String, _root_.scala.Predef.String), _root_.scala.collection.immutable.Map[_root_.scala.Predef.String, _root_.scala.Predef.String]] = _root_.scala.collection.immutable.Map.newBuilder[_root_.scala.Predef.String, _root_.scala.Predef.String] var __inferencerClsPath: _root_.scala.Predef.String = "" var __inferenceBatchSize: _root_.scala.Int = 0 + var __graphStoreStorageConfig: _root_.scala.Option[snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig] = _root_.scala.None var __executable: snapchat.research.gbml.gbml_config.GbmlConfig.InferencerConfig.Executable = snapchat.research.gbml.gbml_config.GbmlConfig.InferencerConfig.Executable.Empty var `_unknownFields__`: _root_.scalapb.UnknownFieldSet.Builder = null var _done__ = false @@ -4215,6 +4570,8 @@ object GbmlConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gb __executable = snapchat.research.gbml.gbml_config.GbmlConfig.InferencerConfig.Executable.Command(_input__.readStringRequireUtf8()) case 40 => __inferenceBatchSize = _input__.readUInt32() + case 50 => + __graphStoreStorageConfig = Option(__graphStoreStorageConfig.fold(_root_.scalapb.LiteParser.readMessage[snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig](_input__))(_root_.scalapb.LiteParser.readMessage(_input__, _))) case tag => if (_unknownFields__ == null) { _unknownFields__ = new _root_.scalapb.UnknownFieldSet.Builder() @@ -4226,6 +4583,7 @@ object GbmlConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gb inferencerArgs = __inferencerArgs.result(), inferencerClsPath = __inferencerClsPath, inferenceBatchSize = __inferenceBatchSize, + graphStoreStorageConfig = __graphStoreStorageConfig, executable = __executable, unknownFields = if (_unknownFields__ == null) _root_.scalapb.UnknownFieldSet.empty else _unknownFields__.result() ) @@ -4237,18 +4595,20 @@ object GbmlConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gb inferencerArgs = __fieldsMap.get(scalaDescriptor.findFieldByNumber(1).get).map(_.as[_root_.scala.Seq[snapchat.research.gbml.gbml_config.GbmlConfig.InferencerConfig.InferencerArgsEntry]]).getOrElse(_root_.scala.Seq.empty).iterator.map(snapchat.research.gbml.gbml_config.GbmlConfig.InferencerConfig._typemapper_inferencerArgs.toCustom(_)).toMap, inferencerClsPath = __fieldsMap.get(scalaDescriptor.findFieldByNumber(2).get).map(_.as[_root_.scala.Predef.String]).getOrElse(""), inferenceBatchSize = __fieldsMap.get(scalaDescriptor.findFieldByNumber(5).get).map(_.as[_root_.scala.Int]).getOrElse(0), + graphStoreStorageConfig = __fieldsMap.get(scalaDescriptor.findFieldByNumber(6).get).flatMap(_.as[_root_.scala.Option[snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig]]), executable = __fieldsMap.get(scalaDescriptor.findFieldByNumber(100).get).flatMap(_.as[_root_.scala.Option[_root_.scala.Predef.String]]).map(snapchat.research.gbml.gbml_config.GbmlConfig.InferencerConfig.Executable.ClsPath(_)) .orElse[snapchat.research.gbml.gbml_config.GbmlConfig.InferencerConfig.Executable](__fieldsMap.get(scalaDescriptor.findFieldByNumber(101).get).flatMap(_.as[_root_.scala.Option[_root_.scala.Predef.String]]).map(snapchat.research.gbml.gbml_config.GbmlConfig.InferencerConfig.Executable.Command(_))) .getOrElse(snapchat.research.gbml.gbml_config.GbmlConfig.InferencerConfig.Executable.Empty) ) case _ => throw new RuntimeException("Expected PMessage") } - def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = snapchat.research.gbml.gbml_config.GbmlConfig.javaDescriptor.getNestedTypes().get(5) - def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = snapchat.research.gbml.gbml_config.GbmlConfig.scalaDescriptor.nestedMessages(5) + def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = snapchat.research.gbml.gbml_config.GbmlConfig.javaDescriptor.getNestedTypes().get(6) + def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = snapchat.research.gbml.gbml_config.GbmlConfig.scalaDescriptor.nestedMessages(6) def messageCompanionForFieldNumber(__number: _root_.scala.Int): _root_.scalapb.GeneratedMessageCompanion[_] = { var __out: _root_.scalapb.GeneratedMessageCompanion[_] = null (__number: @_root_.scala.unchecked) match { case 1 => __out = snapchat.research.gbml.gbml_config.GbmlConfig.InferencerConfig.InferencerArgsEntry + case 6 => __out = snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig } __out } @@ -4261,6 +4621,7 @@ object GbmlConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gb inferencerArgs = _root_.scala.collection.immutable.Map.empty, inferencerClsPath = "", inferenceBatchSize = 0, + graphStoreStorageConfig = _root_.scala.None, executable = snapchat.research.gbml.gbml_config.GbmlConfig.InferencerConfig.Executable.Empty ) sealed trait Executable extends _root_.scalapb.GeneratedOneof { @@ -4446,6 +4807,8 @@ object GbmlConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gb def clsPath: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.Predef.String] = field(_.getClsPath)((c_, f_) => c_.copy(executable = snapchat.research.gbml.gbml_config.GbmlConfig.InferencerConfig.Executable.ClsPath(f_))) def command: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.Predef.String] = field(_.getCommand)((c_, f_) => c_.copy(executable = snapchat.research.gbml.gbml_config.GbmlConfig.InferencerConfig.Executable.Command(f_))) def inferenceBatchSize: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.Int] = field(_.inferenceBatchSize)((c_, f_) => c_.copy(inferenceBatchSize = f_)) + def graphStoreStorageConfig: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig] = field(_.getGraphStoreStorageConfig)((c_, f_) => c_.copy(graphStoreStorageConfig = Option(f_))) + def optionalGraphStoreStorageConfig: _root_.scalapb.lenses.Lens[UpperPB, _root_.scala.Option[snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig]] = field(_.graphStoreStorageConfig)((c_, f_) => c_.copy(graphStoreStorageConfig = f_)) def executable: _root_.scalapb.lenses.Lens[UpperPB, snapchat.research.gbml.gbml_config.GbmlConfig.InferencerConfig.Executable] = field(_.executable)((c_, f_) => c_.copy(executable = f_)) } final val INFERENCER_ARGS_FIELD_NUMBER = 1 @@ -4453,18 +4816,21 @@ object GbmlConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gb final val CLS_PATH_FIELD_NUMBER = 100 final val COMMAND_FIELD_NUMBER = 101 final val INFERENCE_BATCH_SIZE_FIELD_NUMBER = 5 + final val GRAPH_STORE_STORAGE_CONFIG_FIELD_NUMBER = 6 @transient private[gbml_config] val _typemapper_inferencerArgs: _root_.scalapb.TypeMapper[snapchat.research.gbml.gbml_config.GbmlConfig.InferencerConfig.InferencerArgsEntry, (_root_.scala.Predef.String, _root_.scala.Predef.String)] = implicitly[_root_.scalapb.TypeMapper[snapchat.research.gbml.gbml_config.GbmlConfig.InferencerConfig.InferencerArgsEntry, (_root_.scala.Predef.String, _root_.scala.Predef.String)]] def of( inferencerArgs: _root_.scala.collection.immutable.Map[_root_.scala.Predef.String, _root_.scala.Predef.String], inferencerClsPath: _root_.scala.Predef.String, executable: snapchat.research.gbml.gbml_config.GbmlConfig.InferencerConfig.Executable, - inferenceBatchSize: _root_.scala.Int + inferenceBatchSize: _root_.scala.Int, + graphStoreStorageConfig: _root_.scala.Option[snapchat.research.gbml.gbml_config.GbmlConfig.GraphStoreStorageConfig] ): _root_.snapchat.research.gbml.gbml_config.GbmlConfig.InferencerConfig = _root_.snapchat.research.gbml.gbml_config.GbmlConfig.InferencerConfig( inferencerArgs, inferencerClsPath, executable, - inferenceBatchSize + inferenceBatchSize, + graphStoreStorageConfig ) // @@protoc_insertion_point(GeneratedMessageCompanion[snapchat.research.gbml.GbmlConfig.InferencerConfig]) } @@ -4582,8 +4948,8 @@ object GbmlConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gb ) case _ => throw new RuntimeException("Expected PMessage") } - def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = snapchat.research.gbml.gbml_config.GbmlConfig.javaDescriptor.getNestedTypes().get(6) - def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = snapchat.research.gbml.gbml_config.GbmlConfig.scalaDescriptor.nestedMessages(6) + def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = snapchat.research.gbml.gbml_config.GbmlConfig.javaDescriptor.getNestedTypes().get(7) + def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = snapchat.research.gbml.gbml_config.GbmlConfig.scalaDescriptor.nestedMessages(7) def messageCompanionForFieldNumber(__number: _root_.scala.Int): _root_.scalapb.GeneratedMessageCompanion[_] = { var __out: _root_.scalapb.GeneratedMessageCompanion[_] = null (__number: @_root_.scala.unchecked) match { @@ -4875,8 +5241,8 @@ object GbmlConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gb ) case _ => throw new RuntimeException("Expected PMessage") } - def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = snapchat.research.gbml.gbml_config.GbmlConfig.javaDescriptor.getNestedTypes().get(7) - def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = snapchat.research.gbml.gbml_config.GbmlConfig.scalaDescriptor.nestedMessages(7) + def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = snapchat.research.gbml.gbml_config.GbmlConfig.javaDescriptor.getNestedTypes().get(8) + def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = snapchat.research.gbml.gbml_config.GbmlConfig.scalaDescriptor.nestedMessages(8) def messageCompanionForFieldNumber(__number: _root_.scala.Int): _root_.scalapb.GeneratedMessageCompanion[_] = { var __out: _root_.scalapb.GeneratedMessageCompanion[_] = null (__number: @_root_.scala.unchecked) match { @@ -5193,8 +5559,8 @@ object GbmlConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gb ) case _ => throw new RuntimeException("Expected PMessage") } - def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = snapchat.research.gbml.gbml_config.GbmlConfig.javaDescriptor.getNestedTypes().get(8) - def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = snapchat.research.gbml.gbml_config.GbmlConfig.scalaDescriptor.nestedMessages(8) + def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = snapchat.research.gbml.gbml_config.GbmlConfig.javaDescriptor.getNestedTypes().get(9) + def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = snapchat.research.gbml.gbml_config.GbmlConfig.scalaDescriptor.nestedMessages(9) def messageCompanionForFieldNumber(__number: _root_.scala.Int): _root_.scalapb.GeneratedMessageCompanion[_] = { var __out: _root_.scalapb.GeneratedMessageCompanion[_] = null (__number: @_root_.scala.unchecked) match { @@ -5494,8 +5860,8 @@ object GbmlConfig extends scalapb.GeneratedMessageCompanion[snapchat.research.gb ) case _ => throw new RuntimeException("Expected PMessage") } - def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = snapchat.research.gbml.gbml_config.GbmlConfig.javaDescriptor.getNestedTypes().get(9) - def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = snapchat.research.gbml.gbml_config.GbmlConfig.scalaDescriptor.nestedMessages(9) + def javaDescriptor: _root_.com.google.protobuf.Descriptors.Descriptor = snapchat.research.gbml.gbml_config.GbmlConfig.javaDescriptor.getNestedTypes().get(10) + def scalaDescriptor: _root_.scalapb.descriptors.Descriptor = snapchat.research.gbml.gbml_config.GbmlConfig.scalaDescriptor.nestedMessages(10) def messageCompanionForFieldNumber(__number: _root_.scala.Int): _root_.scalapb.GeneratedMessageCompanion[_] = throw new MatchError(__number) lazy val nestedMessagesCompanions: Seq[_root_.scalapb.GeneratedMessageCompanion[_ <: _root_.scalapb.GeneratedMessage]] = Seq.empty def enumCompanionForFieldNumber(__fieldNumber: _root_.scala.Int): _root_.scalapb.GeneratedEnumCompanion[_] = throw new MatchError(__fieldNumber) diff --git a/scala_spark35/common/src/main/scala/snapchat/research/gbml/gbml_config/GbmlConfigProto.scala b/scala_spark35/common/src/main/scala/snapchat/research/gbml/gbml_config/GbmlConfigProto.scala index 1109981b0..1b11b1f79 100644 --- a/scala_spark35/common/src/main/scala/snapchat/research/gbml/gbml_config/GbmlConfigProto.scala +++ b/scala_spark35/common/src/main/scala/snapchat/research/gbml/gbml_config/GbmlConfigProto.scala @@ -26,7 +26,7 @@ object GbmlConfigProto extends _root_.scalapb.GeneratedFileObject { GhfbWV0YWRhdGEucHJvdG8aLXNuYXBjaGF0L3Jlc2VhcmNoL2dibWwvZGF0YXNldF9tZXRhZGF0YS5wcm90bxozc25hcGNoYXQvc mVzZWFyY2gvZ2JtbC90cmFpbmVkX21vZGVsX21ldGFkYXRhLnByb3RvGi9zbmFwY2hhdC9yZXNlYXJjaC9nYm1sL2luZmVyZW5jZ V9tZXRhZGF0YS5wcm90bxozc25hcGNoYXQvcmVzZWFyY2gvZ2JtbC9wb3N0cHJvY2Vzc2VkX21ldGFkYXRhLnByb3RvGjdzbmFwY - 2hhdC9yZXNlYXJjaC9nYm1sL3N1YmdyYXBoX3NhbXBsaW5nX3N0cmF0ZWd5LnByb3RvItVHCgpHYm1sQ29uZmlnEmcKDXRhc2tfb + 2hhdC9yZXNlYXJjaC9nYm1sL3N1YmdyYXBoX3NhbXBsaW5nX3N0cmF0ZWd5LnByb3RvIrhMCgpHYm1sQ29uZmlnEmcKDXRhc2tfb WV0YWRhdGEYASABKAsyLy5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sLkdibWxDb25maWcuVGFza01ldGFkYXRhQhHiPw4SDHRhc2tNZ XRhZGF0YVIMdGFza01ldGFkYXRhEmAKDmdyYXBoX21ldGFkYXRhGAIgASgLMiUuc25hcGNoYXQucmVzZWFyY2guZ2JtbC5HcmFwa E1ldGFkYXRhQhLiPw8SDWdyYXBoTWV0YWRhdGFSDWdyYXBoTWV0YWRhdGESZwoNc2hhcmVkX2NvbmZpZxgDIAEoCzIvLnNuYXBja @@ -121,35 +121,43 @@ object GbmlConfigProto extends _root_.scalapb.GeneratedFileObject { 2tleRgBIAEoCUII4j8FEgNrZXlSA2tleRIgCgV2YWx1ZRgCIAEoCUIK4j8HEgV2YWx1ZVIFdmFsdWU6AjgBGlQKEEdyYXBoRGJBc mdzRW50cnkSGgoDa2V5GAEgASgJQgjiPwUSA2tleVIDa2V5EiAKBXZhbHVlGAIgASgJQgriPwcSBXZhbHVlUgV2YWx1ZToCOAEab woUR3JhcGhEQlNlcnZpY2VDb25maWcSVwoaZ3JhcGhfZGJfY2xpZW50X2NsYXNzX3BhdGgYASABKAlCG+I/GBIWZ3JhcGhEYkNsa - WVudENsYXNzUGF0aFIWZ3JhcGhEYkNsaWVudENsYXNzUGF0aBrXAwoNVHJhaW5lckNvbmZpZxI9ChB0cmFpbmVyX2Nsc19wYXRoG - AEgASgJQhPiPxASDnRyYWluZXJDbHNQYXRoUg50cmFpbmVyQ2xzUGF0aBJ2Cgx0cmFpbmVyX2FyZ3MYAiADKAsyQS5zbmFwY2hhd - C5yZXNlYXJjaC5nYm1sLkdibWxDb25maWcuVHJhaW5lckNvbmZpZy5UcmFpbmVyQXJnc0VudHJ5QhDiPw0SC3RyYWluZXJBcmdzU - gt0cmFpbmVyQXJncxIpCghjbHNfcGF0aBhkIAEoCUIM4j8JEgdjbHNQYXRoSABSB2Nsc1BhdGgSKAoHY29tbWFuZBhlIAEoCUIM4 - j8JEgdjb21tYW5kSABSB2NvbW1hbmQSVgoZc2hvdWxkX2xvZ190b190ZW5zb3Jib2FyZBgMIAEoCEIb4j8YEhZzaG91bGRMb2dUb - 1RlbnNvcmJvYXJkUhZzaG91bGRMb2dUb1RlbnNvcmJvYXJkGlQKEFRyYWluZXJBcmdzRW50cnkSGgoDa2V5GAEgASgJQgjiPwUSA - 2tleVIDa2V5EiAKBXZhbHVlGAIgASgJQgriPwcSBXZhbHVlUgV2YWx1ZToCOAFCDAoKZXhlY3V0YWJsZRrpAwoQSW5mZXJlbmNlc - kNvbmZpZxKFAQoPaW5mZXJlbmNlcl9hcmdzGAEgAygLMkcuc25hcGNoYXQucmVzZWFyY2guZ2JtbC5HYm1sQ29uZmlnLkluZmVyZ - W5jZXJDb25maWcuSW5mZXJlbmNlckFyZ3NFbnRyeUIT4j8QEg5pbmZlcmVuY2VyQXJnc1IOaW5mZXJlbmNlckFyZ3MSRgoTaW5mZ - XJlbmNlcl9jbHNfcGF0aBgCIAEoCUIW4j8TEhFpbmZlcmVuY2VyQ2xzUGF0aFIRaW5mZXJlbmNlckNsc1BhdGgSKQoIY2xzX3Bhd - GgYZCABKAlCDOI/CRIHY2xzUGF0aEgAUgdjbHNQYXRoEigKB2NvbW1hbmQYZSABKAlCDOI/CRIHY29tbWFuZEgAUgdjb21tYW5kE - kkKFGluZmVyZW5jZV9iYXRjaF9zaXplGAUgASgNQhfiPxQSEmluZmVyZW5jZUJhdGNoU2l6ZVISaW5mZXJlbmNlQmF0Y2hTaXplG - lcKE0luZmVyZW5jZXJBcmdzRW50cnkSGgoDa2V5GAEgASgJQgjiPwUSA2tleVIDa2V5EiAKBXZhbHVlGAIgASgJQgriPwcSBXZhb - HVlUgV2YWx1ZToCOAFCDAoKZXhlY3V0YWJsZRrbAgoTUG9zdFByb2Nlc3NvckNvbmZpZxKVAQoTcG9zdF9wcm9jZXNzb3JfYXJnc - xgBIAMoCzJNLnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuR2JtbENvbmZpZy5Qb3N0UHJvY2Vzc29yQ29uZmlnLlBvc3RQcm9jZXNzb - 3JBcmdzRW50cnlCFuI/ExIRcG9zdFByb2Nlc3NvckFyZ3NSEXBvc3RQcm9jZXNzb3JBcmdzElAKF3Bvc3RfcHJvY2Vzc29yX2Nsc - 19wYXRoGAIgASgJQhniPxYSFHBvc3RQcm9jZXNzb3JDbHNQYXRoUhRwb3N0UHJvY2Vzc29yQ2xzUGF0aBpaChZQb3N0UHJvY2Vzc - 29yQXJnc0VudHJ5EhoKA2tleRgBIAEoCUII4j8FEgNrZXlSA2tleRIgCgV2YWx1ZRgCIAEoCUIK4j8HEgV2YWx1ZVIFdmFsdWU6A - jgBGpwCCg1NZXRyaWNzQ29uZmlnEj0KEG1ldHJpY3NfY2xzX3BhdGgYASABKAlCE+I/EBIObWV0cmljc0Nsc1BhdGhSDm1ldHJpY - 3NDbHNQYXRoEnYKDG1ldHJpY3NfYXJncxgCIAMoCzJBLnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuR2JtbENvbmZpZy5NZXRyaWNzQ - 29uZmlnLk1ldHJpY3NBcmdzRW50cnlCEOI/DRILbWV0cmljc0FyZ3NSC21ldHJpY3NBcmdzGlQKEE1ldHJpY3NBcmdzRW50cnkSG - goDa2V5GAEgASgJQgjiPwUSA2tleVIDa2V5EiAKBXZhbHVlGAIgASgJQgriPwcSBXZhbHVlUgV2YWx1ZToCOAEa9AIKDlByb2Zpb - GVyQ29uZmlnEk8KFnNob3VsZF9lbmFibGVfcHJvZmlsZXIYASABKAhCGeI/FhIUc2hvdWxkRW5hYmxlUHJvZmlsZXJSFHNob3VsZ - EVuYWJsZVByb2ZpbGVyEj0KEHByb2ZpbGVyX2xvZ19kaXIYAiABKAlCE+I/EBIOcHJvZmlsZXJMb2dEaXJSDnByb2ZpbGVyTG9nR - GlyEnsKDXByb2ZpbGVyX2FyZ3MYAyADKAsyQy5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sLkdibWxDb25maWcuUHJvZmlsZXJDb25ma - WcuUHJvZmlsZXJBcmdzRW50cnlCEeI/DhIMcHJvZmlsZXJBcmdzUgxwcm9maWxlckFyZ3MaVQoRUHJvZmlsZXJBcmdzRW50cnkSG - goDa2V5GAEgASgJQgjiPwUSA2tleVIDa2V5EiAKBXZhbHVlGAIgASgJQgriPwcSBXZhbHVlUgV2YWx1ZToCOAEaVQoRRmVhdHVyZ - UZsYWdzRW50cnkSGgoDa2V5GAEgASgJQgjiPwUSA2tleVIDa2V5EiAKBXZhbHVlGAIgASgJQgriPwcSBXZhbHVlUgV2YWx1ZToCO - AFiBnByb3RvMw==""" + WVudENsYXNzUGF0aFIWZ3JhcGhEYkNsaWVudENsYXNzUGF0aBqwAgoXR3JhcGhTdG9yZVN0b3JhZ2VDb25maWcSPAoPc3RvcmFnZ + V9jb21tYW5kGAEgASgJQhPiPxASDnN0b3JhZ2VDb21tYW5kUg5zdG9yYWdlQ29tbWFuZBKAAQoMc3RvcmFnZV9hcmdzGAIgAygLM + ksuc25hcGNoYXQucmVzZWFyY2guZ2JtbC5HYm1sQ29uZmlnLkdyYXBoU3RvcmVTdG9yYWdlQ29uZmlnLlN0b3JhZ2VBcmdzRW50c + nlCEOI/DRILc3RvcmFnZUFyZ3NSC3N0b3JhZ2VBcmdzGlQKEFN0b3JhZ2VBcmdzRW50cnkSGgoDa2V5GAEgASgJQgjiPwUSA2tle + VIDa2V5EiAKBXZhbHVlGAIgASgJQgriPwcSBXZhbHVlUgV2YWx1ZToCOAEa7wQKDVRyYWluZXJDb25maWcSPQoQdHJhaW5lcl9jb + HNfcGF0aBgBIAEoCUIT4j8QEg50cmFpbmVyQ2xzUGF0aFIOdHJhaW5lckNsc1BhdGgSdgoMdHJhaW5lcl9hcmdzGAIgAygLMkEuc + 25hcGNoYXQucmVzZWFyY2guZ2JtbC5HYm1sQ29uZmlnLlRyYWluZXJDb25maWcuVHJhaW5lckFyZ3NFbnRyeUIQ4j8NEgt0cmFpb + mVyQXJnc1ILdHJhaW5lckFyZ3MSKQoIY2xzX3BhdGgYZCABKAlCDOI/CRIHY2xzUGF0aEgAUgdjbHNQYXRoEigKB2NvbW1hbmQYZ + SABKAlCDOI/CRIHY29tbWFuZEgAUgdjb21tYW5kElYKGXNob3VsZF9sb2dfdG9fdGVuc29yYm9hcmQYDCABKAhCG+I/GBIWc2hvd + WxkTG9nVG9UZW5zb3Jib2FyZFIWc2hvdWxkTG9nVG9UZW5zb3Jib2FyZBKVAQoaZ3JhcGhfc3RvcmVfc3RvcmFnZV9jb25maWcYD + SABKAsyOi5zbmFwY2hhdC5yZXNlYXJjaC5nYm1sLkdibWxDb25maWcuR3JhcGhTdG9yZVN0b3JhZ2VDb25maWdCHOI/GRIXZ3Jhc + GhTdG9yZVN0b3JhZ2VDb25maWdSF2dyYXBoU3RvcmVTdG9yYWdlQ29uZmlnGlQKEFRyYWluZXJBcmdzRW50cnkSGgoDa2V5GAEgA + SgJQgjiPwUSA2tleVIDa2V5EiAKBXZhbHVlGAIgASgJQgriPwcSBXZhbHVlUgV2YWx1ZToCOAFCDAoKZXhlY3V0YWJsZRqBBQoQS + W5mZXJlbmNlckNvbmZpZxKFAQoPaW5mZXJlbmNlcl9hcmdzGAEgAygLMkcuc25hcGNoYXQucmVzZWFyY2guZ2JtbC5HYm1sQ29uZ + mlnLkluZmVyZW5jZXJDb25maWcuSW5mZXJlbmNlckFyZ3NFbnRyeUIT4j8QEg5pbmZlcmVuY2VyQXJnc1IOaW5mZXJlbmNlckFyZ + 3MSRgoTaW5mZXJlbmNlcl9jbHNfcGF0aBgCIAEoCUIW4j8TEhFpbmZlcmVuY2VyQ2xzUGF0aFIRaW5mZXJlbmNlckNsc1BhdGgSK + QoIY2xzX3BhdGgYZCABKAlCDOI/CRIHY2xzUGF0aEgAUgdjbHNQYXRoEigKB2NvbW1hbmQYZSABKAlCDOI/CRIHY29tbWFuZEgAU + gdjb21tYW5kEkkKFGluZmVyZW5jZV9iYXRjaF9zaXplGAUgASgNQhfiPxQSEmluZmVyZW5jZUJhdGNoU2l6ZVISaW5mZXJlbmNlQ + mF0Y2hTaXplEpUBChpncmFwaF9zdG9yZV9zdG9yYWdlX2NvbmZpZxgGIAEoCzI6LnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuR2Jtb + ENvbmZpZy5HcmFwaFN0b3JlU3RvcmFnZUNvbmZpZ0Ic4j8ZEhdncmFwaFN0b3JlU3RvcmFnZUNvbmZpZ1IXZ3JhcGhTdG9yZVN0b + 3JhZ2VDb25maWcaVwoTSW5mZXJlbmNlckFyZ3NFbnRyeRIaCgNrZXkYASABKAlCCOI/BRIDa2V5UgNrZXkSIAoFdmFsdWUYAiABK + AlCCuI/BxIFdmFsdWVSBXZhbHVlOgI4AUIMCgpleGVjdXRhYmxlGtsCChNQb3N0UHJvY2Vzc29yQ29uZmlnEpUBChNwb3N0X3Byb + 2Nlc3Nvcl9hcmdzGAEgAygLMk0uc25hcGNoYXQucmVzZWFyY2guZ2JtbC5HYm1sQ29uZmlnLlBvc3RQcm9jZXNzb3JDb25maWcuU + G9zdFByb2Nlc3NvckFyZ3NFbnRyeUIW4j8TEhFwb3N0UHJvY2Vzc29yQXJnc1IRcG9zdFByb2Nlc3NvckFyZ3MSUAoXcG9zdF9wc + m9jZXNzb3JfY2xzX3BhdGgYAiABKAlCGeI/FhIUcG9zdFByb2Nlc3NvckNsc1BhdGhSFHBvc3RQcm9jZXNzb3JDbHNQYXRoGloKF + lBvc3RQcm9jZXNzb3JBcmdzRW50cnkSGgoDa2V5GAEgASgJQgjiPwUSA2tleVIDa2V5EiAKBXZhbHVlGAIgASgJQgriPwcSBXZhb + HVlUgV2YWx1ZToCOAEanAIKDU1ldHJpY3NDb25maWcSPQoQbWV0cmljc19jbHNfcGF0aBgBIAEoCUIT4j8QEg5tZXRyaWNzQ2xzU + GF0aFIObWV0cmljc0Nsc1BhdGgSdgoMbWV0cmljc19hcmdzGAIgAygLMkEuc25hcGNoYXQucmVzZWFyY2guZ2JtbC5HYm1sQ29uZ + mlnLk1ldHJpY3NDb25maWcuTWV0cmljc0FyZ3NFbnRyeUIQ4j8NEgttZXRyaWNzQXJnc1ILbWV0cmljc0FyZ3MaVAoQTWV0cmljc + 0FyZ3NFbnRyeRIaCgNrZXkYASABKAlCCOI/BRIDa2V5UgNrZXkSIAoFdmFsdWUYAiABKAlCCuI/BxIFdmFsdWVSBXZhbHVlOgI4A + Rr0AgoOUHJvZmlsZXJDb25maWcSTwoWc2hvdWxkX2VuYWJsZV9wcm9maWxlchgBIAEoCEIZ4j8WEhRzaG91bGRFbmFibGVQcm9ma + WxlclIUc2hvdWxkRW5hYmxlUHJvZmlsZXISPQoQcHJvZmlsZXJfbG9nX2RpchgCIAEoCUIT4j8QEg5wcm9maWxlckxvZ0RpclIOc + HJvZmlsZXJMb2dEaXISewoNcHJvZmlsZXJfYXJncxgDIAMoCzJDLnNuYXBjaGF0LnJlc2VhcmNoLmdibWwuR2JtbENvbmZpZy5Qc + m9maWxlckNvbmZpZy5Qcm9maWxlckFyZ3NFbnRyeUIR4j8OEgxwcm9maWxlckFyZ3NSDHByb2ZpbGVyQXJncxpVChFQcm9maWxlc + kFyZ3NFbnRyeRIaCgNrZXkYASABKAlCCOI/BRIDa2V5UgNrZXkSIAoFdmFsdWUYAiABKAlCCuI/BxIFdmFsdWVSBXZhbHVlOgI4A + RpVChFGZWF0dXJlRmxhZ3NFbnRyeRIaCgNrZXkYASABKAlCCOI/BRIDa2V5UgNrZXkSIAoFdmFsdWUYAiABKAlCCuI/BxIFdmFsd + WVSBXZhbHVlOgI4AWIGcHJvdG8z""" ).mkString) lazy val scalaDescriptor: _root_.scalapb.descriptors.FileDescriptor = { val scalaProto = com.google.protobuf.descriptor.FileDescriptorProto.parseFrom(ProtoBytes) @@ -169,4 +177,4 @@ object GbmlConfigProto extends _root_.scalapb.GeneratedFileObject { } @deprecated("Use javaDescriptor instead. In a future version this will refer to scalaDescriptor.", "ScalaPB 0.5.47") def descriptor: com.google.protobuf.Descriptors.FileDescriptor = javaDescriptor -} \ No newline at end of file +} From a25e95a3161b8d392ea48d5e4848948811ce5db7 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Thu, 22 Jan 2026 18:40:01 +0000 Subject: [PATCH 2/4] Setup to launch custom storage --- .../e2e_hom_cora_sup_gs_task_config.yaml | 4 + .../graph_store/storage_main.py | 44 ++++++ .../distributed/graph_store/storage_main.py | 2 + python/gigl/src/common/vertex_ai_launcher.py | 11 +- .../gigl/src/inference/v2/glt_inferencer.py | 2 + python/gigl/src/training/v2/glt_trainer.py | 2 + .../graph_store_integration_test.py | 6 +- python/tests/test_assets/distributed/utils.py | 128 +----------------- .../distributed_neighborloader_test.py | 113 ---------------- .../src/common/vertex_ai_launcher_test.py | 2 + 10 files changed, 65 insertions(+), 249 deletions(-) create mode 100644 examples/link_prediction/graph_store/storage_main.py diff --git a/examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_task_config.yaml b/examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_task_config.yaml index 6cf4bdeea..886e12eee 100644 --- a/examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_task_config.yaml +++ b/examples/link_prediction/graph_store/configs/e2e_hom_cora_sup_gs_task_config.yaml @@ -30,6 +30,10 @@ inferencerConfig: num_neighbors: "[10, 10]" # Fanout per hop, specified as a string representation of a list for the homogeneous use case inferenceBatchSize: 512 command: python -m examples.link_prediction.graph_store.homogeneous_inference + graphStoreStorageConfig: + storageCommand: python -m examples.link_prediction.graph_store.storage_main + storageArgs: + is_inference: "True" sharedConfig: shouldSkipInference: false # Model Evaluation is currently only supported for tabularized SGS GiGL pipelines. This will soon be added for in-mem SGS GiGL pipelines. diff --git a/examples/link_prediction/graph_store/storage_main.py b/examples/link_prediction/graph_store/storage_main.py new file mode 100644 index 000000000..12ad4b239 --- /dev/null +++ b/examples/link_prediction/graph_store/storage_main.py @@ -0,0 +1,44 @@ +"""Built-in GiGL Graph Store Server. + +Derivved from https://github.com/alibaba/graphlearn-for-pytorch/blob/main/examples/distributed/server_client_mode/sage_supervised_server.py + +""" +import argparse +import os +from distutils.util import strtobool + +import torch + +from gigl.common import UriFactory +from gigl.common.logger import Logger +from gigl.distributed.graph_store.storage_process import storage_node_process +from gigl.distributed.utils import get_graph_store_info + +logger = Logger() + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--task_config_uri", type=str, required=True) + parser.add_argument("--resource_config_uri", type=str, required=True) + parser.add_argument("--job_name", type=str, required=True) + parser.add_argument("--is_inference", type=str, required=True) + args = parser.parse_args() + logger.info(f"Running storage node with arguments: {args}") + + is_inference = bool(strtobool(args.is_inference)) + torch.distributed.init_process_group(backend="gloo") + cluster_info = get_graph_store_info() + logger.info(f"Cluster info: {cluster_info}") + logger.info( + f"World size: {torch.distributed.get_world_size()}, rank: {torch.distributed.get_rank()}, OS world size: {os.environ['WORLD_SIZE']}, OS rank: {os.environ['RANK']}" + ) + logger.info(f"Is inference: {is_inference}") + # Tear down the """"global""" process group so we can have a server-specific process group. + torch.distributed.destroy_process_group() + storage_node_process( + storage_rank=cluster_info.storage_node_rank, + cluster_info=cluster_info, + task_config_uri=UriFactory.create_uri(args.task_config_uri), + is_inference=is_inference, + ) diff --git a/python/gigl/distributed/graph_store/storage_main.py b/python/gigl/distributed/graph_store/storage_main.py index 09a0d79cd..d7141e412 100644 --- a/python/gigl/distributed/graph_store/storage_main.py +++ b/python/gigl/distributed/graph_store/storage_main.py @@ -2,6 +2,8 @@ Derivved from https://github.com/alibaba/graphlearn-for-pytorch/blob/main/examples/distributed/server_client_mode/sage_supervised_server.py +TODO(kmonte): Remove this, and only expose utils. +We keep this around so we can use the utils in tests/integration/distributed/graph_store/graph_store_integration_test.py. """ import argparse import os diff --git a/python/gigl/src/common/vertex_ai_launcher.py b/python/gigl/src/common/vertex_ai_launcher.py index 4d446a7f6..09d6d6a0c 100644 --- a/python/gigl/src/common/vertex_ai_launcher.py +++ b/python/gigl/src/common/vertex_ai_launcher.py @@ -98,6 +98,8 @@ def launch_graph_store_enabled_job( resource_config_uri: Uri, process_command: str, process_runtime_args: Mapping[str, str], + storage_command: str, + storage_args: Mapping[str, str], resource_config_wrapper: GiglResourceConfigWrapper, cpu_docker_uri: Optional[str], cuda_docker_uri: Optional[str], @@ -112,6 +114,8 @@ def launch_graph_store_enabled_job( resource_config_uri: URI to the resource configuration process_command: Command to run in the compute container process_runtime_args: Runtime arguments for the process + storage_command: Command to run in the storage container + storage_args: Arguments to pass to the storage command resource_config_wrapper: Wrapper for the resource configuration cpu_docker_uri: Docker image URI for CPU execution cuda_docker_uri: Docker image URI for GPU execution @@ -173,8 +177,8 @@ def launch_graph_store_enabled_job( job_name=job_name, task_config_uri=task_config_uri, resource_config_uri=resource_config_uri, - command_str=f"python -m gigl.distributed.graph_store.storage_main", - args={}, # No extra args for storage pool + command_str=storage_command, + args=storage_args, use_cuda=is_cpu_execution, container_uri=container_uri, vertex_ai_resource_config=storage_pool_config, @@ -272,9 +276,6 @@ def _build_job_config( ) if vertex_ai_resource_config.scheduling_strategy else None, - boot_disk_size_gb=vertex_ai_resource_config.boot_disk_size_gb - if vertex_ai_resource_config.boot_disk_size_gb - else 100, # Default to 100 GB for backward compatibility ) return job_config diff --git a/python/gigl/src/inference/v2/glt_inferencer.py b/python/gigl/src/inference/v2/glt_inferencer.py index 93f47fbf6..317e250f2 100644 --- a/python/gigl/src/inference/v2/glt_inferencer.py +++ b/python/gigl/src/inference/v2/glt_inferencer.py @@ -83,6 +83,8 @@ def __execute_VAI_inference( resource_config_uri=resource_config_uri, process_command=inference_process_command, process_runtime_args=inference_process_runtime_args, + storage_command=gbml_config_pb_wrapper.inferencer_config.graph_store_storage_config.storage_command, + storage_args=gbml_config_pb_wrapper.inferencer_config.graph_store_storage_config.storage_args, resource_config_wrapper=resource_config_wrapper, cpu_docker_uri=cpu_docker_uri, cuda_docker_uri=cuda_docker_uri, diff --git a/python/gigl/src/training/v2/glt_trainer.py b/python/gigl/src/training/v2/glt_trainer.py index a9a33b760..a805936ea 100644 --- a/python/gigl/src/training/v2/glt_trainer.py +++ b/python/gigl/src/training/v2/glt_trainer.py @@ -79,6 +79,8 @@ def __execute_VAI_training( resource_config_uri=resource_config_uri, process_command=training_process_command, process_runtime_args=training_process_runtime_args, + storage_command=gbml_config_pb_wrapper.trainer_config.graph_store_storage_config.storage_command, + storage_args=gbml_config_pb_wrapper.trainer_config.graph_store_storage_config.storage_args, resource_config_wrapper=resource_config, cpu_docker_uri=cpu_docker_uri, cuda_docker_uri=cuda_docker_uri, diff --git a/python/tests/integration/distributed/graph_store/graph_store_integration_test.py b/python/tests/integration/distributed/graph_store/graph_store_integration_test.py index 23ecc787b..ab127b7ee 100644 --- a/python/tests/integration/distributed/graph_store/graph_store_integration_test.py +++ b/python/tests/integration/distributed/graph_store/graph_store_integration_test.py @@ -38,7 +38,7 @@ def _assert_sampler_input( cluster_info: GraphStoreInfo, - sampler_input: dict[int, torch.Tensor], + sampler_input: list[torch.Tensor], expected_sampler_input: dict[int, list[torch.Tensor]], ) -> None: rank_expected_sampler_input = expected_sampler_input[cluster_info.compute_node_rank] @@ -123,9 +123,7 @@ def _run_compute_tests( torch.distributed.barrier() if node_type is not None: - input_nodes: Union[ - dict[int, torch.Tensor], tuple[NodeType, dict[int, torch.Tensor]] - ] = ( + input_nodes: Union[list[torch.Tensor], tuple[NodeType, list[torch.Tensor]]] = ( node_type, sampler_input, ) diff --git a/python/tests/test_assets/distributed/utils.py b/python/tests/test_assets/distributed/utils.py index b60403271..15a6f3a95 100644 --- a/python/tests/test_assets/distributed/utils.py +++ b/python/tests/test_assets/distributed/utils.py @@ -2,10 +2,7 @@ import torch -from gigl.distributed.graph_store.remote_dist_dataset import RemoteDistDataset -from gigl.distributed.utils import get_free_port, get_free_ports -from gigl.env.distributed import GraphStoreInfo -from gigl.src.common.types.graph_data import EdgeType +from gigl.distributed.utils import get_free_port def assert_tensor_equality( @@ -70,126 +67,3 @@ def create_test_process_group() -> None: world_size=1, init_method=get_process_group_init_method(), ) - - -class MockGraphStoreInfo(GraphStoreInfo): - """ - A mock wrapper around GraphStoreInfo that allows overriding the compute_node_rank property. - - The real GraphStoreInfo.compute_node_rank reads from environment variables (RANK), - which makes it difficult to test. This mock allows setting the compute_node_rank - directly during initialization. - - Args: - real_info: The real GraphStoreInfo instance to delegate to for most properties. - compute_node_rank: The compute node rank value to return instead of reading from env. - - Example: - >>> real_info = GraphStoreInfo(num_storage_nodes=2, ...) - >>> mock_info = MockGraphStoreInfo(real_info, compute_node_rank=0) - >>> mock_info.compute_node_rank # Returns 0 instead of reading from env - """ - - def __init__(self, real_info: GraphStoreInfo, compute_node_rank: int): - self._real_info = real_info - self._compute_node_rank = compute_node_rank - - @property - def num_storage_nodes(self) -> int: - return self._real_info.num_storage_nodes - - @property - def num_compute_nodes(self) -> int: - return self._real_info.num_compute_nodes - - @property - def storage_cluster_master_ip(self) -> str: - return self._real_info.storage_cluster_master_ip - - @property - def num_processes_per_compute(self) -> int: - return self._real_info.num_processes_per_compute - - @property - def compute_node_rank(self) -> int: - return self._compute_node_rank - - -class MockRemoteDistDataset(RemoteDistDataset): - """ - A mock RemoteDistDataset for testing that doesn't make remote RPC calls. - - The real RemoteDistDataset makes remote calls to storage nodes via graphlearn_torch's - request_server mechanism. This mock class overrides all remote-calling methods to - return configurable mock values, enabling unit testing of code that depends on - RemoteDistDataset without needing a real distributed cluster. - - Args: - num_storage_nodes: Number of storage nodes in the mock cluster. Defaults to 2. - num_compute_nodes: Number of compute nodes in the mock cluster. Defaults to 1. - num_processes_per_compute: Number of processes per compute node. Defaults to 1. - compute_node_rank: The rank of the compute node. Defaults to 0. - edge_types: Optional list of edge types for heterogeneous graphs. Defaults to None. - edge_dir: Edge direction, either "in" or "out". Defaults to "out". - - Example: - >>> mock_dataset = MockRemoteDistDataset( - ... num_storage_nodes=2, - ... edge_types=[EdgeType("user", "knows", "user")], - ... ) - >>> mock_dataset.get_edge_types() # Returns the configured edge_types - >>> mock_dataset.cluster_info.num_storage_nodes # Returns 2 - """ - - def __init__( - self, - num_storage_nodes: int = 2, - num_compute_nodes: int = 1, - num_processes_per_compute: int = 1, - compute_node_rank: int = 0, - edge_types: Optional[list[EdgeType]] = None, - edge_dir: str = "out", - ): - # Create a mock GraphStoreInfo with placeholder values - self._mock_cluster_info = GraphStoreInfo( - num_storage_nodes=num_storage_nodes, - num_compute_nodes=num_compute_nodes, - cluster_master_ip="127.0.0.1", - storage_cluster_master_ip="127.0.0.1", - compute_cluster_master_ip="127.0.0.1", - cluster_master_port=12345, - storage_cluster_master_port=12346, - compute_cluster_master_port=12347, - num_processes_per_compute=num_processes_per_compute, - rpc_master_port=12348, - rpc_wait_port=12349, - ) - self._mock_compute_node_rank = compute_node_rank - self._mock_edge_types = edge_types - self._mock_edge_dir = edge_dir - # Don't call super().__init__() to avoid needing a real cluster connection - - @property - def cluster_info(self) -> GraphStoreInfo: - """Returns a MockGraphStoreInfo with the configured compute_node_rank.""" - return MockGraphStoreInfo(self._mock_cluster_info, self._mock_compute_node_rank) - - def get_node_feature_info(self): - """Returns None (no node features configured).""" - return None - - def get_edge_feature_info(self): - """Returns None (no edge features configured).""" - return None - - def get_edge_dir(self) -> str: - """Returns the configured edge direction.""" - return self._mock_edge_dir - - def get_edge_types(self) -> Optional[list[EdgeType]]: - """Returns the configured edge types.""" - return self._mock_edge_types - - def get_free_ports_on_storage_cluster(self, num_ports: int) -> list[int]: - """Returns a list of mock port numbers starting at 20000.""" - return get_free_ports(num_ports=num_ports) diff --git a/python/tests/unit/distributed/distributed_neighborloader_test.py b/python/tests/unit/distributed/distributed_neighborloader_test.py index 156e77069..74ecd13eb 100644 --- a/python/tests/unit/distributed/distributed_neighborloader_test.py +++ b/python/tests/unit/distributed/distributed_neighborloader_test.py @@ -4,7 +4,6 @@ import torch import torch.multiprocessing as mp from graphlearn_torch.distributed import shutdown_rpc -from parameterized import param, parameterized from torch_geometric.data import Data, HeteroData from gigl.distributed.dataset_factory import build_dataset @@ -38,7 +37,6 @@ run_distributed_dataset, ) from tests.test_assets.distributed.utils import ( - MockRemoteDistDataset, assert_tensor_equality, create_test_process_group, ) @@ -575,117 +573,6 @@ def test_isolated_homogeneous_neighbor_loader( args=(dataset, 18), ) - @parameterized.expand( - [ - param( - "input_nodes is None and dataset.node_ids is None", - expected_error=ValueError, - dataset=DistDataset(rank=0, world_size=1, edge_dir="out"), - num_neighbors=[2, 2], - input_nodes=None, - ), - param( - "input_nodes is None for heterogeneous dataset", - expected_error=ValueError, - dataset=DistDataset( - rank=0, - world_size=1, - edge_dir="out", - graph_partition={}, - node_partition_book={}, - node_ids={NodeType("a"): torch.tensor([1, 2, 3])}, - ), - num_neighbors=[2, 2], - input_nodes=None, - ), - param( - "input_nodes is a dict (colocated mode expects Tensor)", - expected_error=ValueError, - dataset=DistDataset(rank=0, world_size=1, edge_dir="out"), - num_neighbors=[2, 2], - input_nodes={0: torch.tensor([10])}, - ), - param( - "input_nodes is tuple with dict as second element", - expected_error=ValueError, - dataset=DistDataset(rank=0, world_size=1, edge_dir="out"), - num_neighbors=[2, 2], - input_nodes=(NodeType("a"), {0: torch.tensor([10])}), - ), - param( - "Heterogeneous dataset with tensor input_nodes (not labeled homogeneous)", - expected_error=ValueError, - dataset=DistDataset( - rank=0, - world_size=1, - edge_dir="out", - graph_partition={}, - node_partition_book={}, - node_ids={ - NodeType("a"): torch.tensor([1, 2]), - NodeType("b"): torch.tensor([3, 4]), - }, - ), - num_neighbors=[2, 2], - input_nodes=torch.tensor([10]), - ), - param( - "input_nodes is None (graph store mode)", - expected_error=ValueError, - dataset=MockRemoteDistDataset(num_storage_nodes=2), - num_neighbors=[2, 2], - input_nodes=None, - ), - param( - "input_nodes is a Tensor (graph store mode expects Mapping)", - expected_error=ValueError, - dataset=MockRemoteDistDataset(num_storage_nodes=2), - num_neighbors=[2, 2], - input_nodes=torch.tensor([10, 20]), - ), - param( - "input_nodes is tuple with Tensor (graph store mode expects Mapping)", - expected_error=ValueError, - dataset=MockRemoteDistDataset(num_storage_nodes=2), - num_neighbors=[2, 2], - input_nodes=(NodeType("a"), torch.tensor([10, 20])), - ), - param( - "Heterogeneous input without edge_types", - expected_error=ValueError, - dataset=MockRemoteDistDataset(num_storage_nodes=2, edge_types=None), - num_neighbors=[2, 2], - input_nodes=( - NodeType("a"), - {0: torch.tensor([10]), 1: torch.tensor([20])}, - ), - ), - param( - "Server rank exceeds num_storage_nodes", - expected_error=ValueError, - dataset=MockRemoteDistDataset(num_storage_nodes=2), - num_neighbors=[2, 2], - input_nodes={0: torch.tensor([10]), 5: torch.tensor([20])}, - ), - param( - "Server rank is negative", - expected_error=ValueError, - dataset=MockRemoteDistDataset(num_storage_nodes=2), - num_neighbors=[2, 2], - input_nodes={-1: torch.tensor([10]), 0: torch.tensor([20])}, - ), - ] - ) - def test_distributed_neighbor_loader_invalid_inputs_colocated( - self, - _: str, - expected_error: type[BaseException], - **kwargs, - ): - create_test_process_group() - with self.assertRaises(expected_error): - DistNeighborLoader(**kwargs) - if __name__ == "__main__": unittest.main() diff --git a/python/tests/unit/src/common/vertex_ai_launcher_test.py b/python/tests/unit/src/common/vertex_ai_launcher_test.py index 80fb7f249..c8e9a72a0 100644 --- a/python/tests/unit/src/common/vertex_ai_launcher_test.py +++ b/python/tests/unit/src/common/vertex_ai_launcher_test.py @@ -147,6 +147,8 @@ def test_launch_training_graph_store_cuda(self, mock_vertex_ai_service_class): process_command=process_command, process_runtime_args=process_runtime_args, resource_config_wrapper=resource_config_wrapper, + storage_command="python -m gigl.distributed.graph_store.storage_main", + storage_args={}, cpu_docker_uri=cpu_docker_uri, cuda_docker_uri=cuda_docker_uri, component=component, From 7c61849b9822a83e67868783d41de26ea17f75b4 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Thu, 22 Jan 2026 18:48:43 +0000 Subject: [PATCH 3/4] fixes --- .../graph_store/storage_main.py | 117 +++++++++++++++++- .../graph_store_integration_test.py | 6 +- 2 files changed, 117 insertions(+), 6 deletions(-) diff --git a/examples/link_prediction/graph_store/storage_main.py b/examples/link_prediction/graph_store/storage_main.py index 12ad4b239..379aa37a6 100644 --- a/examples/link_prediction/graph_store/storage_main.py +++ b/examples/link_prediction/graph_store/storage_main.py @@ -2,21 +2,132 @@ Derivved from https://github.com/alibaba/graphlearn-for-pytorch/blob/main/examples/distributed/server_client_mode/sage_supervised_server.py +# TODO(kmonte): Figure out how we should split out common utils from this file. + """ import argparse import os from distutils.util import strtobool +from typing import Optional +import graphlearn_torch as glt import torch -from gigl.common import UriFactory +from gigl.common import Uri, UriFactory from gigl.common.logger import Logger -from gigl.distributed.graph_store.storage_process import storage_node_process +from gigl.distributed import build_dataset_from_task_config_uri +from gigl.distributed.dist_dataset import DistDataset +from gigl.distributed.graph_store.storage_utils import register_dataset from gigl.distributed.utils import get_graph_store_info +from gigl.distributed.utils.networking import get_free_ports_from_master_node +from gigl.env.distributed import GraphStoreInfo logger = Logger() +def _run_storage_process( + storage_rank: int, + cluster_info: GraphStoreInfo, + dataset: DistDataset, + torch_process_port: int, + storage_world_backend: Optional[str], +) -> None: + register_dataset(dataset) + cluster_master_ip = cluster_info.storage_cluster_master_ip + logger.info( + f"Initializing GLT server for storage node process group {storage_rank} / {cluster_info.num_storage_nodes} on {cluster_master_ip}:{cluster_info.rpc_master_port}" + ) + # Initialize the GLT server before starting the Torch Distributed process group. + # Otherwise, we saw intermittent hangs when initializing the server. + glt.distributed.init_server( + num_servers=cluster_info.num_storage_nodes, + server_rank=storage_rank, + dataset=dataset, + master_addr=cluster_master_ip, + master_port=cluster_info.rpc_master_port, + num_clients=cluster_info.compute_cluster_world_size, + ) + + init_method = f"tcp://{cluster_info.storage_cluster_master_ip}:{torch_process_port}" + logger.info( + f"Initializing storage node process group {storage_rank} / {cluster_info.num_storage_nodes} with backend {storage_world_backend} on {init_method}" + ) + torch.distributed.init_process_group( + backend=storage_world_backend, + world_size=cluster_info.num_storage_nodes, + rank=storage_rank, + init_method=init_method, + ) + + logger.info( + f"Waiting for storage node {storage_rank} / {cluster_info.num_storage_nodes} to exit" + ) + glt.distributed.wait_and_shutdown_server() + logger.info(f"Storage node {storage_rank} exited") + + +def storage_node_process( + storage_rank: int, + cluster_info: GraphStoreInfo, + task_config_uri: Uri, + is_inference: bool = True, + tf_record_uri_pattern: str = ".*-of-.*\.tfrecord(\.gz)?$", + storage_world_backend: Optional[str] = None, +) -> None: + """Run a storage node process + + Should be called *once* per storage node (machine). + + Args: + storage_rank (int): The rank of the storage node. + cluster_info (GraphStoreInfo): The cluster information. + task_config_uri (Uri): The task config URI. + is_inference (bool): Whether the process is an inference process. Defaults to True. + tf_record_uri_pattern (str): The TF Record URI pattern. + storage_world_backend (Optional[str]): The backend for the storage Torch Distributed process group. + """ + init_method = f"tcp://{cluster_info.storage_cluster_master_ip}:{cluster_info.storage_cluster_master_port}" + logger.info( + f"Initializing storage node {storage_rank} / {cluster_info.num_storage_nodes}. OS rank: {os.environ['RANK']}, OS world size: {os.environ['WORLD_SIZE']} init method: {init_method}" + ) + torch.distributed.init_process_group( + backend="gloo", + world_size=cluster_info.num_storage_nodes, + rank=storage_rank, + init_method=init_method, + group_name="gigl_server_comms", + ) + logger.info( + f"Storage node {storage_rank} / {cluster_info.num_storage_nodes} process group initialized" + ) + dataset = build_dataset_from_task_config_uri( + task_config_uri=task_config_uri, + is_inference=is_inference, + _tfrecord_uri_pattern=tf_record_uri_pattern, + ) + torch_process_port = get_free_ports_from_master_node(num_ports=1)[0] + torch.distributed.destroy_process_group() + server_processes = [] + mp_context = torch.multiprocessing.get_context("spawn") + # TODO(kmonte): Enable more than one server process per machine + for i in range(1): + server_process = mp_context.Process( + target=_run_storage_process, + args=( + storage_rank + i, # storage_rank + cluster_info, # cluster_info + dataset, # dataset + torch_process_port, # torch_process_port + storage_world_backend, # storage_world_backend + ), + ) + server_processes.append(server_process) + for server_process in server_processes: + server_process.start() + for server_process in server_processes: + server_process.join() + + if __name__ == "__main__": parser = argparse.ArgumentParser() parser.add_argument("--task_config_uri", type=str, required=True) @@ -25,7 +136,6 @@ parser.add_argument("--is_inference", type=str, required=True) args = parser.parse_args() logger.info(f"Running storage node with arguments: {args}") - is_inference = bool(strtobool(args.is_inference)) torch.distributed.init_process_group(backend="gloo") cluster_info = get_graph_store_info() @@ -33,7 +143,6 @@ logger.info( f"World size: {torch.distributed.get_world_size()}, rank: {torch.distributed.get_rank()}, OS world size: {os.environ['WORLD_SIZE']}, OS rank: {os.environ['RANK']}" ) - logger.info(f"Is inference: {is_inference}") # Tear down the """"global""" process group so we can have a server-specific process group. torch.distributed.destroy_process_group() storage_node_process( diff --git a/python/tests/integration/distributed/graph_store/graph_store_integration_test.py b/python/tests/integration/distributed/graph_store/graph_store_integration_test.py index ab127b7ee..23ecc787b 100644 --- a/python/tests/integration/distributed/graph_store/graph_store_integration_test.py +++ b/python/tests/integration/distributed/graph_store/graph_store_integration_test.py @@ -38,7 +38,7 @@ def _assert_sampler_input( cluster_info: GraphStoreInfo, - sampler_input: list[torch.Tensor], + sampler_input: dict[int, torch.Tensor], expected_sampler_input: dict[int, list[torch.Tensor]], ) -> None: rank_expected_sampler_input = expected_sampler_input[cluster_info.compute_node_rank] @@ -123,7 +123,9 @@ def _run_compute_tests( torch.distributed.barrier() if node_type is not None: - input_nodes: Union[list[torch.Tensor], tuple[NodeType, list[torch.Tensor]]] = ( + input_nodes: Union[ + dict[int, torch.Tensor], tuple[NodeType, dict[int, torch.Tensor]] + ] = ( node_type, sampler_input, ) From 2b3791ce0ef648068045511bc5c3d102d7d99281 Mon Sep 17 00:00:00 2001 From: kmontemayor Date: Thu, 22 Jan 2026 19:15:06 +0000 Subject: [PATCH 4/4] Add validation checks for custom storage main --- .../src/validation_check/config_validator.py | 86 ++++ ...nd_resource_config_compatibility_checks.py | 167 ++++++++ .../libs/resource_config_checks.py | 51 +++ ...source_config_compatibility_checks_test.py | 403 ++++++++++++++++++ .../lib/resource_config_checks_test.py | 90 +++- 5 files changed, 796 insertions(+), 1 deletion(-) create mode 100644 python/gigl/src/validation_check/libs/gbml_and_resource_config_compatibility_checks.py create mode 100644 python/tests/unit/src/validation/lib/gbml_and_resource_config_compatibility_checks_test.py diff --git a/python/gigl/src/validation_check/config_validator.py b/python/gigl/src/validation_check/config_validator.py index 49af4637f..ec0ca4caf 100644 --- a/python/gigl/src/validation_check/config_validator.py +++ b/python/gigl/src/validation_check/config_validator.py @@ -15,6 +15,10 @@ assert_subgraph_sampler_output_exists, assert_trained_model_exists, ) +from gigl.src.validation_check.libs.gbml_and_resource_config_compatibility_checks import ( + check_inferencer_graph_store_compatibility, + check_trainer_graph_store_compatibility, +) from gigl.src.validation_check.libs.name_checks import ( check_if_kfp_pipeline_job_name_valid, ) @@ -191,6 +195,79 @@ logger = Logger() +# Map of start components to graph store compatibility checks to run +# Only run trainer checks when starting at or before Trainer +# Only run inferencer checks when starting at or before Inferencer +START_COMPONENT_TO_GRAPH_STORE_COMPATIBILITY_CHECKS = { + GiGLComponents.ConfigPopulator.value: [ + check_trainer_graph_store_compatibility, + check_inferencer_graph_store_compatibility, + ], + GiGLComponents.DataPreprocessor.value: [ + check_trainer_graph_store_compatibility, + check_inferencer_graph_store_compatibility, + ], + GiGLComponents.SubgraphSampler.value: [ + check_trainer_graph_store_compatibility, + check_inferencer_graph_store_compatibility, + ], + GiGLComponents.SplitGenerator.value: [ + check_trainer_graph_store_compatibility, + check_inferencer_graph_store_compatibility, + ], + GiGLComponents.Trainer.value: [ + check_trainer_graph_store_compatibility, + check_inferencer_graph_store_compatibility, + ], + GiGLComponents.Inferencer.value: [ + check_inferencer_graph_store_compatibility, + ], + # PostProcessor doesn't need graph store compatibility checks +} + +# Map of (start, stop) component tuples to graph store compatibility checks + +STOP_COMPONENT_TO_GRAPH_STORE_COMPATIBILITY_CHECKS_TO_SKIP = { + GiGLComponents.Trainer.value: [ + check_inferencer_graph_store_compatibility, + ], +} + + +def _run_gbml_and_resource_config_compatibility_checks( + start_at: str, + stop_after: Optional[str], + gbml_config_pb_wrapper: GbmlConfigPbWrapper, + resource_config_wrapper: GiglResourceConfigWrapper, +) -> None: + """ + Run compatibility checks between GbmlConfig and GiglResourceConfig. + + These checks verify that graph store mode configurations are consistent + across both the template config (GbmlConfig) and resource config (GiglResourceConfig). + + Args: + start_at: The component to start at. + stop_after: Optional component to stop after. + gbml_config_pb_wrapper: The GbmlConfig wrapper (template config). + resource_config_wrapper: The GiglResourceConfig wrapper (resource config). + """ + # Get the appropriate compatibility checks based on start/stop components + compatibility_checks = set( + START_COMPONENT_TO_GRAPH_STORE_COMPATIBILITY_CHECKS.get(start_at, []) + ) + if stop_after in STOP_COMPONENT_TO_GRAPH_STORE_COMPATIBILITY_CHECKS_TO_SKIP: + for skipped_check in STOP_COMPONENT_TO_GRAPH_STORE_COMPATIBILITY_CHECKS_TO_SKIP[ + stop_after + ]: + compatibility_checks.discard(skipped_check) + + for check in compatibility_checks: + check( + gbml_config_pb_wrapper=gbml_config_pb_wrapper, + resource_config_wrapper=resource_config_wrapper, + ) + def kfp_validation_checks( job_name: str, @@ -261,6 +338,15 @@ def kfp_validation_checks( f"Skipping resource config check {resource_config_check.__name__} because we are using live subgraph sampling backend." ) + # check compatibility between template config and resource config for graph store mode + # These checks ensure that if graph store mode is enabled in one config, it's also enabled in the other + _run_gbml_and_resource_config_compatibility_checks( + start_at=start_at, + stop_after=stop_after, + gbml_config_pb_wrapper=gbml_config_pb_wrapper, + resource_config_wrapper=resource_config_wrapper, + ) + # check if trained model file exist when skipping training if gbml_config_pb.shared_config.should_skip_training == True: assert_trained_model_exists(gbml_config_pb=gbml_config_pb) diff --git a/python/gigl/src/validation_check/libs/gbml_and_resource_config_compatibility_checks.py b/python/gigl/src/validation_check/libs/gbml_and_resource_config_compatibility_checks.py new file mode 100644 index 000000000..10f8bfc72 --- /dev/null +++ b/python/gigl/src/validation_check/libs/gbml_and_resource_config_compatibility_checks.py @@ -0,0 +1,167 @@ +""" +Compatibility checks between GbmlConfig (template config) and GiglResourceConfig (resource config). + +These checks ensure that graph store mode configurations are consistent across both configs. +If graph store mode is set up for trainer or inferencer in one config, it must be set up in the other. +""" + +from gigl.common.logger import Logger +from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper +from gigl.src.common.types.pb_wrappers.gigl_resource_config import ( + GiglResourceConfigWrapper, +) +from snapchat.research.gbml import gigl_resource_config_pb2 + +logger = Logger() + + +def _gbml_config_has_trainer_graph_store( + gbml_config_pb_wrapper: GbmlConfigPbWrapper, +) -> bool: + """ + Check if the GbmlConfig has graph_store_storage_config set for trainer. + + Args: + gbml_config_pb_wrapper: The GbmlConfig wrapper to check. + + Returns: + True if graph_store_storage_config is set for trainer, False otherwise. + """ + trainer_config = gbml_config_pb_wrapper.gbml_config_pb.trainer_config + return trainer_config.HasField("graph_store_storage_config") + + +def _gbml_config_has_inferencer_graph_store( + gbml_config_pb_wrapper: GbmlConfigPbWrapper, +) -> bool: + """ + Check if the GbmlConfig has graph_store_storage_config set for inferencer. + + Args: + gbml_config_pb_wrapper: The GbmlConfig wrapper to check. + + Returns: + True if graph_store_storage_config is set for inferencer, False otherwise. + """ + inferencer_config = gbml_config_pb_wrapper.gbml_config_pb.inferencer_config + return inferencer_config.HasField("graph_store_storage_config") + + +def _resource_config_has_trainer_graph_store( + resource_config_wrapper: GiglResourceConfigWrapper, +) -> bool: + """ + Check if the GiglResourceConfig has VertexAiGraphStoreConfig set for trainer. + + Args: + resource_config_wrapper: The resource config wrapper to check. + + Returns: + True if VertexAiGraphStoreConfig is set for trainer, False otherwise. + """ + trainer_config = resource_config_wrapper.trainer_config + return isinstance(trainer_config, gigl_resource_config_pb2.VertexAiGraphStoreConfig) + + +def _resource_config_has_inferencer_graph_store( + resource_config_wrapper: GiglResourceConfigWrapper, +) -> bool: + """ + Check if the GiglResourceConfig has VertexAiGraphStoreConfig set for inferencer. + + Args: + resource_config_wrapper: The resource config wrapper to check. + + Returns: + True if VertexAiGraphStoreConfig is set for inferencer, False otherwise. + """ + inferencer_config = resource_config_wrapper.inferencer_config + return isinstance( + inferencer_config, gigl_resource_config_pb2.VertexAiGraphStoreConfig + ) + + +def check_trainer_graph_store_compatibility( + gbml_config_pb_wrapper: GbmlConfigPbWrapper, + resource_config_wrapper: GiglResourceConfigWrapper, +) -> None: + """ + Check that trainer graph store mode is consistently configured across both configs. + + If graph_store_storage_config is set in GbmlConfig.trainer_config, then + VertexAiGraphStoreConfig must be set in GiglResourceConfig.trainer_resource_config, + and vice versa. Also validates that storage_command is set when graph store mode is enabled. + + Args: + gbml_config_pb_wrapper: The GbmlConfig wrapper (template config). + resource_config_wrapper: The GiglResourceConfig wrapper (resource config). + + Raises: + AssertionError: If graph store configurations are not compatible or storage_command is missing. + """ + logger.info( + "Config validation check: trainer graph store compatibility between template and resource configs." + ) + + gbml_has_graph_store = _gbml_config_has_trainer_graph_store(gbml_config_pb_wrapper) + resource_has_graph_store = _resource_config_has_trainer_graph_store( + resource_config_wrapper + ) + + if gbml_has_graph_store and not resource_has_graph_store: + raise AssertionError( + "GbmlConfig.trainer_config.graph_store_storage_config is set, but " + "GiglResourceConfig.trainer_resource_config does not use VertexAiGraphStoreConfig. " + "Both configs must use graph store mode for trainer, or neither should." + ) + + if resource_has_graph_store and not gbml_has_graph_store: + raise AssertionError( + "GiglResourceConfig.trainer_resource_config uses VertexAiGraphStoreConfig, but " + "GbmlConfig.trainer_config.graph_store_storage_config is not set. " + "Both configs must use graph store mode for trainer, or neither should." + ) + + +def check_inferencer_graph_store_compatibility( + gbml_config_pb_wrapper: GbmlConfigPbWrapper, + resource_config_wrapper: GiglResourceConfigWrapper, +) -> None: + """ + Check that inferencer graph store mode is consistently configured across both configs. + + If graph_store_storage_config is set in GbmlConfig.inferencer_config, then + VertexAiGraphStoreConfig must be set in GiglResourceConfig.inferencer_resource_config, + and vice versa. Also validates that storage_command is set when graph store mode is enabled. + + Args: + gbml_config_pb_wrapper: The GbmlConfig wrapper (template config). + resource_config_wrapper: The GiglResourceConfig wrapper (resource config). + + Raises: + AssertionError: If graph store configurations are not compatible or storage_command is missing. + """ + logger.info( + "Config validation check: inferencer graph store compatibility between template and resource configs." + ) + + gbml_has_graph_store = _gbml_config_has_inferencer_graph_store( + gbml_config_pb_wrapper + ) + resource_has_graph_store = _resource_config_has_inferencer_graph_store( + resource_config_wrapper + ) + + if gbml_has_graph_store and not resource_has_graph_store: + raise AssertionError( + "GbmlConfig.inferencer_config.graph_store_storage_config is set, but " + "GiglResourceConfig.inferencer_resource_config does not use VertexAiGraphStoreConfig. " + "Both configs must use graph store mode for inferencer, or neither should." + ) + + if resource_has_graph_store and not gbml_has_graph_store: + raise AssertionError( + "GiglResourceConfig.inferencer_resource_config uses VertexAiGraphStoreConfig, but " + "GbmlConfig.inferencer_config.graph_store_storage_config is not set. " + "Both configs must use graph store mode for inferencer, or neither should." + ) diff --git a/python/gigl/src/validation_check/libs/resource_config_checks.py b/python/gigl/src/validation_check/libs/resource_config_checks.py index a60a66945..51cc60651 100644 --- a/python/gigl/src/validation_check/libs/resource_config_checks.py +++ b/python/gigl/src/validation_check/libs/resource_config_checks.py @@ -3,6 +3,7 @@ from google.cloud.aiplatform_v1.types.accelerator_type import AcceleratorType from gigl.common.logger import Logger +from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper from gigl.src.common.types.pb_wrappers.gigl_resource_config import ( GiglResourceConfigWrapper, ) @@ -254,3 +255,53 @@ def _validate_machine_config( or {gigl_resource_config_pb2.VertexAiGraphStoreConfig.__name__}. Got {type(config)}""" ) + + +def check_if_trainer_graph_store_storage_command_valid( + gbml_config_pb_wrapper: GbmlConfigPbWrapper, +) -> None: + """ + Validates that storage_command is set when graph store mode is enabled for trainer. + + Args: + gbml_config_pb_wrapper: The GbmlConfig wrapper to check. + + Raises: + AssertionError: If graph store mode is enabled but storage_command is missing. + """ + logger.info( + "Config validation check: if trainer graph store storage_command is valid." + ) + trainer_config = gbml_config_pb_wrapper.gbml_config_pb.trainer_config + if trainer_config.HasField("graph_store_storage_config"): + storage_command = trainer_config.graph_store_storage_config.storage_command + if not storage_command: + raise AssertionError( + "GbmlConfig.trainer_config.graph_store_storage_config.storage_command must be set " + "when using graph store mode for trainer." + ) + + +def check_if_inferencer_graph_store_storage_command_valid( + gbml_config_pb_wrapper: GbmlConfigPbWrapper, +) -> None: + """ + Validates that storage_command is set when graph store mode is enabled for inferencer. + + Args: + gbml_config_pb_wrapper: The GbmlConfig wrapper to check. + + Raises: + AssertionError: If graph store mode is enabled but storage_command is missing. + """ + logger.info( + "Config validation check: if inferencer graph store storage_command is valid." + ) + inferencer_config = gbml_config_pb_wrapper.gbml_config_pb.inferencer_config + if inferencer_config.HasField("graph_store_storage_config"): + storage_command = inferencer_config.graph_store_storage_config.storage_command + if not storage_command: + raise AssertionError( + "GbmlConfig.inferencer_config.graph_store_storage_config.storage_command must be set " + "when using graph store mode for inferencer." + ) diff --git a/python/tests/unit/src/validation/lib/gbml_and_resource_config_compatibility_checks_test.py b/python/tests/unit/src/validation/lib/gbml_and_resource_config_compatibility_checks_test.py new file mode 100644 index 000000000..686fab9fb --- /dev/null +++ b/python/tests/unit/src/validation/lib/gbml_and_resource_config_compatibility_checks_test.py @@ -0,0 +1,403 @@ +import unittest + +from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper +from gigl.src.common.types.pb_wrappers.gigl_resource_config import ( + GiglResourceConfigWrapper, +) +from gigl.src.validation_check.libs.gbml_and_resource_config_compatibility_checks import ( + check_inferencer_graph_store_compatibility, + check_trainer_graph_store_compatibility, +) +from snapchat.research.gbml import gbml_config_pb2, gigl_resource_config_pb2 + +# Helper functions for creating VertexAiGraphStoreConfig + + +def _create_vertex_ai_graph_store_config() -> ( + gigl_resource_config_pb2.VertexAiGraphStoreConfig +): + """Create a valid VertexAiGraphStoreConfig.""" + config = gigl_resource_config_pb2.VertexAiGraphStoreConfig() + # Graph store pool + config.graph_store_pool.machine_type = "n1-highmem-8" + config.graph_store_pool.gpu_type = "ACCELERATOR_TYPE_UNSPECIFIED" + config.graph_store_pool.gpu_limit = 0 + config.graph_store_pool.num_replicas = 2 + # Compute pool + config.compute_pool.machine_type = "n1-standard-16" + config.compute_pool.gpu_type = "NVIDIA_TESLA_T4" + config.compute_pool.gpu_limit = 2 + config.compute_pool.num_replicas = 3 + return config + + +def _create_vertex_ai_resource_config() -> ( + gigl_resource_config_pb2.VertexAiResourceConfig +): + """Create a valid VertexAiResourceConfig (non-graph store).""" + config = gigl_resource_config_pb2.VertexAiResourceConfig() + config.machine_type = "n1-standard-16" + config.gpu_type = "NVIDIA_TESLA_T4" + config.gpu_limit = 2 + config.num_replicas = 3 + return config + + +def _create_shared_resource_config( + config: gigl_resource_config_pb2.GiglResourceConfig, +) -> None: + """Populate shared resource config fields.""" + config.shared_resource_config.common_compute_config.project = "test-project" + config.shared_resource_config.common_compute_config.region = "us-central1" + config.shared_resource_config.common_compute_config.temp_assets_bucket = ( + "gs://test-temp" + ) + config.shared_resource_config.common_compute_config.temp_regional_assets_bucket = ( + "gs://test-temp-regional" + ) + config.shared_resource_config.common_compute_config.perm_assets_bucket = ( + "gs://test-perm" + ) + config.shared_resource_config.common_compute_config.temp_assets_bq_dataset_name = ( + "test_dataset" + ) + config.shared_resource_config.common_compute_config.embedding_bq_dataset_name = ( + "test_embeddings" + ) + config.shared_resource_config.common_compute_config.gcp_service_account_email = ( + "test@test-project.iam.gserviceaccount.com" + ) + config.shared_resource_config.common_compute_config.dataflow_runner = ( + "DataflowRunner" + ) + + +# Helper functions for creating GbmlConfig configurations + + +def _create_gbml_config_with_trainer_graph_store( + storage_command: str = "python -m gigl.distributed.graph_store.storage_main", +) -> GbmlConfigPbWrapper: + """Create a GbmlConfig with graph_store_storage_config set for trainer.""" + gbml_config = gbml_config_pb2.GbmlConfig() + gbml_config.trainer_config.graph_store_storage_config.storage_command = ( + storage_command + ) + return GbmlConfigPbWrapper(gbml_config_pb=gbml_config) + + +def _create_gbml_config_without_trainer_graph_store() -> GbmlConfigPbWrapper: + """Create a GbmlConfig without graph_store_storage_config for trainer.""" + gbml_config = gbml_config_pb2.GbmlConfig() + gbml_config.trainer_config.trainer_args["some_arg"] = "some_value" + return GbmlConfigPbWrapper(gbml_config_pb=gbml_config) + + +def _create_gbml_config_with_inferencer_graph_store( + storage_command: str = "python -m gigl.distributed.graph_store.storage_main", +) -> GbmlConfigPbWrapper: + """Create a GbmlConfig with graph_store_storage_config set for inferencer.""" + gbml_config = gbml_config_pb2.GbmlConfig() + gbml_config.inferencer_config.graph_store_storage_config.storage_command = ( + storage_command + ) + return GbmlConfigPbWrapper(gbml_config_pb=gbml_config) + + +def _create_gbml_config_without_inferencer_graph_store() -> GbmlConfigPbWrapper: + """Create a GbmlConfig without graph_store_storage_config for inferencer.""" + gbml_config = gbml_config_pb2.GbmlConfig() + gbml_config.inferencer_config.inferencer_args["some_arg"] = "some_value" + return GbmlConfigPbWrapper(gbml_config_pb=gbml_config) + + +def _create_gbml_config_with_both_graph_stores( + storage_command: str = "python -m gigl.distributed.graph_store.storage_main", +) -> GbmlConfigPbWrapper: + """Create a GbmlConfig with graph_store_storage_config for both trainer and inferencer.""" + gbml_config = gbml_config_pb2.GbmlConfig() + gbml_config.trainer_config.graph_store_storage_config.storage_command = ( + storage_command + ) + gbml_config.inferencer_config.graph_store_storage_config.storage_command = ( + storage_command + ) + return GbmlConfigPbWrapper(gbml_config_pb=gbml_config) + + +def _create_gbml_config_without_graph_stores() -> GbmlConfigPbWrapper: + """Create a GbmlConfig without graph_store_storage_config for both trainer and inferencer.""" + gbml_config = gbml_config_pb2.GbmlConfig() + gbml_config.trainer_config.trainer_args["some_arg"] = "some_value" + gbml_config.inferencer_config.inferencer_args["some_arg"] = "some_value" + return GbmlConfigPbWrapper(gbml_config_pb=gbml_config) + + +# Helper functions for creating GiglResourceConfig configurations + + +def _create_resource_config_with_trainer_graph_store() -> GiglResourceConfigWrapper: + """Create a GiglResourceConfig with VertexAiGraphStoreConfig for trainer.""" + config = gigl_resource_config_pb2.GiglResourceConfig() + _create_shared_resource_config(config) + + # Trainer with VertexAiGraphStoreConfig + config.trainer_resource_config.vertex_ai_graph_store_trainer_config.CopyFrom( + _create_vertex_ai_graph_store_config() + ) + # Inferencer with standard config + config.inferencer_resource_config.vertex_ai_inferencer_config.CopyFrom( + _create_vertex_ai_resource_config() + ) + return GiglResourceConfigWrapper(resource_config=config) + + +def _create_resource_config_without_trainer_graph_store() -> GiglResourceConfigWrapper: + """Create a GiglResourceConfig without VertexAiGraphStoreConfig for trainer.""" + config = gigl_resource_config_pb2.GiglResourceConfig() + _create_shared_resource_config(config) + + # Trainer with standard config + config.trainer_resource_config.vertex_ai_trainer_config.CopyFrom( + _create_vertex_ai_resource_config() + ) + # Inferencer with standard config + config.inferencer_resource_config.vertex_ai_inferencer_config.CopyFrom( + _create_vertex_ai_resource_config() + ) + return GiglResourceConfigWrapper(resource_config=config) + + +def _create_resource_config_with_inferencer_graph_store() -> GiglResourceConfigWrapper: + """Create a GiglResourceConfig with VertexAiGraphStoreConfig for inferencer.""" + config = gigl_resource_config_pb2.GiglResourceConfig() + _create_shared_resource_config(config) + + # Trainer with standard config + config.trainer_resource_config.vertex_ai_trainer_config.CopyFrom( + _create_vertex_ai_resource_config() + ) + # Inferencer with VertexAiGraphStoreConfig + config.inferencer_resource_config.vertex_ai_graph_store_inferencer_config.CopyFrom( + _create_vertex_ai_graph_store_config() + ) + return GiglResourceConfigWrapper(resource_config=config) + + +def _create_resource_config_without_inferencer_graph_store() -> ( + GiglResourceConfigWrapper +): + """Create a GiglResourceConfig without VertexAiGraphStoreConfig for inferencer.""" + return _create_resource_config_without_trainer_graph_store() + + +def _create_resource_config_with_both_graph_stores() -> GiglResourceConfigWrapper: + """Create a GiglResourceConfig with VertexAiGraphStoreConfig for both trainer and inferencer.""" + config = gigl_resource_config_pb2.GiglResourceConfig() + _create_shared_resource_config(config) + + # Trainer with VertexAiGraphStoreConfig + config.trainer_resource_config.vertex_ai_graph_store_trainer_config.CopyFrom( + _create_vertex_ai_graph_store_config() + ) + # Inferencer with VertexAiGraphStoreConfig + config.inferencer_resource_config.vertex_ai_graph_store_inferencer_config.CopyFrom( + _create_vertex_ai_graph_store_config() + ) + return GiglResourceConfigWrapper(resource_config=config) + + +def _create_resource_config_without_graph_stores() -> GiglResourceConfigWrapper: + """Create a GiglResourceConfig without VertexAiGraphStoreConfig for both trainer and inferencer.""" + return _create_resource_config_without_trainer_graph_store() + + +# Test Classes + + +class TestTrainerGraphStoreCompatibility(unittest.TestCase): + """Test suite for trainer graph store compatibility checks.""" + + def test_both_have_trainer_graph_store(self): + """Test that both configs having trainer graph store passes validation.""" + gbml_config = _create_gbml_config_with_trainer_graph_store() + resource_config = _create_resource_config_with_trainer_graph_store() + # Should not raise any exception + check_trainer_graph_store_compatibility( + gbml_config_pb_wrapper=gbml_config, + resource_config_wrapper=resource_config, + ) + + def test_neither_has_trainer_graph_store(self): + """Test that neither config having trainer graph store passes validation.""" + gbml_config = _create_gbml_config_without_trainer_graph_store() + resource_config = _create_resource_config_without_trainer_graph_store() + # Should not raise any exception + check_trainer_graph_store_compatibility( + gbml_config_pb_wrapper=gbml_config, + resource_config_wrapper=resource_config, + ) + + def test_template_has_trainer_graph_store_resource_does_not(self): + """Test that template having graph store but resource not raises an assertion error.""" + gbml_config = _create_gbml_config_with_trainer_graph_store() + resource_config = _create_resource_config_without_trainer_graph_store() + with self.assertRaises(AssertionError): + check_trainer_graph_store_compatibility( + gbml_config_pb_wrapper=gbml_config, + resource_config_wrapper=resource_config, + ) + + def test_resource_has_trainer_graph_store_template_does_not(self): + """Test that resource having graph store but template not raises an assertion error.""" + gbml_config = _create_gbml_config_without_trainer_graph_store() + resource_config = _create_resource_config_with_trainer_graph_store() + with self.assertRaises(AssertionError): + check_trainer_graph_store_compatibility( + gbml_config_pb_wrapper=gbml_config, + resource_config_wrapper=resource_config, + ) + + +class TestInferencerGraphStoreCompatibility(unittest.TestCase): + """Test suite for inferencer graph store compatibility checks.""" + + def test_both_have_inferencer_graph_store(self): + """Test that both configs having inferencer graph store passes validation.""" + gbml_config = _create_gbml_config_with_inferencer_graph_store() + resource_config = _create_resource_config_with_inferencer_graph_store() + # Should not raise any exception + check_inferencer_graph_store_compatibility( + gbml_config_pb_wrapper=gbml_config, + resource_config_wrapper=resource_config, + ) + + def test_neither_has_inferencer_graph_store(self): + """Test that neither config having inferencer graph store passes validation.""" + gbml_config = _create_gbml_config_without_inferencer_graph_store() + resource_config = _create_resource_config_without_inferencer_graph_store() + # Should not raise any exception + check_inferencer_graph_store_compatibility( + gbml_config_pb_wrapper=gbml_config, + resource_config_wrapper=resource_config, + ) + + def test_template_has_inferencer_graph_store_resource_does_not(self): + """Test that template having graph store but resource not raises an assertion error.""" + gbml_config = _create_gbml_config_with_inferencer_graph_store() + resource_config = _create_resource_config_without_inferencer_graph_store() + with self.assertRaises(AssertionError): + check_inferencer_graph_store_compatibility( + gbml_config_pb_wrapper=gbml_config, + resource_config_wrapper=resource_config, + ) + + def test_resource_has_inferencer_graph_store_template_does_not(self): + """Test that resource having graph store but template not raises an assertion error.""" + gbml_config = _create_gbml_config_without_inferencer_graph_store() + resource_config = _create_resource_config_with_inferencer_graph_store() + with self.assertRaises(AssertionError): + check_inferencer_graph_store_compatibility( + gbml_config_pb_wrapper=gbml_config, + resource_config_wrapper=resource_config, + ) + + +class TestMixedGraphStoreConfigurations(unittest.TestCase): + """Test suite for mixed graph store configuration scenarios.""" + + def test_both_have_all_graph_stores(self): + """Test that both configs having all graph stores passes validation.""" + gbml_config = _create_gbml_config_with_both_graph_stores() + resource_config = _create_resource_config_with_both_graph_stores() + # Should not raise any exception for trainer + check_trainer_graph_store_compatibility( + gbml_config_pb_wrapper=gbml_config, + resource_config_wrapper=resource_config, + ) + # Should not raise any exception for inferencer + check_inferencer_graph_store_compatibility( + gbml_config_pb_wrapper=gbml_config, + resource_config_wrapper=resource_config, + ) + + def test_neither_has_any_graph_stores(self): + """Test that neither config having any graph stores passes validation.""" + gbml_config = _create_gbml_config_without_graph_stores() + resource_config = _create_resource_config_without_graph_stores() + # Should not raise any exception for trainer + check_trainer_graph_store_compatibility( + gbml_config_pb_wrapper=gbml_config, + resource_config_wrapper=resource_config, + ) + # Should not raise any exception for inferencer + check_inferencer_graph_store_compatibility( + gbml_config_pb_wrapper=gbml_config, + resource_config_wrapper=resource_config, + ) + + def test_trainer_graph_store_only_compatible(self): + """Test trainer graph store only configuration is compatible.""" + gbml_config = _create_gbml_config_with_trainer_graph_store() + resource_config = _create_resource_config_with_trainer_graph_store() + # Should not raise any exception for trainer + check_trainer_graph_store_compatibility( + gbml_config_pb_wrapper=gbml_config, + resource_config_wrapper=resource_config, + ) + # Should not raise any exception for inferencer (neither has it) + check_inferencer_graph_store_compatibility( + gbml_config_pb_wrapper=gbml_config, + resource_config_wrapper=resource_config, + ) + + def test_inferencer_graph_store_only_compatible(self): + """Test inferencer graph store only configuration is compatible.""" + gbml_config = _create_gbml_config_with_inferencer_graph_store() + resource_config = _create_resource_config_with_inferencer_graph_store() + # Should not raise any exception for trainer (neither has it) + check_trainer_graph_store_compatibility( + gbml_config_pb_wrapper=gbml_config, + resource_config_wrapper=resource_config, + ) + # Should not raise any exception for inferencer + check_inferencer_graph_store_compatibility( + gbml_config_pb_wrapper=gbml_config, + resource_config_wrapper=resource_config, + ) + + def test_template_has_both_resource_has_trainer_only(self): + """Test that template having both but resource having only trainer raises an error for inferencer.""" + gbml_config = _create_gbml_config_with_both_graph_stores() + resource_config = _create_resource_config_with_trainer_graph_store() + # Should not raise any exception for trainer + check_trainer_graph_store_compatibility( + gbml_config_pb_wrapper=gbml_config, + resource_config_wrapper=resource_config, + ) + # Should raise an assertion error for inferencer + with self.assertRaises(AssertionError): + check_inferencer_graph_store_compatibility( + gbml_config_pb_wrapper=gbml_config, + resource_config_wrapper=resource_config, + ) + + def test_template_has_both_resource_has_inferencer_only(self): + """Test that template having both but resource having only inferencer raises an error for trainer.""" + gbml_config = _create_gbml_config_with_both_graph_stores() + resource_config = _create_resource_config_with_inferencer_graph_store() + # Should raise an assertion error for trainer + with self.assertRaises(AssertionError): + check_trainer_graph_store_compatibility( + gbml_config_pb_wrapper=gbml_config, + resource_config_wrapper=resource_config, + ) + # Should not raise any exception for inferencer + check_inferencer_graph_store_compatibility( + gbml_config_pb_wrapper=gbml_config, + resource_config_wrapper=resource_config, + ) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/tests/unit/src/validation/lib/resource_config_checks_test.py b/python/tests/unit/src/validation/lib/resource_config_checks_test.py index 20b941f72..3bddf1269 100644 --- a/python/tests/unit/src/validation/lib/resource_config_checks_test.py +++ b/python/tests/unit/src/validation/lib/resource_config_checks_test.py @@ -1,18 +1,21 @@ import unittest +from gigl.src.common.types.pb_wrappers.gbml_config import GbmlConfigPbWrapper from gigl.src.validation_check.libs.resource_config_checks import ( _check_if_dataflow_resource_config_valid, _check_if_spark_resource_config_valid, _validate_accelerator_type, _validate_machine_config, + check_if_inferencer_graph_store_storage_command_valid, check_if_inferencer_resource_config_valid, check_if_preprocessor_resource_config_valid, check_if_shared_resource_config_valid, check_if_split_generator_resource_config_valid, check_if_subgraph_sampler_resource_config_valid, + check_if_trainer_graph_store_storage_command_valid, check_if_trainer_resource_config_valid, ) -from snapchat.research.gbml import gigl_resource_config_pb2 +from snapchat.research.gbml import gbml_config_pb2, gigl_resource_config_pb2 # Helper functions for creating valid configurations @@ -750,5 +753,90 @@ def test_valid_vertex_ai_graph_store_config(self): _validate_machine_config(config) +# Helper functions for creating GbmlConfig configurations + + +def _create_gbml_config_with_trainer_graph_store( + storage_command: str = "python -m gigl.distributed.graph_store.storage_main", +) -> GbmlConfigPbWrapper: + """Create a GbmlConfig with graph_store_storage_config set for trainer.""" + gbml_config = gbml_config_pb2.GbmlConfig() + gbml_config.trainer_config.graph_store_storage_config.storage_command = ( + storage_command + ) + return GbmlConfigPbWrapper(gbml_config_pb=gbml_config) + + +def _create_gbml_config_without_trainer_graph_store() -> GbmlConfigPbWrapper: + """Create a GbmlConfig without graph_store_storage_config for trainer.""" + gbml_config = gbml_config_pb2.GbmlConfig() + gbml_config.trainer_config.trainer_args["some_arg"] = "some_value" + return GbmlConfigPbWrapper(gbml_config_pb=gbml_config) + + +def _create_gbml_config_with_inferencer_graph_store( + storage_command: str = "python -m gigl.distributed.graph_store.storage_main", +) -> GbmlConfigPbWrapper: + """Create a GbmlConfig with graph_store_storage_config set for inferencer.""" + gbml_config = gbml_config_pb2.GbmlConfig() + gbml_config.inferencer_config.graph_store_storage_config.storage_command = ( + storage_command + ) + return GbmlConfigPbWrapper(gbml_config_pb=gbml_config) + + +def _create_gbml_config_without_inferencer_graph_store() -> GbmlConfigPbWrapper: + """Create a GbmlConfig without graph_store_storage_config for inferencer.""" + gbml_config = gbml_config_pb2.GbmlConfig() + gbml_config.inferencer_config.inferencer_args["some_arg"] = "some_value" + return GbmlConfigPbWrapper(gbml_config_pb=gbml_config) + + +class TestTrainerGraphStoreStorageCommand(unittest.TestCase): + """Test suite for trainer graph store storage_command validation.""" + + def test_valid_storage_command(self): + """Test that a valid storage_command passes validation.""" + gbml_config = _create_gbml_config_with_trainer_graph_store() + # Should not raise any exception + check_if_trainer_graph_store_storage_command_valid(gbml_config) + + def test_missing_storage_command(self): + """Test that missing storage_command raises an assertion error.""" + gbml_config = _create_gbml_config_with_trainer_graph_store(storage_command="") + with self.assertRaises(AssertionError): + check_if_trainer_graph_store_storage_command_valid(gbml_config) + + def test_no_graph_store_config(self): + """Test that no graph store config passes validation (nothing to check).""" + gbml_config = _create_gbml_config_without_trainer_graph_store() + # Should not raise any exception - no graph store means nothing to validate + check_if_trainer_graph_store_storage_command_valid(gbml_config) + + +class TestInferencerGraphStoreStorageCommand(unittest.TestCase): + """Test suite for inferencer graph store storage_command validation.""" + + def test_valid_storage_command(self): + """Test that a valid storage_command passes validation.""" + gbml_config = _create_gbml_config_with_inferencer_graph_store() + # Should not raise any exception + check_if_inferencer_graph_store_storage_command_valid(gbml_config) + + def test_missing_storage_command(self): + """Test that missing storage_command raises an assertion error.""" + gbml_config = _create_gbml_config_with_inferencer_graph_store( + storage_command="" + ) + with self.assertRaises(AssertionError): + check_if_inferencer_graph_store_storage_command_valid(gbml_config) + + def test_no_graph_store_config(self): + """Test that no graph store config passes validation (nothing to check).""" + gbml_config = _create_gbml_config_without_inferencer_graph_store() + # Should not raise any exception - no graph store means nothing to validate + check_if_inferencer_graph_store_storage_command_valid(gbml_config) + + if __name__ == "__main__": unittest.main()