Skip to content

Commit 2895eeb

Browse files
committed
Added training pickle builder
1 parent 6053e23 commit 2895eeb

2 files changed

Lines changed: 206 additions & 0 deletions

File tree

training/build_training_pickles.py

Lines changed: 203 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,203 @@
1+
"""
2+
Helper script to convert a condensed spectra table (e.g. full_dataset.xlsx) and
3+
its metadata into CandyCrunch-ready training pickles. Adjust the paths in the
4+
configuration block before running.
5+
"""
6+
7+
import ast
8+
import pickle
9+
from pathlib import Path
10+
11+
import numpy as np
12+
import pandas as pd
13+
from sklearn.model_selection import GroupShuffleSplit
14+
15+
from candycrunch.prediction import bin_intensities
16+
17+
full_dataset_path = Path("full_dataset.xlsx")
18+
metadata_path = Path("training/file_checklist_template.csv")
19+
legacy_train_path = None # e.g. Path("prepared_datasets/train_legacy.pkl")
20+
legacy_test_path = None # e.g. Path("prepared_datasets/test_legacy.pkl")
21+
output_dir = Path("prepared_datasets")
22+
test_size = 0.15
23+
random_state = 42
24+
25+
MODE_MAP = {"negative": 0, "positive": 1}
26+
LC_MAP = {"PGC": 0, "C18": 1}
27+
MOD_MAP = {"reduced": 0, "permethylated": 1}
28+
TRAP_MAP = {"linear": 0, "orbitrap": 1, "amazon": 2}
29+
30+
FEATURE_COLUMNS = [
31+
"binned_intensities",
32+
"mz_remainder",
33+
"reducing_mass",
34+
"glycan_type",
35+
"RT",
36+
"mode",
37+
"lc",
38+
"modification",
39+
"trap",
40+
]
41+
42+
43+
44+
def safe_peak_parse(value):
45+
if isinstance(value, dict):
46+
return {float(k): float(v) for k, v in value.items()}
47+
if isinstance(value, str):
48+
stripped = value.strip()
49+
if not stripped:
50+
return None
51+
try:
52+
parsed = ast.literal_eval(stripped)
53+
except (SyntaxError, ValueError):
54+
return None
55+
if isinstance(parsed, dict):
56+
return {float(k): float(v) for k, v in parsed.items()}
57+
return None
58+
59+
def normalise_spectrum(spec):
60+
total = float(sum(spec.values()))
61+
if total <= 0:
62+
return None
63+
return {mz: intensity / total for mz, intensity in spec.items()}
64+
65+
def process_peaks(df, min_mz=39.714, max_mz=3000.0, bin_num=2048):
66+
frames = np.linspace(min_mz, max_mz, bin_num)
67+
parsed = df["peak_d"].map(safe_peak_parse)
68+
parsed = parsed.map(lambda spec: None if spec is None else normalise_spectrum(spec))
69+
keep_mask = parsed.notnull()
70+
df = df.loc[keep_mask].copy()
71+
parsed = parsed.loc[keep_mask]
72+
binned, remainders = zip(*(bin_intensities(spec, frames) for spec in parsed))
73+
df["binned_intensities"] = [np.asarray(vec, dtype=np.float32) for vec in binned]
74+
df["mz_remainder"] = [np.asarray(vec, dtype=np.float32) for vec in remainders]
75+
return df
76+
77+
def process_retention_times(df):
78+
df = df.copy()
79+
df["RT"] = df["RT"].fillna(15)
80+
df = df[df["RT"] > 2]
81+
df["RT"] = df.groupby("filename")["RT"].transform(lambda rt: rt / max(rt.max(), 30.0))
82+
return df
83+
84+
def infer_glycan_types(df):
85+
def classify(g):
86+
if g.endswith(("GalNAc", "GalNAc6S", "GalNAcOS", "Fuc", "Man", "Gal")):
87+
return 0
88+
if "GlcNAc(b1-4)GlcNAc" in g:
89+
return 1
90+
if g.endswith(("Glc", "GlcOS", "GlcNAc", "Ins")):
91+
return 2
92+
return 3
93+
df = df.copy()
94+
df["glycan_type"] = df["glycan"].map(classify)
95+
return df
96+
97+
def attach_metadata(df, checklist):
98+
meta = checklist.copy()
99+
meta.columns = [col.strip() for col in meta.columns]
100+
meta = meta.set_index("GlycoPOST_ID")
101+
meta.index = meta.index.astype(str).str.strip().str.lower()
102+
103+
def build_dict(column):
104+
if column not in meta.columns:
105+
return {}
106+
series = meta[column].fillna("").astype(str).str.lower().str.strip()
107+
return series.to_dict()
108+
109+
mode_dict = build_dict("mode")
110+
lc_dict = build_dict("LC_type")
111+
mod_dict = build_dict("modification")
112+
trap_dict = build_dict("trap")
113+
114+
def lookup(mapping, value, fallback):
115+
if not value:
116+
return fallback
117+
return mapping.get(value.lower(), fallback)
118+
119+
def map_series(ids, source_dict, mapping, fallback):
120+
def mapper(gid):
121+
if gid in source_dict:
122+
return lookup(mapping, source_dict[gid], fallback)
123+
return fallback
124+
lowered = ids.fillna("").astype(str).str.strip().str.lower()
125+
return lowered.map(mapper).astype(int)
126+
127+
df = df.copy()
128+
ids = df["GlycoPost_ID"].astype(str).str.strip()
129+
df["mode"] = map_series(ids, mode_dict, MODE_MAP, 2)
130+
df["lc"] = map_series(ids, lc_dict, LC_MAP, 2)
131+
df["modification"] = map_series(ids, mod_dict, MOD_MAP, 2)
132+
df["trap"] = map_series(ids, trap_dict, TRAP_MAP, 3)
133+
return df
134+
135+
def process_full_dataset(full_df, checklist):
136+
df = process_retention_times(full_df)
137+
df = infer_glycan_types(df)
138+
df = process_peaks(df)
139+
df = attach_metadata(df, checklist)
140+
return df
141+
142+
def downcast_numeric(df):
143+
result = df.copy()
144+
int_cols = result.select_dtypes(include=["int", "uint", "int64", "uint64"]).columns
145+
float_cols = result.select_dtypes(include=["float", "float64"]).columns
146+
for col in int_cols:
147+
result[col] = pd.to_numeric(result[col], downcast="unsigned")
148+
for col in float_cols:
149+
result[col] = pd.to_numeric(result[col], downcast="float")
150+
return result
151+
152+
def tupleify(df, columns):
153+
return list(df[list(columns)].itertuples(index=False, name=None))
154+
155+
156+
print("Loading condensed spectra")
157+
full_df = pd.read_excel(full_dataset_path)
158+
meta_df = pd.read_csv(metadata_path)
159+
processed = process_full_dataset(full_df, meta_df)
160+
161+
frames = [processed]
162+
if legacy_train_path:
163+
frames.append(pd.read_pickle(legacy_train_path))
164+
if legacy_test_path:
165+
frames.append(pd.read_pickle(legacy_test_path))
166+
combined = pd.concat(frames, ignore_index=True)
167+
168+
glycans = sorted(set(combined["glycan"]))
169+
glycan_to_idx = {g: i for i, g in enumerate(glycans)}
170+
combined["glycan"] = combined["glycan"].map(glycan_to_idx)
171+
172+
print("Splitting train/test by filename")
173+
splitter = GroupShuffleSplit(test_size=test_size, n_splits=1, random_state=random_state)
174+
train_idx, test_idx = next(splitter.split(combined, groups=combined["filename"]))
175+
train_df = combined.iloc[train_idx].reset_index(drop=True)
176+
test_df = combined.iloc[test_idx].reset_index(drop=True)
177+
178+
train_df = downcast_numeric(train_df)
179+
test_df = downcast_numeric(test_df)
180+
181+
print("Writing intermediate dataframes")
182+
output_dir.mkdir(parents=True, exist_ok=True)
183+
train_df.to_pickle(output_dir / "train_second.pkl")
184+
test_df.to_pickle(output_dir / "test_second.pkl")
185+
186+
print("Serialising tuples for CandyCrunch")
187+
X_train = tupleify(train_df, FEATURE_COLUMNS)
188+
X_test = tupleify(test_df, FEATURE_COLUMNS)
189+
y_train = train_df["glycan"].tolist()
190+
y_test = test_df["glycan"].tolist()
191+
192+
with open(output_dir / "X_train.pkl", "wb") as fh:
193+
pickle.dump(X_train, fh)
194+
with open(output_dir / "X_test.pkl", "wb") as fh:
195+
pickle.dump(X_test, fh)
196+
with open(output_dir / "y_train.pkl", "wb") as fh:
197+
pickle.dump(y_train, fh)
198+
with open(output_dir / "y_test.pkl", "wb") as fh:
199+
pickle.dump(y_test, fh)
200+
with open(output_dir / "glycans.pkl", "wb") as fh:
201+
pickle.dump(glycans, fh)
202+
203+
print(f"Saved processed datasets to {output_dir.resolve()}")
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
GlycoPOST_ID,LC_type,mode,modification,trap
2+
GPST_TEMPLATE_001,PGC,negative,reduced,linear
3+
GPST_TEMPLATE_002,C18,positive,permethylated,orbitrap

0 commit comments

Comments
 (0)