Hello,
I'd like to report two issues regarding classification tasks in modnet:
First, the loss function passed to ModnetModel().fit() is overwritten with "categorical_crossentropy" if val_data is not None and self.multi_label=False:
|
if self.num_classes[prop[0]] >= 2: # Classification |
|
targ = prop[0] |
|
if self.multi_label: |
|
y_inner = np.stack(val_data.df_targets[targ].values) |
|
if loss is None: |
|
loss = "binary_crossentropy" |
|
else: |
|
y_inner = tf.keras.utils.to_categorical( |
|
val_data.df_targets[targ].values, |
|
num_classes=self.num_classes[targ], |
|
) |
|
loss = "categorical_crossentropy" |
As the
loss=None case is already handled before in L352-L360 in the preprocessing of the training data, maybe this could be removed here when preprocessing the validation data?
Second, if nested=False, both FitGenetic and ModnetModel.fit_preset() perform a train test split that is not stratified:
|
train_test_split(range(len(data.df_featurized)), test_size=val_fraction) |
|
splits = [ |
|
train_test_split( |
|
range(len(self.train_data.df_featurized)), test_size=val_fraction |
|
) |
|
] |
This is an issue in the case of imbalanced datasets and it would be helpful if the splitting was stratified for classification tasks.
If you are interested, I'm happy to raise a PR with fixes.
Hello,
I'd like to report two issues regarding classification tasks in modnet:
First, the loss function passed to
ModnetModel().fit()is overwritten with"categorical_crossentropy"ifval_datais not None andself.multi_label=False:modnet/modnet/models/vanilla.py
Lines 400 to 411 in e14188d
As the
loss=Nonecase is already handled before in L352-L360 in the preprocessing of the training data, maybe this could be removed here when preprocessing the validation data?Second, if
nested=False, bothFitGeneticandModnetModel.fit_preset()perform a train test split that is not stratified:modnet/modnet/models/vanilla.py
Line 580 in e14188d
modnet/modnet/hyper_opt/fit_genetic.py
Lines 458 to 462 in e14188d
This is an issue in the case of imbalanced datasets and it would be helpful if the splitting was stratified for classification tasks.
If you are interested, I'm happy to raise a PR with fixes.