Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
81 changes: 81 additions & 0 deletions commands.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
### GCD ImageNet (zebra ↔ sorrel) — quick commands

Prereqs
- You already trained a DiffAE autoencoder at 256×256 and have its checkpoint path (last.ckpt).
- You have a 256×256 source image (.png/.jpg) for zebra or sorrel.
- Optional: a custom ImageNet classifier checkpoint (otherwise torchvision ResNet-50 pretrained is used).

Setup (one-time)
```bash
cd /Users/mgalkowski/Desktop/diffae/gcd
conda create --name gcd python=3.11 -y
pip install -r requirements.txt --extra-index-url https://download.pytorch.org/whl/cu118
```

Zebra example (label 340):

```bash
python src/main.py \
--config-path ../configs/single_image_gmc_mlp_proxy_training/imagenet/resnet/zebra \
--config-name config \
strategy.src_img_path=null \
strategy.hf_dataset_name=imagenet-1k \
strategy.hf_split=train \
strategy.hf_label=340 \
strategy.hf_index=0 \
device=cuda:0 \
wandb.project=null
```

Run: sorrel (class 339, sorrel→not‑sorrel direction)
```bash
cd /Users/mgalkowski/Desktop/diffae/gcd/gcd
python src/main.py \
--config-path ../configs/single_image_gmc_mlp_proxy_training/imagenet/resnet/sorrel \
--config-name config \
strategy.src_img_path="$SRC_IMG" \
strategy.dae_kwargs.path_ckpt="$DAE_CKPT" \
device=cuda:0 \
wandb.project=null
```

Notes
- The configs use the default 3‑channel DiffAE wrapper and a ResNet‑50 ImageNet classifier.
- To use a custom classifier, add at the end of the command:
- `strategy.ce_loss_kwargs.components.clf.path_to_weights="/path/to/classifier.ckpt"`
- Outputs are written under `gcd/gcd/outputs/<date_time>/strategy/`. Proxy data and candidate directions are in `strategy/proxy/...`.

Direction transfer (apply a saved direction to another image)
```bash
# Example placeholders — update paths after a run completes
export LOG_DIR="/Users/mgalkowski/Desktop/diffae/gcd/gcd/outputs/<date_time>/strategy"
export DIR_PATH="$LOG_DIR/proxy/<some_subdir>/<direction_file>.pt" # e.g., *_grad_*.pt
export TARGET_IMG="/absolute/path/to/another_image.png"

python src/direction_transfer.py \
--log-dir-path "$LOG_DIR" \
--direction-path "$DIR_PATH" \
--img-path "$TARGET_IMG" \
--subdir-name "transfer_run_1"
```

```bash
python src/direction_transfer_hf_batch.py \
--direction-path "/path/to/outputs/<date_time>/proxy/0/0_grad_*.pt" \
--log-dir-path "/path/to/outputs/<date_time>" \
--dataset-name imagenet-1k \
--split train \
--target-class 340 \
--n-samples 500 \
--step-size 1.0 \
--T-render 100 \
--subdir-name "hf_batch_zebra"
```

Tips for memory/time
- If you see OOM, reduce: `strategy.dae_kwargs.batch_size` (e.g., 128) or narrow `strategy.dae_kwargs.std`.
- To speed up renders, you can lower `strategy.dae_kwargs.T_render` (e.g., 20–50).

Switching class direction
- Zebra run uses `query_label=340`; sorrel run uses `query_label=339`. Use the matching config.

Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
mode: RUN
seed: 0
device: cuda:0
output_dir:

strategy:
_target_: strategies.SingleImageGMCMLPProxyTraining
src_img_path: ??? # path to source sorrel image
n_iters: 20
n_steps_dae: 8
n_steps_proxy: 512
min_max_step_size: [0.0005, 2.0]
device: ${device}
output_dir: ${output_dir}/strategy
center_to_src_latent_sem: true
line_search_weight_cls_list: [1.0, 0.05]
line_search_weight_lpips_list: [0, 1.0]
n_hessian_eigenvecs: 16
dae_type: default

dae_kwargs:
config_name: imagenet256_autoenc
batch_size: 256
std: [0.01, 0.05, 0.1, 0.2]
distribution: uniform_norm
max_norm: 1.0
T_encode: 250
T_render: 50
path_ckpt: ??? # path to your trained DiffAE checkpoint (.ckpt)
device: ${device}
output_dir: ${output_dir}/dae

mc_mlp_proxy_kwargs:
shapes: [512, 256, 128, 64, 1001]
activ_fn: sigmoid
batch_size: 128
n_val_batches: 1
val_loss_history_length: 5
optimizer:
_partial_: true
_target_: torch.optim.AdamW
lr: 0.001
loss_kwargs:
weight_cls: 1.0
weight_lpips: 3.0
device: ${device}
output_dir: ${output_dir}/proxy

ce_loss_kwargs:
weight_cls: 1.0
weight_lpips: 0
device: ${device}
output_dir: ${output_dir}/loss
components:
_target_: losses.CounterfactualLossGeneralComponents
lpips_net: vgg
src_img_path: ${strategy.src_img_path}
device: ${device}
clf:
_target_: classifiers.ImageNetResNet
img_size: 256
query_label: 339 # sorrel class id
use_softmax_and_query_label: false
task: multiclass_classification
model_name: resnet50
weights: IMAGENET1K_V2
path_to_weights: null

wandb:
project: null
group: null
name: imagenet_sorrel

Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
mode: RUN
seed: 0
device: cuda:0
output_dir:

