We read every piece of feedback, and take your input very seriously.
To see all available qualifiers, see our documentation.
1 parent e973181 commit 58793d7Copy full SHA for 58793d7
chebai/models/electra.py
@@ -332,7 +332,8 @@ def _get_prediction_and_labels(
332
#print('bla')
333
missing_labels = loss_kwargs["missing_labels"]
334
d[missing_labels] = 0
335
- return d, labels.int() if labels is not None else None
+ has_positive_entries = torch.sum(labels, dim=-1)>0
336
+ return d[has_positive_entries], labels[has_positive_entries].int() if labels[has_positive_entries] is not None else None
337
elif self.model_type == 'regression':
338
return d, labels
339
else:
0 commit comments