Skip to content

Commit 066fe2a

Browse files
committed
Fixed the test_mlapply.py on pytest programs
1 parent e388229 commit 066fe2a

2 files changed

Lines changed: 8 additions & 1 deletion

File tree

setup.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
'odpy>=1.1.0',
2727
'scikit-learn>=0.24.2',
2828
'psutil>=5.7.0',
29+
'tensorboard==2.19.0',
2930
'torch>=1.9.0',
3031
'fastprogress>=1.0.0',
3132
'onnxruntime-gpu>=1.0.0',

tests/test_mlapply.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import dgbpy.keystr as dbk
77
import dgbpy.mlapply as dgbml
88
import dgbpy.dgbkeras as dgbkeras
9+
import dgbpy.dgbtorch as dgbtorch
910
from init_data import *
1011
if dgbkeras.hasKeras():
1112
from test_dgkeras import default_pars as keras_params
@@ -70,11 +71,13 @@ def test_doTrain_invalid_platform(examplefilenm, capsys):
7071
captured = capsys.readouterr()
7172
assert 'Unsupported machine learning platform' in captured.out
7273

74+
@pytest.mark.skipif(not dgbkeras.hasKeras(), reason="Keras is not available")
7375
@pytest.mark.parametrize('examplefilenm', examples)
7476
def test_doTrain_keras_new_trainingtype(examplefilenm):
7577
kwargs = keras_test_cases()
7678
assert dgbml.doTrain(examplefilenm, **kwargs) == True
7779

80+
@pytest.mark.skipif(not dgbkeras.hasKeras(), reason="Keras is not available")
7881
@pytest.mark.parametrize('examplefilenm', examples)
7982
def test_doTrain_keras_resume_trainingtype(examplefilenm):
8083
kwargs = keras_test_cases(examplefilenm)
@@ -84,6 +87,7 @@ def test_doTrain_keras_resume_trainingtype(examplefilenm):
8487
kwargs['type'] = dgbml.TrainType.Resume
8588
assert dgbml.doTrain(examplefilenm, **kwargs) == True
8689

90+
@pytest.mark.skipif(not dgbkeras.hasKeras(), reason="Keras is not available")
8791
@pytest.mark.parametrize('examplefilenm', examples)
8892
def test_doTrain_keras_transfer_trainingtype(examplefilenm):
8993
kwargs = keras_test_cases(examplefilenm)
@@ -122,13 +126,14 @@ def get_pretrained_modelfilenm(examplefilenm):
122126
return None
123127

124128

125-
129+
@pytest.mark.skipif(not dgbtorch.hasTorch(), reason="Torch is not available")
126130
@pytest.mark.parametrize('examplefilenm', examples)
127131
def test_doTrain_torch_new_trainingtype(examplefilenm):
128132
kwargs = torch_test_cases(examplefilenm)
129133
assert dgbml.doTrain(examplefilenm, **kwargs) == True
130134
models.append(kwargs['outnm'])
131135

136+
@pytest.mark.skipif(not dgbtorch.hasTorch(), reason="Torch is not available")
132137
@pytest.mark.parametrize('examplefilenm', examples)
133138
def test_doTrain_torch_resume_trainingtype(examplefilenm):
134139
kwargs = torch_test_cases(examplefilenm)
@@ -143,6 +148,7 @@ def test_doTrain_torch_resume_trainingtype(examplefilenm):
143148
kwargs['outnm'] = f'torch_test_{get_filenm_from_path(examplefilenm)}_resume.h5'
144149
assert dgbml.doTrain(examplefilenm, **kwargs) == True
145150

151+
@pytest.mark.skipif(not dgbtorch.hasTorch(), reason="Torch is not available")
146152
@pytest.mark.parametrize('examplefilenm', examples)
147153
def test_doTrain_torch_transfer_trainingtype(examplefilenm):
148154
kwargs = torch_test_cases(examplefilenm)

0 commit comments

Comments
 (0)