-
Notifications
You must be signed in to change notification settings - Fork 6
Expand file tree
/
Copy pathrun_attacks.py
More file actions
115 lines (97 loc) · 4.61 KB
/
run_attacks.py
File metadata and controls
115 lines (97 loc) · 4.61 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import os
from datetime import datetime
os.environ["CUBLAS_WORKSPACE_CONFIG"] = ":4096:8" # determinism
import logging
import hydra
import torch
from omegaconf import DictConfig, ListConfig, OmegaConf
from adversariallm.attacks import Attack, AttackResult
from adversariallm.dataset import PromptDataset
from adversariallm.errors import print_exceptions
from adversariallm.io_utils import RunConfig, filter_config, free_vram, load_model_and_tokenizer, log_attack
from run_judges import run_judges
torch.use_deterministic_algorithms(True, warn_only=True)
torch.backends.cuda.matmul.allow_tf32 = True
torch._dynamo.config.recompile_limit = 512 # needed for gemma 3 on AutoDAN
def select_configs(cfg: DictConfig, name: str | ListConfig | None) -> list[tuple[str, DictConfig]]:
if name is not None:
if isinstance(name, ListConfig):
return [(n, cfg[n]) for n in name]
return [(name, cfg[name])]
return list(cfg.items())
def collect_configs(cfg: DictConfig) -> list[RunConfig]:
models_to_run = select_configs(cfg.models, cfg.model)
datasets_to_run = select_configs(cfg.datasets, cfg.dataset)
attacks_to_run = select_configs(cfg.attacks, cfg.attack)
all_run_configs = []
for model, model_params in models_to_run:
for dataset, dataset_params in datasets_to_run:
temp_dataset = PromptDataset.from_name(dataset)(dataset_params)
dset_len = len(temp_dataset)
dataset_params["idx"] = temp_dataset.config_idx
for attack, attack_params in attacks_to_run:
run_config = RunConfig(
model,
dataset,
attack,
model_params,
dataset_params,
attack_params,
)
run_config = filter_config(run_config, dset_len, overwrite=cfg.overwrite)
if run_config is not None:
all_run_configs.append(run_config)
return all_run_configs
def run_attacks(all_run_configs: list[RunConfig], cfg: DictConfig, date_time_string: str) -> None:
last_model = None
last_dataset = None
last_attack = None
for run_config in all_run_configs:
# To avoid reloading the model and dataset for every attack,
# we only reload something if it's different from the last run
if last_model != run_config.model:
logging.info(f"Target: {run_config.model}\n{OmegaConf.to_yaml(run_config.model_params, resolve=True)}")
last_model = run_config.model
model, tokenizer = load_model_and_tokenizer(run_config.model_params)
if last_dataset != run_config.dataset:
logging.info(f"Dataset: {run_config.dataset}\n{OmegaConf.to_yaml(run_config.dataset_params, resolve=True)}")
last_dataset = run_config.dataset
dataset = PromptDataset.from_name(run_config.dataset)(run_config.dataset_params)
if last_attack != run_config.attack:
logging.info(f"Attack: {run_config.attack}\n{OmegaConf.to_yaml(run_config.attack_params, resolve=True)}")
last_attack = run_config.attack
attack: Attack[AttackResult] = Attack.from_name(run_config.attack)(run_config.attack_params)
results = attack.run(model, tokenizer, dataset) # type: ignore
log_attack(run_config, results, cfg, date_time_string)
@hydra.main(config_path="./conf", config_name="config", version_base="1.3")
@print_exceptions
def main(cfg: DictConfig) -> None:
os.makedirs(cfg.save_dir, exist_ok=True)
date_time_string = datetime.now().strftime("%Y-%m-%d/%H-%M-%S")
logging.info("-------------------")
logging.info(f"Commencing run at `{date_time_string}`")
logging.info("-------------------")
# 1. Parse/Collect configs for the judge and the attacks
OmegaConf.set_struct(cfg, False)
# Remove classifiers from config to avoid saving to mongodb
# This way we don't re-run the attacks when only the classifiers change
judges_to_run = cfg.pop('classifiers')
OmegaConf.set_struct(cfg, True)
all_run_configs = collect_configs(cfg)
# 2. Run the attacks
run_attacks(all_run_configs, cfg, date_time_string)
# 3. Run the judges
if judges_to_run is None:
return
for judge in judges_to_run:
# Create judge config
judge_cfg = OmegaConf.create({
"classifier": judge,
"suffixes": [date_time_string.split('/')[-1]], # Use the timestamp from this run to make sure we only judge this run
"filter_by": None
})
free_vram()
# Run the judge
run_judges(judge_cfg)
if __name__ == "__main__":
main()