From 22aa7e41b3a06a3b10bfb73f9c8f47ca07aaa36a Mon Sep 17 00:00:00 2001 From: Ryan McKenna Date: Fri, 12 Jun 2026 11:26:15 -0700 Subject: [PATCH] Add interval_handling field to NumericalAttribute Adds a configurable `interval_handling` field to `NumericalAttribute` that controls how discretized intervals are converted back to numerical values in the reverse transformation. Three modes are supported: 'midpoint' (default, preserving existing behavior), 'sample' (uniform draw from the interval), and 'interval' (keep the pd.Interval as-is). PiperOrigin-RevId: 931246705 --- dpsynth/domain.py | 18 +++++++++++++- dpsynth/transformations.py | 30 ++++++++++++++++++++--- tests/domain_test.py | 15 ++++++++++++ tests/transformations_test.py | 46 +++++++++++++++++++++++++++++++++++ 4 files changed, 105 insertions(+), 4 deletions(-) diff --git a/dpsynth/domain.py b/dpsynth/domain.py index 14bef48..67458df 100644 --- a/dpsynth/domain.py +++ b/dpsynth/domain.py @@ -47,7 +47,7 @@ import math import pathlib -from typing import Any, TypeAlias +from typing import Any, Literal, TypeAlias import attr import pandas as pd @@ -57,6 +57,8 @@ CategoricalValue: TypeAlias = None | bool | int | str | pd.Interval +IntervalHandling = Literal['midpoint', 'sample', 'interval'] + @attr.define(frozen=True) class CategoricalAttribute: @@ -140,6 +142,11 @@ class NumericalAttribute: False, out-of-domain values will be grouped together and treated as a single special out-of-domain value. dtype: The dtype of the data (either 'int' or 'float'). + interval_handling: Controls how discretized intervals are converted back to + numerical values. 'midpoint' returns the interval midpoint (or the finite + endpoint if the other is infinite). 'sample' draws uniformly from the + interval (or returns the finite endpoint if the other is infinite). + 'interval' keeps the pd.Interval in the output unchanged. description: An optional semantic description of the attribute. """ @@ -147,6 +154,7 @@ class NumericalAttribute: max_value: float = attr.field(converter=float) clip_to_range: bool = attr.field(default=True) dtype: str = attr.field(default='float') + interval_handling: str = attr.field(default='midpoint') description: str | None = attr.field(default=None) @min_value.validator # pytype: disable=attribute-error @@ -164,6 +172,14 @@ def _validate_dtype(self, *_): f'dtype must be either "int" or "float", got {self.dtype}.' ) + @interval_handling.validator # pytype: disable=attribute-error + def _validate_interval_handling(self, *_): + if self.interval_handling not in ['midpoint', 'sample', 'interval']: + raise ValueError( + 'interval_handling must be "midpoint", "sample", or "interval",' + f' got {self.interval_handling}.' + ) + @property def exclusive_min_value(self) -> float: """Returns the exclusive minimum value for this attribute.""" diff --git a/dpsynth/transformations.py b/dpsynth/transformations.py index 5844955..3f53cab 100644 --- a/dpsynth/transformations.py +++ b/dpsynth/transformations.py @@ -177,12 +177,36 @@ def transform(value: Any) -> pd.Interval | None: return None return intervals[intervals.get_loc(value)] - def reverse(value: pd.Interval | None) -> float | None: + def _resolve_finite(interval: pd.Interval) -> float: + """Returns the midpoint, handling infinite endpoints.""" + left_finite = math.isfinite(interval.left) + right_finite = math.isfinite(interval.right) + if left_finite and right_finite: + return interval.mid + elif left_finite: + return interval.left + else: + return interval.right + + def reverse(value: pd.Interval | None) -> float | pd.Interval | None: if value is None: return None + if attribute_domain.interval_handling == 'interval': + return value + if attribute_domain.interval_handling == 'sample': + left_finite = math.isfinite(value.left) + right_finite = math.isfinite(value.right) + if left_finite and right_finite: + result = np.random.uniform(value.left, value.right) + elif left_finite: + result = value.left + else: + result = value.right + else: + result = _resolve_finite(value) if attribute_domain.dtype == 'int': - return math.ceil(value.mid) - return value.mid + return math.ceil(result) + return result new_domain = domain.CategoricalAttribute(possible_values) transformation = DiscretizeTransformation(transform, reverse) diff --git a/tests/domain_test.py b/tests/domain_test.py index 892780e..5cae77a 100644 --- a/tests/domain_test.py +++ b/tests/domain_test.py @@ -58,6 +58,21 @@ def test_to_from_yaml_roundtrip(self): loaded_domain = domain.from_yaml_file(temp_file.full_path) self.assertEqual(loaded_domain, original_domain) + def test_interval_handling_yaml_roundtrip(self): + original_domain = { + 'num': domain.NumericalAttribute( + min_value=0, max_value=10, interval_handling='sample' + ), + } + temp_file = self.create_tempfile('temp.yaml', mode='w+') + domain.to_yaml_file(original_domain, temp_file.full_path) + loaded_domain = domain.from_yaml_file(temp_file.full_path) + self.assertEqual(loaded_domain, original_domain) + + def test_invalid_interval_handling(self): + with self.assertRaises(ValueError): + domain.NumericalAttribute(0, 10, interval_handling='bad') + def test_standardize_categorical(self): attribute = domain.CategoricalAttribute( possible_values=['a', 'b', 'c'], out_of_domain_index=1 diff --git a/tests/transformations_test.py b/tests/transformations_test.py index 9e09a5b..81cc434 100644 --- a/tests/transformations_test.py +++ b/tests/transformations_test.py @@ -165,6 +165,52 @@ def test_valid_discretization_for_int_attribute(self): self.assertBetween(transform_fn.inverse(interval1), 0, 5) self.assertBetween(transform_fn.inverse(interval2), 5, 10) + def test_discretize_interval_handling_sample(self): + attr = domain.NumericalAttribute( + min_value=0, max_value=100, interval_handling='sample' + ) + _, transform_fn = transformations.create_discretize_transformation( + attr, [50] + ) + interval = pd.Interval(50, 100) + values = set() + for _ in range(50): + value = transform_fn.inverse(interval) + self.assertBetween(value, 50, 100) + values.add(value) + # Sample mode should produce non-constant output (unlike midpoint). + self.assertGreater(len(values), 1) + self.assertIsNone(transform_fn.inverse(None)) + + def test_discretize_interval_handling_interval(self): + attr = domain.NumericalAttribute( + min_value=0, max_value=10, interval_handling='interval' + ) + _, transform_fn = transformations.create_discretize_transformation( + attr, [5] + ) + interval = pd.Interval(5, 10) + self.assertEqual(transform_fn.inverse(interval), interval) + self.assertIsNone(transform_fn.inverse(None)) + + def test_discretize_reverse_semi_infinite_intervals(self): + # Midpoint mode: semi-infinite intervals should return the finite endpoint. + attr = domain.NumericalAttribute(min_value=0, max_value=10) + _, transform_fn = transformations.create_discretize_transformation( + attr, [5] + ) + self.assertEqual(transform_fn.inverse(pd.Interval(5, np.inf)), 5) + self.assertEqual(transform_fn.inverse(pd.Interval(-np.inf, 5)), 5) + # Sample mode: semi-infinite intervals should also return the finite end. + attr_sample = domain.NumericalAttribute( + min_value=0, max_value=10, interval_handling='sample' + ) + _, transform_fn_sample = transformations.create_discretize_transformation( + attr_sample, [5] + ) + self.assertEqual(transform_fn_sample.inverse(pd.Interval(5, np.inf)), 5) + self.assertEqual(transform_fn_sample.inverse(pd.Interval(-np.inf, 5)), 5) + def test_rare_value_merging_some_rare_values(self): rare_mask = np.array([True, False, True, False]) size, transform_fn = (