-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathreconstruct_auto_interp.py
More file actions
69 lines (53 loc) · 2.43 KB
/
reconstruct_auto_interp.py
File metadata and controls
69 lines (53 loc) · 2.43 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
# Fix filelock issue
import filelock
filelock.FileLock = filelock.SoftFileLock
import os
import json
from pathlib import Path
from datasets import load_dataset
from concurrent.futures import ProcessPoolExecutor, as_completed
# Before running the file, run load_auto_interp.py
# this is annoyingly slow to reconstruct all files. The reason is that we are creating one file per feature, so around 200k files.
# should not be on a login node
# I am currently moving the auto-interp to using SQLite
# Change to working directory
os.chdir("/fast/fdraye/data/featflow/data")
# Find all parquet shard files
shard_files = sorted([f for f in Path(".").glob("shards_layer*_single.parquet")])
print(f"Found {len(shard_files)} parquet shard files: {[f.name for f in shard_files]}")
# Function to write a single record
def write_record(record, output_dir):
out_file = output_dir / record['filename']
with open(out_file, "w", encoding="utf-8") as f:
json.dump(record, f, ensure_ascii=False, indent=2)
return 1
# Process each shard file
for shard_file in shard_files:
# Extract layer name from filename (e.g., "shards_layer0_single.parquet" -> "layer0")
layer_name = shard_file.stem.replace("shards_", "").replace("_single", "")
# Output folder for reconstructed JSONs
recon_dir = Path(f"{layer_name}")
# Skip if already exists and not empty
if recon_dir.exists() and any(recon_dir.iterdir()):
print(f" ✓ {recon_dir} already exists and has files, skipping...")
continue
recon_dir.mkdir(exist_ok=True)
print(f"\nProcessing {shard_file} -> {recon_dir}")
# Load all records from the single Parquet shard
dataset = load_dataset(
"parquet",
data_files=str(shard_file),
cache_dir="/fast/fdraye/data/featflow/data"
)["train"]
print(f" Loaded {len(dataset)} records")
# Parallel writing
NUM_WORKERS = 32 # adjust for your system
completed = 0
with ProcessPoolExecutor(max_workers=NUM_WORKERS) as executor:
futures = [executor.submit(write_record, record, recon_dir) for record in dataset]
for f in as_completed(futures):
completed += f.result()
if completed % 1000 == 0:
print(f" Written {completed}/{len(dataset)} records")
print(f" ✓ Reconstructed {completed} JSON files in {recon_dir}")
print(f"\n🎉 Done! Processed all {len(shard_files)} shard files.")