Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions CLAUDE.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,3 +72,34 @@ python -m pytest tests/
- Leverages DTensor for distributed tensor operations
- Uses linear programming (PuLP) to solve sharding optimization problems
- Includes fake tensor mode for shape inference without actual computation

# Commit messages

Don't commit unless the user explicitly asks you to.

When writing a commit message, don't make a bullet list of the individual
changes. Instead, if the PR is large, explain the order to review changes
(e.g., the logical progression), or if it's short just omit the bullet list
entirely.

Disclose that the PR was authored with Claude.

# Coding Style Guidelines

Follow these rules for all code changes in this repository:

- Minimize comments; be concise; code should be self-explanatory and self-documenting.
- Comments should be useful, for example, comments that remind the reader about
some global context that is non-obvious and can't be inferred locally.
- Don't make trivial (1-2 LOC) helper functions that are only used once unless
it significantly improves code readability.
- Prefer clear abstractions. State management should be explicit.
For example, if managing state in a Python class: there should be a clear
class definition that has all of the members: don't dynamically `setattr`
a field on an object and then dynamically `getattr` the field on the object.
- Match existing code style and architectural patterns.
- Assume the reader has familiarity with PyTorch. They may not be the expert
on the code that is being read, but they should have some experience in the
area.

If uncertain, choose the simpler, more concise implementation.
51 changes: 51 additions & 0 deletions autoparallel/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,27 @@
_APPLY_VIEW_MM_VIEW_PATTERN = False


def _build_alias_map(
named_iter_fn: Callable[..., Any],
) -> dict[str, str]:
"""Build a mapping from alias FQNs to canonical FQNs.

named_parameters()/named_buffers() deduplicate by default, so when a model
registers the same tensor under multiple FQNs only one survives. This
function detects the aliases so they can be re-registered later.
"""
canonical_by_id: dict[int, str] = {}
canonical_fqns: set[str] = set()
for fqn, tensor in named_iter_fn():
canonical_by_id[id(tensor)] = fqn
canonical_fqns.add(fqn)
alias_map: dict[str, str] = {}
for fqn, tensor in named_iter_fn(remove_duplicate=False):
if fqn not in canonical_fqns and id(tensor) in canonical_by_id:
alias_map[fqn] = canonical_by_id[id(tensor)]
return alias_map


