From 8cefb7d9319be7c9790d697774c7431d77aa1647 Mon Sep 17 00:00:00 2001 From: Your Name Date: Tue, 9 Dec 2025 22:31:34 -0800 Subject: [PATCH] Torch compilable RoPE --- llada/model/modeling_llada.py | 23 +++++------------------ 1 file changed, 5 insertions(+), 18 deletions(-) diff --git a/llada/model/modeling_llada.py b/llada/model/modeling_llada.py index b9a22cb..96a82bc 100644 --- a/llada/model/modeling_llada.py +++ b/llada/model/modeling_llada.py @@ -398,29 +398,16 @@ def __init__(self, config: ModelConfig, cache: BufferCache): self.get_rotary_embedding(config.max_sequence_length, _non_meta_init_device(config)) def get_rotary_embedding(self, seq_len: int, device: torch.device) -> Tuple[torch.Tensor, torch.Tensor]: - if ( - (pos_sin := self.__cache.get("rope_pos_sin")) is not None - and (pos_cos := self.__cache.get("rope_pos_cos")) is not None - and pos_sin.shape[-2] >= seq_len - and pos_cos.shape[-2] >= seq_len - ): - if pos_sin.device != device: - pos_sin = pos_sin.to(device) - self.__cache["rope_pos_sin"] = pos_sin - if pos_cos.device != device: - pos_cos = pos_cos.to(device) - self.__cache["rope_pos_cos"] = pos_cos - return pos_sin[:, :, :seq_len, :], pos_cos[:, :, :seq_len, :] - with torch.autocast(device.type, enabled=False): dim = self.config.d_model // self.config.n_heads inv_freq = 1.0 / (self.rope_theta ** (torch.arange(0, dim, 2, device=device, dtype=torch.float) / dim)) seq = torch.arange(seq_len, device=device, dtype=torch.float) freqs = einsum("i , j -> i j", seq, inv_freq) - positions = torch.cat((freqs, freqs), dim=-1) - pos_sin, pos_cos = positions.sin()[None, None, :, :], positions.cos()[None, None, :, :] - self.__cache["rope_pos_sin"] = pos_sin - self.__cache["rope_pos_cos"] = pos_cos + positions = torch.cat([freqs, freqs], dim=-1) + + # unsqueeze over cache allows torch.compile + pos_sin = torch.sin(positions).unsqueeze(0).unsqueeze(0) + pos_cos = torch.cos(positions).unsqueeze(0).unsqueeze(0) return pos_sin, pos_cos def rotate_half(self, x: torch.Tensor) -> torch.Tensor: