Skip to content

Commit 58793d7

Browse files
committed
Exclude positive-free entris from metrics
1 parent e973181 commit 58793d7

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

chebai/models/electra.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -332,7 +332,8 @@ def _get_prediction_and_labels(
332332
#print('bla')
333333
missing_labels = loss_kwargs["missing_labels"]
334334
d[missing_labels] = 0
335-
return d, labels.int() if labels is not None else None
335+
has_positive_entries = torch.sum(labels, dim=-1)>0
336+
return d[has_positive_entries], labels[has_positive_entries].int() if labels[has_positive_entries] is not None else None
336337
elif self.model_type == 'regression':
337338
return d, labels
338339
else:

0 commit comments

Comments
 (0)