def _assign_attr(
attr: Any,
target_module: torch.nn.Module,
Expand Down Expand Up @@ -251,6 +272,13 @@ def __init__(
# in dtype casting and move_to_fake
model = copy.deepcopy(model)

# Capture parameter and buffer alias info before move_to_fake breaks
# aliasing. named_parameters()/named_buffers() deduplicate by default,
# so aliases are dropped. We record alias_fqn -> canonical_fqn so we
# can re-register them later.
self._param_alias_map = _build_alias_map(model.named_parameters)
self._buffer_alias_map = _build_alias_map(model.named_buffers)

# keep a separate copy of the fake orig model to customize for supporting init_weights
self.init_weights_model = move_to_fake(
copy.deepcopy(model), self.fake_mode, device
Expand Down Expand Up @@ -579,6 +607,29 @@ def _register_params_and_init_weights(
attr_kind=_AttrKind.BUFFER,
)

# Register aliased params/buffers that were deduplicated during tracing.
# e.g. if the original model has rope.cache and freqs_cis pointing to
# the same tensor, only one survives in the sharded dict. We register
# the missing alias so the parallel model mirrors the original structure.
for alias_fqn, canonical_fqn in self._param_alias_map.items():
if canonical_fqn in sharded_param_dict:
_assign_attr(
self.parallel_model.get_parameter(canonical_fqn),
self.parallel_model,
self.model,
alias_fqn,
attr_kind=_AttrKind.PARAMETER,
)
for alias_fqn, canonical_fqn in self._buffer_alias_map.items():
if canonical_fqn in sharded_buffer_dict:
_assign_attr(
self.parallel_model.get_buffer(canonical_fqn),
self.parallel_model,
self.model,
alias_fqn,
attr_kind=_AttrKind.BUFFER,
)

# Right now we require a convention that the user model provides an init_weights method,
# although we could snoop for other methods too.
hook_params_setters(self.init_weights_model, self.parallel_model)
Expand Down
13 changes: 11 additions & 2 deletions autoparallel/init_weights.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,17 +122,26 @@ def hook_params_setters(
Also wraps init_weights_model.init_weights (if present) with a TorchDispatchMode
to handle in-place data operations like ``self.weight.data[:] = value``.
"""
parallel_param_fqns = set(n for n, _ in parallel_model.named_parameters())
parallel_buffer_fqns = set(n for n, _ in parallel_model.named_buffers())

for mod_name, mod in sorted(init_weights_model.named_modules()):
params_dict = dict(mod.named_parameters(recurse=False))
buffers_dict = dict(mod.named_buffers(recurse=False))

namespace = {}
for p_name in params_dict:
fqn = mod_name + "." + p_name
fqn = f"{mod_name}.{p_name}" if mod_name else p_name
# Skip aliased parameters not present on the parallel model.
if fqn not in parallel_param_fqns:
continue
namespace[p_name] = _build_param_property(parallel_model, fqn)

for b_name in buffers_dict:
fqn = mod_name + "." + b_name
fqn = f"{mod_name}.{b_name}" if mod_name else b_name
# Skip aliased buffers not present on the parallel model.
if fqn not in parallel_buffer_fqns:
continue
namespace[b_name] = _build_buffer_property(parallel_model, fqn)

cls = mod.__class__
Expand Down
121 changes: 121 additions & 0 deletions tests/test_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,127 @@ def input_fn():
)


def test_init_aliased_buffers(device_mesh_1d):
"""Test that init_weights works when a submodule buffer aliases a top-level buffer.

This mirrors the torchtitan Decoder pattern where rope.cache and freqs_cis
are the same tensor. named_buffers(remove_duplicate=True) deduplicates them,
so only freqs_cis ends up on the parallel model. The init_weights hook must
still correctly propagate values set via the aliased buffer (rope.cache).
"""
dim = 128

class RoPE(nn.Module):
def __init__(self, dim):
super().__init__()
self.register_buffer("cache", torch.zeros(dim), persistent=False)

def forward(self, x):
return x + self.cache

def init_weights(self):
self.cache = torch.arange(dim).float()

class Model(nn.Module):
def __init__(self, dim):
super().__init__()
self.linear = nn.Linear(dim, dim)
self.rope = RoPE(dim)
self.register_buffer("freqs_cis", self.rope.cache, persistent=False)

def forward(self, x):
return self.linear(x) + self.freqs_cis

def init_weights(self):
with torch.no_grad():
self.linear.weight.fill_(1.0)
self.linear.bias.fill_(0.0)
self.rope.init_weights()
self.freqs_cis = self.rope.cache

def input_fn():
b = 512
inputs = (torch.rand(b, dim, device="cuda"),)
return inputs

with torch.device("meta"):
model = Model(dim)

assert model.freqs_cis is model.rope.cache

with AutoParallel(
model,
input_fn,
device_mesh_1d,
) as autop:
x_sharding = (Shard(0),)
autop.add_input_constraints([x_sharding])
sharding_placement = autop.optimize_placement()
parallel_mod = autop.apply_placement(sharding_placement)

parallel_mod.to_empty(device="cuda")
parallel_mod.init_weights()

expected = torch.arange(dim).float().cuda()
assert torch.equal(parallel_mod.get_buffer("freqs_cis").full_tensor(), expected)


def test_init_aliased_parameters(device_mesh_1d):
"""Test that init_weights works when a parameter is registered under two FQNs.

This mirrors weight tying in LLMs where embed.weight and lm_head.weight
are the same parameter. named_parameters() deduplicates them, so the alias
FQN is missing from the parallel model. The init_weights hook must not
crash on the missing alias.
"""
dim = 128

class Model(nn.Module):
def __init__(self, dim):
super().__init__()
self.embed = nn.Linear(dim, dim, bias=False)
# Weight tying: lm_head.weight aliases embed.weight.
# named_parameters() yields embed.weight first (canonical),
# lm_head.weight is the alias. Forward only uses embed.
self.lm_head = nn.Linear(dim, dim, bias=False)
self.lm_head.weight = self.embed.weight

def forward(self, x):
return self.embed(x)

def init_weights(self):
with torch.no_grad():
self.embed.weight.fill_(1.0)

def input_fn():
b = 512
inputs = (torch.rand(b, dim, device="cuda"),)
return inputs

with torch.device("meta"):
model = Model(dim)

assert model.lm_head.weight is model.embed.weight

with AutoParallel(
model,
input_fn,
device_mesh_1d,
) as autop:
x_sharding = (Shard(0),)
autop.add_input_constraints([x_sharding])
sharding_placement = autop.optimize_placement()
parallel_mod = autop.apply_placement(sharding_placement)

parallel_mod.to_empty(device="cuda")
parallel_mod.init_weights()

expected = torch.ones(dim, dim, device="cuda")
assert torch.equal(
parallel_mod.get_parameter("embed.weight").full_tensor(), expected
)


def test_fx_graph_annotate(device_mesh_1d):
dim = 128

Expand Down
Loading