diff --git a/.github/workflows/pytest.yml b/.github/workflows/pytest.yml index 773338d..046b7cb 100644 --- a/.github/workflows/pytest.yml +++ b/.github/workflows/pytest.yml @@ -39,4 +39,4 @@ jobs: - run: pip freeze - name: Run core tests - run: pytest -vv -n auto + run: pytest -vv -n auto --import-mode=importlib diff --git a/conftest.py b/conftest.py new file mode 100644 index 0000000..71a4b8b --- /dev/null +++ b/conftest.py @@ -0,0 +1,25 @@ +# Copyright 2026 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Pytest configuration and global fixtures.""" + +import sys +from absl import flags + +# Parse absl flags once before any tests run to avoid UnparsedFlagAccessError +# when tests use absltest.TestCase.create_tempfile() or create_tempdir(). +try: + flags.FLAGS(sys.argv, known_only=True) +except flags.Error: + pass diff --git a/dpsynth/bin/derive_domain.py b/dpsynth/bin/derive_domain.py index 20e5df6..fe5a834 100644 --- a/dpsynth/bin/derive_domain.py +++ b/dpsynth/bin/derive_domain.py @@ -21,7 +21,6 @@ from absl import logging from dpsynth import domain from dpsynth.bin import _read_csv_args -import fancyflags as ff import numpy as np import pandas as pd @@ -40,10 +39,20 @@ 'Path to the output directory to write the domain to.', ) -_CSV_READ_ARGS = ff.DEFINE_auto( - 'csv_read_args', - _read_csv_args.ReadCsvArgs, - _read_csv_args.FLAG_HELP, +_CSV_READ_ARGS_FIELD_SEPARATOR = flags.DEFINE_string( + 'csv_read_args_field_separator', + None, + 'Field separator for reading CSV files.', +) +_CSV_READ_ARGS_COLUMN_NAMES = flags.DEFINE_list( + 'csv_read_args_column_names', + None, + 'Column names for reading CSV files.', +) +_CSV_READ_ARGS_COLUMN_COUNT = flags.DEFINE_integer( + 'csv_read_args_column_count', + None, + 'Column count for reading CSV files.', ) @@ -120,18 +129,18 @@ def derive_domain_from_data( for col in df.columns: logging.info('Deriving domain for column: %s', col) match df[col].dtype: - case 'object': + case 'object' | 'str' | 'string' | 'category': result[col] = domain.CategoricalAttribute( possible_values=sorted( df[col].unique(), key=lambda x: (isinstance(x, str), x), # sort ints before strs. ) ) - case 'int': + case 'int' | 'int64' | 'int32': result[col] = _create_numerical_attribute( df[col], 'int', numerical_sentinel_value ) - case 'float': + case 'float' | 'float64' | 'float32': result[col] = _create_numerical_attribute( df[col], 'float', numerical_sentinel_value ) @@ -145,7 +154,12 @@ def _get_yaml_filename(dataset_path: PathType) -> str: def main(_) -> None: - read_csv_kwargs = _CSV_READ_ARGS.value().to_read_csv_kwargs() + csv_read_args = _read_csv_args.ReadCsvArgs( + field_separator=_CSV_READ_ARGS_FIELD_SEPARATOR.value, + column_names=_CSV_READ_ARGS_COLUMN_NAMES.value, + column_count=_CSV_READ_ARGS_COLUMN_COUNT.value, + ) + read_csv_kwargs = csv_read_args.to_read_csv_kwargs() # If output_dir is not set, use the parent directory of the dataset path. output_dir = _OUTPUT_DIR.value diff --git a/dpsynth/local_mode/primitives.py b/dpsynth/local_mode/primitives.py index 644181f..10896ae 100644 --- a/dpsynth/local_mode/primitives.py +++ b/dpsynth/local_mode/primitives.py @@ -277,6 +277,12 @@ def select_partitions_sips( rem_user_ids = rem_user_ids[mask] rem_partitions = rem_partitions[mask] + if not selected_partitions: + return ( + np.empty(0, dtype=data.dtype), + np.empty(0, dtype=float), + max_sigma, + ) selected_partitions = np.concatenate(selected_partitions) selected_counts = np.concatenate(selected_counts) return selected_partitions, selected_counts, max_sigma diff --git a/dpsynth/pipeline_transformations/diagnostic_info.py b/dpsynth/pipeline_transformations/diagnostic_info.py index 8d26c04..bdc3c06 100644 --- a/dpsynth/pipeline_transformations/diagnostic_info.py +++ b/dpsynth/pipeline_transformations/diagnostic_info.py @@ -15,6 +15,9 @@ """Module for updating diagnostic information.""" import copy +from dpsynth.pipeline_transformations import types +import pipeline_dp + from dataclasses import dataclass, field from typing import Optional diff --git a/pyproject.toml b/pyproject.toml index 8fa36bc..fd78125 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,7 +24,7 @@ dependencies = [ "pipeline-dp", "tensorflow", "apache-beam", - "dp-accounting", + "dp_accounting @ git+https://github.com/google/differential-privacy.git#subdirectory=python/dp_accounting", "jax[cpu]", "more-itertools", "networkx",