99
1010
1111def register_atomic_subgraph(
12- is_aten : bool = False,
12+ is_core_aten : bool = False,
1313) -> Callable[[torch.nn.Module], torch.nn.Module]:
1414
1515 def decorator(subgraph: torch.nn.Module) -> torch.nn.Module:
16- ATOMIC_SUBGRAPHS.append((subgraph, is_aten ))
16+ ATOMIC_SUBGRAPHS.append((subgraph, is_core_aten ))
1717 return subgraph
1818
1919 return decorator
2020
2121
22- @register_atomic_subgraph(is_aten =True)
22+ @register_atomic_subgraph(is_core_aten =True)
2323class ConvBNReLU(torch.nn.Module): # type: ignore[misc]
2424 def __init__(self) -> None:
2525 super().__init__()
@@ -60,7 +60,7 @@ def forward(
6060 return x
6161
6262
63- @register_atomic_subgraph(is_aten =True)
63+ @register_atomic_subgraph(is_core_aten =True)
6464class ConvReLU(torch.nn.Module): # type: ignore[misc]
6565 def __init__(self) -> None:
6666 super().__init__()
@@ -92,7 +92,7 @@ def forward(
9292 return x
9393
9494
95- @register_atomic_subgraph(is_aten =True)
95+ @register_atomic_subgraph(is_core_aten =True)
9696class ConvGelu(torch.nn.Module): # type: ignore[misc]
9797 def __init__(self) -> None:
9898 super().__init__()
@@ -124,7 +124,7 @@ def forward(
124124 return x
125125
126126
127- @register_atomic_subgraph(is_aten =True)
127+ @register_atomic_subgraph(is_core_aten =True)
128128class ConvSilu(torch.nn.Module): # type: ignore[misc]
129129 def __init__(self) -> None:
130130 super().__init__()
@@ -139,7 +139,7 @@ def forward(
139139 return x
140140
141141
142- @register_atomic_subgraph(is_aten =True)
142+ @register_atomic_subgraph(is_core_aten =True)
143143class MulAdd(torch.nn.Module): # type: ignore[misc]
144144 def __init__(self) -> None:
145145 super().__init__()
@@ -152,7 +152,7 @@ def forward(
152152 return x
153153
154154
155- @register_atomic_subgraph(is_aten =True)
155+ @register_atomic_subgraph(is_core_aten =True)
156156class MulMul(torch.nn.Module): # type: ignore[misc]
157157 def __init__(self) -> None:
158158 super().__init__()
@@ -192,19 +192,30 @@ def get_node_in_fusion_pattern(
192192 return fusion_nodes
193193
194194
195- @lru_cache(maxsize=None)
196195def get_compiled_atomic_subgraphs() -> List[torch.fx.GraphModule]:
197196 """
198197 This function gets the compiled atomic subgraphs from the graph.
199198 LRU cache the result to avoid recompiling the same pattern multiple times.
200199 """
201200 compiled_atomic_subgraphs = []
202- for pattern, is_aten in ATOMIC_SUBGRAPHS:
203- pattern_graph = torch.fx.symbolic_trace (pattern() )
204- if not is_aten :
205- # TODO: Add decomposition and lowering if is_aten is False
201+ for pattern, is_core_aten in ATOMIC_SUBGRAPHS:
202+ pattern_graph = trace_atomic_graph (pattern, is_core_aten )
203+ if not is_core_aten :
204+ # TODO: Add decomposition and lowering if is_core_aten is False
206205 raise NotImplementedError(
207206 "Atomic subgraphs are not supported for non-aten subgraphs yet."
208207 )
209208 compiled_atomic_subgraphs.append(pattern_graph)
210209 return compiled_atomic_subgraphs
210+
211+
212+ @lru_cache(maxsize=None)
213+ def trace_atomic_graph(
214+ graph: torch.nn.Module, is_core_aten: bool = True
215+ ) -> torch.fx.GraphModule:
216+ if is_core_aten:
217+ return torch.fx.symbolic_trace(graph())
218+ else:
219+ raise NotImplementedError(
220+ "Resource partitioner currently does not support unlowered atomic subgraphs"
221+ )
0 commit comments