Skip to content

Commit 49573ce

Browse files
committed
add some unit tests
1 parent b72b843 commit 49573ce

2 files changed

Lines changed: 83 additions & 11 deletions

File tree

conftest.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
import pytest
2+
import numpy as np
3+
4+
from SpectrumCore import Spectrum
5+
6+
7+
@pytest.fixture
8+
def simple_data():
9+
data = np.array([
10+
[5000., 0.4],
11+
[5100., 0.6],
12+
[5200., 0.5],
13+
[5300., 0.4],
14+
[5400., 0.3],
15+
])
16+
17+
return data
18+
19+
20+
@pytest.fixture
21+
def simple_data_error():
22+
data = np.array([
23+
[5000., 0.4, 0.004],
24+
[5100., 0.6, 0.006],
25+
[5200., 0.5, 0.005],
26+
[5300., 0.4, 0.004],
27+
[5400., 0.3, 0.003],
28+
])
29+
30+
return data
31+
32+
33+
@pytest.fixture
34+
def spectrum(simple_data):
35+
return Spectrum(simple_data)
36+
37+
38+
@pytest.fixture
39+
def spectrum_error(simple_data_error):
40+
return Spectrum(simple_data_error)

tests/test_Spectrum.py

Lines changed: 43 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,49 @@
11
import numpy as np
2+
import pytest
23

34
from SpectrumCore import Spectrum
45

56

6-
def test_instantiate():
7-
data = np.array([
8-
[5000., 0.4],
9-
[5100., 0.6],
10-
[5200., 0.5],
11-
[5300., 0.4],
12-
[5400., 0.3],
13-
])
7+
def test_initialization(simple_data):
8+
spec = Spectrum(simple_data)
9+
assert spec is not None
1410

15-
spec = Spectrum(data)
16-
# spec.normalize()
17-
print(spec)
11+
12+
def test_empty_initialization():
13+
with pytest.raises(TypeError):
14+
Spectrum()
15+
16+
17+
def test_slicing(simple_data, spectrum):
18+
assert np.all(spectrum[1:3].data == simple_data[1:3])
19+
20+
21+
def test_length(simple_data, spectrum):
22+
assert len(spectrum) == len(simple_data)
23+
24+
25+
def test_properties(simple_data_error, spectrum_error):
26+
assert spectrum_error.wave_start == simple_data_error[0, 0]
27+
assert spectrum_error.wave_end == simple_data_error[-1, 0]
28+
assert np.all(spectrum_error.wave == simple_data_error[:, 0])
29+
assert np.all(spectrum_error.flux == simple_data_error[:, 1])
30+
assert np.all(spectrum_error.error == simple_data_error[:, 2])
31+
32+
33+
def test_normalize_flux(spectrum):
34+
spectrum.normalize_flux()
35+
assert spectrum.flux.max() == 1.
36+
37+
38+
def test_normalize_wave(spectrum):
39+
spectrum.normalize_wave()
40+
assert spectrum.wave_start == 0.
41+
assert spectrum.wave_end == 1.
42+
43+
44+
def test_add_flux(spectrum):
45+
old_flux = spectrum.flux.copy()
46+
added_flux = np.ones(len(spectrum))
47+
spectrum.add_flux(added_flux)
48+
49+
assert np.all(spectrum.flux == old_flux + added_flux)

0 commit comments

Comments
 (0)