According to your paper, it should be
g_B_C_T_H_W = ... - warmup_ratio * ( cost_B_1_T_1_1 * sint_B_1_T_1_1 * xt_B_C_T_H_W + t_F_theta_B_C_T_H_W )
instead of
|
g_B_C_T_H_W = -cost_B_1_T_1_1 * torch.sqrt(1 - warmup_ratio**2 * sint_B_1_T_1_1**2) * (F_theta_B_C_T_H_W_sg - F_teacher_B_C_T_H_W) - ( |
|
warmup_ratio * cost_B_1_T_1_1 * sint_B_1_T_1_1 * xt_B_C_T_H_W + t_F_theta_B_C_T_H_W |
|
) |
According to your paper, it should be
g_B_C_T_H_W = ... - warmup_ratio * ( cost_B_1_T_1_1 * sint_B_1_T_1_1 * xt_B_C_T_H_W + t_F_theta_B_C_T_H_W )instead of
rcm/rcm/models/t2v_model_distill_rcm.py
Lines 559 to 561 in 64873cf