Skip to content

Commit 9f2704e

Browse files
Merge pull request #22 from niklashoelter/extensive-test-suite-10598339394247217714
Extensive Test Suite Implementation
2 parents 1ad598d + b7e5c14 commit 9f2704e

5 files changed

Lines changed: 210 additions & 8 deletions

File tree

tests/test_cli.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
from unittest.mock import patch
2+
import pytest
23

34
from gpuma.cli import main
45

@@ -119,3 +120,10 @@ def test_cli_verbose(tmp_path, caplog):
119120
# Check for debug logs if any (depends on what's logged)
120121
# The config logic sets logging level to DEBUG
121122
pass
123+
124+
def test_cli_help(capsys):
125+
with patch("sys.argv", ["gpuma", "--help"]):
126+
with pytest.raises(SystemExit):
127+
main()
128+
captured = capsys.readouterr()
129+
assert "GPUMA" in captured.out or "usage:" in captured.out

tests/test_config.py

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,24 @@ def test_validate_config():
7777
with pytest.raises(ValueError, match="Device must be"):
7878
Config({"optimization": {"device": "invalid_device"}})
7979

80+
def test_validate_config_convergence():
81+
# Test invalid convergence criteria
82+
with pytest.raises(ValueError, match="force_convergence_criterion must be a positive float"):
83+
Config({"optimization": {"force_convergence_criterion": -0.01}})
84+
85+
with pytest.raises(ValueError, match="force_convergence_criterion must be a positive float"):
86+
Config({"optimization": {"force_convergence_criterion": "invalid"}})
87+
88+
with pytest.raises(ValueError, match="energy_convergence_criterion must be a positive float"):
89+
Config({"optimization": {"energy_convergence_criterion": 0.0}})
90+
91+
def test_validate_config_device_empty():
92+
with pytest.raises(ValueError, match="Device string in config cannot be empty"):
93+
Config({"optimization": {"device": ""}})
94+
95+
with pytest.raises(ValueError, match="Device string in config cannot be empty"):
96+
Config({"optimization": {"device": " "}})
97+
8098
def test_huggingface_token(tmp_path, monkeypatch):
8199
# Case 1: Token in config
82100
cfg = Config({"optimization": {"huggingface_token": "token_in_config"}})

tests/test_io.py

Lines changed: 47 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import pytest
2+
from unittest.mock import patch
23

