Skip to content

Commit 08daf77

Browse files
committed
Fix init_weights crash with aliased buffers (e.g. rope.cache/freqs_cis)
When a model registers the same tensor under multiple FQNs (e.g. `rope.cache` and `freqs_cis` in torchtitan's Decoder), PyTorch's `named_buffers()` deduplicates them by default, and this deduplication also happens during AOTAutograd tracing. As a result, the AutoParallelModule was missing the alias FQNs entirely. This must be fixed in autoparallel (not deep in PyTorch) because autoparallel returns an nn.Module that the user should be able to use like their original model -- all buffer FQNs from the original model must be present on the returned module. The fix has two parts: (1) in api.py, capture buffer alias info before `move_to_fake` destroys aliasing, then re-register aliases on the parallel module after sharding; (2) in init_weights.py, skip hooking buffer FQNs that don't exist on the parallel model (aliases that were deduplicated). Authored with Claude. stack-info: PR: #321, branch: xmfan/stack/24
1 parent b5a020d commit 08daf77

4 files changed

Lines changed: 132 additions & 1 deletion

File tree

CLAUDE.md

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,3 +72,34 @@ python -m pytest tests/
7272
- Leverages DTensor for distributed tensor operations
7373
- Uses linear programming (PuLP) to solve sharding optimization problems
7474
- Includes fake tensor mode for shape inference without actual computation
75+
76+
# Commit messages
77+
78+
Don't commit unless the user explicitly asks you to.
79+
80+
When writing a commit message, don't make a bullet list of the individual
81+
changes. Instead, if the PR is large, explain the order to review changes
82+
(e.g., the logical progression), or if it's short just omit the bullet list
83+
entirely.
84+
85+
Disclose that the PR was authored with Claude.
86+
87+
# Coding Style Guidelines
88+
89+
Follow these rules for all code changes in this repository:
90+
91+
- Minimize comments; be concise; code should be self-explanatory and self-documenting.
92+
- Comments should be useful, for example, comments that remind the reader about
93+
some global context that is non-obvious and can't be inferred locally.
94+
- Don't make trivial (1-2 LOC) helper functions that are only used once unless
95+
it significantly improves code readability.
96+
- Prefer clear abstractions. State management should be explicit.
97+
For example, if managing state in a Python class: there should be a clear
98+
class definition that has all of the members: don't dynamically `setattr`
99+
a field on an object and then dynamically `getattr` the field on the object.
100+
- Match existing code style and architectural patterns.
101+
- Assume the reader has familiarity with PyTorch. They may not be the expert
102+
on the code that is being read, but they should have some experience in the
103+
area.
104+
105+
If uncertain, choose the simpler, more concise implementation.

autoparallel/api.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -251,6 +251,19 @@ def __init__(
251251
# in dtype casting and move_to_fake
252252
model = copy.deepcopy(model)
253253

254+
# Capture buffer alias info before move_to_fake breaks aliasing.
255+
# named_buffers() deduplicates by default, so aliases are dropped.
256+
# We record alias_fqn -> canonical_fqn so we can re-register them later.
257+
self._buffer_alias_map: dict[str, str] = {}
258+
canonical_by_id: dict[int, str] = {}
259+
canonical_fqns: set[str] = set()
260+
for fqn, buf in model.named_buffers():
261+
canonical_by_id[id(buf)] = fqn
262+
canonical_fqns.add(fqn)
263+
for fqn, buf in model.named_buffers(remove_duplicate=False):
264+
if fqn not in canonical_fqns and id(buf) in canonical_by_id:
265+
self._buffer_alias_map[fqn] = canonical_by_id[id(buf)]
266+
254267
# keep a separate copy of the fake orig model to customize for supporting init_weights
255268
self.init_weights_model = move_to_fake(
256269
copy.deepcopy(model), self.fake_mode, device
@@ -579,6 +592,20 @@ def _register_params_and_init_weights(
579592
attr_kind=_AttrKind.BUFFER,
580593
)
581594

595+
# Register aliased buffers that were deduplicated during tracing.
596+
# e.g. if the original model has rope.cache and freqs_cis pointing to
597+
# the same tensor, only one survives in sharded_buffer_dict. We register
598+
# the missing alias so the parallel model mirrors the original structure.
599+
for alias_fqn, canonical_fqn in self._buffer_alias_map.items():
600+
if canonical_fqn in sharded_buffer_dict:
601+
_assign_attr(
602+
self.parallel_model.get_buffer(canonical_fqn),
603+
self.parallel_model,
604+
self.model,
605+
alias_fqn,
606+
attr_kind=_AttrKind.BUFFER,
607+
)
608+
582609
# Right now we require a convention that the user model provides an init_weights method,
583610
# although we could snoop for other methods too.
584611
hook_params_setters(self.init_weights_model, self.parallel_model)

autoparallel/init_weights.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,8 @@ def hook_params_setters(
122122
Also wraps init_weights_model.init_weights (if present) with a TorchDispatchMode
123123
to handle in-place data operations like ``self.weight.data[:] = value``.
124124
"""
125+
parallel_buffer_fqns = set(n for n, _ in parallel_model.named_buffers())
126+
125127
for mod_name, mod in sorted(init_weights_model.named_modules()):
126128
params_dict = dict(mod.named_parameters(recurse=False))
127129
buffers_dict = dict(mod.named_buffers(recurse=False))
@@ -132,7 +134,13 @@ def hook_params_setters(
132134
namespace[p_name] = _build_param_property(parallel_model, fqn)
133135

134136
for b_name in buffers_dict:
135-
fqn = mod_name + "." + b_name
137+
fqn = f"{mod_name}.{b_name}" if mod_name else b_name
138+
# Skip buffers not present on the parallel model. This happens when
139+
# the original model has aliased buffers (e.g. rope.cache and freqs_cis
140+
# point to the same tensor): named_buffers() deduplicates them so only
141+
# one FQN is registered on the parallel model.
142+
if fqn not in parallel_buffer_fqns:
143+
continue
136144
namespace[b_name] = _build_buffer_property(parallel_model, fqn)
137145

138146
cls = mod.__class__

tests/test_api.py

Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,71 @@ def input_fn():
168168
)
169169

170170

171+
def test_init_aliased_buffers(device_mesh_1d):
172+
"""Test that init_weights works when a submodule buffer aliases a top-level buffer.
173+
174+
This mirrors the torchtitan Decoder pattern where rope.cache and freqs_cis
175+
are the same tensor. named_buffers(remove_duplicate=True) deduplicates them,
176+
so only freqs_cis ends up on the parallel model. The init_weights hook must
177+
still correctly propagate values set via the aliased buffer (rope.cache).
178+
"""
179+
dim = 128
180+
181+
class RoPE(nn.Module):
182+
def __init__(self, dim):
183+
super().__init__()
184+
self.register_buffer("cache", torch.zeros(dim), persistent=False)
185+
186+
def forward(self, x):
187+
return x + self.cache
188+
189+
def init_weights(self):
190+
self.cache = torch.arange(dim).float()
191+
192+
class Model(nn.Module):
193+
def __init__(self, dim):
194+
super().__init__()
195+
self.linear = nn.Linear(dim, dim)
196+
self.rope = RoPE(dim)
197+
self.register_buffer("freqs_cis", self.rope.cache, persistent=False)
198+
199+
def forward(self, x):
200+
return self.linear(x) + self.freqs_cis
201+
202+
def init_weights(self):
203+
with torch.no_grad():
204+
self.linear.weight.fill_(1.0)
205+
self.linear.bias.fill_(0.0)
206+
self.rope.init_weights()
207+
self.freqs_cis = self.rope.cache
208+
209+
def input_fn():
210+
b = 512
211+
inputs = (torch.rand(b, dim, device="cuda"),)
212+
return inputs
213+
214+
with torch.device("meta"):
215+
model = Model(dim)
216+
217+
assert model.freqs_cis is model.rope.cache
218+
219+
with AutoParallel(
220+
model,
221+
input_fn,
222+
device_mesh_1d,
223+
) as autop:
224+
x_sharding = (Shard(0),)
225+
autop.add_input_constraints([x_sharding])
226+
sharding_placement = autop.optimize_placement()
227+
parallel_mod = autop.apply_placement(sharding_placement)
228+
229+
parallel_mod.to_empty(device="cuda")
230+
parallel_mod.init_weights()
231+
232+
expected = torch.arange(dim).float().cuda()
233+
assert torch.equal(parallel_mod.get_buffer("freqs_cis").full_tensor(), expected)
234+
235+
171236
def test_fx_graph_annotate(device_mesh_1d):
172237
dim = 128
173238

0 commit comments

Comments
 (0)