Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion dpsynth/domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@
import math
import pathlib

from typing import Any, TypeAlias
from typing import Any, Literal, TypeAlias

import attr
import pandas as pd
Expand All @@ -57,6 +57,8 @@

CategoricalValue: TypeAlias = None | bool | int | str | pd.Interval

IntervalHandling = Literal['midpoint', 'sample', 'interval']


@attr.define(frozen=True)
class CategoricalAttribute:
Expand Down Expand Up @@ -140,13 +142,19 @@ class NumericalAttribute:
False, out-of-domain values will be grouped together and treated as a
single special out-of-domain value.
dtype: The dtype of the data (either 'int' or 'float').
interval_handling: Controls how discretized intervals are converted back to
numerical values. 'midpoint' returns the interval midpoint (or the finite
endpoint if the other is infinite). 'sample' draws uniformly from the
interval (or returns the finite endpoint if the other is infinite).
'interval' keeps the pd.Interval in the output unchanged.
description: An optional semantic description of the attribute.
"""

min_value: float = attr.field(converter=float)
max_value: float = attr.field(converter=float)
clip_to_range: bool = attr.field(default=True)
dtype: str = attr.field(default='float')
interval_handling: str = attr.field(default='midpoint')
description: str | None = attr.field(default=None)

@min_value.validator # pytype: disable=attribute-error
Expand All @@ -164,6 +172,14 @@ def _validate_dtype(self, *_):
f'dtype must be either "int" or "float", got {self.dtype}.'
)

@interval_handling.validator # pytype: disable=attribute-error
def _validate_interval_handling(self, *_):
if self.interval_handling not in ['midpoint', 'sample', 'interval']:
raise ValueError(
'interval_handling must be "midpoint", "sample", or "interval",'
f' got {self.interval_handling}.'
)

@property
def exclusive_min_value(self) -> float:
"""Returns the exclusive minimum value for this attribute."""
Expand Down
30 changes: 27 additions & 3 deletions dpsynth/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,12 +177,36 @@ def transform(value: Any) -> pd.Interval | None:
return None
return intervals[intervals.get_loc(value)]

def reverse(value: pd.Interval | None) -> float | None:
def _resolve_finite(interval: pd.Interval) -> float:
"""Returns the midpoint, handling infinite endpoints."""
left_finite = math.isfinite(interval.left)
right_finite = math.isfinite(interval.right)
if left_finite and right_finite:
return interval.mid
elif left_finite:
return interval.left
else:
return interval.right

def reverse(value: pd.Interval | None) -> float | pd.Interval | None:
if value is None:
return None
if attribute_domain.interval_handling == 'interval':
return value
if attribute_domain.interval_handling == 'sample':
left_finite = math.isfinite(value.left)
right_finite = math.isfinite(value.right)
if left_finite and right_finite:
result = np.random.uniform(value.left, value.right)
elif left_finite:
result = value.left
else:
result = value.right
else:
result = _resolve_finite(value)
if attribute_domain.dtype == 'int':
return math.ceil(value.mid)
return value.mid
return math.ceil(result)
return result

new_domain = domain.CategoricalAttribute(possible_values)
transformation = DiscretizeTransformation(transform, reverse)
Expand Down
15 changes: 15 additions & 0 deletions tests/domain_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,21 @@ def test_to_from_yaml_roundtrip(self):
loaded_domain = domain.from_yaml_file(temp_file.full_path)
self.assertEqual(loaded_domain, original_domain)

def test_interval_handling_yaml_roundtrip(self):
original_domain = {
'num': domain.NumericalAttribute(
min_value=0, max_value=10, interval_handling='sample'
),
}
temp_file = self.create_tempfile('temp.yaml', mode='w+')
domain.to_yaml_file(original_domain, temp_file.full_path)
loaded_domain = domain.from_yaml_file(temp_file.full_path)
self.assertEqual(loaded_domain, original_domain)

def test_invalid_interval_handling(self):
with self.assertRaises(ValueError):
domain.NumericalAttribute(0, 10, interval_handling='bad')

def test_standardize_categorical(self):
attribute = domain.CategoricalAttribute(
possible_values=['a', 'b', 'c'], out_of_domain_index=1
Expand Down
46 changes: 46 additions & 0 deletions tests/transformations_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,52 @@ def test_valid_discretization_for_int_attribute(self):
self.assertBetween(transform_fn.inverse(interval1), 0, 5)
self.assertBetween(transform_fn.inverse(interval2), 5, 10)

def test_discretize_interval_handling_sample(self):
attr = domain.NumericalAttribute(
min_value=0, max_value=100, interval_handling='sample'
)
_, transform_fn = transformations.create_discretize_transformation(
attr, [50]
)
interval = pd.Interval(50, 100)
values = set()
for _ in range(50):
value = transform_fn.inverse(interval)
self.assertBetween(value, 50, 100)
values.add(value)
# Sample mode should produce non-constant output (unlike midpoint).
self.assertGreater(len(values), 1)
self.assertIsNone(transform_fn.inverse(None))

def test_discretize_interval_handling_interval(self):
attr = domain.NumericalAttribute(
min_value=0, max_value=10, interval_handling='interval'
)
_, transform_fn = transformations.create_discretize_transformation(
attr, [5]
)
interval = pd.Interval(5, 10)
self.assertEqual(transform_fn.inverse(interval), interval)
self.assertIsNone(transform_fn.inverse(None))

def test_discretize_reverse_semi_infinite_intervals(self):
# Midpoint mode: semi-infinite intervals should return the finite endpoint.
attr = domain.NumericalAttribute(min_value=0, max_value=10)
_, transform_fn = transformations.create_discretize_transformation(
attr, [5]
)
self.assertEqual(transform_fn.inverse(pd.Interval(5, np.inf)), 5)
self.assertEqual(transform_fn.inverse(pd.Interval(-np.inf, 5)), 5)
# Sample mode: semi-infinite intervals should also return the finite end.
attr_sample = domain.NumericalAttribute(
min_value=0, max_value=10, interval_handling='sample'
)
_, transform_fn_sample = transformations.create_discretize_transformation(
attr_sample, [5]
)
self.assertEqual(transform_fn_sample.inverse(pd.Interval(5, np.inf)), 5)
self.assertEqual(transform_fn_sample.inverse(pd.Interval(-np.inf, 5)), 5)

def test_rare_value_merging_some_rare_values(self):
rare_mask = np.array([True, False, True, False])
size, transform_fn = (
Expand Down
Loading