From fd4d99894d84d1b29667d8e9efa4785753a07944 Mon Sep 17 00:00:00 2001 From: galkowskim Date: Sat, 15 Nov 2025 19:43:00 +0100 Subject: [PATCH 01/24] update gcd repo --- .../imagenet/resnet/sorrel/config.yaml | 73 +++++++++++++++++ .../imagenet/resnet/zebra/config.yaml | 73 +++++++++++++++++ gcd/src/classifiers/__init__.py | 3 +- gcd/src/classifiers/imagenet.py | 78 +++++++++++++++++++ gcd/src/dae/chexpert/utils/make_config.py | 3 + gcd/src/dae/default/templates.py | 16 ++++ gcd/src/dae/default/utils/make_config.py | 3 + 7 files changed, 248 insertions(+), 1 deletion(-) create mode 100644 gcd/configs/single_image_gmc_mlp_proxy_training/imagenet/resnet/sorrel/config.yaml create mode 100644 gcd/configs/single_image_gmc_mlp_proxy_training/imagenet/resnet/zebra/config.yaml create mode 100644 gcd/src/classifiers/imagenet.py diff --git a/gcd/configs/single_image_gmc_mlp_proxy_training/imagenet/resnet/sorrel/config.yaml b/gcd/configs/single_image_gmc_mlp_proxy_training/imagenet/resnet/sorrel/config.yaml new file mode 100644 index 0000000..d9d3015 --- /dev/null +++ b/gcd/configs/single_image_gmc_mlp_proxy_training/imagenet/resnet/sorrel/config.yaml @@ -0,0 +1,73 @@ +mode: RUN +seed: 0 +device: cuda:0 +output_dir: + +strategy: + _target_: strategies.SingleImageGMCMLPProxyTraining + src_img_path: ??? # path to source sorrel image + n_iters: 20 + n_steps_dae: 8 + n_steps_proxy: 512 + min_max_step_size: [0.0005, 2.0] + device: ${device} + output_dir: ${output_dir}/strategy + center_to_src_latent_sem: true + line_search_weight_cls_list: [1.0, 0.05] + line_search_weight_lpips_list: [0, 1.0] + n_hessian_eigenvecs: 16 + dae_type: default + + dae_kwargs: + config_name: imagenet256_autoenc + batch_size: 256 + std: [0.01, 0.05, 0.1, 0.2] + distribution: uniform_norm + max_norm: 1.0 + T_encode: 250 + T_render: 50 + path_ckpt: ??? # path to your trained DiffAE checkpoint (.ckpt) + device: ${device} + output_dir: ${output_dir}/dae + + mc_mlp_proxy_kwargs: + shapes: [512, 256, 128, 64, 2] + activ_fn: sigmoid + batch_size: 128 + n_val_batches: 1 + val_loss_history_length: 5 + optimizer: + _partial_: true + _target_: torch.optim.AdamW + lr: 0.001 + loss_kwargs: + weight_cls: 1.0 + weight_lpips: 3.0 + device: ${device} + output_dir: ${output_dir}/proxy + + ce_loss_kwargs: + weight_cls: 1.0 + weight_lpips: 0 + device: ${device} + output_dir: ${output_dir}/loss + components: + _target_: losses.CounterfactualLossGeneralComponents + lpips_net: vgg + src_img_path: ${strategy.src_img_path} + device: ${device} + clf: + _target_: classifiers.ImageNetResNet + img_size: 256 + query_label: 339 # sorrel class id + use_softmax_and_query_label: false + task: multiclass_classification + model_name: resnet50 + weights: IMAGENET1K_V2 + path_to_weights: null + +wandb: + project: null + group: null + name: imagenet_sorrel + diff --git a/gcd/configs/single_image_gmc_mlp_proxy_training/imagenet/resnet/zebra/config.yaml b/gcd/configs/single_image_gmc_mlp_proxy_training/imagenet/resnet/zebra/config.yaml new file mode 100644 index 0000000..5507cf8 --- /dev/null +++ b/gcd/configs/single_image_gmc_mlp_proxy_training/imagenet/resnet/zebra/config.yaml @@ -0,0 +1,73 @@ +mode: RUN +seed: 0 +device: cuda:0 +output_dir: + +strategy: + _target_: strategies.SingleImageGMCMLPProxyTraining + src_img_path: ??? # path to source zebra image (.png or .jpg) + n_iters: 20 + n_steps_dae: 8 + n_steps_proxy: 512 + min_max_step_size: [0.0005, 2.0] + device: ${device} + output_dir: ${output_dir}/strategy + center_to_src_latent_sem: true + line_search_weight_cls_list: [1.0, 0.05] + line_search_weight_lpips_list: [0, 1.0] + n_hessian_eigenvecs: 16 + dae_type: default + + dae_kwargs: + config_name: imagenet256_autoenc + batch_size: 256 + std: [0.01, 0.05, 0.1, 0.2] + distribution: uniform_norm + max_norm: 1.0 + T_encode: 250 + T_render: 50 + path_ckpt: ??? # path to your trained DiffAE checkpoint (.ckpt) + device: ${device} + output_dir: ${output_dir}/dae + + mc_mlp_proxy_kwargs: + shapes: [512, 256, 128, 64, 2] + activ_fn: sigmoid + batch_size: 128 + n_val_batches: 1 + val_loss_history_length: 5 + optimizer: + _partial_: true + _target_: torch.optim.AdamW + lr: 0.001 + loss_kwargs: + weight_cls: 1.0 + weight_lpips: 3.0 + device: ${device} + output_dir: ${output_dir}/proxy + + ce_loss_kwargs: + weight_cls: 1.0 + weight_lpips: 0 + device: ${device} + output_dir: ${output_dir}/loss + components: + _target_: losses.CounterfactualLossGeneralComponents + lpips_net: vgg + src_img_path: ${strategy.src_img_path} + device: ${device} + clf: + _target_: classifiers.ImageNetResNet + img_size: 256 + query_label: 340 # zebra class id + use_softmax_and_query_label: false + task: multiclass_classification + model_name: resnet50 + weights: IMAGENET1K_V2 + path_to_weights: null # optional: path to custom weights + +wandb: + project: null + group: null + name: imagenet_zebra + diff --git a/gcd/src/classifiers/__init__.py b/gcd/src/classifiers/__init__.py index 408a9a4..a83850b 100644 --- a/gcd/src/classifiers/__init__.py +++ b/gcd/src/classifiers/__init__.py @@ -1,2 +1,3 @@ from .dense_net import DenseNet -from .dense_net_chexpert import DenseNetCheXpert \ No newline at end of file +from .dense_net_chexpert import DenseNetCheXpert +from .imagenet import ImageNetResNet \ No newline at end of file diff --git a/gcd/src/classifiers/imagenet.py b/gcd/src/classifiers/imagenet.py new file mode 100644 index 0000000..39c47cf --- /dev/null +++ b/gcd/src/classifiers/imagenet.py @@ -0,0 +1,78 @@ +import torch +import torch.nn.functional as F +import torchvision +import torchvision.transforms as tt + +import logging +logging.basicConfig(level = logging.INFO) +log = logging.getLogger(__name__) + +IMAGENET_MEAN = (0.485, 0.456, 0.406) +IMAGENET_STD = (0.229, 0.224, 0.225) + +class ImageNetResNet(torch.nn.Module): + def __init__( + self, + img_size: int = 256, + query_label: int = 340, + use_softmax_and_query_label: bool = False, + task: str = 'multiclass_classification', + model_name: str = 'resnet50', + weights: str = 'IMAGENET1K_V2', + path_to_weights: str = None): + super().__init__() + self.img_size = img_size + self.task = task + self.query_label = int(query_label) + self.use_softmax_and_query_label = use_softmax_and_query_label + self.transforms = tt.Compose([ + tt.Normalize(mean = IMAGENET_MEAN, std = IMAGENET_STD) + ]) + self.model = self._build_model(model_name, weights, path_to_weights) + self.model.eval() + + def _build_model(self, model_name, weights, path_to_weights): + log.info(f'Building torchvision model: {model_name}') + if hasattr(torchvision.models, model_name): + ctor = getattr(torchvision.models, model_name) + else: + raise ValueError(f'Unknown model_name: {model_name}') + if path_to_weights is None: + # Load torchvision pretrained weights + try: + w_enum = getattr(torchvision.models, f'{model_name.upper()}_Weights') + w = getattr(w_enum, weights) + model = ctor(weights = w) + except Exception: + log.warning('Falling back to default pretrained weights') + model = ctor(weights = 'IMAGENET1K_V2') + else: + model = ctor(weights = None) + log.info(f'Loading classifier weights from: {path_to_weights}') + ckpt = torch.load(path_to_weights, map_location = 'cpu') + state_dict = ckpt.get('state_dict', ckpt) + # if state_dict keys are prefixed (e.g., 'model.'), try to strip common prefixes + if any(k.startswith('model.') for k in state_dict.keys()): + state_dict = {k.split('model.', 1)[1]: v for k, v in state_dict.items() if k.startswith('model.')} + missing, unexpected = model.load_state_dict(state_dict, strict = False) + if missing: + log.warning(f'Missing keys when loading classifier: {len(missing)}') + if unexpected: + log.warning(f'Unexpected keys when loading classifier: {len(unexpected)}') + return model + + @torch.no_grad() + def forward(self, x): + assert x.shape[-1] == self.img_size, 'Wrong input shape' + # Expect [0,1] range; rescale if in [-1,1] + if x.min() < 0: + assert x.min() >= -1. and x.max() <= 1. + log.info('Detected input outside [0,1], rescaling from [-1,1] to [0,1]') + x = (x + 1) / 2 + x = self.transforms(x) + logits = self.model(x) + if self.use_softmax_and_query_label: + probs = F.softmax(logits, dim = 1) + return probs[:, self.query_label] + return logits + diff --git a/gcd/src/dae/chexpert/utils/make_config.py b/gcd/src/dae/chexpert/utils/make_config.py index 1fda1d4..b67a41c 100644 --- a/gcd/src/dae/chexpert/utils/make_config.py +++ b/gcd/src/dae/chexpert/utils/make_config.py @@ -8,5 +8,8 @@ def make_config(config_name): if config_name == 'chexpert224_autoenc': log.info(f'Using config: {config_name}') return chexpert224_autoenc() + elif config_name == 'imagenet256_autoenc': + log.info(f'Using config: {config_name}') + return imagenet256_autoenc() else: raise NotImplementedError('Invalid config name or option not implemented yet') \ No newline at end of file diff --git a/gcd/src/dae/default/templates.py b/gcd/src/dae/default/templates.py index 5066961..1f13a33 100644 --- a/gcd/src/dae/default/templates.py +++ b/gcd/src/dae/default/templates.py @@ -66,4 +66,20 @@ def celeba128_autoenc(): conf.eval_every_samples = 10_000_000 conf.data_name = 'celeba128' conf.name = 'celeba128_autoenc' + return conf + +def imagenet256_autoenc(): + conf = ffhq256_autoenc() + conf.img_size = 256 + conf.net_ch = 128 + conf.net_ch_mult = (1, 1, 2, 2, 4, 4) + conf.net_enc_channel_mult = (1, 1, 2, 2, 4, 4, 4) + # Large defaults are OK for inference config stub; not used for training here + conf.eval_every_samples = 10_000_000 + conf.eval_ema_every_samples = 10_000_000 + conf.total_samples = 500_000_000 + conf.batch_size = 64 + conf.make_model_conf() + conf.name = 'imagenet256_autoenc' + conf.data_name = 'imagenet' return conf \ No newline at end of file diff --git a/gcd/src/dae/default/utils/make_config.py b/gcd/src/dae/default/utils/make_config.py index 10205ff..5cf4ab9 100644 --- a/gcd/src/dae/default/utils/make_config.py +++ b/gcd/src/dae/default/utils/make_config.py @@ -11,5 +11,8 @@ def make_config(config_name): elif config_name == 'celeba128_autoenc': log.info(f'Using config: {config_name}') return celeba128_autoenc() + elif config_name == 'imagenet256_autoenc': + log.info(f'Using config: {config_name}') + return imagenet256_autoenc() else: raise NotImplementedError('Invalid config name or option not implemented yet') \ No newline at end of file From 39e064e6269d660864f951dc751a72222771b983 Mon Sep 17 00:00:00 2001 From: galkowskim Date: Sat, 15 Nov 2025 19:49:38 +0100 Subject: [PATCH 02/24] add commands --- commands.md | 73 +++++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 73 insertions(+) create mode 100644 commands.md diff --git a/commands.md b/commands.md new file mode 100644 index 0000000..022e580 --- /dev/null +++ b/commands.md @@ -0,0 +1,73 @@ +### GCD ImageNet (zebra ↔ sorrel) — quick commands + +Prereqs +- You already trained a DiffAE autoencoder at 256×256 and have its checkpoint path (last.ckpt). +- You have a 256×256 source image (.png/.jpg) for zebra or sorrel. +- Optional: a custom ImageNet classifier checkpoint (otherwise torchvision ResNet-50 pretrained is used). + +Setup (one-time) +```bash +cd /Users/mgalkowski/Desktop/diffae/gcd +python -m pip install -r requirements.txt +``` + +Set variables +```bash +# Path to your DiffAE autoencoder EMA checkpoint (last.ckpt) +export DAE_CKPT="/absolute/path/to/checkpoints//last.ckpt" + +# Source image for the run (256×256 PNG/JPG) +export SRC_IMG="/absolute/path/to/source_image.png" +``` + +Run: zebra (class 340, zebra→not‑zebra direction) +```bash +cd /Users/mgalkowski/Desktop/diffae/gcd/gcd +python src/main.py \ + --config-path ../configs/single_image_gmc_mlp_proxy_training/imagenet/resnet/zebra \ + --config-name config \ + strategy.src_img_path="$SRC_IMG" \ + strategy.dae_kwargs.path_ckpt="$DAE_CKPT" \ + device=cuda:0 \ + wandb.project=null +``` + +Run: sorrel (class 339, sorrel→not‑sorrel direction) +```bash +cd /Users/mgalkowski/Desktop/diffae/gcd/gcd +python src/main.py \ + --config-path ../configs/single_image_gmc_mlp_proxy_training/imagenet/resnet/sorrel \ + --config-name config \ + strategy.src_img_path="$SRC_IMG" \ + strategy.dae_kwargs.path_ckpt="$DAE_CKPT" \ + device=cuda:0 \ + wandb.project=null +``` + +Notes +- The configs use the default 3‑channel DiffAE wrapper and a ResNet‑50 ImageNet classifier. +- To use a custom classifier, add at the end of the command: + - `strategy.ce_loss_kwargs.components.clf.path_to_weights="/path/to/classifier.ckpt"` +- Outputs are written under `gcd/gcd/outputs//strategy/`. Proxy data and candidate directions are in `strategy/proxy/...`. + +Direction transfer (apply a saved direction to another image) +```bash +# Example placeholders — update paths after a run completes +export LOG_DIR="/Users/mgalkowski/Desktop/diffae/gcd/gcd/outputs//strategy" +export DIR_PATH="$LOG_DIR/proxy//.pt" # e.g., *_grad_*.pt +export TARGET_IMG="/absolute/path/to/another_image.png" + +python src/direction_transfer.py \ + --log-dir-path "$LOG_DIR" \ + --direction-path "$DIR_PATH" \ + --img-path "$TARGET_IMG" \ + --subdir-name "transfer_run_1" +``` + +Tips for memory/time +- If you see OOM, reduce: `strategy.dae_kwargs.batch_size` (e.g., 128) or narrow `strategy.dae_kwargs.std`. +- To speed up renders, you can lower `strategy.dae_kwargs.T_render` (e.g., 20–50). + +Switching class direction +- Zebra run uses `query_label=340`; sorrel run uses `query_label=339`. Use the matching config. + From ada3babecae4c48af563fd19d7a09b8534bbac97 Mon Sep 17 00:00:00 2001 From: galkowskim Date: Sat, 15 Nov 2025 20:21:18 +0100 Subject: [PATCH 03/24] update commands + add posibility to use hf datasets --- commands.md | 17 +++ .../single_image_gmc_mlp_proxy_training.py | 134 +++++++++++++----- 2 files changed, 117 insertions(+), 34 deletions(-) diff --git a/commands.md b/commands.md index 022e580..5f698e1 100644 --- a/commands.md +++ b/commands.md @@ -32,6 +32,23 @@ python src/main.py \ wandb.project=null ``` +Zebra example (label 340): +``` +python src/main.py \ + --config-path ../configs/single_image_gmc_mlp_proxy_training/imagenet/resnet/zebra \ + --config-name config \ + strategy.src_img_path=null \ + strategy.hf_dataset_name=imagenet-1k \ + strategy.hf_split=train \ + strategy.hf_label=340 \ + strategy.hf_index=0 \ + strategy.hf_token=$HF_TOKEN \ # optional + strategy.hf_cache_dir=/path/to/hf_cache \ # optional + strategy.dae_kwargs.path_ckpt="$DAE_CKPT" \ + device=cuda:0 \ + wandb.project=null +``` + Run: sorrel (class 339, sorrel→not‑sorrel direction) ```bash cd /Users/mgalkowski/Desktop/diffae/gcd/gcd diff --git a/gcd/src/strategies/single_image_gmc_mlp_proxy_training.py b/gcd/src/strategies/single_image_gmc_mlp_proxy_training.py index 442f173..d845c8f 100644 --- a/gcd/src/strategies/single_image_gmc_mlp_proxy_training.py +++ b/gcd/src/strategies/single_image_gmc_mlp_proxy_training.py @@ -11,6 +11,9 @@ from losses import CounterfactualLossFromGeneralComponents from strategies import Strategy from utils import * +from datasets import load_dataset +import torchvision.transforms as T +from PIL import Image import logging logging.basicConfig(level = logging.INFO) @@ -32,7 +35,7 @@ class SingleImageGMCMLPProxyTraining(Strategy): def __init__( self, - src_img_path: str, + src_img_path: str = None, n_iters: int, n_steps_dae: int, n_steps_proxy: int, @@ -43,10 +46,17 @@ def __init__( n_hessian_eigenvecs: int, device: str, output_dir: str, - mc_mlp_proxy_kwargs: dict, + mc_mlp_proxy_kwargs: dict, dae_kwargs: dict, ce_loss_kwargs: dict, - dae_type: str = 'default'): + dae_type: str = 'default', + # Optional HF dataset source for picking the source image by label + hf_dataset_name: str = None, # e.g., 'imagenet-1k' + hf_split: str = 'train', + hf_label: int = None, # required if using HF + hf_index: int = 0, # which occurrence of that label to use + hf_token: str = None, + hf_cache_dir: str = None): """ src_img_path - path of the source image n_iters - number of outer loop iterations @@ -56,7 +66,7 @@ def __init__( n_steps_proxy - number of epochs for proxy training in each outer iteration Note: the size of the dataset for proxy training is n_steps_dae * self.dae.batch_size * n_iters min_max_step_size - coefficients that determine the min and max size of the gradient step, - intermediate steps are distributed uniformly between these values and total number of + intermediate steps are distributed uniformly between these values and total number of step sizes is equal to self.dae.batch_size center_to_src_latent_sem - whether to center all latent sems at src_latent_sem n_hessian_eigenvecs - number of hessian eigenvectors used in line search @@ -74,12 +84,26 @@ def __init__( self.device_cpu = torch.device('cpu') self.output_dir = Path(output_dir) self.output_dir.mkdir(parents = True) - + self.dae = self.get_dae_class(dae_type)(**dae_kwargs) self.proxy = GeneralMulticomponentMLP(**mc_mlp_proxy_kwargs) self.ce_loss = CounterfactualLossFromGeneralComponents(**ce_loss_kwargs) - self.load_src_img(src_img_path) + # Load source image either from local path or HF dataset by label + if src_img_path is not None: + self.load_src_img(src_img_path) + elif hf_dataset_name is not None and hf_label is not None: + self.load_src_img_from_hf( + dataset_name = hf_dataset_name, + split = hf_split, + label = hf_label, + index = hf_index, + hf_token = hf_token, + cache_dir = hf_cache_dir, + image_size = self.dae.config.img_size if hasattr(self.dae, 'config') else 256 + ) + else: + raise ValueError("Provide either src_img_path or (hf_dataset_name and hf_label).") self.clear_proxy_training_data() def get_dae_class(self, dae_type): @@ -96,6 +120,49 @@ def load_src_img(self, path): src_img = (src_img - 0.5) * 2 self.src_img = src_img.to(self.device) + def load_src_img_from_hf( + self, + dataset_name: str, + split: str, + label: int, + index: int = 0, + hf_token: str = None, + cache_dir: str = None, + image_size: int = 256, + ): + """ + Load one source image from a HF dataset (e.g., 'imagenet-1k') by label id. + """ + log.info(f'Loading source image from HF dataset: {dataset_name}, split: {split}, label: {label}, index: {index}') + kwargs = {} + if hf_token is not None: + kwargs['token'] = hf_token + if cache_dir is not None: + kwargs['cache_dir'] = cache_dir + ds = load_dataset(dataset_name, split=split, **kwargs) + # Filter by label; select occurrence by index + ds_label = ds.filter(lambda ex: ex.get('label', None) == label) + if len(ds_label) == 0: + raise ValueError(f"No samples found for label {label} in dataset {dataset_name}/{split}") + if index >= len(ds_label) or index < 0: + raise IndexError(f"index {index} out of bounds for filtered dataset of size {len(ds_label)}") + sample = ds_label[index] + img = sample['image'] + if not isinstance(img, Image.Image): + img = Image.fromarray(img) + # Resize/center-crop to target image_size + transform = T.Compose([ + T.CenterCrop(image_size), + T.Resize(image_size), + T.ToTensor(), # [0,1] + ]) + src_img = transform(img).unsqueeze(0) + # Save [0,1] preview + torchvision.utils.save_image(src_img, self.output_dir / 'src_img.png') + # Scale to [-1,1] for DAE + src_img = (src_img - 0.5) * 2 + self.src_img = src_img.to(self.device) + def get_src_img(self, scale = True): if scale: return (self.src_img + 1) / 2 @@ -134,7 +201,7 @@ def dae_step(self, step_idx, iter): log.info(f'Saving synthetic images to {output_dir}') torchvision.utils.save_image(batch_imgs, output_dir / 'imgs.png') # NOTE: DAE requires input to be in [-1, 1] but outputs imgs in [0, 1] - + # Compute counterfactual loss components log.info('Calculating counterfactual loss components') # In general case, we use counterfactual loss to provide us with @@ -147,20 +214,20 @@ def dae_step(self, step_idx, iter): batch_predictions = batch_components['predictions'] batch_probs = self.ce_loss.get_query_label_probability(batch_predictions) batch_logits = self.ce_loss.get_query_label_logit(batch_predictions) - + # Log pareto fronts log_wandb_scatter( - batch_logits.flatten().numpy(force = True), - batch_lpips.flatten().numpy(force = True), - 'logits from classifier', 'lpips', + batch_logits.flatten().numpy(force = True), + batch_lpips.flatten().numpy(force = True), + 'logits from classifier', 'lpips', f'dae/pareto/logits/iter:{iter}/{step_idx}') log_wandb_scatter( - batch_probs.flatten().numpy(force = True), - batch_lpips.flatten().numpy(force = True), - 'probabilities from classifier', 'lpips', + batch_probs.flatten().numpy(force = True), + batch_lpips.flatten().numpy(force = True), + 'probabilities from classifier', 'lpips', f'dae/pareto/probs/iter:{iter}/{step_idx}') - # Log info + # Log info ce_ids = batch_probs < 0.5 if self.ce_loss.components.clf_pred_label == 1 else batch_probs > 0.5 log.info(f'Number of CEs in synthetic images: {batch_probs[ce_ids].shape[0]}') log.info(f'Saving info file to {output_dir}') @@ -170,7 +237,7 @@ def dae_step(self, step_idx, iter): batch_lpips.numpy(force = True).flatten(), ce_ids.numpy(force = True)]).T df = pd.DataFrame( - df_data, + df_data, columns = ['norm', 'clf_prob', 'lpips', 'is_ce']) df.to_csv(output_dir / 'info.csv') @@ -197,7 +264,7 @@ def dae_step(self, step_idx, iter): torchvision.utils.save_image(diffs_ces_grid, output_dir / 'diffs_ces.png') # Save proxy training data - log.info('Saving data for proxy training') + log.info('Saving data for proxy training') if self.center_to_src_latent_sem: batch_latent_sem = batch_latent_sem - self.src_latent_sem self.save_proxy_training_data(batch_latent_sem, batch_components) @@ -211,7 +278,7 @@ def pre_proxy_loop(self, iter): def proxy_step(self, step_idx, iter): log.info(f'Step: {step_idx}') self.proxy.run_epoch(self.proxy_training_data, step_idx, iter) - + def post_proxy_loop(self, iter): # Take steps of different magnitude in negative gradient direction aka line search # Gradient is calculated for one or more (weight_cls, weight_lpips) pairs and @@ -220,7 +287,7 @@ def post_proxy_loop(self, iter): chunk_size = self.dae.batch_size // len(grads_dict) n_directions = len(self.weight_cls_lpips_pairs) + self.n_hessian_eigenvecs step_sizes_stack = torch.linspace( - *self.min_max_step_size, + *self.min_max_step_size, chunk_size, device = self.device).repeat(n_directions).unsqueeze(1) grads_stack = torch.stack([*grads_dict.values()]).repeat_interleave(chunk_size, 0).squeeze() @@ -252,9 +319,9 @@ def post_proxy_loop(self, iter): weights_pairs += [(1.0, 0.0) for _ in range(chunk_size * self.n_hessian_eigenvecs)] self.proxy.line_search_validation( - batch_latent_sem = batch_latent_sem, - batch_components = batch_components, - query_label = self.ce_loss.components.classifier.query_label, + batch_latent_sem = batch_latent_sem, + batch_components = batch_components, + query_label = self.ce_loss.components.classifier.query_label, iter = iter, step_sizes = step_sizes_stack, weights_pairs = weights_pairs, @@ -263,7 +330,7 @@ def post_proxy_loop(self, iter): # Calculate counterfactual loss on line search images batch_losses = self.ce_loss( - batch_components, + batch_components, pos_label_idx = self.ce_loss.components.classifier.query_label) # Log info about line search @@ -279,14 +346,14 @@ def post_proxy_loop(self, iter): # Log pareto fronts log_wandb_scatter( - batch_logits.flatten().numpy(force = True), - batch_lpips.flatten().numpy(force = True), - 'logits from classifier', 'lpips', + batch_logits.flatten().numpy(force = True), + batch_lpips.flatten().numpy(force = True), + 'logits from classifier', 'lpips', f'proxy/line_search/pareto/logits/iter:{iter}') log_wandb_scatter( - batch_probs.flatten().numpy(force = True), - batch_lpips.flatten().numpy(force = True), - 'probabilities from classifier', 'lpips', + batch_probs.flatten().numpy(force = True), + batch_lpips.flatten().numpy(force = True), + 'probabilities from classifier', 'lpips', f'proxy/line_search/pareto/probs/iter:{iter}') # Log info and counterfactuals if any @@ -303,10 +370,10 @@ def post_proxy_loop(self, iter): batch_lpips.numpy(force = True).flatten(), ce_ids.numpy(force = True)]).T df = pd.DataFrame( - df_data, + df_data, columns = ['step_size', 'from_hess_vec', 'weight_cls', 'weight_lps', 'clf_prob', 'lpips', 'is_ce']) df.to_csv(output_dir / 'info.csv') - + src_img = self.get_src_img() diffs = (batch_imgs - src_img).abs() diffs_scaled = diffs / diffs.amax(dim = (1, 2, 3)).view(-1, 1, 1, 1) @@ -361,7 +428,7 @@ def get_grads_dict(self): log.info(f'Computing hessian for weight_cls: 1.0 and weight_lpips: 0.0') input.requires_grad_() func = lambda x: self.ce_loss( - self.proxy(x), + self.proxy(x), pos_label_idx = self.ce_loss.components.classifier.query_label, weight_cls = 1.0, weight_lpips = 0.0) @@ -401,10 +468,9 @@ def get_grads_dict(self): def save_proxy_training_data(self, batch_latent_sem, batch_components): self.proxy_training_data['latent_sem'].append(batch_latent_sem) self.proxy_training_data['components'].append(batch_components) - + def clear_proxy_training_data(self): self.proxy.increment_data_save_counter() self.proxy_training_data = { 'latent_sem': [], 'components': []} - \ No newline at end of file From adc7dec4c363a4257118bbf0d974a4704d9f729c Mon Sep 17 00:00:00 2001 From: galkowskim Date: Sat, 15 Nov 2025 20:36:51 +0100 Subject: [PATCH 04/24] add posibility to use hf datasets v2 --- .../counterfactual_loss_general_components.py | 85 +++++++++++++++++-- 1 file changed, 78 insertions(+), 7 deletions(-) diff --git a/gcd/src/losses/counterfactual_loss_general_components.py b/gcd/src/losses/counterfactual_loss_general_components.py index 969d20f..a4057ea 100644 --- a/gcd/src/losses/counterfactual_loss_general_components.py +++ b/gcd/src/losses/counterfactual_loss_general_components.py @@ -3,6 +3,10 @@ import torch.nn as nn import torch.nn.functional as F import torchvision +import torchvision.transforms as T +from datasets import load_dataset +from PIL import Image +import os import logging log = logging.getLogger(__name__) @@ -15,11 +19,19 @@ class CounterfactualLossGeneralComponents(nn.Module): """ def __init__( - self, + self, clf, lpips_net: str, - src_img_path: str, - device: str): + src_img_path: str = None, + device: str = 'cuda:0', + # Optional: load source image from HF dataset by label if src_img_path is None + hf_dataset_name: str = None, # e.g., 'imagenet-1k' + hf_split: str = 'train', + hf_label: int = None, + hf_index: int = 0, + hf_token: str = None, + cache_dir: str = None, + image_size: int = 256): """ label_idx - label of interest (for counterfactual explanation) """ @@ -28,7 +40,29 @@ def __init__( self.to(self.device) self.classifier = clf.to(self.device) self.lpips = lpips.LPIPS(net = lpips_net).to(self.device) - self.load_src_img(src_img_path, device) + # Load source image: prefer local path, otherwise try HF dataset + if src_img_path is not None: + self.load_src_img(src_img_path, device) + else: + # Fallback to HF if params provided or env vars present + ds_name = hf_dataset_name or os.getenv('HF_DATASET_NAME', None) + ds_split = hf_split or os.getenv('HF_SPLIT', 'train') + ds_label = hf_label if hf_label is not None else os.getenv('HF_LABEL', None) + ds_index = hf_index if hf_index is not None else int(os.getenv('HF_INDEX', 0)) + ds_token = hf_token or os.getenv('HF_TOKEN', None) + ds_cache = cache_dir or os.getenv('HF_CACHE_DIR', None) + if ds_name is None or ds_label is None: + raise ValueError("CounterfactualLossGeneralComponents: Provide src_img_path or HF dataset parameters (hf_dataset_name & hf_label).") + ds_label = int(ds_label) + self.load_src_img_from_hf( + dataset_name = ds_name, + split = ds_split, + label = ds_label, + index = ds_index, + hf_token = ds_token, + cache_dir = ds_cache, + image_size = image_size, + ) self.clf_pred_label = self.get_clf_pred_label() def forward(self, x): @@ -71,7 +105,7 @@ def get_clf_pred_label(self, img = None): else: output_probs = F.softmax(output, dim = 1) output_prob = output_probs[:, self.classifier.query_label] - + if save_src_img_preds: self.src_img_probs = output_probs self.src_img_logits = output @@ -90,17 +124,54 @@ def get_clf_pred_label(self, img = None): if save_src_img_preds: self.src_img_probs = output_probs self.src_img_logits = output - + label = 1 if output_prob > 0.5 else 0 log.info(f"Class predicted for source image: {label}") log.info(f"Probability of positive class: {output_prob.item()}") return label - + def load_src_img(self, path, device): src_img = torchvision.io.read_image(path).unsqueeze(0) / 255 src_img = (src_img - 0.5) * 2 self.src_img = src_img.to(self.device) + @torch.no_grad() + def load_src_img_from_hf( + self, + dataset_name: str, + split: str, + label: int, + index: int = 0, + hf_token: str = None, + cache_dir: str = None, + image_size: int = 256, + ): + log.info(f'Loading source image from HF dataset: {dataset_name}, split: {split}, label: {label}, index: {index}') + kwargs = {} + if hf_token is not None: + kwargs['token'] = hf_token + if cache_dir is not None: + kwargs['cache_dir'] = cache_dir + ds = load_dataset(dataset_name, split=split, **kwargs) + ds_label = ds.filter(lambda ex: ex.get('label', None) == label) + if len(ds_label) == 0: + raise ValueError(f"No samples found for label {label} in dataset {dataset_name}/{split}") + if index >= len(ds_label) or index < 0: + raise IndexError(f"index {index} out of bounds for filtered dataset of size {len(ds_label)}") + sample = ds_label[index] + img = sample['image'] + if not isinstance(img, Image.Image): + img = Image.fromarray(img) + transform = T.Compose([ + T.CenterCrop(image_size), + T.Resize(image_size), + T.ToTensor(), # [0,1] + ]) + src_img = transform(img).unsqueeze(0) + # Scale to [-1,1] for loss computations (classifier handles rescale internally if needed) + src_img = (src_img - 0.5) * 2 + self.src_img = src_img.to(self.device) + From 3f6d99c879fb1aaae71b409ec78b8dfef32e0c53 Mon Sep 17 00:00:00 2001 From: galkowskim Date: Sat, 15 Nov 2025 20:55:31 +0100 Subject: [PATCH 05/24] update files --- commands.md | 9 ++++----- .../losses/counterfactual_loss_general_components.py | 8 +++----- gcd/src/main.py | 3 +++ .../strategies/single_image_gmc_mlp_proxy_training.py | 11 +++++------ 4 files changed, 15 insertions(+), 16 deletions(-) diff --git a/commands.md b/commands.md index 5f698e1..bd2c86f 100644 --- a/commands.md +++ b/commands.md @@ -8,7 +8,8 @@ Prereqs Setup (one-time) ```bash cd /Users/mgalkowski/Desktop/diffae/gcd -python -m pip install -r requirements.txt +conda create --name gcd python=3.11 -y +pip install -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu118 ``` Set variables @@ -33,7 +34,8 @@ python src/main.py \ ``` Zebra example (label 340): -``` + +```bash python src/main.py \ --config-path ../configs/single_image_gmc_mlp_proxy_training/imagenet/resnet/zebra \ --config-name config \ @@ -42,9 +44,6 @@ python src/main.py \ strategy.hf_split=train \ strategy.hf_label=340 \ strategy.hf_index=0 \ - strategy.hf_token=$HF_TOKEN \ # optional - strategy.hf_cache_dir=/path/to/hf_cache \ # optional - strategy.dae_kwargs.path_ckpt="$DAE_CKPT" \ device=cuda:0 \ wandb.project=null ``` diff --git a/gcd/src/losses/counterfactual_loss_general_components.py b/gcd/src/losses/counterfactual_loss_general_components.py index a4057ea..fa616ae 100644 --- a/gcd/src/losses/counterfactual_loss_general_components.py +++ b/gcd/src/losses/counterfactual_loss_general_components.py @@ -29,8 +29,6 @@ def __init__( hf_split: str = 'train', hf_label: int = None, hf_index: int = 0, - hf_token: str = None, - cache_dir: str = None, image_size: int = 256): """ label_idx - label of interest (for counterfactual explanation) @@ -46,11 +44,11 @@ def __init__( else: # Fallback to HF if params provided or env vars present ds_name = hf_dataset_name or os.getenv('HF_DATASET_NAME', None) - ds_split = hf_split or os.getenv('HF_SPLIT', 'train') + ds_split = hf_split ds_label = hf_label if hf_label is not None else os.getenv('HF_LABEL', None) ds_index = hf_index if hf_index is not None else int(os.getenv('HF_INDEX', 0)) - ds_token = hf_token or os.getenv('HF_TOKEN', None) - ds_cache = cache_dir or os.getenv('HF_CACHE_DIR', None) + ds_token = os.getenv('HF_TOKEN', None) + ds_cache = os.getenv('DATASET_CACHE', None) if ds_name is None or ds_label is None: raise ValueError("CounterfactualLossGeneralComponents: Provide src_img_path or HF dataset parameters (hf_dataset_name & hf_label).") ds_label = int(ds_label) diff --git a/gcd/src/main.py b/gcd/src/main.py index a5b74a4..678c513 100644 --- a/gcd/src/main.py +++ b/gcd/src/main.py @@ -5,6 +5,9 @@ from omegaconf import DictConfig from utils import extract_output_dir_path_from_config, set_seed +from dotenv import load_dotenv + +load_dotenv() log = logging.getLogger(__name__) diff --git a/gcd/src/strategies/single_image_gmc_mlp_proxy_training.py b/gcd/src/strategies/single_image_gmc_mlp_proxy_training.py index d845c8f..3efcec3 100644 --- a/gcd/src/strategies/single_image_gmc_mlp_proxy_training.py +++ b/gcd/src/strategies/single_image_gmc_mlp_proxy_training.py @@ -35,7 +35,6 @@ class SingleImageGMCMLPProxyTraining(Strategy): def __init__( self, - src_img_path: str = None, n_iters: int, n_steps_dae: int, n_steps_proxy: int, @@ -49,14 +48,13 @@ def __init__( mc_mlp_proxy_kwargs: dict, dae_kwargs: dict, ce_loss_kwargs: dict, + src_img_path: str = None, dae_type: str = 'default', # Optional HF dataset source for picking the source image by label hf_dataset_name: str = None, # e.g., 'imagenet-1k' hf_split: str = 'train', hf_label: int = None, # required if using HF - hf_index: int = 0, # which occurrence of that label to use - hf_token: str = None, - hf_cache_dir: str = None): + hf_index: int = 0): # which occurrence of that label to use """ src_img_path - path of the source image n_iters - number of outer loop iterations @@ -87,6 +85,7 @@ def __init__( self.dae = self.get_dae_class(dae_type)(**dae_kwargs) self.proxy = GeneralMulticomponentMLP(**mc_mlp_proxy_kwargs) + print('\n\nInitializing CounterfactualLossGeneralComponents from kwargs: ', components_kwargs, '\n\n') self.ce_loss = CounterfactualLossFromGeneralComponents(**ce_loss_kwargs) # Load source image either from local path or HF dataset by label @@ -98,8 +97,8 @@ def __init__( split = hf_split, label = hf_label, index = hf_index, - hf_token = hf_token, - cache_dir = hf_cache_dir, + hf_token = os.getenv('HF_TOKEN', None), + cache_dir = os.getenv('DATASET_CACHE', None), image_size = self.dae.config.img_size if hasattr(self.dae, 'config') else 256 ) else: From f617baf9671eb2101db9048228fe0b9db79ef2f1 Mon Sep 17 00:00:00 2001 From: galkowskim Date: Sat, 15 Nov 2025 20:57:43 +0100 Subject: [PATCH 06/24] speed up filtering --- .../counterfactual_loss_general_components.py | 26 +++++++++++------- .../single_image_gmc_mlp_proxy_training.py | 27 +++++++++++-------- 2 files changed, 32 insertions(+), 21 deletions(-) diff --git a/gcd/src/losses/counterfactual_loss_general_components.py b/gcd/src/losses/counterfactual_loss_general_components.py index fa616ae..d756a57 100644 --- a/gcd/src/losses/counterfactual_loss_general_components.py +++ b/gcd/src/losses/counterfactual_loss_general_components.py @@ -145,18 +145,24 @@ def load_src_img_from_hf( image_size: int = 256, ): log.info(f'Loading source image from HF dataset: {dataset_name}, split: {split}, label: {label}, index: {index}') - kwargs = {} + # Use streaming to avoid slow filter materialization + stream_kwargs = {"streaming": True} if hf_token is not None: - kwargs['token'] = hf_token + stream_kwargs["token"] = hf_token if cache_dir is not None: - kwargs['cache_dir'] = cache_dir - ds = load_dataset(dataset_name, split=split, **kwargs) - ds_label = ds.filter(lambda ex: ex.get('label', None) == label) - if len(ds_label) == 0: - raise ValueError(f"No samples found for label {label} in dataset {dataset_name}/{split}") - if index >= len(ds_label) or index < 0: - raise IndexError(f"index {index} out of bounds for filtered dataset of size {len(ds_label)}") - sample = ds_label[index] + stream_kwargs["cache_dir"] = cache_dir + ds_stream = load_dataset(dataset_name, split=split, **stream_kwargs) + + found = -1 + sample = None + for ex in ds_stream: + if ex.get("label", None) == label: + found += 1 + if found == index: + sample = ex + break + if sample is None: + raise ValueError(f"Could not find occurrence index {index} for label {label} in {dataset_name}/{split}") img = sample['image'] if not isinstance(img, Image.Image): img = Image.fromarray(img) diff --git a/gcd/src/strategies/single_image_gmc_mlp_proxy_training.py b/gcd/src/strategies/single_image_gmc_mlp_proxy_training.py index 3efcec3..b61c418 100644 --- a/gcd/src/strategies/single_image_gmc_mlp_proxy_training.py +++ b/gcd/src/strategies/single_image_gmc_mlp_proxy_training.py @@ -133,19 +133,24 @@ def load_src_img_from_hf( Load one source image from a HF dataset (e.g., 'imagenet-1k') by label id. """ log.info(f'Loading source image from HF dataset: {dataset_name}, split: {split}, label: {label}, index: {index}') - kwargs = {} + # Use streaming iteration to avoid materializing a full filtered copy (much faster) + stream_kwargs = {"streaming": True} if hf_token is not None: - kwargs['token'] = hf_token + stream_kwargs["token"] = hf_token if cache_dir is not None: - kwargs['cache_dir'] = cache_dir - ds = load_dataset(dataset_name, split=split, **kwargs) - # Filter by label; select occurrence by index - ds_label = ds.filter(lambda ex: ex.get('label', None) == label) - if len(ds_label) == 0: - raise ValueError(f"No samples found for label {label} in dataset {dataset_name}/{split}") - if index >= len(ds_label) or index < 0: - raise IndexError(f"index {index} out of bounds for filtered dataset of size {len(ds_label)}") - sample = ds_label[index] + stream_kwargs["cache_dir"] = cache_dir + ds_stream = load_dataset(dataset_name, split=split, **stream_kwargs) + + found = -1 + sample = None + for ex in ds_stream: + if ex.get("label", None) == label: + found += 1 + if found == index: + sample = ex + break + if sample is None: + raise ValueError(f"Could not find occurrence index {index} for label {label} in {dataset_name}/{split}") img = sample['image'] if not isinstance(img, Image.Image): img = Image.fromarray(img) From e1cede79840b267cdb911e092f48a3a24efff354 Mon Sep 17 00:00:00 2001 From: galkowskim Date: Sat, 15 Nov 2025 21:11:31 +0100 Subject: [PATCH 07/24] add debugging options --- gcd/src/classifiers/imagenet.py | 11 ++++++++++- .../losses/counterfactual_loss_general_components.py | 11 +++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/gcd/src/classifiers/imagenet.py b/gcd/src/classifiers/imagenet.py index 39c47cf..f99ff0e 100644 --- a/gcd/src/classifiers/imagenet.py +++ b/gcd/src/classifiers/imagenet.py @@ -19,7 +19,8 @@ def __init__( task: str = 'multiclass_classification', model_name: str = 'resnet50', weights: str = 'IMAGENET1K_V2', - path_to_weights: str = None): + path_to_weights: str = None, + debug_save_path: str = None): super().__init__() self.img_size = img_size self.task = task @@ -28,6 +29,7 @@ def __init__( self.transforms = tt.Compose([ tt.Normalize(mean = IMAGENET_MEAN, std = IMAGENET_STD) ]) + self.debug_save_path = debug_save_path self.model = self._build_model(model_name, weights, path_to_weights) self.model.eval() @@ -69,6 +71,13 @@ def forward(self, x): assert x.min() >= -1. and x.max() <= 1. log.info('Detected input outside [0,1], rescaling from [-1,1] to [0,1]') x = (x + 1) / 2 + # Save a debug preview (pre-normalization) if requested + if self.debug_save_path is not None: + try: + torchvision.utils.save_image(x, self.debug_save_path) + log.info(f"Saved classifier input preview (pre-normalization) to: {self.debug_save_path}") + except Exception as e: + log.warning(f"Failed to save debug image to {self.debug_save_path}: {e}") x = self.transforms(x) logits = self.model(x) if self.use_softmax_and_query_label: diff --git a/gcd/src/losses/counterfactual_loss_general_components.py b/gcd/src/losses/counterfactual_loss_general_components.py index d756a57..3b44890 100644 --- a/gcd/src/losses/counterfactual_loss_general_components.py +++ b/gcd/src/losses/counterfactual_loss_general_components.py @@ -103,6 +103,17 @@ def get_clf_pred_label(self, img = None): else: output_probs = F.softmax(output, dim = 1) output_prob = output_probs[:, self.classifier.query_label] + # For multiclass, threshold 0.5 is often too strict (1k classes). + # Use argmax to decide if query_label is the predicted class. + if self.classifier.task == 'multiclass_classification': + pred_idx = output_probs.argmax(dim = 1) + label = (pred_idx == self.classifier.query_label).long().item() + if save_src_img_preds: + self.src_img_probs = output_probs + self.src_img_logits = output + log.info(f"Argmax predicted class: {pred_idx.item()}, query_label: {self.classifier.query_label}") + log.info(f"Probability of query class: {output_prob.item()}") + return label if save_src_img_preds: self.src_img_probs = output_probs From 390a3ecab164ce3684f2a3ea674172b27e0be03e Mon Sep 17 00:00:00 2001 From: galkowskim Date: Sat, 15 Nov 2025 21:13:00 +0100 Subject: [PATCH 08/24] fix file --- gcd/src/strategies/single_image_gmc_mlp_proxy_training.py | 1 - 1 file changed, 1 deletion(-) diff --git a/gcd/src/strategies/single_image_gmc_mlp_proxy_training.py b/gcd/src/strategies/single_image_gmc_mlp_proxy_training.py index b61c418..e413b19 100644 --- a/gcd/src/strategies/single_image_gmc_mlp_proxy_training.py +++ b/gcd/src/strategies/single_image_gmc_mlp_proxy_training.py @@ -85,7 +85,6 @@ def __init__( self.dae = self.get_dae_class(dae_type)(**dae_kwargs) self.proxy = GeneralMulticomponentMLP(**mc_mlp_proxy_kwargs) - print('\n\nInitializing CounterfactualLossGeneralComponents from kwargs: ', components_kwargs, '\n\n') self.ce_loss = CounterfactualLossFromGeneralComponents(**ce_loss_kwargs) # Load source image either from local path or HF dataset by label From a9d72e2cad8de5514460b3373ecb6955659b6d47 Mon Sep 17 00:00:00 2001 From: galkowskim Date: Sun, 16 Nov 2025 13:16:32 +0100 Subject: [PATCH 09/24] add asertions --- gcd/src/losses/general_multicomponent_proxy_loss.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/gcd/src/losses/general_multicomponent_proxy_loss.py b/gcd/src/losses/general_multicomponent_proxy_loss.py index 7302d37..0edd0dc 100644 --- a/gcd/src/losses/general_multicomponent_proxy_loss.py +++ b/gcd/src/losses/general_multicomponent_proxy_loss.py @@ -30,8 +30,8 @@ def forward(self, inputs, targets, labels: list = None): pred_lpips = inputs['lpips'] target_lpips = targets['lpips'] - assert pred_preds.shape == target_preds.shape - assert pred_lpips.shape == target_lpips.shape + assert pred_preds.shape == target_preds.shape, f"Invalid shapes: {pred_preds.shape = } and {target_preds.shape = }" + assert pred_lpips.shape == target_lpips.shape, f"Invalid shapes: {pred_lpips.shape = } and {target_lpips.shape = }" loss = self.weight_cls * F.mse_loss(pred_preds, target_preds) + \ self.weight_lpips * F.mse_loss(pred_lpips, target_lpips) From 4e25d7a063bad16198c72b07ad97088da4ee7b60 Mon Sep 17 00:00:00 2001 From: galkowskim Date: Sun, 16 Nov 2025 13:20:16 +0100 Subject: [PATCH 10/24] update configs --- .../imagenet/resnet/sorrel/config.yaml | 2 +- .../imagenet/resnet/zebra/config.yaml | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/gcd/configs/single_image_gmc_mlp_proxy_training/imagenet/resnet/sorrel/config.yaml b/gcd/configs/single_image_gmc_mlp_proxy_training/imagenet/resnet/sorrel/config.yaml index d9d3015..36f1077 100644 --- a/gcd/configs/single_image_gmc_mlp_proxy_training/imagenet/resnet/sorrel/config.yaml +++ b/gcd/configs/single_image_gmc_mlp_proxy_training/imagenet/resnet/sorrel/config.yaml @@ -31,7 +31,7 @@ strategy: output_dir: ${output_dir}/dae mc_mlp_proxy_kwargs: - shapes: [512, 256, 128, 64, 2] + shapes: [512, 256, 128, 64, 1001] activ_fn: sigmoid batch_size: 128 n_val_batches: 1 diff --git a/gcd/configs/single_image_gmc_mlp_proxy_training/imagenet/resnet/zebra/config.yaml b/gcd/configs/single_image_gmc_mlp_proxy_training/imagenet/resnet/zebra/config.yaml index 5507cf8..2320f9f 100644 --- a/gcd/configs/single_image_gmc_mlp_proxy_training/imagenet/resnet/zebra/config.yaml +++ b/gcd/configs/single_image_gmc_mlp_proxy_training/imagenet/resnet/zebra/config.yaml @@ -31,7 +31,7 @@ strategy: output_dir: ${output_dir}/dae mc_mlp_proxy_kwargs: - shapes: [512, 256, 128, 64, 2] + shapes: [512, 256, 128, 64, 1001] activ_fn: sigmoid batch_size: 128 n_val_batches: 1 From fcb65c96d4b1fe8a6e7a7b88ea5204eea502176d Mon Sep 17 00:00:00 2001 From: galkowskim Date: Sun, 16 Nov 2025 17:23:16 +0100 Subject: [PATCH 11/24] fix mismatch --- gcd/src/strategies/single_image_gmc_mlp_proxy_training.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/gcd/src/strategies/single_image_gmc_mlp_proxy_training.py b/gcd/src/strategies/single_image_gmc_mlp_proxy_training.py index e413b19..be9fc82 100644 --- a/gcd/src/strategies/single_image_gmc_mlp_proxy_training.py +++ b/gcd/src/strategies/single_image_gmc_mlp_proxy_training.py @@ -295,7 +295,9 @@ def post_proxy_loop(self, iter): device = self.device).repeat(n_directions).unsqueeze(1) grads_stack = torch.stack([*grads_dict.values()]).repeat_interleave(chunk_size, 0).squeeze() batch_latent_sem = self.src_latent_sem - step_sizes_stack * grads_stack - batch_latent_ddim = self.dae.make_batch_latent_ddim(self.src_latent_ddim) + # Match DDIM batch to the exact number of latent_sem samples (may be < dae.batch_size) + n_ls = batch_latent_sem.shape[0] + batch_latent_ddim = self.src_latent_ddim.repeat(n_ls, 1, 1, 1) # Generate images from line search latent sems log.info('Rendering line search images') From b2a018d58ea7dd84a4f48119431409fd2dd67148 Mon Sep 17 00:00:00 2001 From: galkowskim Date: Mon, 17 Nov 2025 21:07:57 +0100 Subject: [PATCH 12/24] try to fix OOM --- gcd/src/dae/chexpert/dae.py | 36 +++++++++++++++++++++++------------- gcd/src/dae/default/dae.py | 37 ++++++++++++++++++++++++------------- 2 files changed, 47 insertions(+), 26 deletions(-) diff --git a/gcd/src/dae/chexpert/dae.py b/gcd/src/dae/chexpert/dae.py index e7c5093..88cd3ee 100644 --- a/gcd/src/dae/chexpert/dae.py +++ b/gcd/src/dae/chexpert/dae.py @@ -17,7 +17,7 @@ class DAECheXpert(nn.Module): """ def __init__( - self, + self, config_name: str, batch_size: int, std: Union[float, List[float]], @@ -30,7 +30,7 @@ def __init__( max_norm: float = 1., make_output_dir: bool = True): """ - batch_size - size of the batch containing perturbed latent representations + batch_size - size of the batch containing perturbed latent representations of source image T_encode - number of diffusion steps for encoding T_render - number of diffusion steps for rendering @@ -93,18 +93,28 @@ def encode_stochastic(self, x, cond, T = None): x, model_kwargs = {'cond': cond}) return out['sample'] - + @torch.no_grad() - def render(self, noise, cond, T = None): + def render(self, noise, cond, T = None, max_batch: int = None): if T is not None: sampler = self.config._make_diffusion_conf(T).make_sampler() else: sampler = self.sampler_render - pred_img = render_condition(self.config, - self.model, - noise, - sampler = sampler, - cond = cond) + # Chunked rendering to reduce peak memory + N = noise.shape[0] + if max_batch is None: + max_batch = min(N, 32) + outs = [] + for i in range(0, N, max_batch): + n_chunk = noise[i:i + max_batch] + c_chunk = cond[i:i + max_batch] if cond is not None else None + out = render_condition(self.config, + self.model, + n_chunk, + sampler = sampler, + cond = c_chunk) + outs.append(out) + pred_img = torch.cat(outs, dim = 0) pred_img = (pred_img + 1) / 2 return pred_img @@ -116,8 +126,8 @@ def make_batch_latent_sem(self, latent_sem): chunk_size = self.batch_size // len(self.std) for std in self.std: noise = torch.normal( - mean = 0., - std = std, + mean = 0., + std = std, size = (chunk_size, latent_sem.shape[1]), device = self.device) noises.append(noise) @@ -126,8 +136,8 @@ def make_batch_latent_sem(self, latent_sem): elif self.distribution == 'uniform_norm': log.info(f'Sampling latent sem perturbations from uniform distribution over norm from the interval [0, {self.max_norm}]') noise = torch.normal( - mean = 0., - std = 1., + mean = 0., + std = 1., size = (self.batch_size, latent_sem.shape[1]), device = self.device) noise /= noise.norm(dim = 1).unsqueeze(1) diff --git a/gcd/src/dae/default/dae.py b/gcd/src/dae/default/dae.py index f0b7b14..53d30ff 100644 --- a/gcd/src/dae/default/dae.py +++ b/gcd/src/dae/default/dae.py @@ -17,7 +17,7 @@ class DAE(nn.Module): """ def __init__( - self, + self, config_name: str, batch_size: int, std: Union[float, List[float]], # [0.01, 0.05, 0.1, 0.2] @@ -30,7 +30,7 @@ def __init__( max_norm: float = 1., make_output_dir: bool = True): """ - batch_size - size of the batch containing perturbed latent representations + batch_size - size of the batch containing perturbed latent representations of source image T_encode - number of diffusion steps for encoding T_render - number of diffusion steps for rendering @@ -93,18 +93,29 @@ def encode_stochastic(self, x, cond, T = None): x, model_kwargs = {'cond': cond}) return out['sample'] - + @torch.no_grad() - def render(self, noise, cond, T = None): + def render(self, noise, cond, T = None, max_batch: int = None): if T is not None: sampler = self.config._make_diffusion_conf(T).make_sampler() else: sampler = self.sampler_render - pred_img = render_condition(self.config, - self.model, - noise, - sampler = sampler, - cond = cond) + # Chunked rendering to reduce peak memory + N = noise.shape[0] + if max_batch is None: + # default to something conservative + max_batch = min(N, 32) + outs = [] + for i in range(0, N, max_batch): + n_chunk = noise[i:i + max_batch] + c_chunk = cond[i:i + max_batch] if cond is not None else None + out = render_condition(self.config, + self.model, + n_chunk, + sampler = sampler, + cond = c_chunk) + outs.append(out) + pred_img = torch.cat(outs, dim = 0) pred_img = (pred_img + 1) / 2 return pred_img @@ -116,8 +127,8 @@ def make_batch_latent_sem(self, latent_sem): chunk_size = self.batch_size // len(self.std) for std in self.std: noise = torch.normal( - mean = 0., - std = std, + mean = 0., + std = std, size = (chunk_size, latent_sem.shape[1]), device = self.device) noises.append(noise) @@ -126,8 +137,8 @@ def make_batch_latent_sem(self, latent_sem): elif self.distribution == 'uniform_norm': log.info(f'Sampling latent sem perturbations from uniform distribution over norm from the interval [0, {self.max_norm}]') noise = torch.normal( - mean = 0., - std = 1., + mean = 0., + std = 1., size = (self.batch_size, latent_sem.shape[1]), device = self.device) noise /= noise.norm(dim = 1).unsqueeze(1) From ac337f5b3187d59281f6d850a60c410a8bb4d21d Mon Sep 17 00:00:00 2001 From: galkowskim Date: Sun, 23 Nov 2025 13:40:00 +0100 Subject: [PATCH 13/24] update commands --- commands.md | 21 --------------------- 1 file changed, 21 deletions(-) diff --git a/commands.md b/commands.md index bd2c86f..f89dfe4 100644 --- a/commands.md +++ b/commands.md @@ -12,27 +12,6 @@ conda create --name gcd python=3.11 -y pip install -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu118 ``` -Set variables -```bash -# Path to your DiffAE autoencoder EMA checkpoint (last.ckpt) -export DAE_CKPT="/absolute/path/to/checkpoints//last.ckpt" - -# Source image for the run (256×256 PNG/JPG) -export SRC_IMG="/absolute/path/to/source_image.png" -``` - -Run: zebra (class 340, zebra→not‑zebra direction) -```bash -cd /Users/mgalkowski/Desktop/diffae/gcd/gcd -python src/main.py \ - --config-path ../configs/single_image_gmc_mlp_proxy_training/imagenet/resnet/zebra \ - --config-name config \ - strategy.src_img_path="$SRC_IMG" \ - strategy.dae_kwargs.path_ckpt="$DAE_CKPT" \ - device=cuda:0 \ - wandb.project=null -``` - Zebra example (label 340): ```bash From c46988c38db6fc47834d032c10def70a27dea414 Mon Sep 17 00:00:00 2001 From: galkowskim Date: Sun, 23 Nov 2025 17:34:01 +0100 Subject: [PATCH 14/24] add direction transfer script --- commands.md | 13 ++ gcd/src/direction_transfer_hf_batch.py | 190 +++++++++++++++++++++++++ 2 files changed, 203 insertions(+) create mode 100644 gcd/src/direction_transfer_hf_batch.py diff --git a/commands.md b/commands.md index f89dfe4..fa6631f 100644 --- a/commands.md +++ b/commands.md @@ -59,6 +59,19 @@ python src/direction_transfer.py \ --subdir-name "transfer_run_1" ``` +```bash +python src/direction_transfer_hf_batch.py \ + --direction-path "/path/to/outputs//proxy/0/0_grad_*.pt" \ + --log-dir-path "/path/to/outputs/" \ + --dataset-name imagenet-1k \ + --split train \ + --target-class 340 \ + --n-samples 500 \ + --step-size 1.0 \ + --T-render 100 \ + --subdir-name "hf_batch_zebra" +``` + Tips for memory/time - If you see OOM, reduce: `strategy.dae_kwargs.batch_size` (e.g., 128) or narrow `strategy.dae_kwargs.std`. - To speed up renders, you can lower `strategy.dae_kwargs.T_render` (e.g., 20–50). diff --git a/gcd/src/direction_transfer_hf_batch.py b/gcd/src/direction_transfer_hf_batch.py new file mode 100644 index 0000000..2fa0397 --- /dev/null +++ b/gcd/src/direction_transfer_hf_batch.py @@ -0,0 +1,190 @@ +import sys +import os +import torch +import pandas as pd +import torchvision +import omegaconf +from pathlib import Path +from argparse import ArgumentParser +from datasets import load_dataset +from PIL import Image +import torchvision.transforms as T + +from proxies import * +from losses import * +from classifiers import * +from dae import * + +import logging +logging.basicConfig(level = logging.INFO) +log = logging.getLogger(__name__) + +def init_class_from_string(class_name): + return getattr(sys.modules[__name__], class_name) + +def get_cfg(args): + path_cfg = (Path(args.log_dir_path) if args.log_dir_path else Path(args.direction_path).parents[2]) / '.hydra' / 'config.yaml' + cfg = omegaconf.OmegaConf.load(path_cfg) + return cfg + +def init_loss(cfg): + loss_kwargs = cfg.strategy.ce_loss_kwargs + loss_kwargs['make_output_dir'] = False + clf_kwargs = loss_kwargs.components.clf + clf_name = clf_kwargs.pop('_target_').split('.')[-1] + log.info(f'Using {clf_name} classifier') + if clf_name == 'DenseNet': + clf_kwargs.use_probs_and_query_label = False + elif clf_name == 'ResNet': + clf_kwargs.use_softmax_and_query_label = True + clf = init_class_from_string(clf_name)(**clf_kwargs) + comps_kwargs = loss_kwargs.components + comps_name = comps_kwargs._target_.split('.')[-1] + comps_kwargs.pop('_target_') + # We don't need src_img_path here; components will only be used for per-image eval + comps_kwargs.pop('src_img_path', None) + comps = init_class_from_string(comps_name)( + src_img_path = None, + clf = clf, + **comps_kwargs) + loss_kwargs.pop('components') + loss_name = 'CounterfactualLossFromGeneralComponents' + loss_kwargs.weight_cls = cfg.strategy.ce_loss_kwargs.weight_cls + loss_kwargs.weight_lpips = cfg.strategy.ce_loss_kwargs.weight_lpips + loss = init_class_from_string(loss_name)(components = comps, **loss_kwargs) + return loss + +def init_dae(cfg): + dae_kwargs = cfg.strategy.dae_kwargs + dae_kwargs['make_output_dir'] = False + dae_type = cfg.strategy.dae_type + if dae_type == 'default': + dae_class = DAE + elif dae_type == 'chexpert': + dae_class = DAECheXpert + else: + raise NotImplementedError('DAE type not recognized') + log.info(f'Using {dae_class}') + dae = dae_class(**dae_kwargs) + return dae + +def load_direction(path, device): + grad = torch.load(Path(path), map_location='cpu').to(device) + if torch.count_nonzero(grad) == 0: + raise ValueError("Direction tensor contains only zeros.") + return grad + +def iter_hf_by_label(dataset_name, split, label, start_index, n_samples, token=None, cache_dir=None): + kwargs = {"streaming": True} + if token is None: + token = os.getenv('HF_TOKEN', None) + if cache_dir is None: + cache_dir = os.getenv('DATASET_CACHE', None) + if token is not None: + kwargs["token"] = token + if cache_dir is not None: + kwargs["cache_dir"] = cache_dir + ds = load_dataset(dataset_name, split=split, **kwargs) + count = 0 + emitted = 0 + for ex in ds: + if ex.get('label', None) == label: + if count >= start_index: + yield ex + emitted += 1 + if emitted >= n_samples: + break + count += 1 + +def pil_to_tensor_01(img: Image.Image, image_size: int): + transform = T.Compose([ + T.CenterCrop(image_size), + T.Resize(image_size), + T.ToTensor(), # [0,1] + ]) + return transform(img) + +def to_dae_range(x01: torch.Tensor): + return (x01 - 0.5) * 2 + +def from_dae_range(xm1p1: torch.Tensor): + return (xm1p1 + 1) / 2 + +def main(args): + log.info("Batch direction transfer (HF ImageNet)") + cfg = get_cfg(args) + device = torch.device(cfg.device) + loss = init_loss(cfg) + dae = init_dae(cfg) + direction = load_direction(args.direction_path, device) + + out_dir = (Path(args.log_dir_path) if args.log_dir_path else Path(args.direction_path).parents[2]) / 'direction_transfer' / args.subdir_name + out_dir.mkdir(parents=True, exist_ok=True) + csv_rows = [] + + image_size = dae.config.img_size if hasattr(dae, 'config') else 256 + n_done = 0 + for ex in iter_hf_by_label( + dataset_name=args.dataset_name, + split=args.split, + label=args.target_class, + start_index=args.start_index, + n_samples=args.n_samples, + token=args.hf_token, + cache_dir=args.hf_cache_dir + ): + img_pil = ex['image'] if isinstance(ex['image'], Image.Image) else Image.fromarray(ex['image']) + idx = ex['__index_level_0__'] if '__index_level_0__' in ex else n_done + # Prepare tensors + img01 = pil_to_tensor_01(img_pil, image_size).unsqueeze(0).to(device) + img_m1p1 = to_dae_range(img01) + # Encode + latent_sem = dae.encode(img_m1p1) + latent_ddim = dae.encode_stochastic(img_m1p1, latent_sem) + # One-step transfer using fixed step size + step = torch.tensor([[args.step_size]], device=device) + grad = direction + latent_sem_new = latent_sem - step * grad + img_trans = dae.render(latent_ddim, latent_sem_new, T=args.T_render) + # Save pair + orig_name = f"original_{n_done+1}.png" + trans_name = f"inpaint_{n_done+1}.png" + torchvision.utils.save_image(img01, out_dir / orig_name) + torchvision.utils.save_image(img_trans, out_dir / trans_name) + # Log classifier probabilities (optional) + with torch.no_grad(): + comps = loss.get_components(img_trans) + prob_after = loss.get_query_label_probability(comps['predictions']).item() + pred_before_logits = loss.components.classifier(img01) + prob_before = loss.get_query_label_probability(pred_before_logits).item() + csv_rows.append({ + "idx": int(idx), + "orig_path": orig_name, + "inpaint_path": trans_name, + "prob_before": prob_before, + "prob_after": prob_after + }) + n_done += 1 + if n_done % 25 == 0: + log.info(f"Processed {n_done}/{args.n_samples}") + # Save CSV + pd.DataFrame(csv_rows).to_csv(out_dir / "pairs.csv", index=False) + log.info(f"Saved {n_done} pairs to {out_dir}") + +if __name__ == '__main__': + parser = ArgumentParser() + parser.add_argument('--direction-path', type=str, required=True, help='Path to direction .pt file from proxy run') + parser.add_argument('--log-dir-path', type=str, default=None, help='Log dir path with .hydra config (defaults to infer from direction path)') + parser.add_argument('--dataset-name', type=str, default='imagenet-1k', help='HF dataset name') + parser.add_argument('--split', type=str, default='train', help='HF split') + parser.add_argument('--target-class', type=int, required=True, help='HF class id to sample') + parser.add_argument('--n-samples', type=int, default=500, help='Number of images to transfer') + parser.add_argument('--start-index', type=int, default=0, help='Skip N matching samples before collecting') + parser.add_argument('--hf-token', type=str, default=None, help='HF token (or set HF_TOKEN env)') + parser.add_argument('--hf-cache-dir', type=str, default=None, help='HF cache (or set DATASET_CACHE env)') + parser.add_argument('--step-size', type=float, default=1.0, help='Step size multiplier for the direction') + parser.add_argument('--T-render', type=int, default=100, help='Number of DDIM steps for rendering') + parser.add_argument('--subdir-name', type=str, default='hf_batch', help='Output subdir name under direction_transfer') + args = parser.parse_args() + main(args) + From 16003458deec4ea851c0a809dde93c26d2539010 Mon Sep 17 00:00:00 2001 From: galkowskim Date: Sun, 23 Nov 2025 17:57:52 +0100 Subject: [PATCH 15/24] fix script --- gcd/src/direction_transfer_hf_batch.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/gcd/src/direction_transfer_hf_batch.py b/gcd/src/direction_transfer_hf_batch.py index 2fa0397..4dda4dc 100644 --- a/gcd/src/direction_transfer_hf_batch.py +++ b/gcd/src/direction_transfer_hf_batch.py @@ -41,7 +41,7 @@ def init_loss(cfg): comps_kwargs = loss_kwargs.components comps_name = comps_kwargs._target_.split('.')[-1] comps_kwargs.pop('_target_') - # We don't need src_img_path here; components will only be used for per-image eval + comps_kwargs.pop('clf', None) comps_kwargs.pop('src_img_path', None) comps = init_class_from_string(comps_name)( src_img_path = None, From 18bcfb0a628d88e51842ceecfa2bb4aef3a2cd64 Mon Sep 17 00:00:00 2001 From: galkowskim Date: Sun, 23 Nov 2025 18:28:02 +0100 Subject: [PATCH 16/24] fix scripts --- gcd/src/direction_transfer_hf_batch.py | 2 ++ gcd/src/losses/counterfactual_loss_general_components.py | 2 ++ gcd/src/strategies/single_image_gmc_mlp_proxy_training.py | 2 ++ 3 files changed, 6 insertions(+) diff --git a/gcd/src/direction_transfer_hf_batch.py b/gcd/src/direction_transfer_hf_batch.py index 4dda4dc..258edb1 100644 --- a/gcd/src/direction_transfer_hf_batch.py +++ b/gcd/src/direction_transfer_hf_batch.py @@ -97,6 +97,8 @@ def iter_hf_by_label(dataset_name, split, label, start_index, n_samples, token=N count += 1 def pil_to_tensor_01(img: Image.Image, image_size: int): + if img.mode != 'RGB': + img = img.convert('RGB') transform = T.Compose([ T.CenterCrop(image_size), T.Resize(image_size), diff --git a/gcd/src/losses/counterfactual_loss_general_components.py b/gcd/src/losses/counterfactual_loss_general_components.py index 3b44890..39f54ca 100644 --- a/gcd/src/losses/counterfactual_loss_general_components.py +++ b/gcd/src/losses/counterfactual_loss_general_components.py @@ -177,6 +177,8 @@ def load_src_img_from_hf( img = sample['image'] if not isinstance(img, Image.Image): img = Image.fromarray(img) + if img.mode != 'RGB': + img = img.convert('RGB') transform = T.Compose([ T.CenterCrop(image_size), T.Resize(image_size), diff --git a/gcd/src/strategies/single_image_gmc_mlp_proxy_training.py b/gcd/src/strategies/single_image_gmc_mlp_proxy_training.py index be9fc82..60f260a 100644 --- a/gcd/src/strategies/single_image_gmc_mlp_proxy_training.py +++ b/gcd/src/strategies/single_image_gmc_mlp_proxy_training.py @@ -153,6 +153,8 @@ def load_src_img_from_hf( img = sample['image'] if not isinstance(img, Image.Image): img = Image.fromarray(img) + if img.mode != 'RGB': + img = img.convert('RGB') # Resize/center-crop to target image_size transform = T.Compose([ T.CenterCrop(image_size), From 39e51181e9948c1d0ed6156746f0f75090932dfb Mon Sep 17 00:00:00 2001 From: galkowskim Date: Sun, 23 Nov 2025 19:29:08 +0100 Subject: [PATCH 17/24] extend script --- gcd/src/direction_transfer_hf_batch.py | 23 ++++++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/gcd/src/direction_transfer_hf_batch.py b/gcd/src/direction_transfer_hf_batch.py index 258edb1..537c423 100644 --- a/gcd/src/direction_transfer_hf_batch.py +++ b/gcd/src/direction_transfer_hf_batch.py @@ -27,7 +27,7 @@ def get_cfg(args): cfg = omegaconf.OmegaConf.load(path_cfg) return cfg -def init_loss(cfg): +def init_loss(cfg, args): loss_kwargs = cfg.strategy.ce_loss_kwargs loss_kwargs['make_output_dir'] = False clf_kwargs = loss_kwargs.components.clf @@ -38,6 +38,11 @@ def init_loss(cfg): elif clf_name == 'ResNet': clf_kwargs.use_softmax_and_query_label = True clf = init_class_from_string(clf_name)(**clf_kwargs) + # Override target/query class if provided for this direction + if getattr(args, "direction_target_class", None) is not None: + if hasattr(clf, "query_label"): + clf.query_label = int(args.direction_target_class) + log.info(f"Set classifier query_label to direction target class: {clf.query_label}") comps_kwargs = loss_kwargs.components comps_name = comps_kwargs._target_.split('.')[-1] comps_kwargs.pop('_target_') @@ -116,7 +121,7 @@ def main(args): log.info("Batch direction transfer (HF ImageNet)") cfg = get_cfg(args) device = torch.device(cfg.device) - loss = init_loss(cfg) + loss = init_loss(cfg, args) dae = init_dae(cfg) direction = load_direction(args.direction_path, device) @@ -156,15 +161,26 @@ def main(args): # Log classifier probabilities (optional) with torch.no_grad(): comps = loss.get_components(img_trans) + # Query-label probabilities prob_after = loss.get_query_label_probability(comps['predictions']).item() pred_before_logits = loss.components.classifier(img01) prob_before = loss.get_query_label_probability(pred_before_logits).item() + # Max-arg probabilities and classes (multiclass) + probs_after_all = torch.softmax(comps['predictions'], dim = 1) + best_after_prob, best_after_cls = probs_after_all.max(dim = 1) + probs_before_all = torch.softmax(pred_before_logits, dim = 1) + best_before_prob, best_before_cls = probs_before_all.max(dim = 1) csv_rows.append({ "idx": int(idx), "orig_path": orig_name, "inpaint_path": trans_name, "prob_before": prob_before, - "prob_after": prob_after + "prob_after": prob_after, + "pred_before_best_cls": int(best_before_cls.item()), + "pred_before_best_prob": float(best_before_prob.item()), + "pred_after_best_cls": int(best_after_cls.item()), + "pred_after_best_prob": float(best_after_prob.item()), + "direction_target_class": int(getattr(args, "direction_target_class", -1)) }) n_done += 1 if n_done % 25 == 0: @@ -176,6 +192,7 @@ def main(args): if __name__ == '__main__': parser = ArgumentParser() parser.add_argument('--direction-path', type=str, required=True, help='Path to direction .pt file from proxy run') + parser.add_argument('--direction-target-class', type=int, required=True, help='Target class id the direction was trained for (query label)') parser.add_argument('--log-dir-path', type=str, default=None, help='Log dir path with .hydra config (defaults to infer from direction path)') parser.add_argument('--dataset-name', type=str, default='imagenet-1k', help='HF dataset name') parser.add_argument('--split', type=str, default='train', help='HF split') From c27c1a7f80bd6f8eec79734ba48584d60b19976a Mon Sep 17 00:00:00 2001 From: galkowskim Date: Fri, 28 Nov 2025 09:14:51 +0100 Subject: [PATCH 18/24] add script --- gcd/src/tools/copy_and_rename_pairs.py | 67 ++++++++++++++++++++++++++ 1 file changed, 67 insertions(+) create mode 100644 gcd/src/tools/copy_and_rename_pairs.py diff --git a/gcd/src/tools/copy_and_rename_pairs.py b/gcd/src/tools/copy_and_rename_pairs.py new file mode 100644 index 0000000..42fd8b2 --- /dev/null +++ b/gcd/src/tools/copy_and_rename_pairs.py @@ -0,0 +1,67 @@ +import argparse +import os +import re +import shutil +from pathlib import Path + + +def main(): + parser = argparse.ArgumentParser( + description="Copy direction_transfer outputs to a new folder with standardized names." + ) + parser.add_argument("--src-dir", required=True, help="Path to direction_transfer subdir (contains original_*.png, inpaint_*.png)") + parser.add_argument("--dst-dir", required=True, help="Destination directory to write renamed copies") + parser.add_argument("--only-pairs", action="store_true", help="Copy only when both original_k.png and inpaint_k.png exist") + args = parser.parse_args() + + src = Path(args.src_dir) + dst = Path(args.dst_dir) + dst.mkdir(parents=True, exist_ok=True) + + # Patterns: original_1.png, inpaint_1.png + pat_orig = re.compile(r"^original_(\d+)\.png$") + pat_inpt = re.compile(r"^inpaint_(\d+)\.png$") + + indices_orig = {} + indices_inpt = {} + + for entry in os.listdir(src): + m1 = pat_orig.match(entry) + if m1: + idx = int(m1.group(1)) + indices_orig[idx] = src / entry + continue + m2 = pat_inpt.match(entry) + if m2: + idx = int(m2.group(1)) + indices_inpt[idx] = src / entry + + # Decide which indices to process + if args.only_pairs: + all_idx = sorted(set(indices_orig.keys()).intersection(indices_inpt.keys())) + else: + all_idx = sorted(set(indices_orig.keys()).union(indices_inpt.keys())) + + copied_images = 0 + copied_inpaints = 0 + + for i in all_idx: + z = str(i).zfill(5) + if i in indices_orig: + src_path = indices_orig[i] + tgt_path = dst / f"images_{z}.png" + shutil.copy2(src_path, tgt_path) + copied_images += 1 + if i in indices_inpt: + src_path = indices_inpt[i] + tgt_path = dst / f"inpaints_{z}.png" + shutil.copy2(src_path, tgt_path) + copied_inpaints += 1 + + print(f"Done. Wrote {copied_images} images_*.png and {copied_inpaints} inpaints_*.png into {dst}") + + +if __name__ == "__main__": + main() + + From 4813f1efb775fbe07eee0bbd93f5ead2339e4adb Mon Sep 17 00:00:00 2001 From: galkowskim Date: Sat, 29 Nov 2025 08:19:18 +0100 Subject: [PATCH 19/24] modify script --- gcd/src/direction_transfer_hf_batch.py | 117 ++++++++++++++++++++++--- 1 file changed, 106 insertions(+), 11 deletions(-) diff --git a/gcd/src/direction_transfer_hf_batch.py b/gcd/src/direction_transfer_hf_batch.py index 537c423..03c632d 100644 --- a/gcd/src/direction_transfer_hf_batch.py +++ b/gcd/src/direction_transfer_hf_batch.py @@ -9,6 +9,7 @@ from datasets import load_dataset from PIL import Image import torchvision.transforms as T +from typing import Optional, Dict from proxies import * from losses import * @@ -19,6 +20,17 @@ logging.basicConfig(level = logging.INFO) log = logging.getLogger(__name__) +# +# Fill these with your paths if you want per-split defaults instead of passing --predictions-csv +# +PREDICTIONS_PATHS: Dict[str, Dict[str, Optional[str]]] = { + # Example: + # 'imagenet-1k': { + # 'train': '/abs/path/to/imagenet1k_train_predictions.csv', + # 'validation': '/abs/path/to/imagenet1k_val_predictions.csv', + # }, +} + def init_class_from_string(class_name): return getattr(sys.modules[__name__], class_name) @@ -79,7 +91,51 @@ def load_direction(path, device): raise ValueError("Direction tensor contains only zeros.") return grad -def iter_hf_by_label(dataset_name, split, label, start_index, n_samples, token=None, cache_dir=None): +def get_predictions_df(dataset_name: str, split: str, explicit_csv: Optional[str]) -> Optional[pd.DataFrame]: + """ + Loads predictions CSV expected to have: + - index column: 'idx' + - column: 'pred_label' + If explicit_csv is None, attempts to use PREDICTIONS_PATHS[dataset_name][split]. + Returns None if no path provided. + """ + csv_path = explicit_csv + if csv_path is None: + csv_path = PREDICTIONS_PATHS.get(dataset_name, {}).get(split, None) + if csv_path is None: + return None + if not os.path.exists(csv_path): + raise FileNotFoundError(f"Predictions CSV not found: {csv_path}") + df = pd.read_csv(csv_path, index_col="idx") + if "pred_label" not in df.columns: + raise ValueError(f"Predictions CSV missing 'pred_label' column: {csv_path}") + return df + +def _extract_ex_idx(example: dict, fallback: Optional[int] = None) -> Optional[int]: + """ + Try to extract a stable dataset index from a HF streaming example. + Prefers '__index_level_0__', then 'id', then 'idx'. Falls back to provided fallback. + """ + for key in ("__index_level_0__", "id", "idx"): + if key in example: + try: + return int(example[key]) + except Exception: + continue + return fallback + +def iter_hf_by_label( + dataset_name, + split, + label, + start_index, + n_samples, + token=None, + cache_dir=None, + predictions_df: Optional[pd.DataFrame] = None, + filter_id: Optional[int] = None, + n_skip: int = 0, +): kwargs = {"streaming": True} if token is None: token = os.getenv('HF_TOKEN', None) @@ -90,16 +146,40 @@ def iter_hf_by_label(dataset_name, split, label, start_index, n_samples, token=N if cache_dir is not None: kwargs["cache_dir"] = cache_dir ds = load_dataset(dataset_name, split=split, **kwargs) - count = 0 + count_label_hits = 0 emitted = 0 + skipped = 0 + global_idx = -1 for ex in ds: - if ex.get('label', None) == label: - if count >= start_index: - yield ex - emitted += 1 - if emitted >= n_samples: - break - count += 1 + global_idx += 1 + if ex.get('label', None) != label: + continue + # Count only matching label occurrences for start_index logic + if count_label_hits < start_index: + count_label_hits += 1 + continue + + # If predictions filtering is requested, check it here + if predictions_df is not None and filter_id is not None: + ex_idx = _extract_ex_idx(ex, fallback=None) + if ex_idx is None or ex_idx not in predictions_df.index: + # If we can't find a matching prediction idx, skip + continue + pred_label = predictions_df.loc[ex_idx, "pred_label"] + # Keep only correctly predicted examples for the requested class + # Since ex['label'] == label, correctness means pred_label == label == filter_id + if int(pred_label) != int(filter_id): + continue + + # Apply n_skip on the filtered stream + if skipped < n_skip: + skipped += 1 + continue + + yield ex + emitted += 1 + if emitted >= n_samples: + break def pil_to_tensor_01(img: Image.Image, image_size: int): if img.mode != 'RGB': @@ -131,6 +211,14 @@ def main(args): image_size = dae.config.img_size if hasattr(dae, 'config') else 256 n_done = 0 + # Optional predictions filtering + predictions_df = None + if args.filter_id is not None: + predictions_df = get_predictions_df(args.dataset_name, args.split, args.predictions_csv) + if predictions_df is None: + log.error(f"Filtering requested but no predictions CSV provided for {args.dataset_name}/{args.split}.") + log.error("Provide --predictions-csv or fill PREDICTIONS_PATHS in the script.") + sys.exit(1) for ex in iter_hf_by_label( dataset_name=args.dataset_name, split=args.split, @@ -138,10 +226,13 @@ def main(args): start_index=args.start_index, n_samples=args.n_samples, token=args.hf_token, - cache_dir=args.hf_cache_dir + cache_dir=args.hf_cache_dir, + predictions_df=predictions_df, + filter_id=args.filter_id, + n_skip=args.n_skip ): img_pil = ex['image'] if isinstance(ex['image'], Image.Image) else Image.fromarray(ex['image']) - idx = ex['__index_level_0__'] if '__index_level_0__' in ex else n_done + idx = _extract_ex_idx(ex, fallback=n_done) # Prepare tensors img01 = pil_to_tensor_01(img_pil, image_size).unsqueeze(0).to(device) img_m1p1 = to_dae_range(img01) @@ -204,6 +295,10 @@ def main(args): parser.add_argument('--step-size', type=float, default=1.0, help='Step size multiplier for the direction') parser.add_argument('--T-render', type=int, default=100, help='Number of DDIM steps for rendering') parser.add_argument('--subdir-name', type=str, default='hf_batch', help='Output subdir name under direction_transfer') + # Optional predictions-based filtering (keep only correctly predicted examples for --filter-id) + parser.add_argument('--filter-id', type=int, default=None, help='If set, keep only samples with this GT label that are correctly predicted as this label') + parser.add_argument('--n-skip', type=int, default=0, help='Skip first N samples after filtering (useful for sharding)') + parser.add_argument('--predictions-csv', type=str, default=None, help='Optional override path to predictions CSV (index=idx, col=pred_label)') args = parser.parse_args() main(args) From 1f834dcc3bf985a13b5600cb87298ad5b95a6571 Mon Sep 17 00:00:00 2001 From: galkowskim Date: Sat, 29 Nov 2025 10:01:23 +0100 Subject: [PATCH 20/24] adapt script --- gcd/src/direction_transfer_hf_batch.py | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/gcd/src/direction_transfer_hf_batch.py b/gcd/src/direction_transfer_hf_batch.py index 03c632d..e8ca6cc 100644 --- a/gcd/src/direction_transfer_hf_batch.py +++ b/gcd/src/direction_transfer_hf_batch.py @@ -135,6 +135,7 @@ def iter_hf_by_label( predictions_df: Optional[pd.DataFrame] = None, filter_id: Optional[int] = None, n_skip: int = 0, + pred_index_scope: str = "global", # 'global' or 'label' ): kwargs = {"streaming": True} if token is None: @@ -146,22 +147,22 @@ def iter_hf_by_label( if cache_dir is not None: kwargs["cache_dir"] = cache_dir ds = load_dataset(dataset_name, split=split, **kwargs) - count_label_hits = 0 + count_label_hits = 0 # counts examples with ex['label'] == label (zero-based) emitted = 0 skipped = 0 - global_idx = -1 - for ex in ds: - global_idx += 1 + for global_idx, ex in enumerate(ds): if ex.get('label', None) != label: continue # Count only matching label occurrences for start_index logic - if count_label_hits < start_index: - count_label_hits += 1 + current_label_idx = count_label_hits # zero-based index of label-matching sample + count_label_hits += 1 + if current_label_idx < start_index: continue # If predictions filtering is requested, check it here if predictions_df is not None and filter_id is not None: - ex_idx = _extract_ex_idx(ex, fallback=None) + # Choose index to match predictions: global or label-only enumeration + ex_idx = global_idx if pred_index_scope == "global" else current_label_idx if ex_idx is None or ex_idx not in predictions_df.index: # If we can't find a matching prediction idx, skip continue @@ -176,6 +177,10 @@ def iter_hf_by_label( skipped += 1 continue + # Attach derived indices for downstream logging if needed + ex = dict(ex) + ex["_global_idx_enumerate"] = global_idx + ex["_label_idx_enumerate"] = current_label_idx yield ex emitted += 1 if emitted >= n_samples: @@ -229,10 +234,12 @@ def main(args): cache_dir=args.hf_cache_dir, predictions_df=predictions_df, filter_id=args.filter_id, - n_skip=args.n_skip + n_skip=args.n_skip, + pred_index_scope=args.pred_index_scope ): img_pil = ex['image'] if isinstance(ex['image'], Image.Image) else Image.fromarray(ex['image']) - idx = _extract_ex_idx(ex, fallback=n_done) + # Log the index used to match predictions for traceability + idx = ex.get("_global_idx_enumerate", n_done) if args.pred_index_scope == "global" else ex.get("_label_idx_enumerate", n_done) # Prepare tensors img01 = pil_to_tensor_01(img_pil, image_size).unsqueeze(0).to(device) img_m1p1 = to_dae_range(img01) From 33a8f895ef27a289a9d2bc15e8fadd7035aff3a7 Mon Sep 17 00:00:00 2001 From: galkowskim Date: Sat, 29 Nov 2025 10:21:06 +0100 Subject: [PATCH 21/24] modify script --- .../single_image_gmc_mlp_proxy_training.py | 65 +++++++++++++++++-- 1 file changed, 59 insertions(+), 6 deletions(-) diff --git a/gcd/src/strategies/single_image_gmc_mlp_proxy_training.py b/gcd/src/strategies/single_image_gmc_mlp_proxy_training.py index 60f260a..66ccaac 100644 --- a/gcd/src/strategies/single_image_gmc_mlp_proxy_training.py +++ b/gcd/src/strategies/single_image_gmc_mlp_proxy_training.py @@ -14,6 +14,7 @@ from datasets import load_dataset import torchvision.transforms as T from PIL import Image +import os import logging logging.basicConfig(level = logging.INFO) @@ -54,7 +55,13 @@ def __init__( hf_dataset_name: str = None, # e.g., 'imagenet-1k' hf_split: str = 'train', hf_label: int = None, # required if using HF - hf_index: int = 0): # which occurrence of that label to use + hf_index: int = 0, # which occurrence of that label to use + # Optional prediction-based filtering: pick a source image the classifier predicts correctly + hf_predictions_csv: str = None, # CSV with index 'idx' and column 'pred_label' + hf_pred_index_scope: str = 'global', # 'global' (enumerate whole split) or 'label' (enumerate only matching label) + hf_n_skip_filtered: int = 0, # skip first N after filtering (sharding/offset) + hf_filter_by_predictions: bool = False, # enable filtering + ): """ src_img_path - path of the source image n_iters - number of outer loop iterations @@ -96,6 +103,10 @@ def __init__( split = hf_split, label = hf_label, index = hf_index, + predictions_csv = hf_predictions_csv, + pred_index_scope = hf_pred_index_scope, + n_skip_filtered = hf_n_skip_filtered, + filter_by_predictions = hf_filter_by_predictions, hf_token = os.getenv('HF_TOKEN', None), cache_dir = os.getenv('DATASET_CACHE', None), image_size = self.dae.config.img_size if hasattr(self.dae, 'config') else 256 @@ -124,12 +135,20 @@ def load_src_img_from_hf( split: str, label: int, index: int = 0, + predictions_csv: str = None, + pred_index_scope: str = 'label', # 'global' | 'label' + n_skip_filtered: int = 0, + filter_by_predictions: bool = False, hf_token: str = None, cache_dir: str = None, image_size: int = 256, ): """ Load one source image from a HF dataset (e.g., 'imagenet-1k') by label id. + If filter_by_predictions is True, uses predictions_csv to keep only samples correctly + predicted as 'label'. You can choose how indices are matched via pred_index_scope: + - 'global': idx equals enumerate() over whole split (0-based) + - 'label' : idx equals enumerate() over only samples with ex['label'] == label (0-based) """ log.info(f'Loading source image from HF dataset: {dataset_name}, split: {split}, label: {label}, index: {index}') # Use streaming iteration to avoid materializing a full filtered copy (much faster) @@ -140,16 +159,50 @@ def load_src_img_from_hf( stream_kwargs["cache_dir"] = cache_dir ds_stream = load_dataset(dataset_name, split=split, **stream_kwargs) - found = -1 + preds_df = None + if filter_by_predictions: + if predictions_csv is None or not os.path.exists(predictions_csv): + raise FileNotFoundError(f"Filtering by predictions requested but CSV not found: {predictions_csv}") + preds_df = pd.read_csv(predictions_csv, index_col='idx') + if 'pred_label' not in preds_df.columns: + raise ValueError(f"Predictions CSV missing 'pred_label' column: {predictions_csv}") + + label_enum = -1 # counts only samples with matching label + global_enum = -1 # counts all samples + filtered_seen = -1 # counts only samples that passed predictions filter sample = None for ex in ds_stream: - if ex.get("label", None) == label: - found += 1 - if found == index: + global_enum += 1 + # keep only target label + if ex.get("label", None) != label: + continue + label_enum += 1 + + if filter_by_predictions: + idx_to_check = global_enum if pred_index_scope == 'global' else label_enum + if idx_to_check not in preds_df.index: + continue + pred_label = int(preds_df.loc[idx_to_check, 'pred_label']) + if pred_label != int(label): + continue + filtered_seen += 1 + if filtered_seen < n_skip_filtered: + continue + # pick by filtered order + if filtered_seen == index: + sample = ex + break + else: + # original behavior: pick by label occurrence + if label_enum == index: sample = ex break if sample is None: - raise ValueError(f"Could not find occurrence index {index} for label {label} in {dataset_name}/{split}") + if filter_by_predictions: + raise ValueError(f"Could not find filtered occurrence index {index} for label {label} in {dataset_name}/{split} " + f"(pred_index_scope={pred_index_scope}, n_skip_filtered={n_skip_filtered})") + else: + raise ValueError(f"Could not find occurrence index {index} for label {label} in {dataset_name}/{split}") img = sample['image'] if not isinstance(img, Image.Image): img = Image.fromarray(img) From 8b402ab9a22800b9597dbbfe63ac29f47d6e26fa Mon Sep 17 00:00:00 2001 From: galkowskim Date: Sat, 29 Nov 2025 13:37:53 +0100 Subject: [PATCH 22/24] add script --- images_prep.py | 58 ++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 58 insertions(+) create mode 100644 images_prep.py diff --git a/images_prep.py b/images_prep.py new file mode 100644 index 0000000..85a9393 --- /dev/null +++ b/images_prep.py @@ -0,0 +1,58 @@ +#!/usr/bin/env python3 +import argparse +import re +from pathlib import Path +import shutil + +def main(): + p = argparse.ArgumentParser(description="Pair originals and 'ours' inpaints, rename to images_00001.png/inpaints_00001.png") + p.add_argument("input_folder", type=Path) + p.add_argument("output_folder", type=Path) + p.add_argument("--ext", default="png", choices=["png","jpg","jpeg"], help="Image extension to process") + p.add_argument("--start-index", type=int, default=1, help="Starting index for numbering") + p.add_argument("--dry-run", action="store_true") + args = p.parse_args() + + inp = args.input_folder + out = args.output_folder + out.mkdir(parents=True, exist_ok=True) + + # originals look like: 13.png + original_re = re.compile(rf"^(\d+)\.{re.escape(args.ext)}$", re.IGNORECASE) + + originals = [] + for pth in sorted(inp.glob(f"*.{args.ext}")): + m = original_re.match(pth.name) + if m: + originals.append((int(m.group(1)), pth)) + + idx = args.start_index + for img_id, orig_path in originals: + # inpaint like: 13_340_zebra_ours.png (anything ending with _ours.) + candidates = list(inp.glob(f"{img_id}_*_ours.{args.ext}")) + if not candidates: + # try case-insensitive ext variations + candidates = [p for p in inp.iterdir() + if p.is_file() and p.name.lower().startswith(f"{img_id}_") + and p.name.lower().endswith(f"_ours.{args.ext}")] + if not candidates: + # no pair -> skip + continue + + inpaint_path = sorted(candidates)[0] # pick first if multiple + + z = str(idx).zfill(5) + dst_img = out / f"images_{z}.{args.ext}" + dst_inp = out / f"inpaints_{z}.{args.ext}" + + if args.dry_run: + print(f"[DRY] {orig_path} -> {dst_img}") + print(f"[DRY] {inpaint_path} -> {dst_inp}") + else: + shutil.copy2(orig_path, dst_img) + shutil.copy2(inpaint_path, dst_inp) + + idx += 1 + +if __name__ == "__main__": + main() \ No newline at end of file From c1bfc57d04b7dcf9e0365836d90d97e60193c049 Mon Sep 17 00:00:00 2001 From: galkowskim Date: Sat, 29 Nov 2025 13:38:00 +0100 Subject: [PATCH 23/24] update --- images_prep.py | 1 - 1 file changed, 1 deletion(-) diff --git a/images_prep.py b/images_prep.py index 85a9393..5d4838b 100644 --- a/images_prep.py +++ b/images_prep.py @@ -1,4 +1,3 @@ -#!/usr/bin/env python3 import argparse import re from pathlib import Path From 2791071b78c8e3e61dd905d8167cf906e9663934 Mon Sep 17 00:00:00 2001 From: galkowskim Date: Sat, 29 Nov 2025 13:41:11 +0100 Subject: [PATCH 24/24] rm script --- images_prep.py | 57 -------------------------------------------------- 1 file changed, 57 deletions(-) delete mode 100644 images_prep.py diff --git a/images_prep.py b/images_prep.py deleted file mode 100644 index 5d4838b..0000000 --- a/images_prep.py +++ /dev/null @@ -1,57 +0,0 @@ -import argparse -import re -from pathlib import Path -import shutil - -def main(): - p = argparse.ArgumentParser(description="Pair originals and 'ours' inpaints, rename to images_00001.png/inpaints_00001.png") - p.add_argument("input_folder", type=Path) - p.add_argument("output_folder", type=Path) - p.add_argument("--ext", default="png", choices=["png","jpg","jpeg"], help="Image extension to process") - p.add_argument("--start-index", type=int, default=1, help="Starting index for numbering") - p.add_argument("--dry-run", action="store_true") - args = p.parse_args() - - inp = args.input_folder - out = args.output_folder - out.mkdir(parents=True, exist_ok=True) - - # originals look like: 13.png - original_re = re.compile(rf"^(\d+)\.{re.escape(args.ext)}$", re.IGNORECASE) - - originals = [] - for pth in sorted(inp.glob(f"*.{args.ext}")): - m = original_re.match(pth.name) - if m: - originals.append((int(m.group(1)), pth)) - - idx = args.start_index - for img_id, orig_path in originals: - # inpaint like: 13_340_zebra_ours.png (anything ending with _ours.) - candidates = list(inp.glob(f"{img_id}_*_ours.{args.ext}")) - if not candidates: - # try case-insensitive ext variations - candidates = [p for p in inp.iterdir() - if p.is_file() and p.name.lower().startswith(f"{img_id}_") - and p.name.lower().endswith(f"_ours.{args.ext}")] - if not candidates: - # no pair -> skip - continue - - inpaint_path = sorted(candidates)[0] # pick first if multiple - - z = str(idx).zfill(5) - dst_img = out / f"images_{z}.{args.ext}" - dst_inp = out / f"inpaints_{z}.{args.ext}" - - if args.dry_run: - print(f"[DRY] {orig_path} -> {dst_img}") - print(f"[DRY] {inpaint_path} -> {dst_inp}") - else: - shutil.copy2(orig_path, dst_img) - shutil.copy2(inpaint_path, dst_inp) - - idx += 1 - -if __name__ == "__main__": - main() \ No newline at end of file