File tree Expand file tree Collapse file tree 3 files changed +3
-4
lines changed
Expand file tree Collapse file tree 3 files changed +3
-4
lines changed Original file line number Diff line number Diff line change @@ -299,8 +299,7 @@ def _process_for_loss(
299299 labels = labels .float ()
300300 if "missing_labels" in kwargs_copy :
301301 missing_labels = kwargs_copy .pop ("missing_labels" )
302- output = output * (~ missing_labels ).int () - 10000 * missing_labels .int ()
303- labels = labels * (~ missing_labels ).int ()
302+ output [missing_labels ] = - 1e8
304303 if self .model_type == "classification" :
305304 assert ((labels <= torch .tensor (1.0 )) & (labels >= torch .tensor (0.0 ))).all ()
306305 return output , labels , kwargs_copy
@@ -332,7 +331,7 @@ def _get_prediction_and_labels(
332331 if "missing_labels" in loss_kwargs :
333332 #print('bla')
334333 missing_labels = loss_kwargs ["missing_labels" ]
335- d = d * ( ~ missing_labels ). int (). to ( device = d . device )
334+ d [ missing_labels ] = 0
336335 return d , labels .int () if labels is not None else None
337336 elif self .model_type == 'regression' :
338337 return d , labels
Original file line number Diff line number Diff line change @@ -5,6 +5,7 @@ init_args:
55 class_path : torchmetrics.classification.MultilabelF1Score
66 init_args :
77 average : micro
8+ zero_division : 1
89 macro-f1 :
910 class_path : chebai.callbacks.epoch_metrics.MacroF1
1011 roc-auc :
Original file line number Diff line number Diff line change 11class_path : chebai.models.Electra
22init_args :
3- model_type : regression
43 optimizer_kwargs :
54 lr : 1e-4
65 config :
You can’t perform that action at this time.
0 commit comments