From 4cf23381ae4abaf5533bf806cbb8a6b0aa831971 Mon Sep 17 00:00:00 2001 From: Ryan McKenna Date: Fri, 12 Jun 2026 13:29:38 -0700 Subject: [PATCH] Add vectorized transformations for local_mode Adds `vectorized_transformations.py` to `dpsynth/local_mode/`, a vectorized fork of `dpsynth/transformations.py` that replaces per-element Python loops with bulk numpy operations for discrete encoding, discretization, and rare-value merging. PiperOrigin-RevId: 931305064 --- .../local_mode/vectorized_transformations.py | 325 ++++++++++++++++++ .../vectorized_transformations_test.py | 310 +++++++++++++++++ 2 files changed, 635 insertions(+) create mode 100644 dpsynth/local_mode/vectorized_transformations.py create mode 100644 tests/local_mode/vectorized_transformations_test.py diff --git a/dpsynth/local_mode/vectorized_transformations.py b/dpsynth/local_mode/vectorized_transformations.py new file mode 100644 index 0000000..ee8a9b5 --- /dev/null +++ b/dpsynth/local_mode/vectorized_transformations.py @@ -0,0 +1,325 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Vectorized transformations for the local_mode package. + +This is a vectorized fork of ``dpsynth/transformations.py``, optimized for +single-machine (numpy-based) environments. Functions operate on 1-D numpy +arrays rather than scalar values, yielding significant speedups by replacing +per-element Python loops with bulk numpy operations. + +Covers: + * Discrete encoding / decoding (categorical <-> integer index). + * Discretization / undiscretization (numerical <-> bin index). + * Rare-value merging / unmerging (domain compression). +""" + +from __future__ import annotations + +from collections.abc import Sequence + +from dpsynth import domain +import numpy as np + +# --------------------------------------------------------------------------- +# Discrete encoding / decoding +# --------------------------------------------------------------------------- + + +def discrete_encode( + data: np.typing.ArrayLike, + attribute_domain: domain.CategoricalAttribute, +) -> np.ndarray: + """Maps categorical values to integer indices in ``attribute_domain``. + + Out-of-domain values are mapped to ``attribute_domain.out_of_domain_index``. + + Args: + data: 1-D array of categorical values (any type that can appear in + ``attribute_domain.possible_values``). + attribute_domain: The categorical attribute defining the encoding. + + Returns: + A 1-D integer array of indices into ``attribute_domain.possible_values``. + + Raises: + ValueError: If *data* is not 1-D. + """ + data = np.asarray(data) + if data.ndim != 1: + raise ValueError(f'data must be 1-D, got shape {data.shape}.') + lookup = {v: i for i, v in enumerate(attribute_domain.possible_values)} + default = attribute_domain.out_of_domain_index + encoder = np.vectorize(lambda v: lookup.get(v, default), otypes=[int]) + return encoder(data) + + +def discrete_decode( + encoded: np.typing.ArrayLike, + attribute_domain: domain.CategoricalAttribute, +) -> np.ndarray: + """Maps integer indices back to categorical values. + + Args: + encoded: 1-D integer array of indices into + ``attribute_domain.possible_values``. + attribute_domain: The categorical attribute defining the decoding. + + Returns: + A 1-D object-dtype array of categorical values. + + Raises: + ValueError: If *encoded* is not 1-D. + """ + encoded = np.asarray(encoded) + if encoded.ndim != 1: + raise ValueError(f'encoded must be 1-D, got shape {encoded.shape}.') + values = np.array(attribute_domain.possible_values, dtype=object) + return values[encoded] + + +# --------------------------------------------------------------------------- +# Discretization / undiscretization +# --------------------------------------------------------------------------- + + +def _validate_bin_edges(bin_edges, attribute_domain): + """Validates bin_edges against the attribute domain.""" + if bin_edges.size == 0: + raise ValueError(f'bin_edges must not be empty, got {bin_edges}.') + if ( + bin_edges[0] < attribute_domain.min_value + or bin_edges[-1] >= attribute_domain.max_value + ): + raise ValueError( + 'bin_edges must be within the range' + f' [{attribute_domain.min_value}, {attribute_domain.max_value}),' + f' got {list(bin_edges)}.' + ) + if np.any(np.diff(bin_edges) <= 0): + raise ValueError( + f'bin_edges must be monotonically increasing, got {list(bin_edges)}.' + ) + + +def discretize( + data: np.typing.ArrayLike, + bin_edges: Sequence[int | float], + attribute_domain: domain.NumericalAttribute, +) -> np.ndarray: + """Maps numerical values to bin indices via ``np.searchsorted``. + + Mirrors the semantics of ``transformations.create_discretize_transformation`` + but operates on entire arrays at once. Bin intervals are right-closed: + ``(left, right]``, matching the ``pd.IntervalIndex`` convention used in the + scalar implementation. + + Args: + data: 1-D array of numerical values. + bin_edges: Sorted inner bin edges (same convention as + ``create_discretize_transformation``). Must be monotonically increasing + and within ``[min_value, max_value)``. + attribute_domain: The ``NumericalAttribute`` describing the data domain. + + Returns: + A 1-D integer array of 0-based bin indices. When + ``attribute_domain.clip_to_range`` is ``False``, index 0 represents the + out-of-domain (``None``) bin and in-domain bins start at 1. + + Raises: + ValueError: If *data* is not 1-D or *bin_edges* are invalid. + """ + data = np.asarray(data, dtype=float) + if data.ndim != 1: + raise ValueError(f'data must be 1-D, got shape {data.shape}.') + bin_edges = np.asarray(bin_edges, dtype=float) + _validate_bin_edges(bin_edges, attribute_domain) + + if attribute_domain.clip_to_range: + standardized = np.clip( + data, attribute_domain.min_value, attribute_domain.max_value + ) + # NaN values (from non-numeric inputs) become min_value after clip, but + # np.clip does not handle NaN, so fix them explicitly. + standardized = np.where( + np.isnan(standardized), attribute_domain.min_value, standardized + ) + # side='left' gives right-closed intervals: value == edge -> left bin. + return np.searchsorted(bin_edges, standardized, side='left') + else: + standardized = data.copy() + ood_mask = ( + np.isnan(standardized) + | (standardized < attribute_domain.min_value) + | (standardized > attribute_domain.max_value) + ) + standardized[ood_mask] = attribute_domain.min_value + indices = np.searchsorted(bin_edges, standardized, side='left') + # Shift in-domain indices by 1 to leave room for the None bin at 0. + indices += 1 + indices[ood_mask] = 0 + return indices + + +def undiscretize( + bin_indices: np.typing.ArrayLike, + bin_edges: Sequence[int | float], + attribute_domain: domain.NumericalAttribute, +) -> np.ndarray: + """Maps bin indices back to bin midpoints. + + This is the inverse of :func:`discretize`. + + Args: + bin_indices: 1-D integer array of bin indices (as produced by + :func:`discretize`). + bin_edges: The same sorted inner bin edges used during discretization. + attribute_domain: The ``NumericalAttribute`` describing the data domain. + + Returns: + A 1-D float array of midpoints. When + ``attribute_domain.clip_to_range`` is ``False``, index 0 maps to ``NaN`` + (representing ``None`` / out-of-domain). If ``attribute_domain.dtype`` + is ``'int'``, midpoints are rounded up via ``np.ceil`` and cast to int. + + Raises: + ValueError: If *bin_indices* is not 1-D or *bin_edges* are invalid. + """ + bin_indices = np.asarray(bin_indices, dtype=int) + if bin_indices.ndim != 1: + raise ValueError(f'bin_indices must be 1-D, got shape {bin_indices.shape}.') + bin_edges = np.asarray(bin_edges, dtype=float) + _validate_bin_edges(bin_edges, attribute_domain) + + full_edges = np.r_[ + attribute_domain.exclusive_min_value, + bin_edges, + attribute_domain.max_value, + ] + midpoints = (full_edges[:-1] + full_edges[1:]) / 2.0 + + if attribute_domain.clip_to_range: + result = midpoints[bin_indices] + else: + # Index 0 -> NaN (None bin); in-domain indices are shifted by 1. + result = np.full(bin_indices.shape, np.nan) + in_domain = bin_indices > 0 + result[in_domain] = midpoints[bin_indices[in_domain] - 1] + + if attribute_domain.dtype == 'int': + valid = ~np.isnan(result) + result[valid] = np.ceil(result[valid]) + if np.all(valid): + result = result.astype(int) + return result + + +# --------------------------------------------------------------------------- +# Rare-value merging / unmerging +# --------------------------------------------------------------------------- + + +def merge_rare_values( + data: np.typing.ArrayLike, + rare_value_mask: np.typing.ArrayLike, +) -> tuple[int, np.ndarray]: + """Maps integer-encoded data to a compressed domain, merging rare values. + + Non-rare values are renumbered contiguously starting from 0; all rare values + are mapped to the last index in the compressed domain. + + Args: + data: 1-D integer array in the original domain. + rare_value_mask: 1-D boolean array indicating which original-domain values + are rare (``True`` means rare). + + Returns: + A tuple ``(compressed_size, compressed_data)`` where *compressed_size* is + the number of bins in the compressed domain and *compressed_data* is a 1-D + integer array of the same length as *data*. + + Raises: + ValueError: If inputs have incorrect shapes or dtypes. + """ + data = np.asarray(data, dtype=int) + rare_value_mask = np.asarray(rare_value_mask, dtype=bool) + if data.ndim != 1: + raise ValueError(f'data must be 1-D, got shape {data.shape}.') + if rare_value_mask.ndim != 1: + raise ValueError( + f'rare_value_mask must be 1-D, got shape {rare_value_mask.shape}.' + ) + + num_rare = int(rare_value_mask.sum()) + num_common = rare_value_mask.size - num_rare + compressed_size = num_common + (1 if num_rare >= 1 else 0) + + # Common values get contiguous indices; rare values all map to the last bin. + mapping = np.empty(rare_value_mask.size, dtype=int) + mapping[rare_value_mask] = compressed_size - 1 + mapping[~rare_value_mask] = np.arange(num_common) + + return compressed_size, mapping[data] + + +def unmerge_rare_values( + data: np.typing.ArrayLike, + rare_value_mask: np.typing.ArrayLike, + rng: np.random.Generator, +) -> np.ndarray: + """Maps compressed-domain integers back, randomly restoring rare values. + + This is the inverse of :func:`merge_rare_values`. For the merged-rare bin, + each element is randomly assigned to one of the original rare values. + + Args: + data: 1-D integer array in the compressed domain. + rare_value_mask: 1-D boolean array (same as used in + :func:`merge_rare_values`). + rng: A numpy random number generator used for sampling rare values. + + Returns: + A 1-D integer array in the original domain. + + Raises: + ValueError: If inputs have incorrect shapes or dtypes. + """ + data = np.asarray(data, dtype=int) + rare_value_mask = np.asarray(rare_value_mask, dtype=bool) + if data.ndim != 1: + raise ValueError(f'data must be 1-D, got shape {data.shape}.') + if rare_value_mask.ndim != 1: + raise ValueError( + f'rare_value_mask must be 1-D, got shape {rare_value_mask.shape}.' + ) + + num_rare = int(rare_value_mask.sum()) + num_common = rare_value_mask.size - num_rare + compressed_size = num_common + (1 if num_rare >= 1 else 0) + rare_bin = compressed_size - 1 + + common_indices = np.where(~rare_value_mask)[0] + inv_mapping = np.empty(compressed_size, dtype=int) + inv_mapping[:num_common] = common_indices + if num_rare >= 1: + inv_mapping[rare_bin] = -1 # Placeholder; overwritten below. + + result = inv_mapping[data] + + if num_rare >= 1: + rare_indices = np.where(rare_value_mask)[0] + rare_mask = data == rare_bin + result[rare_mask] = rng.choice(rare_indices, size=int(rare_mask.sum())) + + return result diff --git a/tests/local_mode/vectorized_transformations_test.py b/tests/local_mode/vectorized_transformations_test.py new file mode 100644 index 0000000..498f3dc --- /dev/null +++ b/tests/local_mode/vectorized_transformations_test.py @@ -0,0 +1,310 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from absl.testing import absltest +from absl.testing import parameterized +from dpsynth import domain +from dpsynth.local_mode import vectorized_transformations +import numpy as np + + +class DiscreteEncodeTest(absltest.TestCase): + + def test_basic_encoding(self): + attr = domain.CategoricalAttribute([None, 'a', 'b', 'c']) + data = np.array([None, 'a', 'b', 'c', 'a'], dtype=object) + result = vectorized_transformations.discrete_encode(data, attr) + np.testing.assert_array_equal(result, [0, 1, 2, 3, 1]) + + def test_out_of_domain_maps_to_default(self): + attr = domain.CategoricalAttribute([None, 'a', 'b'], out_of_domain_index=0) + data = np.array(['a', 'x', 'b', 'z'], dtype=object) + result = vectorized_transformations.discrete_encode(data, attr) + np.testing.assert_array_equal(result, [1, 0, 2, 0]) + + def test_integer_possible_values(self): + attr = domain.CategoricalAttribute([10, 20, 30]) + data = np.array([20, 30, 10, 10]) + result = vectorized_transformations.discrete_encode(data, attr) + np.testing.assert_array_equal(result, [1, 2, 0, 0]) + + def test_empty_data(self): + attr = domain.CategoricalAttribute(['a', 'b']) + data = np.array([], dtype=object) + result = vectorized_transformations.discrete_encode(data, attr) + self.assertEqual(result.shape, (0,)) + + def test_2d_raises(self): + attr = domain.CategoricalAttribute(['a', 'b']) + with self.assertRaises(ValueError): + vectorized_transformations.discrete_encode(np.array([['a']]), attr) + + +class DiscreteDecodeTest(absltest.TestCase): + + def test_basic_decoding(self): + attr = domain.CategoricalAttribute([None, 'a', 'b', 'c']) + encoded = np.array([0, 1, 2, 3, 1]) + result = vectorized_transformations.discrete_decode(encoded, attr) + expected = np.array([None, 'a', 'b', 'c', 'a'], dtype=object) + np.testing.assert_array_equal(result, expected) + + def test_roundtrip(self): + attr = domain.CategoricalAttribute(['x', 'y', 'z']) + data = np.array(['x', 'y', 'z', 'y', 'x'], dtype=object) + encoded = vectorized_transformations.discrete_encode(data, attr) + decoded = vectorized_transformations.discrete_decode(encoded, attr) + np.testing.assert_array_equal(decoded, data) + + def test_empty(self): + attr = domain.CategoricalAttribute(['a']) + result = vectorized_transformations.discrete_decode( + np.array([], dtype=int), attr + ) + self.assertEqual(result.shape, (0,)) + + def test_2d_raises(self): + attr = domain.CategoricalAttribute(['a', 'b']) + with self.assertRaises(ValueError): + vectorized_transformations.discrete_decode(np.array([[0]]), attr) + + +class DiscretizeTest(parameterized.TestCase): + + def test_clip_to_range_basic(self): + attr = domain.NumericalAttribute( + min_value=0, max_value=10, clip_to_range=True + ) + data = np.array([1.0, 5.0, 5.00001, 8.0, -1.0, 11.0]) + result = vectorized_transformations.discretize(data, [5], attr) + # (exclusive_min, 5] -> bin 0, (5, 10] -> bin 1 + np.testing.assert_array_equal(result, [0, 0, 1, 1, 0, 1]) + + def test_no_clip_to_range_ood(self): + attr = domain.NumericalAttribute( + min_value=0, max_value=10, clip_to_range=False + ) + data = np.array([5.0, 8.0, -1.0, 11.0, np.nan]) + result = vectorized_transformations.discretize(data, [5], attr) + # OOD -> 0, (exclusive_min, 5] -> 1, (5, 10] -> 2 + np.testing.assert_array_equal(result, [1, 2, 0, 0, 0]) + + def test_multiple_bins(self): + attr = domain.NumericalAttribute( + min_value=0, max_value=100, clip_to_range=True + ) + data = np.array([10.0, 25.0, 50.0, 75.0, 90.0]) + result = vectorized_transformations.discretize(data, [25, 50, 75], attr) + np.testing.assert_array_equal(result, [0, 0, 1, 2, 3]) + + def test_integer_attribute(self): + attr = domain.NumericalAttribute( + min_value=0, max_value=10, dtype='int', clip_to_range=True + ) + data = np.array([0.0, 5.0, 6.0, 10.0]) + result = vectorized_transformations.discretize(data, [5], attr) + # For int dtype with bin_edge 5: (-1, 5] -> 0, (5, 10] -> 1 + np.testing.assert_array_equal(result, [0, 0, 1, 1]) + + def test_empty_data(self): + attr = domain.NumericalAttribute( + min_value=0, max_value=10, clip_to_range=True + ) + data = np.array([], dtype=float) + result = vectorized_transformations.discretize(data, [5], attr) + self.assertEqual(result.shape, (0,)) + + def test_invalid_bin_edges_raises(self): + attr = domain.NumericalAttribute(min_value=0, max_value=10) + with self.assertRaises(ValueError): + vectorized_transformations.discretize(np.array([1.0]), [], attr) + with self.assertRaises(ValueError): + vectorized_transformations.discretize(np.array([1.0]), [-1, 5], attr) + with self.assertRaises(ValueError): + vectorized_transformations.discretize(np.array([1.0]), [5, 3], attr) + + def test_2d_data_raises(self): + attr = domain.NumericalAttribute(min_value=0, max_value=10) + with self.assertRaises(ValueError): + vectorized_transformations.discretize(np.array([[1.0]]), [5], attr) + + +class UndiscretizeTest(parameterized.TestCase): + + def test_clip_to_range_midpoints(self): + attr = domain.NumericalAttribute( + min_value=0, max_value=10, clip_to_range=True + ) + bin_indices = np.array([0, 1]) + result = vectorized_transformations.undiscretize(bin_indices, [5], attr) + # Full edges: [exclusive_min, 5, 10]. Midpoints of intervals. + self.assertEqual(result.shape, (2,)) + self.assertBetween(result[0], 0, 5) + self.assertBetween(result[1], 5, 10) + + def test_no_clip_to_range_ood_nan(self): + attr = domain.NumericalAttribute( + min_value=0, max_value=10, clip_to_range=False + ) + bin_indices = np.array([0, 1, 2]) + result = vectorized_transformations.undiscretize(bin_indices, [5], attr) + self.assertTrue(np.isnan(result[0])) + self.assertBetween(result[1], 0, 5) + self.assertBetween(result[2], 5, 10) + + def test_integer_dtype_ceils(self): + attr = domain.NumericalAttribute( + min_value=1, max_value=5, dtype='int', clip_to_range=True + ) + bin_indices = np.array([0, 1]) + result = vectorized_transformations.undiscretize(bin_indices, [3], attr) + # All results should be integers (ceiled). + for v in result: + self.assertEqual(v, int(v)) + + def test_roundtrip_clip(self): + attr = domain.NumericalAttribute( + min_value=0, max_value=100, clip_to_range=True + ) + edges = [25, 50, 75] + data = np.array([10.0, 40.0, 60.0, 90.0]) + indices = vectorized_transformations.discretize(data, edges, attr) + midpoints = vectorized_transformations.undiscretize(indices, edges, attr) + # Each midpoint should be within the correct bin. + full_edges = np.r_[attr.exclusive_min_value, edges, attr.max_value] + for i, idx in enumerate(indices): + self.assertBetween(midpoints[i], full_edges[idx], full_edges[idx + 1]) + + def test_empty(self): + attr = domain.NumericalAttribute( + min_value=0, max_value=10, clip_to_range=True + ) + result = vectorized_transformations.undiscretize( + np.array([], dtype=int), [5], attr + ) + self.assertEqual(result.shape, (0,)) + + def test_2d_raises(self): + attr = domain.NumericalAttribute(min_value=0, max_value=10) + with self.assertRaises(ValueError): + vectorized_transformations.undiscretize(np.array([[0]]), [5], attr) + + +class MergeRareValuesTest(absltest.TestCase): + + def test_some_rare_values(self): + rare_mask = np.array([True, False, True, False]) + data = np.array([0, 1, 2, 3]) + size, compressed = vectorized_transformations.merge_rare_values( + data, rare_mask + ) + self.assertEqual(size, 3) + np.testing.assert_array_equal(compressed, [2, 0, 2, 1]) + + def test_no_rare_values(self): + rare_mask = np.array([False, False, False, False]) + data = np.array([0, 1, 2, 3]) + size, compressed = vectorized_transformations.merge_rare_values( + data, rare_mask + ) + self.assertEqual(size, 4) + np.testing.assert_array_equal(compressed, [0, 1, 2, 3]) + + def test_all_rare_values(self): + rare_mask = np.array([True, True, True, True]) + data = np.array([0, 1, 2, 3]) + size, compressed = vectorized_transformations.merge_rare_values( + data, rare_mask + ) + self.assertEqual(size, 1) + np.testing.assert_array_equal(compressed, [0, 0, 0, 0]) + + def test_empty_data(self): + rare_mask = np.array([True, False]) + data = np.array([], dtype=int) + size, compressed = vectorized_transformations.merge_rare_values( + data, rare_mask + ) + self.assertEqual(size, 2) + self.assertEqual(compressed.shape, (0,)) + + def test_2d_data_raises(self): + with self.assertRaises(ValueError): + vectorized_transformations.merge_rare_values( + np.array([[0]]), np.array([False]) + ) + + def test_2d_mask_raises(self): + with self.assertRaises(ValueError): + vectorized_transformations.merge_rare_values( + np.array([0]), np.array([[False]]) + ) + + +class UnmergeRareValuesTest(absltest.TestCase): + + def setUp(self): + super().setUp() + self.rng = np.random.default_rng(42) + + def test_roundtrip_some_rare(self): + rare_mask = np.array([True, False, True, False]) + data = np.array([0, 1, 2, 3, 1, 0]) + _, compressed = vectorized_transformations.merge_rare_values( + data, rare_mask + ) + unmerged = vectorized_transformations.unmerge_rare_values( + compressed, rare_mask, self.rng + ) + # Common values should round-trip exactly. + for i, val in enumerate(data): + if not rare_mask[val]: + self.assertEqual(unmerged[i], val) + else: + # Rare values should map to one of the original rare indices. + self.assertIn(unmerged[i], [0, 2]) + + def test_roundtrip_no_rare(self): + rare_mask = np.array([False, False, False]) + data = np.array([0, 1, 2, 1]) + _, compressed = vectorized_transformations.merge_rare_values( + data, rare_mask + ) + unmerged = vectorized_transformations.unmerge_rare_values( + compressed, rare_mask, self.rng + ) + np.testing.assert_array_equal(unmerged, data) + + def test_roundtrip_all_rare(self): + rare_mask = np.array([True, True, True]) + data = np.array([0, 1, 2, 0]) + _, compressed = vectorized_transformations.merge_rare_values( + data, rare_mask + ) + unmerged = vectorized_transformations.unmerge_rare_values( + compressed, rare_mask, self.rng + ) + for v in unmerged: + self.assertIn(v, [0, 1, 2]) + + def test_empty_data(self): + rare_mask = np.array([True, False]) + result = vectorized_transformations.unmerge_rare_values( + np.array([], dtype=int), rare_mask, self.rng + ) + self.assertEqual(result.shape, (0,)) + + +if __name__ == '__main__': + absltest.main()