66import dgbpy .keystr as dbk
77import dgbpy .mlapply as dgbml
88import dgbpy .dgbkeras as dgbkeras
9+ import dgbpy .dgbtorch as dgbtorch
910from init_data import *
1011if dgbkeras .hasKeras ():
1112 from test_dgkeras import default_pars as keras_params
@@ -70,11 +71,13 @@ def test_doTrain_invalid_platform(examplefilenm, capsys):
7071 captured = capsys .readouterr ()
7172 assert 'Unsupported machine learning platform' in captured .out
7273
74+ @pytest .mark .skipif (not dgbkeras .hasKeras (), reason = "Keras is not available" )
7375@pytest .mark .parametrize ('examplefilenm' , examples )
7476def test_doTrain_keras_new_trainingtype (examplefilenm ):
7577 kwargs = keras_test_cases ()
7678 assert dgbml .doTrain (examplefilenm , ** kwargs ) == True
7779
80+ @pytest .mark .skipif (not dgbkeras .hasKeras (), reason = "Keras is not available" )
7881@pytest .mark .parametrize ('examplefilenm' , examples )
7982def test_doTrain_keras_resume_trainingtype (examplefilenm ):
8083 kwargs = keras_test_cases (examplefilenm )
@@ -84,6 +87,7 @@ def test_doTrain_keras_resume_trainingtype(examplefilenm):
8487 kwargs ['type' ] = dgbml .TrainType .Resume
8588 assert dgbml .doTrain (examplefilenm , ** kwargs ) == True
8689
90+ @pytest .mark .skipif (not dgbkeras .hasKeras (), reason = "Keras is not available" )
8791@pytest .mark .parametrize ('examplefilenm' , examples )
8892def test_doTrain_keras_transfer_trainingtype (examplefilenm ):
8993 kwargs = keras_test_cases (examplefilenm )
@@ -122,13 +126,14 @@ def get_pretrained_modelfilenm(examplefilenm):
122126 return None
123127
124128
125-
129+ @ pytest . mark . skipif ( not dgbtorch . hasTorch (), reason = "Torch is not available" )
126130@pytest .mark .parametrize ('examplefilenm' , examples )
127131def test_doTrain_torch_new_trainingtype (examplefilenm ):
128132 kwargs = torch_test_cases (examplefilenm )
129133 assert dgbml .doTrain (examplefilenm , ** kwargs ) == True
130134 models .append (kwargs ['outnm' ])
131135
136+ @pytest .mark .skipif (not dgbtorch .hasTorch (), reason = "Torch is not available" )
132137@pytest .mark .parametrize ('examplefilenm' , examples )
133138def test_doTrain_torch_resume_trainingtype (examplefilenm ):
134139 kwargs = torch_test_cases (examplefilenm )
@@ -143,6 +148,7 @@ def test_doTrain_torch_resume_trainingtype(examplefilenm):
143148 kwargs ['outnm' ] = f'torch_test_{ get_filenm_from_path (examplefilenm )} _resume.h5'
144149 assert dgbml .doTrain (examplefilenm , ** kwargs ) == True
145150
151+ @pytest .mark .skipif (not dgbtorch .hasTorch (), reason = "Torch is not available" )
146152@pytest .mark .parametrize ('examplefilenm' , examples )
147153def test_doTrain_torch_transfer_trainingtype (examplefilenm ):
148154 kwargs = torch_test_cases (examplefilenm )
0 commit comments