Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/pytest.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,4 +39,4 @@ jobs:
- run: pip freeze

- name: Run core tests
run: pytest -vv -n auto
run: pytest -vv -n auto --import-mode=importlib
25 changes: 25 additions & 0 deletions conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
# Copyright 2026 Google LLC
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Pytest configuration and global fixtures."""

import sys
from absl import flags

# Parse absl flags once before any tests run to avoid UnparsedFlagAccessError
# when tests use absltest.TestCase.create_tempfile() or create_tempdir().
try:
flags.FLAGS(sys.argv, known_only=True)
except flags.Error:
pass
32 changes: 23 additions & 9 deletions dpsynth/bin/derive_domain.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
from absl import logging
from dpsynth import domain
from dpsynth.bin import _read_csv_args
import fancyflags as ff
import numpy as np
import pandas as pd

Expand All @@ -40,10 +39,20 @@
'Path to the output directory to write the domain to.',
)

_CSV_READ_ARGS = ff.DEFINE_auto(
'csv_read_args',
_read_csv_args.ReadCsvArgs,
_read_csv_args.FLAG_HELP,
_CSV_READ_ARGS_FIELD_SEPARATOR = flags.DEFINE_string(
'csv_read_args_field_separator',
None,
'Field separator for reading CSV files.',
)
_CSV_READ_ARGS_COLUMN_NAMES = flags.DEFINE_list(
'csv_read_args_column_names',
None,
'Column names for reading CSV files.',
)
_CSV_READ_ARGS_COLUMN_COUNT = flags.DEFINE_integer(
'csv_read_args_column_count',
None,
'Column count for reading CSV files.',
)


Expand Down Expand Up @@ -120,18 +129,18 @@ def derive_domain_from_data(
for col in df.columns:
logging.info('Deriving domain for column: %s', col)
match df[col].dtype:
case 'object':
case 'object' | 'str' | 'string' | 'category':
result[col] = domain.CategoricalAttribute(
possible_values=sorted(
df[col].unique(),
key=lambda x: (isinstance(x, str), x), # sort ints before strs.
)
)
case 'int':
case 'int' | 'int64' | 'int32':
result[col] = _create_numerical_attribute(
df[col], 'int', numerical_sentinel_value
)
case 'float':
case 'float' | 'float64' | 'float32':
result[col] = _create_numerical_attribute(
df[col], 'float', numerical_sentinel_value
)
Expand All @@ -145,7 +154,12 @@ def _get_yaml_filename(dataset_path: PathType) -> str:


def main(_) -> None:
read_csv_kwargs = _CSV_READ_ARGS.value().to_read_csv_kwargs()
csv_read_args = _read_csv_args.ReadCsvArgs(
field_separator=_CSV_READ_ARGS_FIELD_SEPARATOR.value,
column_names=_CSV_READ_ARGS_COLUMN_NAMES.value,
column_count=_CSV_READ_ARGS_COLUMN_COUNT.value,
)
read_csv_kwargs = csv_read_args.to_read_csv_kwargs()

# If output_dir is not set, use the parent directory of the dataset path.
output_dir = _OUTPUT_DIR.value
Expand Down
6 changes: 6 additions & 0 deletions dpsynth/local_mode/primitives.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,6 +277,12 @@ def select_partitions_sips(
rem_user_ids = rem_user_ids[mask]
rem_partitions = rem_partitions[mask]

if not selected_partitions:
return (
np.empty(0, dtype=data.dtype),
np.empty(0, dtype=float),
max_sigma,
)
selected_partitions = np.concatenate(selected_partitions)
selected_counts = np.concatenate(selected_counts)
return selected_partitions, selected_counts, max_sigma
Expand Down
3 changes: 3 additions & 0 deletions dpsynth/pipeline_transformations/diagnostic_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,9 @@
"""Module for updating diagnostic information."""

import copy
from dpsynth.pipeline_transformations import types
import pipeline_dp

from dataclasses import dataclass, field
from typing import Optional

Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ dependencies = [
"pipeline-dp",
"tensorflow",
"apache-beam",
"dp-accounting",
"dp_accounting @ git+https://github.com/google/differential-privacy.git#subdirectory=python/dp_accounting",
"jax[cpu]",
"more-itertools",
"networkx",
Expand Down
Loading