|
| 1 | +""" |
| 2 | +Helper script to convert a condensed spectra table (e.g. full_dataset.xlsx) and |
| 3 | +its metadata into CandyCrunch-ready training pickles. Adjust the paths in the |
| 4 | +configuration block before running. |
| 5 | +""" |
| 6 | + |
| 7 | +import ast |
| 8 | +import pickle |
| 9 | +from pathlib import Path |
| 10 | + |
| 11 | +import numpy as np |
| 12 | +import pandas as pd |
| 13 | +from sklearn.model_selection import GroupShuffleSplit |
| 14 | + |
| 15 | +from candycrunch.prediction import bin_intensities |
| 16 | + |
| 17 | +full_dataset_path = Path("full_dataset.xlsx") |
| 18 | +metadata_path = Path("training/file_checklist_template.csv") |
| 19 | +legacy_train_path = None # e.g. Path("prepared_datasets/train_legacy.pkl") |
| 20 | +legacy_test_path = None # e.g. Path("prepared_datasets/test_legacy.pkl") |
| 21 | +output_dir = Path("prepared_datasets") |
| 22 | +test_size = 0.15 |
| 23 | +random_state = 42 |
| 24 | + |
| 25 | +MODE_MAP = {"negative": 0, "positive": 1} |
| 26 | +LC_MAP = {"PGC": 0, "C18": 1} |
| 27 | +MOD_MAP = {"reduced": 0, "permethylated": 1} |
| 28 | +TRAP_MAP = {"linear": 0, "orbitrap": 1, "amazon": 2} |
| 29 | + |
| 30 | +FEATURE_COLUMNS = [ |
| 31 | + "binned_intensities", |
| 32 | + "mz_remainder", |
| 33 | + "reducing_mass", |
| 34 | + "glycan_type", |
| 35 | + "RT", |
| 36 | + "mode", |
| 37 | + "lc", |
| 38 | + "modification", |
| 39 | + "trap", |
| 40 | +] |
| 41 | + |
| 42 | + |
| 43 | + |
| 44 | +def safe_peak_parse(value): |
| 45 | + if isinstance(value, dict): |
| 46 | + return {float(k): float(v) for k, v in value.items()} |
| 47 | + if isinstance(value, str): |
| 48 | + stripped = value.strip() |
| 49 | + if not stripped: |
| 50 | + return None |
| 51 | + try: |
| 52 | + parsed = ast.literal_eval(stripped) |
| 53 | + except (SyntaxError, ValueError): |
| 54 | + return None |
| 55 | + if isinstance(parsed, dict): |
| 56 | + return {float(k): float(v) for k, v in parsed.items()} |
| 57 | + return None |
| 58 | + |
| 59 | +def normalise_spectrum(spec): |
| 60 | + total = float(sum(spec.values())) |
| 61 | + if total <= 0: |
| 62 | + return None |
| 63 | + return {mz: intensity / total for mz, intensity in spec.items()} |
| 64 | + |
| 65 | +def process_peaks(df, min_mz=39.714, max_mz=3000.0, bin_num=2048): |
| 66 | + frames = np.linspace(min_mz, max_mz, bin_num) |
| 67 | + parsed = df["peak_d"].map(safe_peak_parse) |
| 68 | + parsed = parsed.map(lambda spec: None if spec is None else normalise_spectrum(spec)) |
| 69 | + keep_mask = parsed.notnull() |
| 70 | + df = df.loc[keep_mask].copy() |
| 71 | + parsed = parsed.loc[keep_mask] |
| 72 | + binned, remainders = zip(*(bin_intensities(spec, frames) for spec in parsed)) |
| 73 | + df["binned_intensities"] = [np.asarray(vec, dtype=np.float32) for vec in binned] |
| 74 | + df["mz_remainder"] = [np.asarray(vec, dtype=np.float32) for vec in remainders] |
| 75 | + return df |
| 76 | + |
| 77 | +def process_retention_times(df): |
| 78 | + df = df.copy() |
| 79 | + df["RT"] = df["RT"].fillna(15) |
| 80 | + df = df[df["RT"] > 2] |
| 81 | + df["RT"] = df.groupby("filename")["RT"].transform(lambda rt: rt / max(rt.max(), 30.0)) |
| 82 | + return df |
| 83 | + |
| 84 | +def infer_glycan_types(df): |
| 85 | + def classify(g): |
| 86 | + if g.endswith(("GalNAc", "GalNAc6S", "GalNAcOS", "Fuc", "Man", "Gal")): |
| 87 | + return 0 |
| 88 | + if "GlcNAc(b1-4)GlcNAc" in g: |
| 89 | + return 1 |
| 90 | + if g.endswith(("Glc", "GlcOS", "GlcNAc", "Ins")): |
| 91 | + return 2 |
| 92 | + return 3 |
| 93 | + df = df.copy() |
| 94 | + df["glycan_type"] = df["glycan"].map(classify) |
| 95 | + return df |
| 96 | + |
| 97 | +def attach_metadata(df, checklist): |
| 98 | + meta = checklist.copy() |
| 99 | + meta.columns = [col.strip() for col in meta.columns] |
| 100 | + meta = meta.set_index("GlycoPOST_ID") |
| 101 | + meta.index = meta.index.astype(str).str.strip().str.lower() |
| 102 | + |
| 103 | + def build_dict(column): |
| 104 | + if column not in meta.columns: |
| 105 | + return {} |
| 106 | + series = meta[column].fillna("").astype(str).str.lower().str.strip() |
| 107 | + return series.to_dict() |
| 108 | + |
| 109 | + mode_dict = build_dict("mode") |
| 110 | + lc_dict = build_dict("LC_type") |
| 111 | + mod_dict = build_dict("modification") |
| 112 | + trap_dict = build_dict("trap") |
| 113 | + |
| 114 | + def lookup(mapping, value, fallback): |
| 115 | + if not value: |
| 116 | + return fallback |
| 117 | + return mapping.get(value.lower(), fallback) |
| 118 | + |
| 119 | + def map_series(ids, source_dict, mapping, fallback): |
| 120 | + def mapper(gid): |
| 121 | + if gid in source_dict: |
| 122 | + return lookup(mapping, source_dict[gid], fallback) |
| 123 | + return fallback |
| 124 | + lowered = ids.fillna("").astype(str).str.strip().str.lower() |
| 125 | + return lowered.map(mapper).astype(int) |
| 126 | + |
| 127 | + df = df.copy() |
| 128 | + ids = df["GlycoPost_ID"].astype(str).str.strip() |
| 129 | + df["mode"] = map_series(ids, mode_dict, MODE_MAP, 2) |
| 130 | + df["lc"] = map_series(ids, lc_dict, LC_MAP, 2) |
| 131 | + df["modification"] = map_series(ids, mod_dict, MOD_MAP, 2) |
| 132 | + df["trap"] = map_series(ids, trap_dict, TRAP_MAP, 3) |
| 133 | + return df |
| 134 | + |
| 135 | +def process_full_dataset(full_df, checklist): |
| 136 | + df = process_retention_times(full_df) |
| 137 | + df = infer_glycan_types(df) |
| 138 | + df = process_peaks(df) |
| 139 | + df = attach_metadata(df, checklist) |
| 140 | + return df |
| 141 | + |
| 142 | +def downcast_numeric(df): |
| 143 | + result = df.copy() |
| 144 | + int_cols = result.select_dtypes(include=["int", "uint", "int64", "uint64"]).columns |
| 145 | + float_cols = result.select_dtypes(include=["float", "float64"]).columns |
| 146 | + for col in int_cols: |
| 147 | + result[col] = pd.to_numeric(result[col], downcast="unsigned") |
| 148 | + for col in float_cols: |
| 149 | + result[col] = pd.to_numeric(result[col], downcast="float") |
| 150 | + return result |
| 151 | + |
| 152 | +def tupleify(df, columns): |
| 153 | + return list(df[list(columns)].itertuples(index=False, name=None)) |
| 154 | + |
| 155 | + |
| 156 | +print("Loading condensed spectra") |
| 157 | +full_df = pd.read_excel(full_dataset_path) |
| 158 | +meta_df = pd.read_csv(metadata_path) |
| 159 | +processed = process_full_dataset(full_df, meta_df) |
| 160 | + |
| 161 | +frames = [processed] |
| 162 | +if legacy_train_path: |
| 163 | + frames.append(pd.read_pickle(legacy_train_path)) |
| 164 | +if legacy_test_path: |
| 165 | + frames.append(pd.read_pickle(legacy_test_path)) |
| 166 | +combined = pd.concat(frames, ignore_index=True) |
| 167 | + |
| 168 | +glycans = sorted(set(combined["glycan"])) |
| 169 | +glycan_to_idx = {g: i for i, g in enumerate(glycans)} |
| 170 | +combined["glycan"] = combined["glycan"].map(glycan_to_idx) |
| 171 | + |
| 172 | +print("Splitting train/test by filename") |
| 173 | +splitter = GroupShuffleSplit(test_size=test_size, n_splits=1, random_state=random_state) |
| 174 | +train_idx, test_idx = next(splitter.split(combined, groups=combined["filename"])) |
| 175 | +train_df = combined.iloc[train_idx].reset_index(drop=True) |
| 176 | +test_df = combined.iloc[test_idx].reset_index(drop=True) |
| 177 | + |
| 178 | +train_df = downcast_numeric(train_df) |
| 179 | +test_df = downcast_numeric(test_df) |
| 180 | + |
| 181 | +print("Writing intermediate dataframes") |
| 182 | +output_dir.mkdir(parents=True, exist_ok=True) |
| 183 | +train_df.to_pickle(output_dir / "train_second.pkl") |
| 184 | +test_df.to_pickle(output_dir / "test_second.pkl") |
| 185 | + |
| 186 | +print("Serialising tuples for CandyCrunch") |
| 187 | +X_train = tupleify(train_df, FEATURE_COLUMNS) |
| 188 | +X_test = tupleify(test_df, FEATURE_COLUMNS) |
| 189 | +y_train = train_df["glycan"].tolist() |
| 190 | +y_test = test_df["glycan"].tolist() |
| 191 | + |
| 192 | +with open(output_dir / "X_train.pkl", "wb") as fh: |
| 193 | + pickle.dump(X_train, fh) |
| 194 | +with open(output_dir / "X_test.pkl", "wb") as fh: |
| 195 | + pickle.dump(X_test, fh) |
| 196 | +with open(output_dir / "y_train.pkl", "wb") as fh: |
| 197 | + pickle.dump(y_train, fh) |
| 198 | +with open(output_dir / "y_test.pkl", "wb") as fh: |
| 199 | + pickle.dump(y_test, fh) |
| 200 | +with open(output_dir / "glycans.pkl", "wb") as fh: |
| 201 | + pickle.dump(glycans, fh) |
| 202 | + |
| 203 | +print(f"Saved processed datasets to {output_dir.resolve()}") |
0 commit comments