strategy:
_target_: strategies.SingleImageGMCMLPProxyTraining
src_img_path: ??? # path to source zebra image (.png or .jpg)
n_iters: 20
n_steps_dae: 8
n_steps_proxy: 512
min_max_step_size: [0.0005, 2.0]
device: ${device}
output_dir: ${output_dir}/strategy
center_to_src_latent_sem: true
line_search_weight_cls_list: [1.0, 0.05]
line_search_weight_lpips_list: [0, 1.0]
n_hessian_eigenvecs: 16
dae_type: default

dae_kwargs:
config_name: imagenet256_autoenc
batch_size: 256
std: [0.01, 0.05, 0.1, 0.2]
distribution: uniform_norm
max_norm: 1.0
T_encode: 250
T_render: 50
path_ckpt: ??? # path to your trained DiffAE checkpoint (.ckpt)
device: ${device}
output_dir: ${output_dir}/dae

mc_mlp_proxy_kwargs:
shapes: [512, 256, 128, 64, 1001]
activ_fn: sigmoid
batch_size: 128
n_val_batches: 1
val_loss_history_length: 5
optimizer:
_partial_: true
_target_: torch.optim.AdamW
lr: 0.001
loss_kwargs:
weight_cls: 1.0
weight_lpips: 3.0
device: ${device}
output_dir: ${output_dir}/proxy

ce_loss_kwargs:
weight_cls: 1.0
weight_lpips: 0
device: ${device}
output_dir: ${output_dir}/loss
components:
_target_: losses.CounterfactualLossGeneralComponents
lpips_net: vgg
src_img_path: ${strategy.src_img_path}
device: ${device}
clf:
_target_: classifiers.ImageNetResNet
img_size: 256
query_label: 340 # zebra class id
use_softmax_and_query_label: false
task: multiclass_classification
model_name: resnet50
weights: IMAGENET1K_V2
path_to_weights: null # optional: path to custom weights

wandb:
project: null
group: null
name: imagenet_zebra

3 changes: 2 additions & 1 deletion gcd/src/classifiers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
from .dense_net import DenseNet
from .dense_net_chexpert import DenseNetCheXpert
from .dense_net_chexpert import DenseNetCheXpert
from .imagenet import ImageNetResNet
87 changes: 87 additions & 0 deletions gcd/src/classifiers/imagenet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,87 @@
import torch
import torch.nn.functional as F
import torchvision
import torchvision.transforms as tt

import logging
logging.basicConfig(level = logging.INFO)
log = logging.getLogger(__name__)

IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD = (0.229, 0.224, 0.225)

class ImageNetResNet(torch.nn.Module):
def __init__(
self,
img_size: int = 256,
query_label: int = 340,
use_softmax_and_query_label: bool = False,
task: str = 'multiclass_classification',
model_name: str = 'resnet50',
weights: str = 'IMAGENET1K_V2',
path_to_weights: str = None,
debug_save_path: str = None):
super().__init__()
self.img_size = img_size
self.task = task
self.query_label = int(query_label)
self.use_softmax_and_query_label = use_softmax_and_query_label
self.transforms = tt.Compose([
tt.Normalize(mean = IMAGENET_MEAN, std = IMAGENET_STD)
])
self.debug_save_path = debug_save_path
self.model = self._build_model(model_name, weights, path_to_weights)
self.model.eval()

def _build_model(self, model_name, weights, path_to_weights):
log.info(f'Building torchvision model: {model_name}')
if hasattr(torchvision.models, model_name):
ctor = getattr(torchvision.models, model_name)
else:
raise ValueError(f'Unknown model_name: {model_name}')
if path_to_weights is None:
# Load torchvision pretrained weights
try:
w_enum = getattr(torchvision.models, f'{model_name.upper()}_Weights')
w = getattr(w_enum, weights)
model = ctor(weights = w)
except Exception:
log.warning('Falling back to default pretrained weights')
model = ctor(weights = 'IMAGENET1K_V2')
else:
model = ctor(weights = None)
log.info(f'Loading classifier weights from: {path_to_weights}')
ckpt = torch.load(path_to_weights, map_location = 'cpu')
state_dict = ckpt.get('state_dict', ckpt)
# if state_dict keys are prefixed (e.g., 'model.'), try to strip common prefixes
if any(k.startswith('model.') for k in state_dict.keys()):
state_dict = {k.split('model.', 1)[1]: v for k, v in state_dict.items() if k.startswith('model.')}
missing, unexpected = model.load_state_dict(state_dict, strict = False)
if missing:
log.warning(f'Missing keys when loading classifier: {len(missing)}')
if unexpected:
log.warning(f'Unexpected keys when loading classifier: {len(unexpected)}')
return model

@torch.no_grad()
def forward(self, x):
assert x.shape[-1] == self.img_size, 'Wrong input shape'
# Expect [0,1] range; rescale if in [-1,1]
if x.min() < 0:
assert x.min() >= -1. and x.max() <= 1.
log.info('Detected input outside [0,1], rescaling from [-1,1] to [0,1]')
x = (x + 1) / 2
# Save a debug preview (pre-normalization) if requested
if self.debug_save_path is not None:
try:
torchvision.utils.save_image(x, self.debug_save_path)
log.info(f"Saved classifier input preview (pre-normalization) to: {self.debug_save_path}")
except Exception as e:
log.warning(f"Failed to save debug image to {self.debug_save_path}: {e}")
x = self.transforms(x)
logits = self.model(x)
if self.use_softmax_and_query_label:
probs = F.softmax(logits, dim = 1)
return probs[:, self.query_label]
return logits

Loading