34
from gpuma.io_handler import (
45
file_exists,
@@ -28,6 +29,25 @@ def test_read_xyz_not_found():
2829
with pytest.raises(FileNotFoundError):
2930
read_xyz("non_existent.xyz")
3031

32+
def test_read_xyz_empty_file(tmp_path):
33+
f = tmp_path / "empty.xyz"
34+
f.touch()
35+
36+
with pytest.raises(ValueError):
37+
read_xyz(str(f))
38+
39+
def test_read_xyz_malformed(tmp_path):
40+
f = tmp_path / "malformed.xyz"
41+
# First line not integer
42+
f.write_text("invalid\ncomment\nC 0 0 0")
43+
with pytest.raises(ValueError, match="First line must contain the number of atoms"):
44+
read_xyz(str(f))
45+
46+
# Missing coordinates
47+
f.write_text("1\ncomment\nC 0 0")
48+
with pytest.raises(ValueError, match="must contain at least 4 elements"):
49+
read_xyz(str(f))
50+
3151
def test_read_multi_xyz(tmp_path, sample_multi_xyz_content):
3252
f = tmp_path / "multi.xyz"
3353
f.write_text(sample_multi_xyz_content)
@@ -39,6 +59,25 @@ def test_read_multi_xyz(tmp_path, sample_multi_xyz_content):
3959
assert structs[1].comment == "Methane"
4060
assert len(structs[1].symbols) == 5
4161

62+
def test_read_multi_xyz_malformed(tmp_path):
63+
f = tmp_path / "multi_malformed.xyz"
64+
content = """1
65+
Good
66+
H 0 0 0
67+
1
68+
Bad
69+
H 0 0
70+
1
71+
Good2
72+
H 1 1 1
73+
"""
74+
f.write_text(content)
75+
structs = read_multi_xyz(str(f))
76+
# It should skip the middle one
77+
assert len(structs) == 2
78+
assert structs[0].comment == "Good"
79+
assert structs[1].comment == "Good2"
80+
4281
def test_read_xyz_directory(tmp_path, sample_xyz_content):
4382
d = tmp_path / "xyz_dir"
4483
d.mkdir()
@@ -58,17 +97,21 @@ def test_save_xyz_file(tmp_path, sample_structure):
5897
assert "Methane | Energy: -10.500000 eV | Charge: 0 | Multiplicity: 1" in content
5998
assert "C 0.000000" in content
6099

100+
def test_save_xyz_permission_error(tmp_path, sample_structure):
101+
f = tmp_path / "protected.xyz"
102+
103+
# Mock open to raise PermissionError
104+
with patch("builtins.open", side_effect=PermissionError("Denied")):
105+
with pytest.raises(PermissionError):
106+
save_xyz_file(sample_structure, str(f))
107+
61108
def test_save_multi_xyz(tmp_path, sample_structure):
62109
f = tmp_path / "output_multi.xyz"
63110
s1 = sample_structure
64111
s2 = sample_structure.with_energy(-20.0)
65112
save_multi_xyz([s1, s2], str(f), comments=["First", "Second"])
66113

67114
content = f.read_text()
68-
# Energy might persist on s1 if mutated
69-
assert "First | Energy: -20.000000" in content or "First | Energy: -10.500000" not in content
70-
# wait, s2 = sample_structure.with_energy(-20.0) mutates sample_structure!
71-
# Because s2 is s1.
72115
assert "Second | Energy: -20.000000" in content
73116

74117
def test_smiles_to_xyz():

tests/test_logging.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import logging
2+
from unittest.mock import patch
3+
4+
from gpuma.logging_utils import configure_logging
5+
6+
7+
def test_configure_logging_root():
8+
# Reset root logger handlers
9+
logging.getLogger().handlers = []
10+
11+
configure_logging(level=logging.DEBUG)
12+
13+
logger = logging.getLogger()
14+
assert logger.level == logging.DEBUG
15+
assert len(logger.handlers) >= 1
16+
assert isinstance(logger.handlers[0], logging.StreamHandler)
17+
18+
def test_configure_logging_named():
19+
logger_name = "test_logger"
20+
# Ensure fresh start
21+
logging.getLogger(logger_name).handlers = []
22+
23+
configure_logging(level=logging.WARNING, logger_name=logger_name)
24+
25+
logger = logging.getLogger(logger_name)
26+
assert logger.level == logging.WARNING
27+
assert len(logger.handlers) >= 1
28+
29+
def test_configure_logging_idempotent():
30+
# Test that calling it twice doesn't add multiple handlers if not needed
31+
# (The implementation checks `if not logger.handlers`)
32+
logging.getLogger().handlers = []
33+
34+
configure_logging(level=logging.INFO)
35+
handlers_count = len(logging.getLogger().handlers)
36+
37+
configure_logging(level=logging.INFO)
38+
assert len(logging.getLogger().handlers) == handlers_count

tests/test_models.py

Lines changed: 99 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
from unittest.mock import MagicMock, patch
1+
from unittest.mock import MagicMock, patch, ANY
22

33
import pytest
44
import torch
@@ -37,12 +37,17 @@ def test_device_for_torch():
3737
assert dev.type == "cuda"
3838
assert dev.index == 0
3939

40+
# Test invalid device fallback
41+
with patch("torch.cuda.is_available", return_value=False):
42+
# Even if we ask for cuda, if not available it might raise or fallback depending on implementation
43+
# _device_for_torch calls _parse_device_string which checks availability.
44+
# If _parse_device_string returns cpu, _device_for_torch returns cpu device.
45+
dev = _device_for_torch("cuda")
46+
assert dev.type == "cpu"
47+
4048
def test_load_model_fairchem_logic(mock_hf_token):
4149
config = Config({"optimization": {"model_name": "test_model"}})
4250

43-
# We need to patch fairchem.core inside the function or pre-import it
44-
# Since it is a local import, patching the module where it lives (fairchem.core) works
45-
# if we can import it first.
4651
try:
4752
import fairchem.core as _ # noqa: F401
4853
except ImportError:
@@ -75,3 +80,93 @@ def test_load_model_torchsim_logic(mock_hf_token):
7580

7681
mock_model_cls.assert_called()
7782
assert model is mock_model_cls.return_value
83+
84+
def test_load_model_fairchem_path(mock_hf_token, tmp_path):
85+
model_file = tmp_path / "model.pt"
86+
model_file.touch()
87+
88+
config = Config({"optimization": {"model_path": str(model_file)}})
89+
90+
try:
91+
import fairchem.core as _
92+
except ImportError:
93+
pytest.skip("fairchem.core not installed")
94+
95+
with patch("fairchem.core.pretrained_mlip.load_predict_unit") as mock_load_unit, \
96+
patch("fairchem.core.FAIRChemCalculator") as mock_calc_cls:
97+
98+
mock_load_unit.return_value = MagicMock()
99+
mock_calc_cls.return_value = MagicMock()
100+
101+
calc = load_model_fairchem(config)
102+
103+
mock_load_unit.assert_called_with(path=model_file, device=ANY)
104+
mock_calc_cls.assert_called()
105+
106+
def test_load_model_torchsim_path(mock_hf_token, tmp_path):
107+
model_file = tmp_path / "model.pt"
108+
model_file.touch()
109+
110+
config = Config({"optimization": {"model_path": str(model_file)}})
111+
112+
try:
113+
import torch_sim.models.fairchem as _
114+
except ImportError:
115+
pytest.skip("torch_sim not installed")
116+
117+
with patch("torch_sim.models.fairchem.FairChemModel") as mock_model_cls:
118+
mock_model_cls.return_value = MagicMock()
119+
120+
model = load_model_torchsim(config)
121+
122+
call_kwargs = mock_model_cls.call_args.kwargs
123+
assert call_kwargs["model"] == model_file
124+
125+
def test_model_cache_creation_failure(mock_hf_token, caplog):
126+
# Test fails to create cache dir
127+
config = Config({"optimization": {"model_cache_dir": "/invalid/path/cache", "model_name": "uma"}})
128+
129+
try:
130+
import fairchem.core as _
131+
except ImportError:
132+
pytest.skip("fairchem.core not installed")
133+
134+
with patch("os.makedirs", side_effect=OSError("Permission denied")), \
135+
patch("fairchem.core.pretrained_mlip.get_predict_unit") as mock_get_unit, \
136+
patch("fairchem.core.FAIRChemCalculator"):
137+
138+
load_model_fairchem(config)
139+
140+
# Verify get_predict_unit called without cache_dir
141+
call_kwargs = mock_get_unit.call_args.kwargs
142+
assert "cache_dir" not in call_kwargs
143+
144+
# Check that warning was logged
145+
assert "Could not create model cache directory" in caplog.text
146+
147+
def test_model_path_not_exists(mock_hf_token):
148+
# Path doesn't exist, should fall back to name
149+
config = Config({"optimization": {"model_path": "/non/existent/path", "model_name": "fallback"}})
150+
151+
try:
152+
import fairchem.core as _
153+
except ImportError:
154+
pytest.skip("fairchem.core not installed")
155+
156+
with patch("fairchem.core.pretrained_mlip.get_predict_unit") as mock_get_unit, \
157+
patch("fairchem.core.FAIRChemCalculator"):
158+
159+
load_model_fairchem(config)
160+
161+
# Should call get_predict_unit (name based)
162+
mock_get_unit.assert_called()
163+
args, _ = mock_get_unit.call_args
164+
assert args[0] == "fallback"
165+
166+
def test_missing_model_name(mock_hf_token):
167+
# Config without model_name (and no valid path) should raise ValueError
168+
# Config default has model_name, so we must explicitly set it to empty
169+
config = Config({"optimization": {"model_name": ""}})
170+
171+
with pytest.raises(ValueError, match="Model name must be specified"):
172+
load_model_fairchem(config)

0 commit comments

Comments
 (0)