Skip to content

Commit d1ac27a

Browse files
authored
fix(merger): handle non-sharded tensors in FSDP2 checkpoint merging (#155)
Non-sharded buffers (e.g. inv_freq) are stored as plain Tensors rather than DTensors, causing AttributeError on _local_tensor access. Now falls back to using the tensor directly, and deduplicates identical copies across ranks instead of concatenating them.
1 parent 8260c77 commit d1ac27a

1 file changed

Lines changed: 9 additions & 2 deletions

File tree

src/lmms_engine/merger/fsdp2.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -98,13 +98,20 @@ def consolidate(self, shard_state_dicts: list[dict]) -> dict:
9898
state_dict[key] = []
9999
for model_state_shard in shard_state_dicts:
100100
tensor = model_state_shard.pop(key)
101-
state_dict[key].append(tensor._local_tensor.bfloat16())
101+
# Non-sharded tensors (e.g. buffers like inv_freq) are plain Tensors,
102+
# while FSDP-sharded parameters are DTensors with _local_tensor.
103+
local = tensor._local_tensor if hasattr(tensor, "_local_tensor") else tensor
104+
state_dict[key].append(local.bfloat16())
102105

103106
# Merge tensors along dim=0 (data parallel dimension)
104107
for key in sorted(state_dict):
105108
if not isinstance(state_dict[key], list):
106109
continue
107-
state_dict[key] = torch.cat(state_dict[key], dim=0)
110+
# Non-sharded tensors are duplicated across ranks; just take the first one
111+
if all(t.shape == state_dict[key][0].shape and torch.equal(t, state_dict[key][0]) for t in state_dict[key][1:]):
112+
state_dict[key] = state_dict[key][0]
113+
else:
114+
state_dict[key] = torch.cat(state_dict[key], dim=0)
108115

109116
return state_dict
110117

0 commit comments

Comments
 (0)