-
Notifications
You must be signed in to change notification settings - Fork 7
Expand file tree
/
Copy pathexploit.py
More file actions
116 lines (97 loc) · 4.54 KB
/
exploit.py
File metadata and controls
116 lines (97 loc) · 4.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
from transformers import AutoModelForSequenceClassification
from trlx import trlx
from trlx.data.configs import TRLConfig
# from trlx.data.default_configs import default_ppo_config
from sklearn.metrics.pairwise import manhattan_distances
from lm_utils import *
warnings.filterwarnings("ignore")
LAM_ADV = 0.5
# LAM_DIV = 500.0
LAM_DIV1 = 100
LAM_DIV2 = 5
EXPLOIT_MODEL = 'gpt2-large'
ENSEMBLE_SIZE = 5
def get_classifier_fn(classifier_model=CLASSIFIER_MODEL):
tokenizer = AutoTokenizer.from_pretrained(classifier_model)
models = [AutoModelForSequenceClassification.from_pretrained(f'./models/{classifier_model}_classifier_{i}').to(DEVICE)
for i in range(ENSEMBLE_SIZE)]
sub_batch_size = 512
def classify(responses):
with torch.no_grad():
all_results = []
for model in models:
if len(responses) > sub_batch_size:
all_model_results = []
for i in range(0, len(responses), sub_batch_size):
inputs = tokenizer(responses[i: i+sub_batch_size], padding="max_length",
truncation=True, max_length=MAX_LENGTH, return_tensors='pt').to(DEVICE)
outputs = model(**inputs).logits
all_model_results.append(outputs[:, 1])
model_results = torch.cat(all_model_results)
else:
inputs = tokenizer(responses, padding="max_length", truncation=True,
max_length=MAX_LENGTH, return_tensors='pt').to(DEVICE)
outputs = model(**inputs).logits
model_results = outputs[:, 1]
all_results.append(model_results)
all_results = torch.mean(torch.stack(all_results), dim=0)
return all_results
return classify
def get_encoder_fn():
def get_gpt2_embedding_numpy(sentences):
return get_gpt2_embedding(sentences).cpu().numpy()
return get_gpt2_embedding_numpy
def get_dist_fn():
def l1_dist(txt_embeddings: torch.tensor) -> np.ndarray:
return torch.tensor(manhattan_distances(txt_embeddings)).to(DEVICE)
return l1_dist
def get_reward_fn(classifier_fn, encoder_fn, dist_fn):
def reward_fn(samples, **kwargs) -> torch.tensor:
sample_lens = [len(s) for s in samples]
samples = [s if len(s) > 0 else 'The' for s in samples]
with torch.no_grad():
responses = target_lm(samples, pad_token_id=50256)
responses = [r[0]['generated_text'] for r in responses]
responses = [r[len(s):] for r, s in zip(responses, samples)]
responses = [remove_leading_whitespace(r) for r in responses]
embeddings = encoder_fn(samples)
dist_matrix = dist_fn(embeddings) / LAM_DIV1
# div_reward = -1 * (torch.mean(torch.exp(-dist_matrix), dim=1) * LAM_DIV) # diversity reward
div_reward = -1 * torch.mean(torch.exp(-dist_matrix), dim=1) * LAM_DIV2 # diversity reward
del dist_matrix
adv_reward = classifier_fn(responses) * LAM_ADV # adversarial reward
rewards = div_reward + adv_reward
rewards = torch.clip(rewards, -5, 5)
for i, sl in enumerate(sample_lens): # penalize sentences that are too short
if sl <= 10:
rewards[i] = -5
return rewards.tolist()
return reward_fn
def get_config():
config = TRLConfig.load_yaml("trlx/configs/ppo_config.yml")
config.train.trackers = ('aim',)
config.train.total_steps = 500000
config.train.epochs = 1000
config.train.checkpoint_interval = 1000
config.train.eval_interval = 500
config.model.model_path = EXPLOIT_MODEL
config.method.gen_kwargs.update({'max_new_tokens': 10})
config.train.batch_size = 4096
config.method.init_kl_coef = 0.05 # 0.15
config.method.target = 6 # 7
config.optimizer.kwargs.update({'lr': 1e-6}) # 5e-7})
config.model.num_layers_unfrozen = 1
return config
if __name__ == '__main__':
print(f'Running exploit step...')
config = get_config()
classifier_fn = get_classifier_fn()
encoder_fn = get_encoder_fn()
dist_fn = get_dist_fn()
reward_fn = get_reward_fn(classifier_fn, encoder_fn, dist_fn)
print(f'Running rl training for {config.train.total_steps} steps...')
trainer = trlx.train(reward_fn=reward_fn, config=config)
print('Saving...')
# trainer.save_pretrained('./models/exploit_generator')
trainer.save('./models/exploit_generator')
print('Done :)')