From 1ab86996745ad40c56b2a955b366cd1cfd2121a6 Mon Sep 17 00:00:00 2001 From: samwaltonnorwood Date: Tue, 27 Jun 2023 17:23:58 +0200 Subject: [PATCH] Identity skip + args for edge gated block --- mace/blocks.py | 8 ++++++++ mace/irreps_tools.py | 14 ++++++++++++++ 2 files changed, 22 insertions(+) diff --git a/mace/blocks.py b/mace/blocks.py index 5bab5f8..2aeca2a 100644 --- a/mace/blocks.py +++ b/mace/blocks.py @@ -90,6 +90,10 @@ def __init__( hidden_irreps: o3.Irreps, avg_num_neighbors: float, rbf_hidden_channels: int = 64, + edge_gates_irreps: o3.Irreps = o3.Irreps("0e"), + num_gates: int = 4, + multi_conv: bool = False, + exponential: bool = True, ) -> None: super().__init__() self.node_attrs_irreps = node_attrs_irreps @@ -100,6 +104,10 @@ def __init__( self.hidden_irreps = hidden_irreps self.avg_num_neighbors = avg_num_neighbors self.rbf_hidden_channels = rbf_hidden_channels + self.edge_gates_irreps = edge_gates_irreps + self.num_gates = num_gates + self.multi_conv = multi_conv + self.exponential = exponential self._setup() diff --git a/mace/irreps_tools.py b/mace/irreps_tools.py index ff020e6..5bcd453 100644 --- a/mace/irreps_tools.py +++ b/mace/irreps_tools.py @@ -84,3 +84,17 @@ def forward(self, tensor: torch.Tensor) -> torch.Tensor: field = field.reshape(batch, mul, d) out.append(field) return torch.cat(out, dim=-1) + + +@compile_mode("script") +class lifted_skip(torch.nn.Module): + def __init__(self, irreps_in: o3.Irreps, irreps_out: o3.Irreps) -> None: + super().__init__() + self.irreps_in = o3.Irreps(irreps_in) + self.irreps_out = o3.Irreps(irreps_out) + + def forward(self, tensor: torch.Tensor) -> torch.Tensor: + batch, _ = tensor.shape + template = torch.zeros(batch, self.irreps_out.dim, device=tensor.device) + template[:, 0 : self.irreps_in.dim] = tensor + return template