Skip to content

Commit 7de7ffe

Browse files
committed
test(families): tests for exponential family
1 parent c19925f commit 7de7ffe

File tree

1 file changed

+261
-0
lines changed

1 file changed

+261
-0
lines changed
Lines changed: 261 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,261 @@
1+
"""
2+
Tests for Exponential Distribution Family
3+
4+
This module tests the functionality of the exponential distribution family,
5+
including parameterizations, characteristics, and sampling.
6+
"""
7+
8+
__author__ = "Fedor Myznikov"
9+
__copyright__ = "Copyright (c) 2025 PySATL project"
10+
__license__ = "SPDX-License-Identifier: MIT"
11+
12+
13+
import numpy as np
14+
import pytest
15+
from scipy.stats import expon
16+
17+
from pysatl_core.distributions.support import ContinuousSupport
18+
from pysatl_core.families.configuration import configure_families_register
19+
from pysatl_core.types import (
20+
CharacteristicName,
21+
ContinuousSupportShape1D,
22+
FamilyName,
23+
UnivariateContinuous,
24+
)
25+
26+
from .base import BaseDistributionTest
27+
28+
29+
class TestExponentialFamily(BaseDistributionTest):
30+
"""Test suite for Exponential distribution family."""
31+
32+
def setup_method(self):
33+
"""Setup before each test method."""
34+
registry = configure_families_register()
35+
self.exponential_family = registry.get(FamilyName.EXPONENTIAL)
36+
self.exponential_dist_example = self.exponential_family(lambda_=0.5)
37+
38+
def test_family_properties(self):
39+
"""Test basic properties of exponential family."""
40+
assert self.exponential_family.name == FamilyName.EXPONENTIAL
41+
42+
# Check parameterizations
43+
expected_parametrizations = {"rate", "scale"}
44+
assert set(self.exponential_family.parametrization_names) == expected_parametrizations
45+
assert self.exponential_family.base_parametrization_name == "rate"
46+
47+
def test_rate_parametrization_creation(self):
48+
"""Test creation of distribution with rate parametrization."""
49+
dist = self.exponential_family(lambda_=0.5)
50+
51+
assert dist.family_name == FamilyName.EXPONENTIAL
52+
assert dist.distribution_type == UnivariateContinuous
53+
assert dist.parameters == {"lambda_": 0.5}
54+
assert dist.parametrization_name == "rate"
55+
56+
def test_scale_parametrization_creation(self):
57+
"""Test creation of distribution with scale parametrization."""
58+
dist = self.exponential_family(beta=2.0, parametrization_name="scale")
59+
60+
assert dist.parameters == {"beta": 2.0}
61+
assert dist.parametrization_name == "scale"
62+
63+
def test_parametrization_constraints(self):
64+
"""Test parameter constraints validation."""
65+
# lambda_ must be positive
66+
with pytest.raises(ValueError, match="lambda_ > 0"):
67+
self.exponential_family(lambda_=-1.0)
68+
69+
# beta must be positive
70+
with pytest.raises(ValueError, match="beta > 0"):
71+
self.exponential_family(beta=0.0, parametrization_name="scale")
72+
73+
def test_moments(self):
74+
"""Test moment calculations."""
75+
# Mean
76+
mean_func = self.exponential_dist_example.query_method(CharacteristicName.MEAN)
77+
assert abs(mean_func(None) - 2.0) < self.CALCULATION_PRECISION
78+
79+
# Variance
80+
var_func = self.exponential_dist_example.query_method(CharacteristicName.VAR)
81+
assert abs(var_func(None) - 4.0) < self.CALCULATION_PRECISION
82+
83+
# Skewness
84+
skew_func = self.exponential_dist_example.query_method(CharacteristicName.SKEW)
85+
assert abs(skew_func(None) - 2.0) < self.CALCULATION_PRECISION
86+
87+
def test_kurtosis_calculation(self):
88+
"""Test kurtosis calculation with excess parameter."""
89+
kurt_func = self.exponential_dist_example.query_method(CharacteristicName.KURT)
90+
91+
raw_kurt = kurt_func(None)
92+
assert abs(raw_kurt - 9.0) < self.CALCULATION_PRECISION
93+
94+
excess_kurt = kurt_func(None, excess=True)
95+
assert abs(excess_kurt - 6.0) < self.CALCULATION_PRECISION
96+
97+
raw_kurt_explicit = kurt_func(None, excess=False)
98+
assert abs(raw_kurt_explicit - 9.0) < self.CALCULATION_PRECISION
99+
100+
@pytest.mark.parametrize(
101+
"parametrization_name, params, expected_lambda",
102+
[
103+
("rate", {"lambda_": 0.5}, 0.5),
104+
("scale", {"beta": 2.0}, 0.5), # lambda = 1/beta = 0.5
105+
],
106+
)
107+
def test_parametrization_conversions(self, parametrization_name, params, expected_lambda):
108+
"""Test conversions between different parameterizations."""
109+
base_params = self.exponential_family.to_base(
110+
self.exponential_family.get_parametrization(parametrization_name)(**params)
111+
)
112+
113+
assert abs(base_params.parameters["lambda_"] - expected_lambda) < self.CALCULATION_PRECISION
114+
115+
def test_analytical_computations_availability(self):
116+
"""Test that analytical computations are available for exponential distribution."""
117+
comp = self.exponential_family(lambda_=1.0).analytical_computations
118+
119+
expected_chars = {
120+
CharacteristicName.PDF,
121+
CharacteristicName.CDF,
122+
CharacteristicName.PPF,
123+
CharacteristicName.CF,
124+
CharacteristicName.MEAN,
125+
CharacteristicName.VAR,
126+
CharacteristicName.SKEW,
127+
CharacteristicName.KURT,
128+
}
129+
assert set(comp.keys()) == expected_chars
130+
131+
def test_pdf_array_input(self):
132+
"""Test PDF calculation with array input."""
133+
pdf = self.exponential_dist_example.query_method(CharacteristicName.PDF)
134+
x_array = np.array([-1.0, 0.0, 1.0, 2.0, 3.0, 4.0])
135+
136+
pdf_array = pdf(x_array)
137+
assert pdf_array.shape == x_array.shape
138+
scipy_pdf = expon.pdf(x_array, scale=2.0) # scale = 1/lambda = 2.0
139+
140+
self.assert_arrays_almost_equal(pdf_array, scipy_pdf)
141+
142+
def test_cdf_array_input(self):
143+
"""Test CDF calculation with array input."""
144+
cdf = self.exponential_dist_example.query_method(CharacteristicName.CDF)
145+
x_array = np.array([-1.0, 0.0, 1.0, 2.0, 3.0, 4.0])
146+
147+
cdf_array = cdf(x_array)
148+
assert cdf_array.shape == x_array.shape
149+
scipy_cdf = expon.cdf(x_array, scale=2.0) # scale = 1/lambda = 2.0
150+
151+
self.assert_arrays_almost_equal(cdf_array, scipy_cdf)
152+
153+
def test_ppf_array_input(self):
154+
"""Test PPF calculation with array input."""
155+
ppf = self.exponential_dist_example.query_method(CharacteristicName.PPF)
156+
p_array = np.array([0.001, 0.01, 0.1, 0.25, 0.5, 0.75, 0.9, 0.99, 0.999])
157+
158+
ppf_array = ppf(p_array)
159+
assert ppf_array.shape == p_array.shape
160+
scipy_ppf = expon.ppf(p_array, scale=2.0) # scale = 1/lambda = 2.0
161+
162+
self.assert_arrays_almost_equal(ppf_array, scipy_ppf)
163+
164+
def test_characteristic_function_array_input(self):
165+
"""Test characteristic function calculation with array input."""
166+
char_func = self.exponential_dist_example.query_method(CharacteristicName.CF)
167+
t_array = np.array([-2.0, -1.0, 0.0, 1.0, 2.0])
168+
169+
cf_array = char_func(t_array)
170+
assert cf_array.shape == t_array.shape
171+
172+
lambda_ = 0.5
173+
denominator = lambda_**2 + t_array**2
174+
expected_real = lambda_**2 / denominator
175+
expected_imag = lambda_ * t_array / denominator
176+
177+
expected_real = np.where(np.abs(t_array) < self.CALCULATION_PRECISION, 1.0, expected_real)
178+
expected_imag = np.where(np.abs(t_array) < self.CALCULATION_PRECISION, 0.0, expected_imag)
179+
180+
expected = expected_real + 1j * expected_imag
181+
182+
self.assert_arrays_almost_equal(cf_array.real, expected.real)
183+
self.assert_arrays_almost_equal(cf_array.imag, expected.imag)
184+
185+
def test_exponential_support(self):
186+
"""Test that exponential distribution has correct support [0, ∞)."""
187+
dist = self.exponential_dist_example
188+
189+
assert dist.support is not None
190+
assert isinstance(dist.support, ContinuousSupport)
191+
192+
assert dist.support.left == 0.0
193+
assert dist.support.right == float("inf")
194+
assert dist.support.left_closed
195+
assert not dist.support.right_closed
196+
197+
# Test containment
198+
assert dist.support.contains(0.0) is True
199+
assert dist.support.contains(1.0) is True
200+
assert dist.support.contains(-0.1) is False
201+
assert dist.support.contains(float("inf")) is False
202+
203+
# Test array
204+
test_points = np.array([-0.1, 0.0, 1.0, 10.0])
205+
expected = np.array([False, True, True, True])
206+
results = dist.support.contains(test_points)
207+
np.testing.assert_array_equal(results, expected)
208+
209+
assert dist.support.shape == ContinuousSupportShape1D.RAY_RIGHT
210+
211+
212+
class TestExponentialFamilyEdgeCases(BaseDistributionTest):
213+
"""Test edge cases and error conditions for exponential distribution."""
214+
215+
def setup_method(self):
216+
"""Setup before each test method."""
217+
registry = configure_families_register()
218+
self.exponential_family = registry.get(FamilyName.EXPONENTIAL)
219+
220+
def test_invalid_parameterization(self):
221+
"""Test error for invalid parameterization name."""
222+
with pytest.raises(KeyError):
223+
self.exponential_family.distribution(parametrization_name="invalid_name", lambda_=1.0)
224+
225+
def test_missing_parameters(self):
226+
"""Test error for missing required parameters."""
227+
with pytest.raises(TypeError):
228+
self.exponential_family.distribution() # Missing lambda_
229+
230+
def test_invalid_probability_ppf(self):
231+
"""Test PPF with invalid probability values."""
232+
dist = self.exponential_family(lambda_=1.0)
233+
ppf = dist.query_method(CharacteristicName.PPF)
234+
235+
# Test boundaries
236+
assert ppf(0.0) == 0.0
237+
assert ppf(1.0) == float("inf")
238+
239+
# Test invalid probabilities
240+
with pytest.raises(ValueError):
241+
ppf(-0.1)
242+
with pytest.raises(ValueError):
243+
ppf(1.1)
244+
245+
def test_characteristic_function_at_zero(self):
246+
"""Test characteristic function at zero returns 1."""
247+
dist = self.exponential_family(lambda_=1.0)
248+
char_func = dist.query_method(CharacteristicName.CF)
249+
250+
cf_value_zero = char_func(0.0)
251+
assert abs(cf_value_zero.real - 1.0) < self.CALCULATION_PRECISION
252+
assert abs(cf_value_zero.imag) < self.CALCULATION_PRECISION
253+
254+
def test_characteristic_function_large_t(self):
255+
"""Test characteristic function with large t."""
256+
dist = self.exponential_family(lambda_=1.0)
257+
char_func = dist.query_method(CharacteristicName.CF)
258+
259+
cf_value_large = char_func(1000.0)
260+
assert np.iscomplexobj(cf_value_large)
261+
assert abs(cf_value_large) <= 1.0

0 commit comments

Comments
 (0)