Skip to content

Commit dc03638

Browse files
Update DPO_pLM.py
1 parent a2c2381 commit dc03638

1 file changed

Lines changed: 6 additions & 4 deletions

File tree

DPO_ProtGPT2/DPO_pLM.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -253,6 +253,7 @@ def dpo_weighted_loss(pi_log_likelihood, ref_log_likelihood, weights, beta=0.1):
253253
return loss
254254

255255

256+
256257
def dpo_ranked_loss(pi_log_likelihood, pi_ref_loglikelihood, weights, beta=0.1):
257258
"""
258259
Calculates the Directed Policy Optimization (DPO) ranked loss.
@@ -272,16 +273,17 @@ def dpo_ranked_loss(pi_log_likelihood, pi_ref_loglikelihood, weights, beta=0.1):
272273
pi_ratio = beta * pi_log_likelihood
273274
else:
274275
pi_ratio = beta * (pi_log_likelihood - pi_ref_loglikelihood)
275-
276-
uniform_weights = torch.ones_like(pi_ratio)
277-
print(f"pi ratios: {pi_ratio}")
278-
276+
277+
uniform_weights = torch.arange(pi_ratio.size(0) -1 , -1, -1, device=advantages.device, dtype=advantages.dtype)
278+
uniform_weights = torch.softmax(uniform_weights, dim=0)
279279

280280
loss = F.cross_entropy(pi_ratio, uniform_weights)
281281
return loss
282282

283283

284284

285+
286+
285287
# ---------------------------
286288
# Training and Evaluation
287289
# ---------------------------

0 commit comments

Comments
 (0)