Skip to content

Commit 28ee322

Browse files
ganlerCopilot
andauthored
feat: data generation for ctx distillation and rule to code (#5)
* feat: data generation for ctx distillation and rule to code * hotfix * hotfix * chore * Update datagen/ctxdistill/ctxdistill.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update datagen/ctxdistill/ctxdistill.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
1 parent ed4a8db commit 28ee322

9 files changed

Lines changed: 1230 additions & 9 deletions

File tree

datagen/ctxdistill/ctxdistill.py

Lines changed: 365 additions & 0 deletions
Large diffs are not rendered by default.

datagen/ctxdistill/main.py

Lines changed: 57 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
# SPDX-FileCopyrightText: (c) UIUC PurpCode Team
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
from datagen.ctxdistill.ctxdistill import run_distillation
6+
7+
DEFAULT_SAMPLE_RATIO_MEDIUM = 0.25
8+
9+
# Core logic behind:
10+
# Pick a subset of data {D} total prompts -> {N} prompts
11+
# Ctx based sampling -> {N * S} where S is the sample size
12+
# Verfication and pick best of S -> getting {K} prompts with verified responses
13+
# SFT over {K} prompt-response pairs
14+
# RL on {D} - {N where most responses are right} prompts, i.e., rm very easy prompts
15+
16+
17+
def main(**kwargs):
18+
single_turn_datasets = [
19+
# mal event / single-turn
20+
(
21+
"purpcode/mal-event-jailbreak-single-oss-16k",
22+
DEFAULT_SAMPLE_RATIO_MEDIUM,
23+
4096,
24+
),
25+
(
26+
"purpcode/mal-event-seed-attack-oss-24k",
27+
DEFAULT_SAMPLE_RATIO_MEDIUM,
28+
4096,
29+
),
30+
# vul code / single-turn
31+
("purpcode/vul2prompt-general-oss-26k", DEFAULT_SAMPLE_RATIO_MEDIUM, 4096),
32+
("purpcode/vul2prompt-benign2vul-oss-21k", DEFAULT_SAMPLE_RATIO_MEDIUM, 4096),
33+
("purpcode/vul2prompt-vul2vul-oss-21k", DEFAULT_SAMPLE_RATIO_MEDIUM, 2048),
34+
(
35+
"purpcode/vul2prompt-jailbreaking-oss-11k",
36+
DEFAULT_SAMPLE_RATIO_MEDIUM,
37+
2048,
38+
),
39+
# utility
40+
("purpcode/secqa_utility_train", DEFAULT_SAMPLE_RATIO_MEDIUM, 4096),
41+
("KodCode/KodCode-V1-SFT-R1", DEFAULT_SAMPLE_RATIO_MEDIUM, 8192),
42+
]
43+
multi_turn_datasets = [
44+
# vul code / multi-turn
45+
("purpcode/vul2prompt-multi-oss-5k", 1.0),
46+
# mal event / multi-turn
47+
("purpcode/mal-event-fitd-multi-turn-oss-2k", 1.0),
48+
]
49+
50+
datasets = single_turn_datasets + multi_turn_datasets
51+
run_distillation(datasets=datasets, **kwargs)
52+
53+
54+
if __name__ == "__main__":
55+
from fire import Fire
56+
57+
Fire(main)

0 commit comments

Comments
 (0)