diff --git a/fs_mol/modules/gnn.py b/fs_mol/modules/gnn.py index 060643e8..858e3d75 100644 --- a/fs_mol/modules/gnn.py +++ b/fs_mol/modules/gnn.py @@ -61,6 +61,8 @@ def add_gnn_model_arguments(parser: argparse.ArgumentParser): parser.add_argument( "--num_gnn_layers", type=int, default=10, help="Number of GNN layers to use." ) + parser.add_argument("--skip-node-embedding", action="store_true", + help="Skip learning a per-node embedding. Input dim then equals model dim.") def make_gnn_config_from_args(args: argparse.Namespace) -> GNNConfig: diff --git a/fs_mol/modules/graph_feature_extractor.py b/fs_mol/modules/graph_feature_extractor.py index d3f1d402..689cfadc 100644 --- a/fs_mol/modules/graph_feature_extractor.py +++ b/fs_mol/modules/graph_feature_extractor.py @@ -23,6 +23,7 @@ class GraphFeatureExtractorConfig: gnn_config: GNNConfig = GNNConfig() readout_config: GraphReadoutConfig = GraphReadoutConfig() output_norm: Literal["off", "layer", "batch"] = "off" + skip_node_proj: bool = False def add_graph_feature_extractor_arguments(parser: argparse.ArgumentParser): @@ -37,6 +38,7 @@ def make_graph_feature_extractor_config_from_args( initial_node_feature_dim=initial_node_feature_dim, gnn_config=make_gnn_config_from_args(args), readout_config=make_graph_readout_config_from_args(args), + skip_node_proj=args.skip_node_embedding, ) @@ -45,10 +47,13 @@ def __init__(self, config: GraphFeatureExtractorConfig): super().__init__() self.config = config - # Initial (per-node) layers: - self.init_node_proj = nn.Linear( - config.initial_node_feature_dim, config.gnn_config.hidden_dim, bias=False - ) + # Learn a per-node node embedding or skip this layer. + if config.skip_node_proj: + self.init_node_proj = nn.Identity(config.gnn_config.hidden_dim, config.gnn_config.hidden_dim) + else: + self.init_node_proj = nn.Linear( + config.initial_node_feature_dim, config.gnn_config.hidden_dim, bias=False + ) self.gnn = GNN(self.config.gnn_config)