-
Notifications
You must be signed in to change notification settings - Fork 8
Expand file tree
/
Copy pathtestCLI.py
More file actions
36 lines (30 loc) · 1.22 KB
/
testCLI.py
File metadata and controls
36 lines (30 loc) · 1.22 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import unittest
from chebai.callbacks.save_config import CustomSaveConfigCallback
from chebai.cli import ChebaiCLI
class TestChebaiCLI(unittest.TestCase):
def setUp(self):
self.cli_args = [
"fit",
"--trainer=configs/training/default_trainer.yml",
"--model=configs/model/ffn.yml",
"--model.init_args.hidden_layers=[10]",
"--model.train_metrics=configs/metrics/micro-macro-f1.yml",
"--data=tests/unit/cli/mock_dm_config.yml",
"--model.pass_loss_kwargs=false",
"--trainer.min_epochs=1",
"--trainer.max_epochs=1",
"--model.criterion=tests/unit/cli/bce_loss.yml",
]
def test_mlp_on_chebai_cli(self):
# Instantiate ChebaiCLI and ensure no exceptions are raised
try:
ChebaiCLI(
args=self.cli_args,
save_config_callback=CustomSaveConfigCallback,
save_config_kwargs={"config_filename": "lightning_config.yaml"},
parser_kwargs={"parser_mode": "omegaconf"},
)
except Exception as e:
self.fail(f"ChebaiCLI raised an unexpected exception: {e}")
if __name__ == "__main__":
unittest.main()