diff --git a/commands.md b/commands.md new file mode 100644 index 0000000..fa6631f --- /dev/null +++ b/commands.md @@ -0,0 +1,81 @@ +### GCD ImageNet (zebra ↔ sorrel) — quick commands + +Prereqs +- You already trained a DiffAE autoencoder at 256×256 and have its checkpoint path (last.ckpt). +- You have a 256×256 source image (.png/.jpg) for zebra or sorrel. +- Optional: a custom ImageNet classifier checkpoint (otherwise torchvision ResNet-50 pretrained is used). + +Setup (one-time) +```bash +cd /Users/mgalkowski/Desktop/diffae/gcd +conda create --name gcd python=3.11 -y +pip install -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu118 +``` + +Zebra example (label 340): + +```bash +python src/main.py \ + --config-path ../configs/single_image_gmc_mlp_proxy_training/imagenet/resnet/zebra \ + --config-name config \ + strategy.src_img_path=null \ + strategy.hf_dataset_name=imagenet-1k \ + strategy.hf_split=train \ + strategy.hf_label=340 \ + strategy.hf_index=0 \ + device=cuda:0 \ + wandb.project=null +``` + +Run: sorrel (class 339, sorrel→not‑sorrel direction) +```bash +cd /Users/mgalkowski/Desktop/diffae/gcd/gcd +python src/main.py \ + --config-path ../configs/single_image_gmc_mlp_proxy_training/imagenet/resnet/sorrel \ + --config-name config \ + strategy.src_img_path="$SRC_IMG" \ + strategy.dae_kwargs.path_ckpt="$DAE_CKPT" \ + device=cuda:0 \ + wandb.project=null +``` + +Notes +- The configs use the default 3‑channel DiffAE wrapper and a ResNet‑50 ImageNet classifier. +- To use a custom classifier, add at the end of the command: + - `strategy.ce_loss_kwargs.components.clf.path_to_weights="/path/to/classifier.ckpt"` +- Outputs are written under `gcd/gcd/outputs//strategy/`. Proxy data and candidate directions are in `strategy/proxy/...`. + +Direction transfer (apply a saved direction to another image) +```bash +# Example placeholders — update paths after a run completes +export LOG_DIR="/Users/mgalkowski/Desktop/diffae/gcd/gcd/outputs//strategy" +export DIR_PATH="$LOG_DIR/proxy//.pt" # e.g., *_grad_*.pt +export TARGET_IMG="/absolute/path/to/another_image.png" + +python src/direction_transfer.py \ + --log-dir-path "$LOG_DIR" \ + --direction-path "$DIR_PATH" \ + --img-path "$TARGET_IMG" \ + --subdir-name "transfer_run_1" +``` + +```bash +python src/direction_transfer_hf_batch.py \ + --direction-path "/path/to/outputs//proxy/0/0_grad_*.pt" \ + --log-dir-path "/path/to/outputs/" \ + --dataset-name imagenet-1k \ + --split train \ + --target-class 340 \ + --n-samples 500 \ + --step-size 1.0 \ + --T-render 100 \ + --subdir-name "hf_batch_zebra" +``` + +Tips for memory/time +- If you see OOM, reduce: `strategy.dae_kwargs.batch_size` (e.g., 128) or narrow `strategy.dae_kwargs.std`. +- To speed up renders, you can lower `strategy.dae_kwargs.T_render` (e.g., 20–50). + +Switching class direction +- Zebra run uses `query_label=340`; sorrel run uses `query_label=339`. Use the matching config. + diff --git a/gcd/configs/single_image_gmc_mlp_proxy_training/imagenet/resnet/sorrel/config.yaml b/gcd/configs/single_image_gmc_mlp_proxy_training/imagenet/resnet/sorrel/config.yaml new file mode 100644 index 0000000..36f1077 --- /dev/null +++ b/gcd/configs/single_image_gmc_mlp_proxy_training/imagenet/resnet/sorrel/config.yaml @@ -0,0 +1,73 @@ +mode: RUN +seed: 0 +device: cuda:0 +output_dir: + +strategy: + _target_: strategies.SingleImageGMCMLPProxyTraining + src_img_path: ??? # path to source sorrel image + n_iters: 20 + n_steps_dae: 8 + n_steps_proxy: 512 + min_max_step_size: [0.0005, 2.0] + device: ${device} + output_dir: ${output_dir}/strategy + center_to_src_latent_sem: true + line_search_weight_cls_list: [1.0, 0.05] + line_search_weight_lpips_list: [0, 1.0] + n_hessian_eigenvecs: 16 + dae_type: default + + dae_kwargs: + config_name: imagenet256_autoenc + batch_size: 256 + std: [0.01, 0.05, 0.1, 0.2] + distribution: uniform_norm + max_norm: 1.0 + T_encode: 250 + T_render: 50 + path_ckpt: ??? # path to your trained DiffAE checkpoint (.ckpt) + device: ${device} + output_dir: ${output_dir}/dae + + mc_mlp_proxy_kwargs: + shapes: [512, 256, 128, 64, 1001] + activ_fn: sigmoid + batch_size: 128 + n_val_batches: 1 + val_loss_history_length: 5 + optimizer: + _partial_: true + _target_: torch.optim.AdamW + lr: 0.001 + loss_kwargs: + weight_cls: 1.0 + weight_lpips: 3.0 + device: ${device} + output_dir: ${output_dir}/proxy + + ce_loss_kwargs: + weight_cls: 1.0 + weight_lpips: 0 + device: ${device} + output_dir: ${output_dir}/loss + components: + _target_: losses.CounterfactualLossGeneralComponents + lpips_net: vgg + src_img_path: ${strategy.src_img_path} + device: ${device} + clf: + _target_: classifiers.ImageNetResNet + img_size: 256 + query_label: 339 # sorrel class id + use_softmax_and_query_label: false + task: multiclass_classification + model_name: resnet50 + weights: IMAGENET1K_V2 + path_to_weights: null + +wandb: + project: null + group: null + name: imagenet_sorrel + diff --git a/gcd/configs/single_image_gmc_mlp_proxy_training/imagenet/resnet/zebra/config.yaml b/gcd/configs/single_image_gmc_mlp_proxy_training/imagenet/resnet/zebra/config.yaml new file mode 100644 index 0000000..2320f9f --- /dev/null +++ b/gcd/configs/single_image_gmc_mlp_proxy_training/imagenet/resnet/zebra/config.yaml @@ -0,0 +1,73 @@ +mode: RUN +seed: 0 +device: cuda:0 +output_dir: + +strategy: + _target_: strategies.SingleImageGMCMLPProxyTraining + src_img_path: ??? # path to source zebra image (.png or .jpg) + n_iters: 20 + n_steps_dae: 8 + n_steps_proxy: 512 + min_max_step_size: [0.0005, 2.0] + device: ${device} + output_dir: ${output_dir}/strategy + center_to_src_latent_sem: true + line_search_weight_cls_list: [1.0, 0.05] + line_search_weight_lpips_list: [0, 1.0] + n_hessian_eigenvecs: 16 + dae_type: default + + dae_kwargs: + config_name: imagenet256_autoenc + batch_size: 256 + std: [0.01, 0.05, 0.1, 0.2] + distribution: uniform_norm + max_norm: 1.0 + T_encode: 250 + T_render: 50 + path_ckpt: ??? # path to your trained DiffAE checkpoint (.ckpt) + device: ${device} + output_dir: ${output_dir}/dae + + mc_mlp_proxy_kwargs: + shapes: [512, 256, 128, 64, 1001] + activ_fn: sigmoid + batch_size: 128 + n_val_batches: 1 + val_loss_history_length: 5 + optimizer: + _partial_: true + _target_: torch.optim.AdamW + lr: 0.001 + loss_kwargs: + weight_cls: 1.0 + weight_lpips: 3.0 + device: ${device} + output_dir: ${output_dir}/proxy + + ce_loss_kwargs: + weight_cls: 1.0 + weight_lpips: 0 + device: ${device} + output_dir: ${output_dir}/loss + components: + _target_: losses.CounterfactualLossGeneralComponents + lpips_net: vgg + src_img_path: ${strategy.src_img_path} + device: ${device} + clf: + _target_: classifiers.ImageNetResNet + img_size: 256 + query_label: 340 # zebra class id + use_softmax_and_query_label: false + task: multiclass_classification + model_name: resnet50 + weights: IMAGENET1K_V2 + path_to_weights: null # optional: path to custom weights + +wandb: + project: null + group: null + name: imagenet_zebra + diff --git a/gcd/src/classifiers/__init__.py b/gcd/src/classifiers/__init__.py index 408a9a4..a83850b 100644 --- a/gcd/src/classifiers/__init__.py +++ b/gcd/src/classifiers/__init__.py @@ -1,2 +1,3 @@ from .dense_net import DenseNet -from .dense_net_chexpert import DenseNetCheXpert \ No newline at end of file +from .dense_net_chexpert import DenseNetCheXpert +from .imagenet import ImageNetResNet \ No newline at end of file diff --git a/gcd/src/classifiers/imagenet.py b/gcd/src/classifiers/imagenet.py new file mode 100644 index 0000000..f99ff0e --- /dev/null +++ b/gcd/src/classifiers/imagenet.py @@ -0,0 +1,87 @@ +import torch +import torch.nn.functional as F +import torchvision +import torchvision.transforms as tt + +import logging +logging.basicConfig(level = logging.INFO) +log = logging.getLogger(__name__) + +IMAGENET_MEAN = (0.485, 0.456, 0.406) +IMAGENET_STD = (0.229, 0.224, 0.225) + +class ImageNetResNet(torch.nn.Module): + def __init__( + self, + img_size: int = 256, + query_label: int = 340, + use_softmax_and_query_label: bool = False, + task: str = 'multiclass_classification', + model_name: str = 'resnet50', + weights: str = 'IMAGENET1K_V2', + path_to_weights: str = None, + debug_save_path: str = None): + super().__init__() + self.img_size = img_size + self.task = task + self.query_label = int(query_label) + self.use_softmax_and_query_label = use_softmax_and_query_label + self.transforms = tt.Compose([ + tt.Normalize(mean = IMAGENET_MEAN, std = IMAGENET_STD) + ]) + self.debug_save_path = debug_save_path + self.model = self._build_model(model_name, weights, path_to_weights) + self.model.eval() + + def _build_model(self, model_name, weights, path_to_weights): + log.info(f'Building torchvision model: {model_name}') + if hasattr(torchvision.models, model_name): + ctor = getattr(torchvision.models, model_name) + else: + raise ValueError(f'Unknown model_name: {model_name}') + if path_to_weights is None: + # Load torchvision pretrained weights + try: + w_enum = getattr(torchvision.models, f'{model_name.upper()}_Weights') + w = getattr(w_enum, weights) + model = ctor(weights = w) + except Exception: + log.warning('Falling back to default pretrained weights') + model = ctor(weights = 'IMAGENET1K_V2') + else: + model = ctor(weights = None) + log.info(f'Loading classifier weights from: {path_to_weights}') + ckpt = torch.load(path_to_weights, map_location = 'cpu') + state_dict = ckpt.get('state_dict', ckpt) + # if state_dict keys are prefixed (e.g., 'model.'), try to strip common prefixes + if any(k.startswith('model.') for k in state_dict.keys()): + state_dict = {k.split('model.', 1)[1]: v for k, v in state_dict.items() if k.startswith('model.')} + missing, unexpected = model.load_state_dict(state_dict, strict = False) + if missing: + log.warning(f'Missing keys when loading classifier: {len(missing)}') + if unexpected: + log.warning(f'Unexpected keys when loading classifier: {len(unexpected)}') + return model + + @torch.no_grad() + def forward(self, x): + assert x.shape[-1] == self.img_size, 'Wrong input shape' + # Expect [0,1] range; rescale if in [-1,1] + if x.min() < 0: + assert x.min() >= -1. and x.max() <= 1. + log.info('Detected input outside [0,1], rescaling from [-1,1] to [0,1]') + x = (x + 1) / 2 + # Save a debug preview (pre-normalization) if requested + if self.debug_save_path is not None: + try: + torchvision.utils.save_image(x, self.debug_save_path) + log.info(f"Saved classifier input preview (pre-normalization) to: {self.debug_save_path}") + except Exception as e: + log.warning(f"Failed to save debug image to {self.debug_save_path}: {e}") + x = self.transforms(x) + logits = self.model(x) + if self.use_softmax_and_query_label: + probs = F.softmax(logits, dim = 1) + return probs[:, self.query_label] + return logits + diff --git a/gcd/src/dae/chexpert/dae.py b/gcd/src/dae/chexpert/dae.py index e7c5093..88cd3ee 100644 --- a/gcd/src/dae/chexpert/dae.py +++ b/gcd/src/dae/chexpert/dae.py @@ -17,7 +17,7 @@ class DAECheXpert(nn.Module): """ def __init__( - self, + self, config_name: str, batch_size: int, std: Union[float, List[float]], @@ -30,7 +30,7 @@ def __init__( max_norm: float = 1., make_output_dir: bool = True): """ - batch_size - size of the batch containing perturbed latent representations + batch_size - size of the batch containing perturbed latent representations of source image T_encode - number of diffusion steps for encoding T_render - number of diffusion steps for rendering @@ -93,18 +93,28 @@ def encode_stochastic(self, x, cond, T = None): x, model_kwargs = {'cond': cond}) return out['sample'] - + @torch.no_grad() - def render(self, noise, cond, T = None): + def render(self, noise, cond, T = None, max_batch: int = None): if T is not None: sampler = self.config._make_diffusion_conf(T).make_sampler() else: sampler = self.sampler_render - pred_img = render_condition(self.config, - self.model, - noise, - sampler = sampler, - cond = cond) + # Chunked rendering to reduce peak memory + N = noise.shape[0] + if max_batch is None: + max_batch = min(N, 32) + outs = [] + for i in range(0, N, max_batch): + n_chunk = noise[i:i + max_batch] + c_chunk = cond[i:i + max_batch] if cond is not None else None + out = render_condition(self.config, + self.model, + n_chunk, + sampler = sampler, + cond = c_chunk) + outs.append(out) + pred_img = torch.cat(outs, dim = 0) pred_img = (pred_img + 1) / 2 return pred_img @@ -116,8 +126,8 @@ def make_batch_latent_sem(self, latent_sem): chunk_size = self.batch_size // len(self.std) for std in self.std: noise = torch.normal( - mean = 0., - std = std, + mean = 0., + std = std, size = (chunk_size, latent_sem.shape[1]), device = self.device) noises.append(noise) @@ -126,8 +136,8 @@ def make_batch_latent_sem(self, latent_sem): elif self.distribution == 'uniform_norm': log.info(f'Sampling latent sem perturbations from uniform distribution over norm from the interval [0, {self.max_norm}]') noise = torch.normal( - mean = 0., - std = 1., + mean = 0., + std = 1., size = (self.batch_size, latent_sem.shape[1]), device = self.device) noise /= noise.norm(dim = 1).unsqueeze(1) diff --git a/gcd/src/dae/chexpert/utils/make_config.py b/gcd/src/dae/chexpert/utils/make_config.py index 1fda1d4..b67a41c 100644 --- a/gcd/src/dae/chexpert/utils/make_config.py +++ b/gcd/src/dae/chexpert/utils/make_config.py @@ -8,5 +8,8 @@ def make_config(config_name): if config_name == 'chexpert224_autoenc': log.info(f'Using config: {config_name}') return chexpert224_autoenc() + elif config_name == 'imagenet256_autoenc': + log.info(f'Using config: {config_name}') + return imagenet256_autoenc() else: raise NotImplementedError('Invalid config name or option not implemented yet') \ No newline at end of file diff --git a/gcd/src/dae/default/dae.py b/gcd/src/dae/default/dae.py index f0b7b14..53d30ff 100644 --- a/gcd/src/dae/default/dae.py +++ b/gcd/src/dae/default/dae.py @@ -17,7 +17,7 @@ class DAE(nn.Module): """ def __init__( - self, + self, config_name: str, batch_size: int, std: Union[float, List[float]], # [0.01, 0.05, 0.1, 0.2] @@ -30,7 +30,7 @@ def __init__( max_norm: float = 1., make_output_dir: bool = True): """ - batch_size - size of the batch containing perturbed latent representations + batch_size - size of the batch containing perturbed latent representations of source image T_encode - number of diffusion steps for encoding T_render - number of diffusion steps for rendering @@ -93,18 +93,29 @@ def encode_stochastic(self, x, cond, T = None): x, model_kwargs = {'cond': cond}) return out['sample'] - + @torch.no_grad() - def render(self, noise, cond, T = None): + def render(self, noise, cond, T = None, max_batch: int = None): if T is not None: sampler = self.config._make_diffusion_conf(T).make_sampler() else: sampler = self.sampler_render - pred_img = render_condition(self.config, - self.model, - noise, - sampler = sampler, - cond = cond) + # Chunked rendering to reduce peak memory + N = noise.shape[0] + if max_batch is None: + # default to something conservative + max_batch = min(N, 32) + outs = [] + for i in range(0, N, max_batch): + n_chunk = noise[i:i + max_batch] + c_chunk = cond[i:i + max_batch] if cond is not None else None + out = render_condition(self.config, + self.model, + n_chunk, + sampler = sampler, + cond = c_chunk) + outs.append(out) + pred_img = torch.cat(outs, dim = 0) pred_img = (pred_img + 1) / 2 return pred_img @@ -116,8 +127,8 @@ def make_batch_latent_sem(self, latent_sem): chunk_size = self.batch_size // len(self.std) for std in self.std: noise = torch.normal( - mean = 0., - std = std, + mean = 0., + std = std, size = (chunk_size, latent_sem.shape[1]), device = self.device) noises.append(noise) @@ -126,8 +137,8 @@ def make_batch_latent_sem(self, latent_sem): elif self.distribution == 'uniform_norm': log.info(f'Sampling latent sem perturbations from uniform distribution over norm from the interval [0, {self.max_norm}]') noise = torch.normal( - mean = 0., - std = 1., + mean = 0., + std = 1., size = (self.batch_size, latent_sem.shape[1]), device = self.device) noise /= noise.norm(dim = 1).unsqueeze(1) diff --git a/gcd/src/dae/default/templates.py b/gcd/src/dae/default/templates.py index 5066961..1f13a33 100644 --- a/gcd/src/dae/default/templates.py +++ b/gcd/src/dae/default/templates.py @@ -66,4 +66,20 @@ def celeba128_autoenc(): conf.eval_every_samples = 10_000_000 conf.data_name = 'celeba128' conf.name = 'celeba128_autoenc' + return conf + +def imagenet256_autoenc(): + conf = ffhq256_autoenc() + conf.img_size = 256 + conf.net_ch = 128 + conf.net_ch_mult = (1, 1, 2, 2, 4, 4) + conf.net_enc_channel_mult = (1, 1, 2, 2, 4, 4, 4) + # Large defaults are OK for inference config stub; not used for training here + conf.eval_every_samples = 10_000_000 + conf.eval_ema_every_samples = 10_000_000 + conf.total_samples = 500_000_000 + conf.batch_size = 64 + conf.make_model_conf() + conf.name = 'imagenet256_autoenc' + conf.data_name = 'imagenet' return conf \ No newline at end of file diff --git a/gcd/src/dae/default/utils/make_config.py b/gcd/src/dae/default/utils/make_config.py index 10205ff..5cf4ab9 100644 --- a/gcd/src/dae/default/utils/make_config.py +++ b/gcd/src/dae/default/utils/make_config.py @@ -11,5 +11,8 @@ def make_config(config_name): elif config_name == 'celeba128_autoenc': log.info(f'Using config: {config_name}') return celeba128_autoenc() + elif config_name == 'imagenet256_autoenc': + log.info(f'Using config: {config_name}') + return imagenet256_autoenc() else: raise NotImplementedError('Invalid config name or option not implemented yet') \ No newline at end of file diff --git a/gcd/src/direction_transfer_hf_batch.py b/gcd/src/direction_transfer_hf_batch.py new file mode 100644 index 0000000..e8ca6cc --- /dev/null +++ b/gcd/src/direction_transfer_hf_batch.py @@ -0,0 +1,311 @@ +import sys +import os +import torch +import pandas as pd +import torchvision +import omegaconf +from pathlib import Path +from argparse import ArgumentParser +from datasets import load_dataset +from PIL import Image +import torchvision.transforms as T +from typing import Optional, Dict + +from proxies import * +from losses import * +from classifiers import * +from dae import * + +import logging +logging.basicConfig(level = logging.INFO) +log = logging.getLogger(__name__) + +# +# Fill these with your paths if you want per-split defaults instead of passing --predictions-csv +# +PREDICTIONS_PATHS: Dict[str, Dict[str, Optional[str]]] = { + # Example: + # 'imagenet-1k': { + # 'train': '/abs/path/to/imagenet1k_train_predictions.csv', + # 'validation': '/abs/path/to/imagenet1k_val_predictions.csv', + # }, +} + +def init_class_from_string(class_name): + return getattr(sys.modules[__name__], class_name) + +def get_cfg(args): + path_cfg = (Path(args.log_dir_path) if args.log_dir_path else Path(args.direction_path).parents[2]) / '.hydra' / 'config.yaml' + cfg = omegaconf.OmegaConf.load(path_cfg) + return cfg + +def init_loss(cfg, args): + loss_kwargs = cfg.strategy.ce_loss_kwargs + loss_kwargs['make_output_dir'] = False + clf_kwargs = loss_kwargs.components.clf + clf_name = clf_kwargs.pop('_target_').split('.')[-1] + log.info(f'Using {clf_name} classifier') + if clf_name == 'DenseNet': + clf_kwargs.use_probs_and_query_label = False + elif clf_name == 'ResNet': + clf_kwargs.use_softmax_and_query_label = True + clf = init_class_from_string(clf_name)(**clf_kwargs) + # Override target/query class if provided for this direction + if getattr(args, "direction_target_class", None) is not None: + if hasattr(clf, "query_label"): + clf.query_label = int(args.direction_target_class) + log.info(f"Set classifier query_label to direction target class: {clf.query_label}") + comps_kwargs = loss_kwargs.components + comps_name = comps_kwargs._target_.split('.')[-1] + comps_kwargs.pop('_target_') + comps_kwargs.pop('clf', None) + comps_kwargs.pop('src_img_path', None) + comps = init_class_from_string(comps_name)( + src_img_path = None, + clf = clf, + **comps_kwargs) + loss_kwargs.pop('components') + loss_name = 'CounterfactualLossFromGeneralComponents' + loss_kwargs.weight_cls = cfg.strategy.ce_loss_kwargs.weight_cls + loss_kwargs.weight_lpips = cfg.strategy.ce_loss_kwargs.weight_lpips + loss = init_class_from_string(loss_name)(components = comps, **loss_kwargs) + return loss + +def init_dae(cfg): + dae_kwargs = cfg.strategy.dae_kwargs + dae_kwargs['make_output_dir'] = False + dae_type = cfg.strategy.dae_type + if dae_type == 'default': + dae_class = DAE + elif dae_type == 'chexpert': + dae_class = DAECheXpert + else: + raise NotImplementedError('DAE type not recognized') + log.info(f'Using {dae_class}') + dae = dae_class(**dae_kwargs) + return dae + +def load_direction(path, device): + grad = torch.load(Path(path), map_location='cpu').to(device) + if torch.count_nonzero(grad) == 0: + raise ValueError("Direction tensor contains only zeros.") + return grad + +def get_predictions_df(dataset_name: str, split: str, explicit_csv: Optional[str]) -> Optional[pd.DataFrame]: + """ + Loads predictions CSV expected to have: + - index column: 'idx' + - column: 'pred_label' + If explicit_csv is None, attempts to use PREDICTIONS_PATHS[dataset_name][split]. + Returns None if no path provided. + """ + csv_path = explicit_csv + if csv_path is None: + csv_path = PREDICTIONS_PATHS.get(dataset_name, {}).get(split, None) + if csv_path is None: + return None + if not os.path.exists(csv_path): + raise FileNotFoundError(f"Predictions CSV not found: {csv_path}") + df = pd.read_csv(csv_path, index_col="idx") + if "pred_label" not in df.columns: + raise ValueError(f"Predictions CSV missing 'pred_label' column: {csv_path}") + return df + +def _extract_ex_idx(example: dict, fallback: Optional[int] = None) -> Optional[int]: + """ + Try to extract a stable dataset index from a HF streaming example. + Prefers '__index_level_0__', then 'id', then 'idx'. Falls back to provided fallback. + """ + for key in ("__index_level_0__", "id", "idx"): + if key in example: + try: + return int(example[key]) + except Exception: + continue + return fallback + +def iter_hf_by_label( + dataset_name, + split, + label, + start_index, + n_samples, + token=None, + cache_dir=None, + predictions_df: Optional[pd.DataFrame] = None, + filter_id: Optional[int] = None, + n_skip: int = 0, + pred_index_scope: str = "global", # 'global' or 'label' +): + kwargs = {"streaming": True} + if token is None: + token = os.getenv('HF_TOKEN', None) + if cache_dir is None: + cache_dir = os.getenv('DATASET_CACHE', None) + if token is not None: + kwargs["token"] = token + if cache_dir is not None: + kwargs["cache_dir"] = cache_dir + ds = load_dataset(dataset_name, split=split, **kwargs) + count_label_hits = 0 # counts examples with ex['label'] == label (zero-based) + emitted = 0 + skipped = 0 + for global_idx, ex in enumerate(ds): + if ex.get('label', None) != label: + continue + # Count only matching label occurrences for start_index logic + current_label_idx = count_label_hits # zero-based index of label-matching sample + count_label_hits += 1 + if current_label_idx < start_index: + continue + + # If predictions filtering is requested, check it here + if predictions_df is not None and filter_id is not None: + # Choose index to match predictions: global or label-only enumeration + ex_idx = global_idx if pred_index_scope == "global" else current_label_idx + if ex_idx is None or ex_idx not in predictions_df.index: + # If we can't find a matching prediction idx, skip + continue + pred_label = predictions_df.loc[ex_idx, "pred_label"] + # Keep only correctly predicted examples for the requested class + # Since ex['label'] == label, correctness means pred_label == label == filter_id + if int(pred_label) != int(filter_id): + continue + + # Apply n_skip on the filtered stream + if skipped < n_skip: + skipped += 1 + continue + + # Attach derived indices for downstream logging if needed + ex = dict(ex) + ex["_global_idx_enumerate"] = global_idx + ex["_label_idx_enumerate"] = current_label_idx + yield ex + emitted += 1 + if emitted >= n_samples: + break + +def pil_to_tensor_01(img: Image.Image, image_size: int): + if img.mode != 'RGB': + img = img.convert('RGB') + transform = T.Compose([ + T.CenterCrop(image_size), + T.Resize(image_size), + T.ToTensor(), # [0,1] + ]) + return transform(img) + +def to_dae_range(x01: torch.Tensor): + return (x01 - 0.5) * 2 + +def from_dae_range(xm1p1: torch.Tensor): + return (xm1p1 + 1) / 2 + +def main(args): + log.info("Batch direction transfer (HF ImageNet)") + cfg = get_cfg(args) + device = torch.device(cfg.device) + loss = init_loss(cfg, args) + dae = init_dae(cfg) + direction = load_direction(args.direction_path, device) + + out_dir = (Path(args.log_dir_path) if args.log_dir_path else Path(args.direction_path).parents[2]) / 'direction_transfer' / args.subdir_name + out_dir.mkdir(parents=True, exist_ok=True) + csv_rows = [] + + image_size = dae.config.img_size if hasattr(dae, 'config') else 256 + n_done = 0 + # Optional predictions filtering + predictions_df = None + if args.filter_id is not None: + predictions_df = get_predictions_df(args.dataset_name, args.split, args.predictions_csv) + if predictions_df is None: + log.error(f"Filtering requested but no predictions CSV provided for {args.dataset_name}/{args.split}.") + log.error("Provide --predictions-csv or fill PREDICTIONS_PATHS in the script.") + sys.exit(1) + for ex in iter_hf_by_label( + dataset_name=args.dataset_name, + split=args.split, + label=args.target_class, + start_index=args.start_index, + n_samples=args.n_samples, + token=args.hf_token, + cache_dir=args.hf_cache_dir, + predictions_df=predictions_df, + filter_id=args.filter_id, + n_skip=args.n_skip, + pred_index_scope=args.pred_index_scope + ): + img_pil = ex['image'] if isinstance(ex['image'], Image.Image) else Image.fromarray(ex['image']) + # Log the index used to match predictions for traceability + idx = ex.get("_global_idx_enumerate", n_done) if args.pred_index_scope == "global" else ex.get("_label_idx_enumerate", n_done) + # Prepare tensors + img01 = pil_to_tensor_01(img_pil, image_size).unsqueeze(0).to(device) + img_m1p1 = to_dae_range(img01) + # Encode + latent_sem = dae.encode(img_m1p1) + latent_ddim = dae.encode_stochastic(img_m1p1, latent_sem) + # One-step transfer using fixed step size + step = torch.tensor([[args.step_size]], device=device) + grad = direction + latent_sem_new = latent_sem - step * grad + img_trans = dae.render(latent_ddim, latent_sem_new, T=args.T_render) + # Save pair + orig_name = f"original_{n_done+1}.png" + trans_name = f"inpaint_{n_done+1}.png" + torchvision.utils.save_image(img01, out_dir / orig_name) + torchvision.utils.save_image(img_trans, out_dir / trans_name) + # Log classifier probabilities (optional) + with torch.no_grad(): + comps = loss.get_components(img_trans) + # Query-label probabilities + prob_after = loss.get_query_label_probability(comps['predictions']).item() + pred_before_logits = loss.components.classifier(img01) + prob_before = loss.get_query_label_probability(pred_before_logits).item() + # Max-arg probabilities and classes (multiclass) + probs_after_all = torch.softmax(comps['predictions'], dim = 1) + best_after_prob, best_after_cls = probs_after_all.max(dim = 1) + probs_before_all = torch.softmax(pred_before_logits, dim = 1) + best_before_prob, best_before_cls = probs_before_all.max(dim = 1) + csv_rows.append({ + "idx": int(idx), + "orig_path": orig_name, + "inpaint_path": trans_name, + "prob_before": prob_before, + "prob_after": prob_after, + "pred_before_best_cls": int(best_before_cls.item()), + "pred_before_best_prob": float(best_before_prob.item()), + "pred_after_best_cls": int(best_after_cls.item()), + "pred_after_best_prob": float(best_after_prob.item()), + "direction_target_class": int(getattr(args, "direction_target_class", -1)) + }) + n_done += 1 + if n_done % 25 == 0: + log.info(f"Processed {n_done}/{args.n_samples}") + # Save CSV + pd.DataFrame(csv_rows).to_csv(out_dir / "pairs.csv", index=False) + log.info(f"Saved {n_done} pairs to {out_dir}") + +if __name__ == '__main__': + parser = ArgumentParser() + parser.add_argument('--direction-path', type=str, required=True, help='Path to direction .pt file from proxy run') + parser.add_argument('--direction-target-class', type=int, required=True, help='Target class id the direction was trained for (query label)') + parser.add_argument('--log-dir-path', type=str, default=None, help='Log dir path with .hydra config (defaults to infer from direction path)') + parser.add_argument('--dataset-name', type=str, default='imagenet-1k', help='HF dataset name') + parser.add_argument('--split', type=str, default='train', help='HF split') + parser.add_argument('--target-class', type=int, required=True, help='HF class id to sample') + parser.add_argument('--n-samples', type=int, default=500, help='Number of images to transfer') + parser.add_argument('--start-index', type=int, default=0, help='Skip N matching samples before collecting') + parser.add_argument('--hf-token', type=str, default=None, help='HF token (or set HF_TOKEN env)') + parser.add_argument('--hf-cache-dir', type=str, default=None, help='HF cache (or set DATASET_CACHE env)') + parser.add_argument('--step-size', type=float, default=1.0, help='Step size multiplier for the direction') + parser.add_argument('--T-render', type=int, default=100, help='Number of DDIM steps for rendering') + parser.add_argument('--subdir-name', type=str, default='hf_batch', help='Output subdir name under direction_transfer') + # Optional predictions-based filtering (keep only correctly predicted examples for --filter-id) + parser.add_argument('--filter-id', type=int, default=None, help='If set, keep only samples with this GT label that are correctly predicted as this label') + parser.add_argument('--n-skip', type=int, default=0, help='Skip first N samples after filtering (useful for sharding)') + parser.add_argument('--predictions-csv', type=str, default=None, help='Optional override path to predictions CSV (index=idx, col=pred_label)') + args = parser.parse_args() + main(args) + diff --git a/gcd/src/losses/counterfactual_loss_general_components.py b/gcd/src/losses/counterfactual_loss_general_components.py index 969d20f..39f54ca 100644 --- a/gcd/src/losses/counterfactual_loss_general_components.py +++ b/gcd/src/losses/counterfactual_loss_general_components.py @@ -3,6 +3,10 @@ import torch.nn as nn import torch.nn.functional as F import torchvision +import torchvision.transforms as T +from datasets import load_dataset +from PIL import Image +import os import logging log = logging.getLogger(__name__) @@ -15,11 +19,17 @@ class CounterfactualLossGeneralComponents(nn.Module): """ def __init__( - self, + self, clf, lpips_net: str, - src_img_path: str, - device: str): + src_img_path: str = None, + device: str = 'cuda:0', + # Optional: load source image from HF dataset by label if src_img_path is None + hf_dataset_name: str = None, # e.g., 'imagenet-1k' + hf_split: str = 'train', + hf_label: int = None, + hf_index: int = 0, + image_size: int = 256): """ label_idx - label of interest (for counterfactual explanation) """ @@ -28,7 +38,29 @@ def __init__( self.to(self.device) self.classifier = clf.to(self.device) self.lpips = lpips.LPIPS(net = lpips_net).to(self.device) - self.load_src_img(src_img_path, device) + # Load source image: prefer local path, otherwise try HF dataset + if src_img_path is not None: + self.load_src_img(src_img_path, device) + else: + # Fallback to HF if params provided or env vars present + ds_name = hf_dataset_name or os.getenv('HF_DATASET_NAME', None) + ds_split = hf_split + ds_label = hf_label if hf_label is not None else os.getenv('HF_LABEL', None) + ds_index = hf_index if hf_index is not None else int(os.getenv('HF_INDEX', 0)) + ds_token = os.getenv('HF_TOKEN', None) + ds_cache = os.getenv('DATASET_CACHE', None) + if ds_name is None or ds_label is None: + raise ValueError("CounterfactualLossGeneralComponents: Provide src_img_path or HF dataset parameters (hf_dataset_name & hf_label).") + ds_label = int(ds_label) + self.load_src_img_from_hf( + dataset_name = ds_name, + split = ds_split, + label = ds_label, + index = ds_index, + hf_token = ds_token, + cache_dir = ds_cache, + image_size = image_size, + ) self.clf_pred_label = self.get_clf_pred_label() def forward(self, x): @@ -71,7 +103,18 @@ def get_clf_pred_label(self, img = None): else: output_probs = F.softmax(output, dim = 1) output_prob = output_probs[:, self.classifier.query_label] - + # For multiclass, threshold 0.5 is often too strict (1k classes). + # Use argmax to decide if query_label is the predicted class. + if self.classifier.task == 'multiclass_classification': + pred_idx = output_probs.argmax(dim = 1) + label = (pred_idx == self.classifier.query_label).long().item() + if save_src_img_preds: + self.src_img_probs = output_probs + self.src_img_logits = output + log.info(f"Argmax predicted class: {pred_idx.item()}, query_label: {self.classifier.query_label}") + log.info(f"Probability of query class: {output_prob.item()}") + return label + if save_src_img_preds: self.src_img_probs = output_probs self.src_img_logits = output @@ -90,17 +133,62 @@ def get_clf_pred_label(self, img = None): if save_src_img_preds: self.src_img_probs = output_probs self.src_img_logits = output - + label = 1 if output_prob > 0.5 else 0 log.info(f"Class predicted for source image: {label}") log.info(f"Probability of positive class: {output_prob.item()}") return label - + def load_src_img(self, path, device): src_img = torchvision.io.read_image(path).unsqueeze(0) / 255 src_img = (src_img - 0.5) * 2 self.src_img = src_img.to(self.device) + @torch.no_grad() + def load_src_img_from_hf( + self, + dataset_name: str, + split: str, + label: int, + index: int = 0, + hf_token: str = None, + cache_dir: str = None, + image_size: int = 256, + ): + log.info(f'Loading source image from HF dataset: {dataset_name}, split: {split}, label: {label}, index: {index}') + # Use streaming to avoid slow filter materialization + stream_kwargs = {"streaming": True} + if hf_token is not None: + stream_kwargs["token"] = hf_token + if cache_dir is not None: + stream_kwargs["cache_dir"] = cache_dir + ds_stream = load_dataset(dataset_name, split=split, **stream_kwargs) + + found = -1 + sample = None + for ex in ds_stream: + if ex.get("label", None) == label: + found += 1 + if found == index: + sample = ex + break + if sample is None: + raise ValueError(f"Could not find occurrence index {index} for label {label} in {dataset_name}/{split}") + img = sample['image'] + if not isinstance(img, Image.Image): + img = Image.fromarray(img) + if img.mode != 'RGB': + img = img.convert('RGB') + transform = T.Compose([ + T.CenterCrop(image_size), + T.Resize(image_size), + T.ToTensor(), # [0,1] + ]) + src_img = transform(img).unsqueeze(0) + # Scale to [-1,1] for loss computations (classifier handles rescale internally if needed) + src_img = (src_img - 0.5) * 2 + self.src_img = src_img.to(self.device) + diff --git a/gcd/src/losses/general_multicomponent_proxy_loss.py b/gcd/src/losses/general_multicomponent_proxy_loss.py index 7302d37..0edd0dc 100644 --- a/gcd/src/losses/general_multicomponent_proxy_loss.py +++ b/gcd/src/losses/general_multicomponent_proxy_loss.py @@ -30,8 +30,8 @@ def forward(self, inputs, targets, labels: list = None): pred_lpips = inputs['lpips'] target_lpips = targets['lpips'] - assert pred_preds.shape == target_preds.shape - assert pred_lpips.shape == target_lpips.shape + assert pred_preds.shape == target_preds.shape, f"Invalid shapes: {pred_preds.shape = } and {target_preds.shape = }" + assert pred_lpips.shape == target_lpips.shape, f"Invalid shapes: {pred_lpips.shape = } and {target_lpips.shape = }" loss = self.weight_cls * F.mse_loss(pred_preds, target_preds) + \ self.weight_lpips * F.mse_loss(pred_lpips, target_lpips) diff --git a/gcd/src/main.py b/gcd/src/main.py index a5b74a4..678c513 100644 --- a/gcd/src/main.py +++ b/gcd/src/main.py @@ -5,6 +5,9 @@ from omegaconf import DictConfig from utils import extract_output_dir_path_from_config, set_seed +from dotenv import load_dotenv + +load_dotenv() log = logging.getLogger(__name__) diff --git a/gcd/src/strategies/single_image_gmc_mlp_proxy_training.py b/gcd/src/strategies/single_image_gmc_mlp_proxy_training.py index 442f173..66ccaac 100644 --- a/gcd/src/strategies/single_image_gmc_mlp_proxy_training.py +++ b/gcd/src/strategies/single_image_gmc_mlp_proxy_training.py @@ -11,6 +11,10 @@ from losses import CounterfactualLossFromGeneralComponents from strategies import Strategy from utils import * +from datasets import load_dataset +import torchvision.transforms as T +from PIL import Image +import os import logging logging.basicConfig(level = logging.INFO) @@ -32,7 +36,6 @@ class SingleImageGMCMLPProxyTraining(Strategy): def __init__( self, - src_img_path: str, n_iters: int, n_steps_dae: int, n_steps_proxy: int, @@ -43,10 +46,22 @@ def __init__( n_hessian_eigenvecs: int, device: str, output_dir: str, - mc_mlp_proxy_kwargs: dict, + mc_mlp_proxy_kwargs: dict, dae_kwargs: dict, ce_loss_kwargs: dict, - dae_type: str = 'default'): + src_img_path: str = None, + dae_type: str = 'default', + # Optional HF dataset source for picking the source image by label + hf_dataset_name: str = None, # e.g., 'imagenet-1k' + hf_split: str = 'train', + hf_label: int = None, # required if using HF + hf_index: int = 0, # which occurrence of that label to use + # Optional prediction-based filtering: pick a source image the classifier predicts correctly + hf_predictions_csv: str = None, # CSV with index 'idx' and column 'pred_label' + hf_pred_index_scope: str = 'global', # 'global' (enumerate whole split) or 'label' (enumerate only matching label) + hf_n_skip_filtered: int = 0, # skip first N after filtering (sharding/offset) + hf_filter_by_predictions: bool = False, # enable filtering + ): """ src_img_path - path of the source image n_iters - number of outer loop iterations @@ -56,7 +71,7 @@ def __init__( n_steps_proxy - number of epochs for proxy training in each outer iteration Note: the size of the dataset for proxy training is n_steps_dae * self.dae.batch_size * n_iters min_max_step_size - coefficients that determine the min and max size of the gradient step, - intermediate steps are distributed uniformly between these values and total number of + intermediate steps are distributed uniformly between these values and total number of step sizes is equal to self.dae.batch_size center_to_src_latent_sem - whether to center all latent sems at src_latent_sem n_hessian_eigenvecs - number of hessian eigenvectors used in line search @@ -74,12 +89,30 @@ def __init__( self.device_cpu = torch.device('cpu') self.output_dir = Path(output_dir) self.output_dir.mkdir(parents = True) - + self.dae = self.get_dae_class(dae_type)(**dae_kwargs) self.proxy = GeneralMulticomponentMLP(**mc_mlp_proxy_kwargs) self.ce_loss = CounterfactualLossFromGeneralComponents(**ce_loss_kwargs) - self.load_src_img(src_img_path) + # Load source image either from local path or HF dataset by label + if src_img_path is not None: + self.load_src_img(src_img_path) + elif hf_dataset_name is not None and hf_label is not None: + self.load_src_img_from_hf( + dataset_name = hf_dataset_name, + split = hf_split, + label = hf_label, + index = hf_index, + predictions_csv = hf_predictions_csv, + pred_index_scope = hf_pred_index_scope, + n_skip_filtered = hf_n_skip_filtered, + filter_by_predictions = hf_filter_by_predictions, + hf_token = os.getenv('HF_TOKEN', None), + cache_dir = os.getenv('DATASET_CACHE', None), + image_size = self.dae.config.img_size if hasattr(self.dae, 'config') else 256 + ) + else: + raise ValueError("Provide either src_img_path or (hf_dataset_name and hf_label).") self.clear_proxy_training_data() def get_dae_class(self, dae_type): @@ -96,6 +129,98 @@ def load_src_img(self, path): src_img = (src_img - 0.5) * 2 self.src_img = src_img.to(self.device) + def load_src_img_from_hf( + self, + dataset_name: str, + split: str, + label: int, + index: int = 0, + predictions_csv: str = None, + pred_index_scope: str = 'label', # 'global' | 'label' + n_skip_filtered: int = 0, + filter_by_predictions: bool = False, + hf_token: str = None, + cache_dir: str = None, + image_size: int = 256, + ): + """ + Load one source image from a HF dataset (e.g., 'imagenet-1k') by label id. + If filter_by_predictions is True, uses predictions_csv to keep only samples correctly + predicted as 'label'. You can choose how indices are matched via pred_index_scope: + - 'global': idx equals enumerate() over whole split (0-based) + - 'label' : idx equals enumerate() over only samples with ex['label'] == label (0-based) + """ + log.info(f'Loading source image from HF dataset: {dataset_name}, split: {split}, label: {label}, index: {index}') + # Use streaming iteration to avoid materializing a full filtered copy (much faster) + stream_kwargs = {"streaming": True} + if hf_token is not None: + stream_kwargs["token"] = hf_token + if cache_dir is not None: + stream_kwargs["cache_dir"] = cache_dir + ds_stream = load_dataset(dataset_name, split=split, **stream_kwargs) + + preds_df = None + if filter_by_predictions: + if predictions_csv is None or not os.path.exists(predictions_csv): + raise FileNotFoundError(f"Filtering by predictions requested but CSV not found: {predictions_csv}") + preds_df = pd.read_csv(predictions_csv, index_col='idx') + if 'pred_label' not in preds_df.columns: + raise ValueError(f"Predictions CSV missing 'pred_label' column: {predictions_csv}") + + label_enum = -1 # counts only samples with matching label + global_enum = -1 # counts all samples + filtered_seen = -1 # counts only samples that passed predictions filter + sample = None + for ex in ds_stream: + global_enum += 1 + # keep only target label + if ex.get("label", None) != label: + continue + label_enum += 1 + + if filter_by_predictions: + idx_to_check = global_enum if pred_index_scope == 'global' else label_enum + if idx_to_check not in preds_df.index: + continue + pred_label = int(preds_df.loc[idx_to_check, 'pred_label']) + if pred_label != int(label): + continue + filtered_seen += 1 + if filtered_seen < n_skip_filtered: + continue + # pick by filtered order + if filtered_seen == index: + sample = ex + break + else: + # original behavior: pick by label occurrence + if label_enum == index: + sample = ex + break + if sample is None: + if filter_by_predictions: + raise ValueError(f"Could not find filtered occurrence index {index} for label {label} in {dataset_name}/{split} " + f"(pred_index_scope={pred_index_scope}, n_skip_filtered={n_skip_filtered})") + else: + raise ValueError(f"Could not find occurrence index {index} for label {label} in {dataset_name}/{split}") + img = sample['image'] + if not isinstance(img, Image.Image): + img = Image.fromarray(img) + if img.mode != 'RGB': + img = img.convert('RGB') + # Resize/center-crop to target image_size + transform = T.Compose([ + T.CenterCrop(image_size), + T.Resize(image_size), + T.ToTensor(), # [0,1] + ]) + src_img = transform(img).unsqueeze(0) + # Save [0,1] preview + torchvision.utils.save_image(src_img, self.output_dir / 'src_img.png') + # Scale to [-1,1] for DAE + src_img = (src_img - 0.5) * 2 + self.src_img = src_img.to(self.device) + def get_src_img(self, scale = True): if scale: return (self.src_img + 1) / 2 @@ -134,7 +259,7 @@ def dae_step(self, step_idx, iter): log.info(f'Saving synthetic images to {output_dir}') torchvision.utils.save_image(batch_imgs, output_dir / 'imgs.png') # NOTE: DAE requires input to be in [-1, 1] but outputs imgs in [0, 1] - + # Compute counterfactual loss components log.info('Calculating counterfactual loss components') # In general case, we use counterfactual loss to provide us with @@ -147,20 +272,20 @@ def dae_step(self, step_idx, iter): batch_predictions = batch_components['predictions'] batch_probs = self.ce_loss.get_query_label_probability(batch_predictions) batch_logits = self.ce_loss.get_query_label_logit(batch_predictions) - + # Log pareto fronts log_wandb_scatter( - batch_logits.flatten().numpy(force = True), - batch_lpips.flatten().numpy(force = True), - 'logits from classifier', 'lpips', + batch_logits.flatten().numpy(force = True), + batch_lpips.flatten().numpy(force = True), + 'logits from classifier', 'lpips', f'dae/pareto/logits/iter:{iter}/{step_idx}') log_wandb_scatter( - batch_probs.flatten().numpy(force = True), - batch_lpips.flatten().numpy(force = True), - 'probabilities from classifier', 'lpips', + batch_probs.flatten().numpy(force = True), + batch_lpips.flatten().numpy(force = True), + 'probabilities from classifier', 'lpips', f'dae/pareto/probs/iter:{iter}/{step_idx}') - # Log info + # Log info ce_ids = batch_probs < 0.5 if self.ce_loss.components.clf_pred_label == 1 else batch_probs > 0.5 log.info(f'Number of CEs in synthetic images: {batch_probs[ce_ids].shape[0]}') log.info(f'Saving info file to {output_dir}') @@ -170,7 +295,7 @@ def dae_step(self, step_idx, iter): batch_lpips.numpy(force = True).flatten(), ce_ids.numpy(force = True)]).T df = pd.DataFrame( - df_data, + df_data, columns = ['norm', 'clf_prob', 'lpips', 'is_ce']) df.to_csv(output_dir / 'info.csv') @@ -197,7 +322,7 @@ def dae_step(self, step_idx, iter): torchvision.utils.save_image(diffs_ces_grid, output_dir / 'diffs_ces.png') # Save proxy training data - log.info('Saving data for proxy training') + log.info('Saving data for proxy training') if self.center_to_src_latent_sem: batch_latent_sem = batch_latent_sem - self.src_latent_sem self.save_proxy_training_data(batch_latent_sem, batch_components) @@ -211,7 +336,7 @@ def pre_proxy_loop(self, iter): def proxy_step(self, step_idx, iter): log.info(f'Step: {step_idx}') self.proxy.run_epoch(self.proxy_training_data, step_idx, iter) - + def post_proxy_loop(self, iter): # Take steps of different magnitude in negative gradient direction aka line search # Gradient is calculated for one or more (weight_cls, weight_lpips) pairs and @@ -220,12 +345,14 @@ def post_proxy_loop(self, iter): chunk_size = self.dae.batch_size // len(grads_dict) n_directions = len(self.weight_cls_lpips_pairs) + self.n_hessian_eigenvecs step_sizes_stack = torch.linspace( - *self.min_max_step_size, + *self.min_max_step_size, chunk_size, device = self.device).repeat(n_directions).unsqueeze(1) grads_stack = torch.stack([*grads_dict.values()]).repeat_interleave(chunk_size, 0).squeeze() batch_latent_sem = self.src_latent_sem - step_sizes_stack * grads_stack - batch_latent_ddim = self.dae.make_batch_latent_ddim(self.src_latent_ddim) + # Match DDIM batch to the exact number of latent_sem samples (may be < dae.batch_size) + n_ls = batch_latent_sem.shape[0] + batch_latent_ddim = self.src_latent_ddim.repeat(n_ls, 1, 1, 1) # Generate images from line search latent sems log.info('Rendering line search images') @@ -252,9 +379,9 @@ def post_proxy_loop(self, iter): weights_pairs += [(1.0, 0.0) for _ in range(chunk_size * self.n_hessian_eigenvecs)] self.proxy.line_search_validation( - batch_latent_sem = batch_latent_sem, - batch_components = batch_components, - query_label = self.ce_loss.components.classifier.query_label, + batch_latent_sem = batch_latent_sem, + batch_components = batch_components, + query_label = self.ce_loss.components.classifier.query_label, iter = iter, step_sizes = step_sizes_stack, weights_pairs = weights_pairs, @@ -263,7 +390,7 @@ def post_proxy_loop(self, iter): # Calculate counterfactual loss on line search images batch_losses = self.ce_loss( - batch_components, + batch_components, pos_label_idx = self.ce_loss.components.classifier.query_label) # Log info about line search @@ -279,14 +406,14 @@ def post_proxy_loop(self, iter): # Log pareto fronts log_wandb_scatter( - batch_logits.flatten().numpy(force = True), - batch_lpips.flatten().numpy(force = True), - 'logits from classifier', 'lpips', + batch_logits.flatten().numpy(force = True), + batch_lpips.flatten().numpy(force = True), + 'logits from classifier', 'lpips', f'proxy/line_search/pareto/logits/iter:{iter}') log_wandb_scatter( - batch_probs.flatten().numpy(force = True), - batch_lpips.flatten().numpy(force = True), - 'probabilities from classifier', 'lpips', + batch_probs.flatten().numpy(force = True), + batch_lpips.flatten().numpy(force = True), + 'probabilities from classifier', 'lpips', f'proxy/line_search/pareto/probs/iter:{iter}') # Log info and counterfactuals if any @@ -303,10 +430,10 @@ def post_proxy_loop(self, iter): batch_lpips.numpy(force = True).flatten(), ce_ids.numpy(force = True)]).T df = pd.DataFrame( - df_data, + df_data, columns = ['step_size', 'from_hess_vec', 'weight_cls', 'weight_lps', 'clf_prob', 'lpips', 'is_ce']) df.to_csv(output_dir / 'info.csv') - + src_img = self.get_src_img() diffs = (batch_imgs - src_img).abs() diffs_scaled = diffs / diffs.amax(dim = (1, 2, 3)).view(-1, 1, 1, 1) @@ -361,7 +488,7 @@ def get_grads_dict(self): log.info(f'Computing hessian for weight_cls: 1.0 and weight_lpips: 0.0') input.requires_grad_() func = lambda x: self.ce_loss( - self.proxy(x), + self.proxy(x), pos_label_idx = self.ce_loss.components.classifier.query_label, weight_cls = 1.0, weight_lpips = 0.0) @@ -401,10 +528,9 @@ def get_grads_dict(self): def save_proxy_training_data(self, batch_latent_sem, batch_components): self.proxy_training_data['latent_sem'].append(batch_latent_sem) self.proxy_training_data['components'].append(batch_components) - + def clear_proxy_training_data(self): self.proxy.increment_data_save_counter() self.proxy_training_data = { 'latent_sem': [], 'components': []} - \ No newline at end of file diff --git a/gcd/src/tools/copy_and_rename_pairs.py b/gcd/src/tools/copy_and_rename_pairs.py new file mode 100644 index 0000000..42fd8b2 --- /dev/null +++ b/gcd/src/tools/copy_and_rename_pairs.py @@ -0,0 +1,67 @@ +import argparse +import os +import re +import shutil +from pathlib import Path + + +def main(): + parser = argparse.ArgumentParser( + description="Copy direction_transfer outputs to a new folder with standardized names." + ) + parser.add_argument("--src-dir", required=True, help="Path to direction_transfer subdir (contains original_*.png, inpaint_*.png)") + parser.add_argument("--dst-dir", required=True, help="Destination directory to write renamed copies") + parser.add_argument("--only-pairs", action="store_true", help="Copy only when both original_k.png and inpaint_k.png exist") + args = parser.parse_args() + + src = Path(args.src_dir) + dst = Path(args.dst_dir) + dst.mkdir(parents=True, exist_ok=True) + + # Patterns: original_1.png, inpaint_1.png + pat_orig = re.compile(r"^original_(\d+)\.png$") + pat_inpt = re.compile(r"^inpaint_(\d+)\.png$") + + indices_orig = {} + indices_inpt = {} + + for entry in os.listdir(src): + m1 = pat_orig.match(entry) + if m1: + idx = int(m1.group(1)) + indices_orig[idx] = src / entry + continue + m2 = pat_inpt.match(entry) + if m2: + idx = int(m2.group(1)) + indices_inpt[idx] = src / entry + + # Decide which indices to process + if args.only_pairs: + all_idx = sorted(set(indices_orig.keys()).intersection(indices_inpt.keys())) + else: + all_idx = sorted(set(indices_orig.keys()).union(indices_inpt.keys())) + + copied_images = 0 + copied_inpaints = 0 + + for i in all_idx: + z = str(i).zfill(5) + if i in indices_orig: + src_path = indices_orig[i] + tgt_path = dst / f"images_{z}.png" + shutil.copy2(src_path, tgt_path) + copied_images += 1 + if i in indices_inpt: + src_path = indices_inpt[i] + tgt_path = dst / f"inpaints_{z}.png" + shutil.copy2(src_path, tgt_path) + copied_inpaints += 1 + + print(f"Done. Wrote {copied_images} images_*.png and {copied_inpaints} inpaints_*.png into {dst}") + + +if __name__ == "__main__": + main() + +