Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
201 changes: 201 additions & 0 deletions Week 2 FFT and tokenization/preprocessing2final_Hirani_Shah.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,201 @@
{
"nbformat": 4,
"nbformat_minor": 0,
"metadata": {
"colab": {
"provenance": [],
"gpuType": "T4"
},
"kernelspec": {
"name": "python3",
"display_name": "Python 3"
},
"language_info": {
"name": "python"
},
"accelerator": "GPU"
},
"cells": [
{
"cell_type": "code",
"source": [
"\n",
"!pip install mne scipy numpy matplotlib torch --quiet\n",
"\n",
"import torch\n",
"import numpy as np\n",
"import mne\n",
"from scipy.signal import butter, filtfilt, spectrogram, stft\n",
"import matplotlib.pyplot as plt\n",
"\n",
"pth_path = '/content/eeg_signals_raw_with_mean_std.pth'\n",
"eeg_data = torch.load(pth_path)\n",
"print(\" loaded.\")\n",
"\n",
"\n",
"FS = 1000\n",
"FREQ_RANGE = (5, 95)\n",
"TRIM_MS = 0\n",
"CHANNELS = 128\n",
"TOKEN_GROUP = 4\n",
"PROJ_DIM = 1024\n",
"\n",
"\n",
"raw_lengths = [\n",
" seg['eeg'].shape[1]\n",
" for subj in eeg_data['dataset'].values()\n",
" for seg in subj if isinstance(seg, dict) and 'eeg' in seg\n",
"]\n",
"median_len = int(np.median(raw_lengths))\n",
"median_len -= median_len % 4\n",
"print(f\" Median EEG segment length: {median_len}\")\n",
"\n",
"\n",
"def bandpass_filter(signal, low, high, fs, order=5):\n",
" b, a = butter(order, [low / (fs/2), high / (fs/2)], btype='band')\n",
" return filtfilt(b, a, signal, axis=-1)\n",
"\n",
"def trim_and_pad(signal, target_len):\n",
" if signal.shape[1] < target_len:\n",
" pad_width = target_len - signal.shape[1]\n",
" signal = np.pad(signal, ((0, 0), (0, pad_width)), mode='constant')\n",
" return signal[:, :target_len]\n",
"\n",
"def normalize(signal):\n",
" mean = signal.mean(axis=1, keepdims=True)\n",
" std = signal.std(axis=1, keepdims=True) + 1e-8\n",
" return (signal - mean) / std\n",
"\n",
"def pad_channels(signal, target_channels=128):\n",
" c, t = signal.shape\n",
" if c < target_channels:\n",
" repeats = (target_channels + c - 1) // c\n",
" signal = np.tile(signal, (repeats, 1))[:target_channels, :]\n",
" return signal\n",
"\n",
"def temporal_tokenize(signal, group_size=4):\n",
" c, t = signal.shape\n",
" t_new = t // group_size\n",
" signal = signal[:, :t_new * group_size]\n",
" return signal.reshape(c, t_new, group_size).mean(axis=2)\n",
"\n",
"def linear_project(token_tensor, out_dim=1024):\n",
" W = np.random.randn(token_tensor.shape[1], out_dim)\n",
" return token_tensor @ W\n",
"\n",
"def plot_psd(signal, fs=1000, subject_id='', seg_idx=0):\n",
" info = mne.create_info(ch_names=[f\"ch{i}\" for i in range(signal.shape[0])], sfreq=fs, ch_types=\"eeg\")\n",
" raw = mne.io.RawArray(signal, info)\n",
" psds, freqs = mne.time_frequency.psd_array_welch(signal, sfreq=fs, fmin=1, fmax=100, n_fft=256)\n",
"\n",
" plt.figure(figsize=(10, 4))\n",
" plt.plot(freqs, psds[0], label='Channel 0')\n",
" plt.xlabel(\"Frequency (Hz)\")\n",
" plt.ylabel(\"Power Spectral Density\")\n",
" plt.title(f\"PSD (Subject {subject_id}, Segment {seg_idx})\")\n",
" plt.grid(True)\n",
" plt.legend()\n",
" plt.show()\n",
"\n",
"def plot_spectrogram(signal, fs=1000, subject_id='', seg_idx=0):\n",
" f, t, Zxx = stft(signal[0], fs=fs, nperseg=128)\n",
" power_db = 10 * np.log10(np.abs(Zxx)**2 + 1e-8)\n",
"\n",
" plt.figure(figsize=(10, 4))\n",
" plt.pcolormesh(t, f, power_db, shading='gouraud', cmap='viridis')\n",
" plt.ylabel('Frequency [Hz]')\n",
" plt.xlabel('Time [sec]')\n",
" plt.colorbar(label='Power [dB]')\n",
" plt.title(f\"Spectrogram (Subject {subject_id}, Segment {seg_idx}, Channel 0)\")\n",
" plt.tight_layout()\n",
" plt.show()\n",
"\n",
"\n",
"all_subject_embeddings = {}\n",
"\n",
"for subject_id, segments in eeg_data['dataset'].items():\n",
" print(f\" Processing Subject {subject_id} \")\n",
" subject_embeddings = []\n",
"\n",
" for seg_idx, segment in enumerate(segments):\n",
" if isinstance(segment, dict) and 'eeg' in segment:\n",
" signal = segment['eeg'].numpy()\n",
"\n",
"\n",
" filtered = bandpass_filter(signal, FREQ_RANGE[0], FREQ_RANGE[1], FS)\n",
" trimmed = trim_and_pad(filtered, target_len=median_len)\n",
" normalized = normalize(trimmed)\n",
" padded = pad_channels(normalized, target_channels=CHANNELS)\n",
" tokenized = temporal_tokenize(padded, group_size=TOKEN_GROUP)\n",
"\n",
"\n",
" token_tensor = tokenized.T\n",
" embedded = linear_project(token_tensor, out_dim=PROJ_DIM)\n",
" subject_embeddings.append(embedded)\n",
"\n",
"\n",
" if seg_idx == 0:\n",
" plot_psd(padded, fs=FS, subject_id=subject_id, seg_idx=seg_idx)\n",
" plot_spectrogram(padded, fs=FS, subject_id=subject_id, seg_idx=seg_idx)\n",
" else:\n",
" print(f\"Skipping segment {seg_idx} (no 'eeg' key)\")\n",
"\n",
" if subject_embeddings:\n",
" all_subject_embeddings[subject_id] = subject_embeddings\n",
" print(f\" Finished Subject {subject_id} with {len(subject_embeddings)} segt\")\n",
"\n",
"\n",
"torch.save(all_subject_embeddings, 'dreamdiffusion_eeg_embeddings.pth')\n",
"print(\"Embeddings saved \")\n",
"\n"
],
"metadata": {
"id": "uLSvFU5WJru0",
"outputId": "149ad767-0313-46fb-bc25-49a1a46a1152",
"colab": {
"base_uri": "https://localhost:8080/",
"height": 505
}
},
"execution_count": 1,
"outputs": [
{
"output_type": "stream",
"name": "stdout",
"text": [
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.4/7.4 MB\u001b[0m \u001b[31m88.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m363.4/363.4 MB\u001b[0m \u001b[31m4.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m13.8/13.8 MB\u001b[0m \u001b[31m33.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m24.6/24.6 MB\u001b[0m \u001b[31m29.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m883.7/883.7 kB\u001b[0m \u001b[31m25.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n",
"\u001b[2K \u001b[91m━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[91m╸\u001b[0m\u001b[90m━━━━━━━━━━━━━\u001b[0m \u001b[32m442.3/664.8 MB\u001b[0m \u001b[31m136.3 MB/s\u001b[0m eta \u001b[36m0:00:02\u001b[0m\n",
"\u001b[?25h\u001b[31mERROR: Operation cancelled by user\u001b[0m\u001b[31m\n",
"\u001b[0m"
]
},
{
"output_type": "error",
"ename": "ModuleNotFoundError",
"evalue": "No module named 'mne'",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<ipython-input-1-972ae073a854>\u001b[0m in \u001b[0;36m<cell line: 0>\u001b[0;34m()\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mnumpy\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mmne\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mscipy\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msignal\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mbutter\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfiltfilt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mspectrogram\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstft\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mmatplotlib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpyplot\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mplt\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'mne'",
"",
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0;32m\nNOTE: If your import is failing due to a missing package, you can\nmanually install dependencies using either !pip or !apt.\n\nTo view examples of installing some common dependencies, click the\n\"Open Examples\" button below.\n\u001b[0;31m---------------------------------------------------------------------------\u001b[0m\n"
],
"errorDetails": {
"actions": [
{
"action": "open_url",
"actionText": "Open Examples",
"url": "/notebooks/snippets/importing_libraries.ipynb"
}
]
}
}
]
}
]
}