Skip to content

Commit 2364907

Browse files
committed
Fix handling of batches without positives
1 parent 58793d7 commit 2364907

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

chebai/models/base.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -303,7 +303,7 @@ def _execute(
303303
logger=True,
304304
sync_dist=sync_dist,
305305
)
306-
if metrics and labels is not None:
306+
if metrics and labels is not None and pr is not None:
307307
for metric_name, metric in metrics.items():
308308
metric.update(pr, tar)
309309
self._log_metrics(prefix, metrics, len(batch))

chebai/models/electra.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -333,7 +333,10 @@ def _get_prediction_and_labels(
333333
missing_labels = loss_kwargs["missing_labels"]
334334
d[missing_labels] = 0
335335
has_positive_entries = torch.sum(labels, dim=-1)>0
336-
return d[has_positive_entries], labels[has_positive_entries].int() if labels[has_positive_entries] is not None else None
336+
if torch.sum(has_positive_entries):
337+
return d[has_positive_entries], labels[has_positive_entries].int() if labels[has_positive_entries] is not None else None
338+
else:
339+
return None, None
337340
elif self.model_type == 'regression':
338341
return d, labels
339342
else:

0 commit comments

Comments
 (0)