-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathwake2vec_morpheme_expansion-2.py
More file actions
1089 lines (913 loc) · 43.7 KB
/
wake2vec_morpheme_expansion-2.py
File metadata and controls
1089 lines (913 loc) · 43.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# -*- coding: utf-8 -*-
"""Wake2Vec_morpheme-expansion.ipynb
Automatically generated by Colab.
Original file is located at
https://colab.research.google.com/drive/1YxzvmFZN2PQfgwhi2bHbtrfisacZGGgb
# Wake2Vec Morpheme Expansion Pipeline
This notebook documents a controlled procedure for integrating Joyce-style neologisms into a compact GPT-type language model through morphology-aware token expansion. It curates a small lexicon of prefixes and suffixes and generate synthetic candidates, then extend the tokenizer to admit previously split neologisms as single tokens. New embeddings are initialised by morphemic composition, using the rule \(E(\text{word}) = \alpha\,E(\text{prefix}) + (1 - 2\alpha)\,E(\text{root}) + \alpha\,E(\text{suffix}) + \varepsilon\), where \(\alpha\) is a fixed weight and \(\varepsilon\) is small Gaussian noise that prevents identical vectors. Training proceeds in two stages: an embedding-only warm-up on a mixture of synthetic lines and Finnegans *Wake* text, followed by a short full-model fine-tune under conservative schedules suitable for a T4 environment.
Reports top-five neighbor overlap for the newly introduced tokens before and after training, track shifts in embedding norms, provide a t-SNE projection of the new tokens against pre-training neighbor centroids, and save JSON snapshots of neighborhoods at each stage. These diagnostics are intended to show coherent integration of the new forms into the embedding space rather than collapse or runaway drift, and to make the procedure straightforward to reproduce on modest hardware.
**Config**
Base model: `TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T`. Composition weight \(\alpha = 0.25\). Maximum sequence length set to 1024 to respect T4 memory limits. Batching uses `per_device_train_batch_size = 1` with `gradient_accumulation_steps = 8`, attention implementation set to `eager`, and `use_cache = False`. Phase 1 trains input embeddings and the tied output head only; Phase 2 unfreezes all parameters with a warm-up ratio of 0.10 and light weight decay. All runs write plots and machine-readable artifacts to `runs/<RUN_ID>/` and generate a brief HTML report.
---
## Run controls
- **BASE_MODEL:** `TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T`
- **α (composition weight):** `0.25` (can tune)
- **Max seq length:** `1024` (T4-safe; raise only if VRAM allows)
- **Batching:** `per_device_train_batch_size=1`, `gradient_accumulation_steps=8`
- **Attn impl:** `eager` (avoid SDPA spikes on T4)
- **Two phases:**
- **Phase 1:** embeddings + lm_head only, Adafactor/8-bit Adam, 1 epoch
- **Phase 2:** full model, short run, warmup 0.10
## Inputs
- `data/FW_TEXT.txt` — Finnegans Wake plain text (slice for demo)
- `data/morpheme_data.json` or `data/morphemes.csv`
Structure maps:
- `prefixes`: `{ prefix → [example words…] }`
- `suffixes`: `{ suffix → [example words…] }`
## Outputs (per run)
- `runs/<RUN_ID>/metrics/`
- `pre_morpheme_snapshot.json`
- `morpheme_comparison_p1.json` *(midpoint, after Phase 1)*
- `morpheme_comparison.json` *(final, after Phase 2)*
- `summary_stats_p1.json`, `summary_stats.json`
- `runs/<RUN_ID>/plots/`
- `hist_overlap_top5(_p1).png`, `hist_norm_change(_p1).png`
- `scatter_norm_vs_overlap.png`, `tsne_newtokens_vs_precentroids.png`
- `reports/Wake2Vec_Report.html`
## Quickstart
1. **Reset & install** deps (Colab-friendly).
2. **Load data** (prefers JSON).
3. **Generate** synthetic forms (prefix + root + suffix).
4. **Expand tokenizer** (add new tokens); compose embeddings with α-rule; tie head.
5. **Phase 1**: train embeddings only. Saves midpoint snapshot.
6. **Phase 2**: unfreeze and short fine-tune.
7. **Diagnostics**: compute overlap@5, norm deltas, t-SNE; write HTML report.
## Diagnostics (what “good” looks like)
- **Top-5 neighbor overlap (pre→post):** ~3–4/5 indicates coherent integration (not collapse).
- **Norm shift (Δ‖E‖):** small positive mean (slight energy increase from training).
- **Qualitative neighbors:** morpheme-aligned (e.g., `presounder` ≈ `resound`, `ensounder`, …).
- **Tokenization:** most synthetic forms now **single IDs**.
## Repro & env
- `RUN_ID = "t4_<unix>"` auto-stamped; seeds fixed at 42.
- Tested on Colab T4 with: `transformers 4.57.1`, `datasets 2.21.0`, `pyarrow 22.0.0`.
- T4 guardrails: `MAX_LEN=1024`, `gradient_checkpointing=True`, attention=`eager`, batch=1 + accum=8.
## Troubleshooting (T4)
- **CUDA OOM** → lower `MAX_LEN` to 768/512; keep batch=1; accum=8–16; ensure `use_cache=False`; `PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True`.
- **Version noise** → uninstall RAPIDS/TF; pin `transformers 4.57.1`, `datasets 2.21.0`, `pyarrow 22.0.0`.
---
*Wake2Vec tests morphology-aware token expansion to integrate Joyce-style neologisms into a small language model without destabilising the embedding space. We curate a prefix/suffix lexicon, generate synthetic forms, initialise new vectors by morpheme composition, and train in two phases. Evaluation reports neighbor-overlap@5, embedding-norm shifts, and qualitative neighborhoods, with JSON snapshots for reproducibility.*
"""
!pip -q install --no-cache-dir --upgrade-strategy eager \
"transformers==4.57.1" "datasets==2.21.0" "accelerate==1.0.1" \
"peft==0.12.0" "bitsandbytes==0.43.3" \
"huggingface-hub>=0.34,<1.0" \
"pyarrow==22.0.0" "numpy==2.0.2" "pandas==2.2.2" "requests==2.32.4" \
"matplotlib>=3.8" "scikit-learn>=1.5" "umap-learn" "faiss-cpu" "wordfreq" "Unidecode"
import os; os.kill(os.getpid(), 9) # rr
"""Imports, seeds, run IDs, paths"""
import numpy as np, torch, transformers, datasets, pyarrow as pa, json, time, random, gc, os
from pathlib import Path
from google.colab import drive
print("Transformers:", transformers.__version__)
print("Datasets :", datasets.__version__)
print("PyArrow :", pa.__version__)
print("Torch :", torch.__version__)
print("CUDA :", torch.version.cuda)
drive.mount('/content/drive', force_remount=True)
# Pps
PROJECT = "wake2vec"
BASE_MODEL = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
RUN_ID = f"t4_{int(time.time())}"
ROOT = Path("/content")
PERSIST = Path("/content/drive/MyDrive")/PROJECT
RUN_DIR = ROOT/"runs"/RUN_ID
METRICS_DIR = RUN_DIR/"metrics"; PLOTS_DIR = RUN_DIR/"plots"; REPORTS_DIR = ROOT/"reports"
ADAPT_DIR = RUN_DIR/"phase2_lora"/"final_adapters"; TOK_SAVE = ADAPT_DIR # keep tokenizer here too
for d in [RUN_DIR, METRICS_DIR, PLOTS_DIR, REPORTS_DIR, ADAPT_DIR, PERSIST/"runs", PERSIST/"adapters", PERSIST/"reports", PERSIST/"archives", PERSIST/"notebooks"]:
d.mkdir(parents=True, exist_ok=True)
# clean & seeds
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
torch.backends.cuda.matmul.allow_tf32 = True
def set_seed(s=42):
random.seed(s); np.random.seed(s); torch.manual_seed(s); torch.cuda.manual_seed_all(s)
set_seed(42)
print("RUN_ID:", RUN_ID)
"""Load data"""
import pandas as pd, json, re
from collections import defaultdict
from pathlib import Path
MORPH_CSV = Path("/content/morphemes.csv")
assert MORPH_CSV.exists(), f"Not found: {MORPH_CSV}"
df = pd.read_csv(MORPH_CSV, dtype=str, keep_default_na=False)
# normalize cols
df.columns = [c.strip().lower() for c in df.columns]
required = {"type","morpheme"}
missing = required - set(df.columns)
if missing:
raise ValueError(f"CSV is missing columns: {missing}. Expected at least: {required} plus example1..exampleN")
# collect example cols
ex_cols = [c for c in df.columns if c.startswith("example")]
ex_cols.sort(key=lambda s: (len(s), s))
morph = {"prefixes": defaultdict(list), "suffixes": defaultdict(list)}
skipped = 0
for _, r in df.iterrows():
kind = r["type"].strip().lower()
piece = r["morpheme"].strip()
if kind not in ("prefix","suffix") or not piece:
skipped += 1
continue
examples = []
for c in ex_cols:
val = str(r[c]).strip()
if val and val.lower() != "nan":
examples.append(val)
if kind == "prefix":
morph["prefixes"][piece].extend(examples)
else:
morph["suffixes"][piece].extend(examples)
# dedupe + sort
for d in (morph["prefixes"], morph["suffixes"]):
for k in list(d.keys()):
d[k] = sorted(set([w for w in d[k] if w]))
prefixes = list(morph["prefixes"].keys())
suffixes = list(morph["suffixes"].keys())
print(f"[morph] prefixes: {len(prefixes)} | suffixes: {len(suffixes)}")
print(f"[morph] prefix examples total: {sum(len(v) for v in morph['prefixes'].values())} | "
f"suffix examples total: {sum(len(v) for v in morph['suffixes'].values())} | skipped rows: {skipped}")
out_dir = (PERSIST/"runs"/RUN_ID)
out_dir.mkdir(parents=True, exist_ok=True)
(out_dir/"morpheme_data.json").write_text(json.dumps(morph, indent=2), encoding="utf-8")
# expose variables used downstream
print("Sample prefixes:", prefixes[:5])
print("Sample suffixes:", suffixes[:5])
"""Tok expansion + composed init"""
import re, json, math, numpy as np, torch
from pathlib import Path
from transformers import AutoTokenizer, AutoModelForCausalLM
assert 'morph' in globals() and 'prefixes' in globals() and 'suffixes' in globals(), "Run the morpheme loader first."
# synthetic_lines
if 'synthetic_lines' not in globals() or not synthetic_lines:
import random
ROOTS = ["river","thunder","word","sound","dance","queen","storm","tree","night","sun","rain","book"]
random.seed(13)
synthetic_lines = []
for _ in range(600):
p = random.choice(prefixes) if prefixes else ""
r = random.choice(ROOTS)
s = random.choice(suffixes) if suffixes else ""
synthetic_lines.append(f"The {p+r+s} rose and fell as if the {r} had learned to {s or 'sing'}.\n")
(PERSIST/"runs"/RUN_ID).mkdir(parents=True, exist_ok=True)
(PERSIST/"runs"/RUN_ID/"synthetic_lines.txt").write_text("".join(synthetic_lines), encoding="utf-8")
# tokenizer
BASE_MODEL = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
tok = AutoTokenizer.from_pretrained(BASE_MODEL, use_fast=True)
if tok.pad_token is None:
tok.pad_token = tok.eos_token
# synthetic + raw morph pieces
new_tokens = set()
for line in synthetic_lines:
for m in re.finditer(r"\b([a-zA-Z][a-zA-Z\-]{2,})\b", line):
new_tokens.add(m.group(1).lower())
for p in prefixes: new_tokens.add(p)
for s in suffixes: new_tokens.add(s)
# filter unknowns only
new_tokens = [t for t in sorted(new_tokens) if tok.convert_tokens_to_ids(t) == tok.unk_token_id]
TOK_SAVE = (RUN_DIR/"phase2_lora"/"final_adapters")
TOK_SAVE.mkdir(parents=True, exist_ok=True)
added = tok.add_tokens(new_tokens, special_tokens=False)
tok.save_pretrained(str(TOK_SAVE))
json.dump(new_tokens, open(METRICS_DIR/"new_tokens.json","w"))
print(f"[tokenizer] added {added} tokens → saved at {TOK_SAVE}")
# load model, resize without mean-resizing, tie head
model = AutoModelForCausalLM.from_pretrained(BASE_MODEL, torch_dtype=torch.float32, device_map="auto")
old_size = model.get_input_embeddings().weight.shape[0]
model.resize_token_embeddings(len(tok), mean_resizing=False)
emb = model.get_input_embeddings().weight
device, dtype = emb.device, emb.dtype
# helper: only return real token embeddings
def emb_for_token(t: str):
tid = tok.convert_tokens_to_ids(t)
if tid >= 0 and tid != tok.unk_token_id:
return emb[tid].detach().clone()
return None
def mean_emb_for_words(words):
vecs = []
for w in words:
tid = tok.convert_tokens_to_ids(w)
if tid >= 0 and tid != tok.unk_token_id:
vecs.append(emb[tid].detach().clone())
if vecs:
return torch.stack(vecs, dim=0).mean(dim=0)
# fallback: small random around global std
std = emb.detach().std().item()
return torch.randn(emb.shape[1], device=device, dtype=dtype) * (0.1 * std)
alpha = 0.25
E_comp = []
new_ids = []
# prebuild sorted lists for greedy matches
_pref_sorted = sorted(prefixes, key=len, reverse=True)
_suf_sorted = sorted(suffixes, key=len, reverse=True)
for t in new_tokens:
tid = tok.convert_tokens_to_ids(t)
if tid < old_size:
continue
# greedy longest prefix/suffix match
p = next((pp for pp in _pref_sorted if t.startswith(pp)), "")
s = next((ss for ss in _suf_sorted if t.endswith(ss)), "")
core = t[len(p):len(t)-len(s) if s else None]
Ep = mean_emb_for_words(morph["prefixes"].get(p, [p])) if p else torch.zeros(emb.shape[1], device=device, dtype=dtype)
tmp = emb_for_token(core)
Er = tmp if tmp is not None else mean_emb_for_words([core])
Es = mean_emb_for_words(morph["suffixes"].get(s, [s])) if s else torch.zeros(emb.shape[1], device=device, dtype=dtype)
comp = alpha*Ep + (1 - 2*alpha)*Er + alpha*Es
comp = comp + 0.01 * emb.detach().std().item() * torch.randn_like(comp)
with torch.no_grad():
emb[tid] = comp
new_ids.append(tid)
E_comp.append(comp.detach().cpu().numpy())
new_ids = torch.tensor(new_ids, dtype=torch.long, device=device)
np.save(METRICS_DIR/"E_comp_newtokens.npy", np.stack(E_comp))
print(f"[init] wrote composed embeddings for {len(new_ids)} new rows. Vocab {old_size} → {len(tok)}")
# tie output head
model.lm_head.weight = model.get_input_embeddings().weight
model.config.pad_token_id = tok.pad_token_id
model.config.use_cache = False
model.config._attn_implementation = "eager"
"""build dataset (FW + Synthetic)"""
assert tok.pad_token_id is not None
assert model.lm_head.weight.data_ptr() == model.get_input_embeddings().weight.data_ptr()
print("new tokens:", len(new_ids), "| vocab size:", len(tok))
"""# P1 Embeddings-only warm-up"""
!pip -q uninstall -y accelerate || true
!pip -q install --no-cache-dir "accelerate==1.2.1"
import os; os.kill(os.getpid(), 9) # rr
FW_TEXT = Path("/content/FW_TEXT.txt")
assert FW_TEXT.exists(), f"Not found: {FW_TEXT}"
fw_text = FW_TEXT.read_text(encoding="utf-8")
print(f"Loaded {len(fw_text)} chars from {FW_TEXT}")
from datasets import Dataset
# FW blocks
MAX_LEN = 384
def chunk_text(txt, max_chars=1200):
parts = []
buf = []
n = 0
for line in txt.splitlines():
if not line.strip(): continue
buf.append(line)
n += len(line)+1
if n >= max_chars:
parts.append("\n".join(buf)+"\n")
buf, n = [], 0
if buf: parts.append("\n".join(buf)+"\n")
return parts
fw_blocks = chunk_text(fw_text) if fw_text else []
mix = fw_blocks + synthetic_lines
random.shuffle(mix)
def to_ids(s):
return tok.encode(s, add_special_tokens=False)[:MAX_LEN]
enc = [{"input_ids": to_ids(s)} for s in mix if to_ids(s)]
# small split
split = int(0.9*len(enc))
train_ds = Dataset.from_list(enc[:split])
valid_ds = Dataset.from_list(enc[split:]) if split < len(enc) else Dataset.from_list(enc[:100])
def dc(features):
# pad on right with eos/pad
import torch
maxlen = max(len(x["input_ids"]) for x in features)
input_ids = []
labels = []
for f in features:
ids = f["input_ids"]
pad = [tok.pad_token_id]*(maxlen-len(ids))
input_ids.append(ids+pad)
labels.append(ids+pad)
return {"input_ids": torch.tensor(input_ids), "labels": torch.tensor(labels)}
print("train blocks:", len(train_ds), "valid:", len(valid_ds))
import inspect, accelerate, transformers
from accelerate import Accelerator
print("Transformers:", transformers.__version__)
print("Accelerate :", accelerate.__version__, accelerate.__file__)
print("unwrap_model sig:", inspect.signature(Accelerator.unwrap_model))
import inspect
from accelerate import Accelerator
_orig_unwrap = Accelerator.unwrap_model
def _unwrap_model_compat(self, model, *args, **kwargs):
# Transformers>=4.56 may pass keep_torch_compile; older Accelerate doesn't accept it.
kwargs.pop("keep_torch_compile", None)
# Map/ensure keep_fp32_wrapper with sensible default
keep_fp32_wrapper = kwargs.pop("keep_fp32_wrapper", True)
# If caller passed it positionally, respect that
if args:
# first positional after model would be keep_fp32_wrapper
keep_fp32_wrapper = bool(args[0])
return _orig_unwrap(self, model, keep_fp32_wrapper)
Accelerator.unwrap_model = _unwrap_model_compat
print("Patched accelerate.Accelerator.unwrap_model for keep_torch_compile compatibility.")
print("new sig (logical): (self, model, keep_fp32_wrapper: bool = True)")
from transformers import Trainer, TrainingArguments
import numpy as np, torch, json
from sklearn.metrics.pairwise import cosine_similarity
from pathlib import Path
# freeze all except embeddings + head
for p in model.parameters(): p.requires_grad = False
E = model.get_input_embeddings().weight; E.requires_grad_(True)
for p in model.lm_head.parameters(): p.requires_grad = True
from transformers import TrainingArguments
args1 = TrainingArguments(
output_dir=str(RUN_DIR/"phase1"),
per_device_train_batch_size=1,
gradient_accumulation_steps=16,
max_steps=2000,
learning_rate=5e-4,
warmup_ratio=0.0,
logging_strategy="steps", logging_steps=50,
save_strategy="no",
eval_strategy="no",
gradient_checkpointing=True,
fp16=False,
report_to="none",
optim="adafactor",
)
trainer1 = Trainer(model=model, args=args1, train_dataset=train_ds, data_collator=dc)
out1 = trainer1.train()
print(out1)
# snapshot vs composed init
with torch.no_grad():
W1 = model.get_input_embeddings().weight.detach().cpu().numpy()
sim1 = cosine_similarity(W1[new_ids.cpu()], W1)
top5_1 = np.argsort(-sim1, axis=1)[:,1:6]
E_comp_np = np.load(METRICS_DIR/"E_comp_newtokens.npy")
sim0 = cosine_similarity(E_comp_np, W1)
top5_0 = np.argsort(-sim0, axis=1)[:,1:6]
def overlap5(a,b): return len(set(a.tolist()) & set(b.tolist()))
overlaps1 = np.array([overlap5(top5_0[i], top5_1[i]) for i in range(len(new_ids))])
pre_norms = np.linalg.norm(W1[new_ids.cpu()], axis=1)
(Path(METRICS_DIR/"summary_stats_p1.json")).write_text(
json.dumps({"phase":"phase1","compared_tokens":int(len(new_ids)),
"mean_top5_overlap":float(overlaps1.mean()),"mean_norm_delta":0.0}, indent=2)
)
(Path(METRICS_DIR/"morpheme_comparison_p1.json")).write_text(
json.dumps({"top5_pre":top5_0.tolist(),"top5_p1":top5_1.tolist(),
"overlap@5":overlaps1.tolist()}, indent=2)
)
np.save(METRICS_DIR/"pre_norms.npy", pre_norms)
print("P1 mean overlap@5:", overlaps1.mean())
import os, json, time, pathlib, numpy as np, torch
from datasets import load_from_disk
from transformers import (
AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer,
DataCollatorForLanguageModeling, TrainerCallback
)
# set up
DRIVE_ROOT = pathlib.Path("/content/drive/MyDrive/wake2vec")
RUN_ID = f"t4_{int(time.time())}"
(DRIVE_ROOT/"runs"/RUN_ID).mkdir(parents=True, exist_ok=True)
RUN_DIR = pathlib.Path("/content/runs")/RUN_ID
if not RUN_DIR.exists():
RUN_DIR.parent.mkdir(parents=True, exist_ok=True)
os.symlink(str(DRIVE_ROOT/"runs"/RUN_ID), str(RUN_DIR))
METRICS_DIR = RUN_DIR/"metrics"; METRICS_DIR.mkdir(parents=True, exist_ok=True)
# callbacks
class LossStreamer(TrainerCallback):
def __init__(self, log_every=50, window=200, out_json=None):
self.log_every, self.window, self.out_json = log_every, window, out_json
self.buf, self.recs = [], []
def on_log(self, args, state, control, logs=None, **kw):
if not logs or "loss" not in logs: return
s = int(state.global_step or 0); L = float(logs["loss"])
self.buf.append(L); self.recs.append({"step": s, "loss": L, "lr": logs.get("learning_rate")})
if s % self.log_every == 0 and s > 0:
w = self.buf[-self.window:]; ma = sum(w)/len(w)
print(f"[P1 {s}] train_loss={L:.4f} ma({len(w)})={ma:.4f}")
if self.out_json: open(self.out_json,"w").write(json.dumps(self.recs, indent=2))
loss_cb = LossStreamer(out_json=str(METRICS_DIR/"phase1_loss_log.json"))
# Snapshot new rows every 200 steps
NEW_IDS_PATH = DRIVE_ROOT/"new_ids.npy"
new_ids = np.load(NEW_IDS_PATH) if NEW_IDS_PATH.exists() else None
class EmbedSnapshot(TrainerCallback):
def __init__(self, run_dir, new_ids, every=200):
self.run_dir, self.new_ids, self.every = pathlib.Path(run_dir), new_ids, every
def on_step_end(self, args, state, control, **kw):
if self.new_ids is None: return
s = int(state.global_step or 0)
if s>0 and s % self.every == 0:
m = kw.get("model");
if m is None: return
with torch.no_grad():
E = m.get_input_embeddings().weight.detach().cpu().numpy()
np.save(self.run_dir/"metrics"/f"E_postP1_step{s}.npy", E[self.new_ids])
print(f"[SNAP] new-row embeddings @ {s}")
snap_cb = EmbedSnapshot(RUN_DIR, new_ids, every=200)
#model/ tok
BASE = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
tok = AutoTokenizer.from_pretrained(BASE, use_fast=True)
if tok.pad_token_id is None: tok.pad_token = tok.eos_token or "</s>"
model = AutoModelForCausalLM.from_pretrained(BASE, torch_dtype=torch.float32, device_map="auto")
model.config.use_cache = False
with torch.no_grad():
model.get_output_embeddings().weight = model.get_input_embeddings().weight
train_ds = load_from_disk(str(DRIVE_ROOT/"datasets"/"train_ds"))
valid_ds = load_from_disk(str(DRIVE_ROOT/"datasets"/"valid_ds"))
# use a tiny shard for quick evals
valid_ds_small = valid_ds.select(range(min(1000, len(valid_ds))))
collator = DataCollatorForLanguageModeling(tokenizer=tok, mlm=False)
# training args
args = TrainingArguments(
output_dir=str(RUN_DIR),
seed=42,
per_device_train_batch_size=1,
gradient_accumulation_steps=16,
max_steps=1100,
learning_rate=5e-4,
warmup_ratio=0.0,
optim="adafactor",
logging_steps=50,
save_steps=100,
save_total_limit=12,
evaluation_strategy="steps",
eval_steps=200,
gradient_checkpointing=True,
fp16=False, bf16=False,
report_to=["none"],
max_grad_norm=1.0,
)
trainer = Trainer(
model=model,
args=args,
train_dataset=train_ds,
eval_dataset=valid_ds_small,
data_collator=collator,
callbacks=[loss_cb, snap_cb],
)
open(RUN_DIR/"run_manifest.json","w").write(json.dumps({
"run_id": RUN_ID, "base": BASE, "started": time.time(),
"max_steps": 1100, "grad_accum": 16, "optim": "adafactor",
"eval_strategy": "steps", "eval_steps": 200, "valid_shard": len(valid_ds_small)
}, indent=2))
trainer.train()
trainer.save_model(str(RUN_DIR/"checkpoint-final"))
tok.save_pretrained(str(RUN_DIR/"checkpoint-final"))
print("[DONE] P1 complete →", RUN_DIR/"checkpoint-final")
import os, glob, json, pathlib, time, re
from google.colab import drive
# Mount Drive (idempotent)
try:
drive.mount('/content/drive')
except Exception:
pass
from pathlib import Path
PERSIST = Path("/content/drive/MyDrive/wake2vec")
PERSIST.mkdir(parents=True, exist_ok=True)
# Prefer last local run, else last Drive run
local_runs = sorted(Path("/content/runs").glob("*"), key=lambda p: p.stat().st_mtime) if Path("/content/runs").exists() else []
drive_runs = sorted((PERSIST / "runs").glob("*"), key=lambda p: p.stat().st_mtime) if (PERSIST / "runs").exists() else []
if local_runs:
RUN_DIR = local_runs[-1]
RUN_ID = RUN_DIR.name
SOURCE = "local"
else:
if drive_runs:
RUN_DIR = drive_runs[-1]
RUN_ID = RUN_DIR.name
SOURCE = "drive"
else:
raise SystemExit("No prior run dirs found in /content/runs or Drive /wake2vec/runs.")
print(f"[INFO] Using RUN_ID={RUN_ID} from {SOURCE}: {RUN_DIR}")
(PERSIST / "runs" / RUN_ID).mkdir(parents=True, exist_ok=True)
if SOURCE == "local":
os.system(f'rsync -a --ignore-existing "{RUN_DIR}/" "{PERSIST}/runs/{RUN_ID}/"')
else:
# create local symlink for fast IO but persistent storage
Path("/content/runs").mkdir(parents=True, exist_ok=True)
target = Path("/content/runs") / RUN_ID
if not target.exists():
os.symlink(str(RUN_DIR), str(target))
RUN_DIR = target
print(f"[INFO] Symlinked drive run → {RUN_DIR}")
# Create common dirs
METRICS_DIR = RUN_DIR / "metrics"
METRICS_DIR.mkdir(parents=True, exist_ok=True)
# latest checkpoint dir
ckpts = sorted(glob.glob(str(RUN_DIR / "checkpoint-*")), key=lambda p: int(re.findall(r"checkpoint-(\d+)", p)[0]) if re.findall(r"checkpoint-(\d+)", p) else -1)
CKPT = ckpts[-1] if ckpts else None
print(f"[INFO] Latest checkpoint: {CKPT if CKPT else 'NONE'}")
manifest = {
"run_id": RUN_ID,
"resumed_at": time.time(),
"source": SOURCE,
"latest_ckpt": CKPT,
}
(RUN_DIR / "resume_manifest.json").write_text(json.dumps(manifest, indent=2))
print("[OK] resume_manifest.json written.")
# P1 FINALIZE: compute overlap@5 (composed-init vs post-P1), norm stats, plots
import os, json, math, glob, pathlib, time, shutil
import numpy as np
import torch
import matplotlib.pyplot as plt
from transformers import AutoModelForCausalLM, AutoTokenizer
# run meta
ROOT = pathlib.Path("/content")
RUNS = ROOT / "runs"
RUN_ID = sorted([p.name for p in RUNS.glob("*")])[-1] # last run
RUN_DIR = RUNS / RUN_ID
METRICS_DIR = RUN_DIR / "metrics"
PLOTS_DIR = RUN_DIR / "plots"
PERSIST = pathlib.Path("/content/drive/MyDrive/wake2vec")
ADAPT_DIR = RUN_DIR / "phase2_lora" / "final_adapters"
for d in (METRICS_DIR, PLOTS_DIR, ADAPT_DIR, PERSIST): d.mkdir(parents=True, exist_ok=True)
manifest_path = RUN_DIR / "run_manifest.json"
if manifest_path.exists():
manifest = json.loads(manifest_path.read_text())
else:
manifest = {"run_id": RUN_ID, "time": time.time()}
manifest_path.write_text(json.dumps(manifest, indent=2))
# Locate latest checkpoint or use final model in memory or use a saved checkpoint: trainer usually dumps 'checkpoint-<step>'
ckpts = sorted(glob.glob(str(RUN_DIR / "checkpoint-*")), key=lambda p: int(p.rsplit("-",1)[-1]))
ckpt_dir = ckpts[-1] if ckpts else str(RUN_DIR)
base_model_name = "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T"
print(f"[INFO] Loading model/tokenizer from: {ckpt_dir}")
tok = AutoTokenizer.from_pretrained(ckpt_dir, use_fast=True)
if tok.pad_token_id is None:
tok.pad_token = tok.eos_token or "</s>"
model = AutoModelForCausalLM.from_pretrained(ckpt_dir, torch_dtype=torch.float32, device_map="auto")
model.config.use_cache = False
with torch.no_grad():
model.get_output_embeddings().weight = model.get_input_embeddings().weight
emb = model.get_input_embeddings().weight.detach().cpu().numpy() # [V, d]
V, d = emb.shape
print(f"[INFO] vocab={V}, dim={d}")
# Load new_ids and composed init vectors (optional)
new_ids_path = PERSIST / "new_ids.npy"
E_comp_path = PERSIST / "E_comp.npy" # composed init vectors aligned to new_ids
new_ids = np.load(new_ids_path) if new_ids_path.exists() else None
E_comp = np.load(E_comp_path) if E_comp_path.exists() else None
if new_ids is not None:
print(f"[INFO] new_ids loaded: {new_ids.shape[0]}")
else:
print("[WARN] new_ids.npy not found; will compute only global stats/plots.")
# top-k neighbors by cosine
def topk_neighbors(vecs, mat, k=5, mask_self=None):
# vecs: [m, d], mat: [V, d]
# returns indices [m, k]
va = vecs / (np.linalg.norm(vecs, axis=1, keepdims=True) + 1e-9)
ma = mat / (np.linalg.norm(mat, axis=1, keepdims=True) + 1e-9)
sims = va @ ma.T # [m, V]
if mask_self is not None:
sims[np.arange(sims.shape[0])[:,None], mask_self[:,None]] = -1e9
nbrs = np.argpartition(-sims, kth=np.arange(5), axis=1)[:, :k]
# sort each row’s top-k
part_vals = np.take_along_axis(sims, nbrs, axis=1)
order = np.argsort(-part_vals, axis=1)
return np.take_along_axis(nbrs, order, axis=1)
# Overlap@5: composed-init vs post-P1
overlap5 = None
if (new_ids is not None) and (E_comp is not None) and (E_comp.shape[0] == new_ids.shape[0]):
# neighbors of composed vectors vs neighbors of current learned new embeddings
E_new_post = emb[new_ids] # [n_new, d] current rows after P1
k = 5
nbr_comp = topk_neighbors(E_comp, emb, k=k) # composed vec, top-k in current vocab
nbr_post = topk_neighbors(E_new_post, emb, k=k) # learned row, top-k in current vocab
# overlap fraction per token
ov = []
for a, b in zip(nbr_comp, nbr_post):
ov.append(len(set(a.tolist()) & set(b.tolist())) / k)
overlap5 = np.array(ov)
np.save(METRICS_DIR / "p1_overlap_at5.npy", overlap5)
print(f"[OK] P1 overlap@5: mean={overlap5.mean():.4f} | median={np.median(overlap5):.4f}")
else:
print("[WARN] Skipping overlap@5 (missing E_comp or new_ids).")
# Norm stats
norms = np.linalg.norm(emb, axis=1)
np.save(METRICS_DIR / "postP1_norms.npy", norms)
if new_ids is not None:
norms_new = norms[new_ids]
np.save(METRICS_DIR / "postP1_norms_new.npy", norms_new)
# Plots
plt.figure(figsize=(7,4.5))
plt.hist(norms, bins=60)
plt.title("Post-P1 embedding norms (all vocab)")
plt.xlabel("‖E‖"); plt.ylabel("count")
plt.tight_layout()
plt.savefig(PLOTS_DIR / "hist_postP1_norms.png", dpi=160)
if (overlap5 is not None):
plt.figure(figsize=(7,4.5))
plt.hist(overlap5, bins=np.linspace(0,1,11), align="left", rwidth=0.9)
plt.xticks(np.linspace(0,1,11))
plt.title("P1 overlap@5 (composed-init vs post-P1)")
plt.xlabel("overlap@5"); plt.ylabel("count")
plt.tight_layout()
plt.savefig(PLOTS_DIR / "hist_overlap_top5_P1.png", dpi=160)
# Persist quick stats
stats = {
"run_id": RUN_ID,
"vocab": int(V),
"dim": int(d),
"n_new": int(new_ids.shape[0]) if new_ids is not None else None,
"p1_overlap5_mean": float(overlap5.mean()) if overlap5 is not None else None,
"p1_overlap5_median": float(np.median(overlap5)) if overlap5 is not None else None,
"timestamp": time.time(),
}
json.dump(stats, open(METRICS_DIR / "p1_stats.json","w"), indent=2)
print("[DONE] P1 metrics saved →", METRICS_DIR)
import shutil, tarfile
def snapshot_to_drive(tag):
# copy artifacts
d_run = PERSIST/'runs'/RUN_ID/tag
d_run.mkdir(parents=True, exist_ok=True)
if METRICS_DIR.exists(): shutil.copytree(METRICS_DIR, d_run/'metrics', dirs_exist_ok=True)
if PLOTS_DIR.exists(): shutil.copytree(PLOTS_DIR, d_run/'plots', dirs_exist_ok=True)
if (REPORTS_DIR/'Wake2Vec_Report.html').exists():
(PERSIST/'reports').mkdir(parents=True, exist_ok=True)
shutil.copy(REPORTS_DIR/'Wake2Vec_Report.html', PERSIST/'reports'/f'Wake2Vec_Report_{RUN_ID}_{tag}.html')
# adapters/tokenizer
src_ad = PERSIST/'adapters'/RUN_ID/'final_adapters'
if src_ad.exists():
shutil.copytree(src_ad, PERSIST/'adapters'/RUN_ID/'final_adapters', dirs_exist_ok=True)
# tarball of the local run folder for belt+braces
(PERSIST/'archives').mkdir(parents=True, exist_ok=True)
with tarfile.open(PERSIST/'archives'/f"{RUN_ID}_{tag}.tar.gz", "w:gz") as tar:
tar.add(str(RUN_DIR), arcname=f"runs/{RUN_ID}")
print(f"[snapshot] saved → Drive under runs/{RUN_ID}/{tag} and archives/")
snapshot_to_drive("phase1")
snapshot_to_drive("phase2")
snapshot_to_drive("phase3")
"""# P2 LoRA boost"""
# p2 lora
import os, json, pathlib, time
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, DataCollatorForLanguageModeling, TrainerCallback
from datasets import load_from_disk
RUNS = pathlib.Path("/content/runs")
RUN_ID = sorted([p.name for p in RUNS.glob("*")])[-1]
RUN_DIR = RUNS / RUN_ID
METRICS_DIR = RUN_DIR / "metrics"
ADAPT_DIR = RUN_DIR / "phase2_lora" / "final_adapters"
PERSIST = pathlib.Path("/content/drive/MyDrive/wake2vec")
for d in (METRICS_DIR, ADAPT_DIR, PERSIST): d.mkdir(parents=True, exist_ok=True)
# loss
class LossStreamer(TrainerCallback):
def __init__(self, log_every=50, window=200, out_json=None):
self.log_every = log_every
self.window = window
self.buf = []
self.recs = []
self.out_json = out_json
def on_log(self, args, state, control, logs=None, **kwargs):
if not logs or "loss" not in logs:
return
step = int(state.global_step or 0)
loss = float(logs["loss"])
self.buf.append(loss)
self.recs.append({"step": step, "loss": loss, "lr": logs.get("learning_rate", None)})
if step % self.log_every == 0 and step > 0:
w = self.buf[-self.window:]
ma = sum(w)/len(w)
lr = logs.get("learning_rate", None)
print(f"[P2 step {step}] loss={loss:.4f} ma({len(w)})={ma:.4f}" + (f" lr={lr:.2e}" if lr else ""))
if self.out_json:
with open(self.out_json, "w") as f: json.dump(self.recs, f, indent=2)
loss_cb2 = LossStreamer(log_every=50, window=200, out_json=str(METRICS_DIR/"phase2_loss_log.json"))
# imports
def try_import_peft():
try:
import peft
return peft
except Exception:
return None
peft = try_import_peft()
if peft is None:
print("[WARN] peft not available — will train base model (no adapters) and use Adafactor.")
# checkpoint for p2 start
ckpts = sorted([p for p in RUN_DIR.glob("checkpoint-*")], key=lambda p: int(p.name.split("-")[-1]))
ckpt_dir = str(ckpts[-1] if ckpts else RUN_DIR)
print(f"[INFO] P2 loading from: {ckpt_dir}")
tok = AutoTokenizer.from_pretrained(ckpt_dir, use_fast=True)
if tok.pad_token_id is None:
tok.pad_token = tok.eos_token or "</s>"
model = AutoModelForCausalLM.from_pretrained(ckpt_dir, torch_dtype=torch.float32, device_map="auto")
model.config.use_cache = False
with torch.no_grad():
model.get_output_embeddings().weight = model.get_input_embeddings().weight # tie
# datasets
train_path = RUN_DIR / "train_ds"
valid_path = RUN_DIR / "valid_ds"
train_ds = load_from_disk(str(train_path)) if train_path.exists() else None
valid_ds = load_from_disk(str(valid_path)) if valid_path.exists() else None
assert (train_ds is not None) and (valid_ds is not None), "Tokenized datasets not found in RUN_DIR."
dc = DataCollatorForLanguageModeling(tokenizer=tok, mlm=False)
# lora
use_optim = "adamw_bnb_8bit"
if peft is not None:
from peft import LoraConfig, get_peft_model, TaskType
lcfg = LoraConfig(
task_type=TaskType.CAUSAL_LM,
r=8, lora_alpha=16, lora_dropout=0.05,
target_modules=["q_proj","k_proj","v_proj","o_proj","gate_proj","up_proj","down_proj"],
bias="none"
)
model = get_peft_model(model, lcfg)
try:
import bitsandbytes as bnb # noqa: F401
except Exception:
print("[WARN] bitsandbytes missing — falling back to Adafactor.")
use_optim = "adafactor"
else:
use_optim = "adafactor"
# training args
args2 = TrainingArguments(
output_dir=str(RUN_DIR/"phase2_lora"),
per_device_train_batch_size=1,
gradient_accumulation_steps=16,
num_train_epochs=1,
learning_rate=2e-5,
warmup_ratio=0.10,
logging_strategy="steps",
logging_steps=50,
save_strategy="epoch",
evaluation_strategy="epoch",
gradient_checkpointing=True,
fp16=False, bf16=False,
report_to=["none"],
optim=use_optim,
max_grad_norm=1.0,
)
trainer2 = Trainer(
model=model,
args=args2,
train_dataset=train_ds,
eval_dataset=valid_ds,
data_collator=dc,
callbacks=[loss_cb2]
)
out2 = trainer2.train()
print(out2)
# save
if peft is not None and hasattr(model, "save_pretrained"):
model.save_pretrained(str(ADAPT_DIR), safe_serialization=True)
else:
(ADAPT_DIR / "full_model").mkdir(parents=True, exist_ok=True)
model.save_pretrained(str(ADAPT_DIR / "full_model"), safe_serialization=True)
tok.save_pretrained(str(ADAPT_DIR))
drive_target = PERSIST / "adapters" / RUN_ID / "final_adapters"
drive_target.mkdir(parents=True, exist_ok=True)
os.system(f'cp -r "{ADAPT_DIR}"/* "{drive_target}"')
print("[DONE] Phase-2 saved →", ADAPT_DIR, "and mirrored →", drive_target)
"""# P3 embedding alignment (new rows only, LM CE + anchors)
"""
from sklearn.metrics.pairwise import cosine_similarity
import torch.nn as nn
with torch.no_grad():
W2 = model.get_input_embeddings().weight.detach().cpu().numpy()
sim_pre = cosine_similarity(W2[new_ids.cpu()], W2)
top5_pre = np.argsort(-sim_pre, axis=1)[:,1:6]
# targets
centroids = torch.tensor(W2[top5_pre].mean(axis=1), dtype=torch.float32, device=E.device)
pre_norms = torch.tensor(np.load(METRICS_DIR/"pre_norms.npy")[:len(new_ids)], dtype=torch.float32, device=E.device)
E_comp = torch.tensor(np.load(METRICS_DIR/"E_comp_newtokens.npy")[:len(new_ids)], dtype=torch.float32, device=E.device)
# freeze all but embeddings+head
for p in model.parameters(): p.requires_grad = False
E = model.get_input_embeddings().weight; E.requires_grad_(True)
for p in model.lm_head.parameters(): p.requires_grad = True
LMB_ANCHOR, LMB_CENTROID, LMB_NORM = 1e-3, 1e-3, 5e-4
id_to_row = {int(t.item()): r for r,t in enumerate(new_ids.cpu())}
def batch_rows(input_ids):
ids = torch.unique(input_ids).tolist()
rows = [id_to_row[i] for i in ids if i in id_to_row]
return torch.tensor(rows, dtype=torch.long, device=E.device) if rows else None
from transformers import Trainer
class MorphAlignTrainer(Trainer):
def compute_loss(self, model, inputs, return_outputs=False, num_items_in_batch=None):
outputs = model(**inputs); loss = outputs.loss
rows = batch_rows(inputs["input_ids"])
if rows is not None:
E_rows = E[new_ids[rows].to(E.device)]
l_cent = (1 - nn.functional.cosine_similarity(E_rows, centroids[rows], dim=1)).mean()
l_norm = nn.functional.mse_loss(E_rows.norm(dim=1), pre_norms[rows])
l_anch = nn.functional.mse_loss(E_rows, E_comp[rows])
loss = loss + LMB_CENTROID*l_cent + LMB_NORM*l_norm + LMB_ANCHOR*l_anch
return (loss, outputs) if return_outputs else loss
class NewRowMaskTrainer(MorphAlignTrainer):
def training_step(self, model, inputs, num_items_in_batch=None):
out = super().training_step(model, inputs, num_items_in_batch)
if E.grad is not None:
mask = torch.zeros_like(E.grad, dtype=torch.bool)
mask[new_ids.to(E.device)] = True
E.grad = torch.where(mask, E.grad, torch.zeros_like(E.grad))
return out
from transformers import TrainerCallback
import torch, math, json, numpy as np
class P3LiveLogger(TrainerCallback):
def __init__(self, new_ids, E, centroids, pre_norms, E_comp, log_every=50, out_json=METRICS_DIR/"phase3_live_log.json"):
self.new_ids = new_ids
self.E = E
self.centroids = centroids
self.pre_norms = pre_norms
self.E_comp = E_comp
self.log_every = log_every
self.out_json = out_json
self.records = []
@torch.no_grad()
def _metrics_snapshot(self, step):
rows = self.new_ids.to(self.E.device)
E_rows = self.E[rows]
# terms
cos_cent = torch.nn.functional.cosine_similarity(E_rows, self.centroids, dim=1)
l_centroid = (1 - cos_cent).mean().item()
l_norm = torch.nn.functional.mse_loss(E_rows.norm(dim=1), self.pre_norms).item()
l_anchor = torch.nn.functional.mse_loss(E_rows, self.E_comp).item()
# health stats
mean_norm = E_rows.norm(dim=1).mean().item()
grad_norm = (self.E.grad[rows].norm().item() if self.E.grad is not None else float("nan"))
# light overlap probe: compare to centroids’ neighbors proxy in current space
# (cheap proxy: average cosine with centroids instead of full top-k)
mean_cos_cent = cos_cent.mean().item()
self.records.append({
"step": int(step),
"l_centroid": l_centroid,
"l_norm": l_norm,
"l_anchor": l_anchor,
"mean_norm": mean_norm,
"grad_norm": grad_norm,
"mean_cos_centroid": mean_cos_cent,
})
def on_step_end(self, args, state, control, **kwargs):
if state.global_step % self.log_every == 0 and state.global_step > 0:
self._metrics_snapshot(state.global_step)
return control