Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 0 additions & 7 deletions dpdl/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -522,13 +522,6 @@ def cli(
rich_help_panel='Opacus options',
)
] = 8,
noise_batch_ratio: Annotated[
Optional[float],
typer.Option(
help='Noise-batch ratio (https://arxiv.org/abs/2501.18914)',
rich_help_panel='Opacus options',
)
] = None,
target_hypers: Annotated[
Optional[List[str]],
typer.Option(
Expand Down
20 changes: 0 additions & 20 deletions dpdl/configurationmanager.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ class Hyperparameters(BaseModel):
noise_multiplier: Optional[float]
max_grad_norm: Optional[float]
target_epsilon: Optional[float]
noise_batch_ratio: Optional[float]
privacy: bool = True # Only used in __str__
max_length: Optional[int] = None

Expand All @@ -39,24 +38,6 @@ def check_target_epsilon_or_noise_multiplier(cls, values):

return values

@root_validator(pre=True)
def check_target_epsilon_or_noise_batch_ratio(cls, values):
target_epsilon, noise_batch_ratio = values.get('target_epsilon'), values.get('noise_batch_ratio')

if all([target_epsilon, noise_batch_ratio]):
raise ValueError('Both, target_epsilon and noise_batch_ratio given.')

return values

@root_validator(pre=True)
def check_noise_batch_ratio_or_noise_multiplier(cls, values):
noise_multiplier, noise_batch_ratio = values.get('noise_multiplier'), values.get('noise_batch_ratio')

if all([noise_multiplier, noise_batch_ratio]):
raise ValueError('Both, noise_multiplier and noise_batch_ratio given.')

return values

@root_validator(pre=True)
def check_epochs(cls, values):
epochs = values.get('epochs')
Expand Down Expand Up @@ -84,7 +65,6 @@ def __str__(self):
('Noise multiplier', self.noise_multiplier),
('Max grad norn', self.max_grad_norm),
('Target epsilon', self.target_epsilon),
('Noise-batch ratio', self.noise_batch_ratio),
]
hypers.extend(privacy_hypers)

Expand Down
59 changes: 46 additions & 13 deletions dpdl/metrics_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,25 +45,58 @@ def _get_classification_metrics(
return torchmetrics.MetricCollection(metrics)


class LanguageModelMetrics(torchmetrics.MetricCollection):
def __init__(self, vocab_size: int, ignore_index: int, sync: bool) -> None:
metrics = {
'MulticlassAccuracy': torchmetrics.classification.MulticlassAccuracy(
num_classes=vocab_size,
average='micro',
ignore_index=ignore_index,
sync_on_compute=sync,
),
'Perplexity': Perplexity(
ignore_index=ignore_index,
sync_on_compute=sync,
),
}
super().__init__(metrics)

def update(self, preds, target) -> None:
# Accuracy metrics use standard flattened inputs
if not hasattr(preds, 'ndim'):
return super().update(preds, target)

# We need to shape the data for perplexity that expects 3D logits and 2D labels
if preds.ndim == 3:
shift_logits = preds[:, :-1, :].contiguous() # (batch, seq_len-1, vocab)
shift_labels = target[:, 1:].contiguous() # (batch, seq_len-1)
shift_logits_flat = shift_logits.view(-1, shift_logits.size(-1)) # (batch*(seq_len-1), vocab)
shift_labels_flat = shift_labels.view(-1) # (batch*(seq_len-1))

self['Perplexity'].update(shift_logits, shift_labels)

for name, metric in self.items():
if name == 'Perplexity':
continue

metric.update(shift_logits_flat, shift_labels_flat)

return

return super().update(preds, target)


def _get_language_model_metrics(
vocab_size: int,
ignore_index: int,
sync: bool,
) -> torchmetrics.MetricCollection:
metrics = {
'MulticlassAccuracy': torchmetrics.classification.MulticlassAccuracy(
num_classes=vocab_size,
average='micro',
ignore_index=ignore_index,
sync_on_compute=sync,
),
'Perplexity': Perplexity(
ignore_index=ignore_index,
sync_on_compute=sync,
),
}

return torchmetrics.MetricCollection(metrics)
return LanguageModelMetrics(
vocab_size=vocab_size,
ignore_index=ignore_index,
sync=sync,
)


class MetricsFactory:
Expand Down
19 changes: 1 addition & 18 deletions dpdl/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -403,7 +403,6 @@ def __init__(
secure_mode: bool = False,
target_epsilon: float | None = None,
target_delta: float | None = None,
noise_batch_ratio: float | None = None,
seed: int = 0,
**kwargs,
):
Expand All @@ -412,7 +411,6 @@ def __init__(
self.clipping_mode = clipping_mode
self.target_epsilon = target_epsilon
self.target_delta = target_delta
self.noise_batch_ratio = noise_batch_ratio
self.seed = seed
self.poisson_sampling = poisson_sampling
self.normalize_clipping = normalize_clipping
Expand All @@ -437,18 +435,9 @@ def _has_target_privacy_params(self):
if self.target_epsilon and not self.target_delta:
raise ValueError('Parameter "target_epsilon" and "target_delta" not given.')

if self.noise_batch_ratio and not self.target_delta:
raise ValueError('Parameter "target_epsilon" and "target_delta" not given.')

if all([self.target_epsilon, self.noise_batch_ratio]):
raise ValueError('Parameters "target_epsilon" and "noise_batch_ratio" are exclusive.')

if all([self.target_epsilon, self.noise_multiplier]):
raise ValueError('Parameters "target_epsilon" and "noise_multiplier" are exlusive.')

if all([self.noise_batch_ratio, self.noise_multiplier]):
raise ValueError('Parameters "noise_batch_ratio" and "noise_multiplier" are exclusive.')

if self.target_epsilon and not self.target_delta:
raise ValueError('Parameter "target_epsilon" present, but "target_delta" is missing.')

Expand Down Expand Up @@ -489,9 +478,6 @@ def setup(self):
if self.target_epsilon == -1:
self.noise_multiplier = 0

if self.noise_batch_ratio:
self.noise_multiplier = self.noise_batch_ratio * self.datamodule.batch_size

dp_model, dp_optimizer, dp_dataloader = self.privacy_engine.make_private(
module=model,
optimizer=optimizer,
Expand Down Expand Up @@ -816,15 +802,13 @@ def compute_loss(self, model, batch, forward_output, normalize_by: int | None =

def update_metrics(self, model, batch, forward_output, metrics = None):
_, y = batch
preds, y_flat = shift_and_flatten(forward_output, y)

if metrics is not None:
metrics_to_update = metrics
else:
metrics_to_update = model.train_metrics if model.training else model.valid_metrics

with torch.no_grad():
metrics_to_update.update(preds, y_flat)
metrics_to_update.update(forward_output, y)

# Define task specific adapters
_ADAPTERS = {
Expand Down Expand Up @@ -1026,7 +1010,6 @@ def _get_target_privacy_params(hyperparams):
max_grad_norm=hyperparams.max_grad_norm,
target_epsilon=target_epsilon,
target_delta=target_delta,
noise_batch_ratio=hyperparams.noise_batch_ratio,
poisson_sampling=configuration.poisson_sampling,
normalize_clipping=configuration.normalize_clipping,
# config
Expand Down