Skip to content

Commit e973181

Browse files
committed
Make missing lables handling more direct
1 parent f74964c commit e973181

File tree

3 files changed

+3
-4
lines changed

3 files changed

+3
-4
lines changed

chebai/models/electra.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -299,8 +299,7 @@ def _process_for_loss(
299299
labels = labels.float()
300300
if "missing_labels" in kwargs_copy:
301301
missing_labels = kwargs_copy.pop("missing_labels")
302-
output = output * (~missing_labels).int() - 10000 * missing_labels.int()
303-
labels = labels * (~missing_labels).int()
302+
output[missing_labels] = -1e8
304303
if self.model_type == "classification":
305304
assert ((labels <= torch.tensor(1.0)) & (labels >= torch.tensor(0.0))).all()
306305
return output, labels, kwargs_copy
@@ -332,7 +331,7 @@ def _get_prediction_and_labels(
332331
if "missing_labels" in loss_kwargs:
333332
#print('bla')
334333
missing_labels = loss_kwargs["missing_labels"]
335-
d = d * (~missing_labels).int().to(device=d.device)
334+
d[missing_labels] = 0
336335
return d, labels.int() if labels is not None else None
337336
elif self.model_type == 'regression':
338337
return d, labels

configs/metrics/micro-macro-f1-roc-auc.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ init_args:
55
class_path: torchmetrics.classification.MultilabelF1Score
66
init_args:
77
average: micro
8+
zero_division: 1
89
macro-f1:
910
class_path: chebai.callbacks.epoch_metrics.MacroF1
1011
roc-auc:

configs/model/electra.yml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
class_path: chebai.models.Electra
22
init_args:
3-
model_type: regression
43
optimizer_kwargs:
54
lr: 1e-4
65
config:

0 commit comments

Comments
 (0)