Skip to content

Commit 5f3b970

Browse files
committed
poc teacher notebook
1 parent 03fa77f commit 5f3b970

4 files changed

Lines changed: 186 additions & 1 deletion

File tree

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ venv312/
2828
venv311/
2929
# Jupyter Notebook
3030
.ipynb_checkpoints
31-
31+
tmp/
32+
tmp*/
3233
# Distribution / packaging
3334
.Python
3435
env/

notebooks/02_baseline_models.ipynb

Whitespace-only changes.

notebooks/03_synthetic_data_generation.ipynb

Whitespace-only changes.

notebooks/teacher.ipynb

Lines changed: 184 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,184 @@
1+
{
2+
"cells": [
3+
{
4+
"cell_type": "markdown",
5+
"metadata": {},
6+
"source": [
7+
"## Training the Teacher Model\n",
8+
"The first step in the pipeline is to train a teacher model on the SST-2 dataset. This model will be used to classify the synthetic data generated by the generator. \n"
9+
]
10+
},
11+
{
12+
"cell_type": "code",
13+
"execution_count": 22,
14+
"metadata": {},
15+
"outputs": [
16+
{
17+
"data": {
18+
"application/vnd.jupyter.widget-view+json": {
19+
"model_id": "0389b731c3af4704bc3f0564bd5f6417",
20+
"version_major": 2,
21+
"version_minor": 0
22+
},
23+
"text/plain": [
24+
"Map: 0%| | 0/256 [00:00<?, ? examples/s]"
25+
]
26+
},
27+
"metadata": {},
28+
"output_type": "display_data"
29+
}
30+
],
31+
"source": [
32+
"import torch\n",
33+
"from datasets import load_dataset\n",
34+
"from transformers import AutoModelForSequenceClassification, AutoTokenizer, DataCollatorWithPadding\n",
35+
"\n",
36+
"datasets = load_dataset(\"glue\", \"sst2\", split=\"train[:256]\")\n",
37+
"datasets = datasets.rename_column(\"label\", \"labels\")\n",
38+
"\n",
39+
"tokenizer = AutoTokenizer.from_pretrained(\"prajjwal1/bert-small\", use_fast=True)\n",
40+
"def tokenize_function(examples):\n",
41+
" return tokenizer(examples[\"sentence\"], truncation=True, max_length=32)\n",
42+
"\n",
43+
"tokenized_datasets = datasets.map(tokenize_function, batched=True)\n",
44+
"tokenized_datasets.set_format(type=\"torch\", columns=[\"input_ids\", \"attention_mask\", \"labels\"])\n",
45+
"\n",
46+
"\n"
47+
]
48+
},
49+
{
50+
"cell_type": "code",
51+
"execution_count": 24,
52+
"metadata": {},
53+
"outputs": [
54+
{
55+
"name": "stderr",
56+
"output_type": "stream",
57+
"text": [
58+
"Some weights of BertForSequenceClassification were not initialized from the model checkpoint at prajjwal1/bert-small and are newly initialized: ['classifier.bias', 'classifier.weight']\n",
59+
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
60+
]
61+
},
62+
{
63+
"name": "stdout",
64+
"output_type": "stream",
65+
"text": [
66+
"Model loaded on mps\n"
67+
]
68+
}
69+
],
70+
"source": [
71+
"model = AutoModelForSequenceClassification.from_pretrained(\"prajjwal1/bert-small\")\n",
72+
"device = \"cuda\" if torch.cuda.is_available() else \"mps\" if torch.backends.mps.is_available() else \"cpu\"\n",
73+
"model.to(device)\n",
74+
"print(f\"Model loaded on {device}\")"
75+
]
76+
},
77+
{
78+
"cell_type": "code",
79+
"execution_count": 26,
80+
"metadata": {},
81+
"outputs": [],
82+
"source": [
83+
"from transformers import Trainer, TrainingArguments\n",
84+
"\n",
85+
"args = TrainingArguments(\n",
86+
" output_dir=\"tmp/teacher_poc\",\n",
87+
" per_device_train_batch_size=2,\n",
88+
" per_device_eval_batch_size=2,\n",
89+
" num_train_epochs=1,\n",
90+
" learning_rate=2e-5,\n",
91+
" eval_strategy=\"steps\",\n",
92+
" eval_steps=50,\n",
93+
")\n",
94+
"\n",
95+
"trainer = Trainer(\n",
96+
" model=model,\n",
97+
" args=args,\n",
98+
" train_dataset=tokenized_datasets,\n",
99+
" eval_dataset=tokenized_datasets.shuffle(seed=0).select(range(64)), # tiny eval slice\n",
100+
" data_collator=DataCollatorWithPadding(tokenizer),\n",
101+
")"
102+
]
103+
},
104+
{
105+
"cell_type": "code",
106+
"execution_count": 27,
107+
"metadata": {},
108+
"outputs": [
109+
{
110+
"data": {
111+
"text/html": [
112+
"\n",
113+
" <div>\n",
114+
" \n",
115+
" <progress value='128' max='128' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
116+
" [128/128 00:18, Epoch 1/1]\n",
117+
" </div>\n",
118+
" <table border=\"1\" class=\"dataframe\">\n",
119+
" <thead>\n",
120+
" <tr style=\"text-align: left;\">\n",
121+
" <th>Step</th>\n",
122+
" <th>Training Loss</th>\n",
123+
" <th>Validation Loss</th>\n",
124+
" </tr>\n",
125+
" </thead>\n",
126+
" <tbody>\n",
127+
" <tr>\n",
128+
" <td>50</td>\n",
129+
" <td>No log</td>\n",
130+
" <td>0.669309</td>\n",
131+
" </tr>\n",
132+
" <tr>\n",
133+
" <td>100</td>\n",
134+
" <td>No log</td>\n",
135+
" <td>0.646186</td>\n",
136+
" </tr>\n",
137+
" </tbody>\n",
138+
"</table><p>"
139+
],
140+
"text/plain": [
141+
"<IPython.core.display.HTML object>"
142+
]
143+
},
144+
"metadata": {},
145+
"output_type": "display_data"
146+
},
147+
{
148+
"data": {
149+
"text/plain": [
150+
"TrainOutput(global_step=128, training_loss=0.6919310092926025, metrics={'train_runtime': 24.1107, 'train_samples_per_second': 10.618, 'train_steps_per_second': 5.309, 'total_flos': 349921897560.0, 'train_loss': 0.6919310092926025, 'epoch': 1.0})"
151+
]
152+
},
153+
"execution_count": 27,
154+
"metadata": {},
155+
"output_type": "execute_result"
156+
}
157+
],
158+
"source": [
159+
"trainer.train()"
160+
]
161+
}
162+
],
163+
"metadata": {
164+
"kernelspec": {
165+
"display_name": "Python3.11 (sentisynth)",
166+
"language": "python",
167+
"name": "auctionn"
168+
},
169+
"language_info": {
170+
"codemirror_mode": {
171+
"name": "ipython",
172+
"version": 3
173+
},
174+
"file_extension": ".py",
175+
"mimetype": "text/x-python",
176+
"name": "python",
177+
"nbconvert_exporter": "python",
178+
"pygments_lexer": "ipython3",
179+
"version": "3.11.2"
180+
}
181+
},
182+
"nbformat": 4,
183+
"nbformat_minor": 2
184+
}

0 commit comments

Comments
 (0)