Skip to content

Commit 16d890f

Browse files
committed
fix(aero_realtime): fix NaN loss and duplicate system prompts
- Make time_embedding.inv_freq persistent=True so it survives FSDP2 state dict save/load cycle (was corrupted → NaN loss) - Remove auto-injected system prompt from chat template that was added before every turn due to per-message apply_chat_template
1 parent 7c6dd02 commit 16d890f

2 files changed

Lines changed: 1 addition & 4 deletions

File tree

src/lmms_engine/datasets/processor/aero_realtime_processor.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -615,9 +615,6 @@ def chat_template(self):
615615
"{% set image_count = namespace(value=0) %}"
616616
"{% set video_count = namespace(value=0) %}"
617617
"{% for message in messages %}"
618-
"{% if loop.first and message['role'] != 'system' %}"
619-
"<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n"
620-
"{% endif %}"
621618
"<|im_start|>{{ message['role'] }}\n"
622619
"{% if message['content'] is string %}"
623620
"{{ message['content'] }}<|im_end|>\n"

src/lmms_engine/models/aero_realtime/modeling_aero_realtime.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def __init__(self, dim: int, theta: float = 10000.0) -> None:
138138
inv_freq = torch.exp(
139139
-math.log(self.theta) * torch.arange(self.dim // 2).float() / (self.dim // 2)
140140
)
141-
self.register_buffer("inv_freq", inv_freq, persistent=False)
141+
self.register_buffer("inv_freq", inv_freq, persistent=True)
142142

143143
def forward(self, time_tensor: torch.Tensor) -> torch.Tensor:
144144
inv_freq = self.inv_freq.to(device=time_tensor.device, dtype=time_tensor.dtype)

0 commit comments

Comments
 (0)