Skip to content

Commit 0a66ef4

Browse files
authored
Merge pull request #143 from schnamo/dev
tidy up config files for loss, fix missing labels issue, etc
2 parents b32e6c5 + 203b2b3 commit 0a66ef4

File tree

8 files changed

+5
-7
lines changed

8 files changed

+5
-7
lines changed

chebai/models/base.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -278,6 +278,7 @@ def _execute(
278278
loss_kwargs = dict()
279279
if self.pass_loss_kwargs:
280280
loss_kwargs = loss_kwargs_candidates
281+
loss_kwargs["current_epoch"] = self.trainer.current_epoch
281282
loss = self.criterion(loss_data, loss_labels, **loss_kwargs)
282283
if isinstance(loss, tuple):
283284
unnamed_loss_index = 1

chebai/models/electra.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -241,7 +241,6 @@ def __init__(
241241
self.config = ElectraConfig(**config, output_attentions=True)
242242
self.word_dropout = nn.Dropout(config.get("word_dropout", 0))
243243
self.model_type = model_type
244-
self.pass_loss_kwargs = True
245244

246245
in_d = self.config.hidden_size
247246
self.output = nn.Sequential(

configs/loss/bce_new.yml

Lines changed: 0 additions & 1 deletion
This file was deleted.

configs/loss/bce_try.yml

Lines changed: 0 additions & 1 deletion
This file was deleted.

configs/loss/bce_unweighted.yml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
class_path: torch.nn.BCEWithLogitsLoss
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
class_path: chebai.loss.bce_weighted.BCEWeighted
22
init_args:
3-
beta: 1000
3+
beta: 0.99
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
11
class_path: chebai.loss.focal_loss.FocalLoss
22
init_args:
33
task_type: multi-label
4-
num_classes: 12

configs/model/electra.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
class_path: chebai.models.Electra
22
init_args:
3-
model_type: regression
3+
model_type: classification
44
optimizer_kwargs:
55
lr: 1e-4
66
config:
@@ -9,4 +9,4 @@ init_args:
99
num_attention_heads: 8
1010
num_hidden_layers: 6
1111
type_vocab_size: 1
12-
hidden_size: 256
12+
hidden_size: 256

0 commit comments

Comments
 (0)