Commit 08daf77
committed
Fix init_weights crash with aliased buffers (e.g. rope.cache/freqs_cis)
When a model registers the same tensor under multiple FQNs (e.g.
`rope.cache` and `freqs_cis` in torchtitan's Decoder), PyTorch's
`named_buffers()` deduplicates them by default, and this deduplication
also happens during AOTAutograd tracing. As a result, the
AutoParallelModule was missing the alias FQNs entirely.
This must be fixed in autoparallel (not deep in PyTorch) because
autoparallel returns an nn.Module that the user should be able to use
like their original model -- all buffer FQNs from the original model
must be present on the returned module.
The fix has two parts: (1) in api.py, capture buffer alias info before
`move_to_fake` destroys aliasing, then re-register aliases on the
parallel module after sharding; (2) in init_weights.py, skip hooking
buffer FQNs that don't exist on the parallel model (aliases that were
deduplicated).
Authored with Claude.
stack-info: PR: #321, branch: xmfan/stack/241 parent b5a020d commit 08daf77
4 files changed
Lines changed: 132 additions & 1 deletion
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
72 | 72 | | |
73 | 73 | | |
74 | 74 | | |
| 75 | + | |
| 76 | + | |
| 77 | + | |
| 78 | + | |
| 79 | + | |
| 80 | + | |
| 81 | + | |
| 82 | + | |
| 83 | + | |
| 84 | + | |
| 85 | + | |
| 86 | + | |
| 87 | + | |
| 88 | + | |
| 89 | + | |
| 90 | + | |
| 91 | + | |
| 92 | + | |
| 93 | + | |
| 94 | + | |
| 95 | + | |
| 96 | + | |
| 97 | + | |
| 98 | + | |
| 99 | + | |
| 100 | + | |
| 101 | + | |
| 102 | + | |
| 103 | + | |
| 104 | + | |
| 105 | + | |
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
251 | 251 | | |
252 | 252 | | |
253 | 253 | | |
| 254 | + | |
| 255 | + | |
| 256 | + | |
| 257 | + | |
| 258 | + | |
| 259 | + | |
| 260 | + | |
| 261 | + | |
| 262 | + | |
| 263 | + | |
| 264 | + | |
| 265 | + | |
| 266 | + | |
254 | 267 | | |
255 | 268 | | |
256 | 269 | | |
| |||
579 | 592 | | |
580 | 593 | | |
581 | 594 | | |
| 595 | + | |
| 596 | + | |
| 597 | + | |
| 598 | + | |
| 599 | + | |
| 600 | + | |
| 601 | + | |
| 602 | + | |
| 603 | + | |
| 604 | + | |
| 605 | + | |
| 606 | + | |
| 607 | + | |
| 608 | + | |
582 | 609 | | |
583 | 610 | | |
584 | 611 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
122 | 122 | | |
123 | 123 | | |
124 | 124 | | |
| 125 | + | |
| 126 | + | |
125 | 127 | | |
126 | 128 | | |
127 | 129 | | |
| |||
132 | 134 | | |
133 | 135 | | |
134 | 136 | | |
135 | | - | |
| 137 | + | |
| 138 | + | |
| 139 | + | |
| 140 | + | |
| 141 | + | |
| 142 | + | |
| 143 | + | |
136 | 144 | | |
137 | 145 | | |
138 | 146 | | |
| |||
| Original file line number | Diff line number | Diff line change | |
|---|---|---|---|
| |||
168 | 168 | | |
169 | 169 | | |
170 | 170 | | |
| 171 | + | |
| 172 | + | |
| 173 | + | |
| 174 | + | |
| 175 | + | |
| 176 | + | |
| 177 | + | |
| 178 | + | |
| 179 | + | |
| 180 | + | |
| 181 | + | |
| 182 | + | |
| 183 | + | |
| 184 | + | |
| 185 | + | |
| 186 | + | |
| 187 | + | |
| 188 | + | |
| 189 | + | |
| 190 | + | |
| 191 | + | |
| 192 | + | |
| 193 | + | |
| 194 | + | |
| 195 | + | |
| 196 | + | |
| 197 | + | |
| 198 | + | |
| 199 | + | |
| 200 | + | |
| 201 | + | |
| 202 | + | |
| 203 | + | |
| 204 | + | |
| 205 | + | |
| 206 | + | |
| 207 | + | |
| 208 | + | |
| 209 | + | |
| 210 | + | |
| 211 | + | |
| 212 | + | |
| 213 | + | |
| 214 | + | |
| 215 | + | |
| 216 | + | |
| 217 | + | |
| 218 | + | |
| 219 | + | |
| 220 | + | |
| 221 | + | |
| 222 | + | |
| 223 | + | |
| 224 | + | |
| 225 | + | |
| 226 | + | |
| 227 | + | |
| 228 | + | |
| 229 | + | |
| 230 | + | |
| 231 | + | |
| 232 | + | |
| 233 | + | |
| 234 | + | |
| 235 | + | |
171 | 236 | | |
172 | 237 | | |
173 | 238 | | |
| |||
0 commit comments