Skip to content

Commit c19454c

Browse files
committed
register rotary embedding frequencies as non-persistent buffers
1 parent 6867768 commit c19454c

1 file changed

Lines changed: 2 additions & 5 deletions

File tree

src/diffusers/models/transformers/transformer_wan.py

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -218,17 +218,14 @@ def __init__(
218218
freqs_cos.append(freq_cos)
219219
freqs_sin.append(freq_sin)
220220

221-
self.freqs_cos = torch.cat(freqs_cos, dim=1)
222-
self.freqs_sin = torch.cat(freqs_sin, dim=1)
221+
self.register_buffer("freqs_cos", torch.cat(freqs_cos, dim=1), persistent=False)
222+
self.register_buffer("freqs_sin", torch.cat(freqs_sin, dim=1), persistent=False)
223223

224224
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
225225
batch_size, num_channels, num_frames, height, width = hidden_states.shape
226226
p_t, p_h, p_w = self.patch_size
227227
ppf, pph, ppw = num_frames // p_t, height // p_h, width // p_w
228228

229-
self.freqs_cos = self.freqs_cos.to(hidden_states.device)
230-
self.freqs_sin = self.freqs_sin.to(hidden_states.device)
231-
232229
split_sizes = [
233230
self.attention_head_dim - 2 * (self.attention_head_dim // 3),
234231
self.attention_head_dim // 3,

0 commit comments

Comments
 (0)