diff --git a/Week 2 FFT and tokenization/preprocessing2final_Hirani_Shah.ipynb b/Week 2 FFT and tokenization/preprocessing2final_Hirani_Shah.ipynb new file mode 100644 index 0000000..d3ee944 --- /dev/null +++ b/Week 2 FFT and tokenization/preprocessing2final_Hirani_Shah.ipynb @@ -0,0 +1,201 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "gpuType": "T4" + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "code", + "source": [ + "\n", + "!pip install mne scipy numpy matplotlib torch --quiet\n", + "\n", + "import torch\n", + "import numpy as np\n", + "import mne\n", + "from scipy.signal import butter, filtfilt, spectrogram, stft\n", + "import matplotlib.pyplot as plt\n", + "\n", + "pth_path = '/content/eeg_signals_raw_with_mean_std.pth'\n", + "eeg_data = torch.load(pth_path)\n", + "print(\" loaded.\")\n", + "\n", + "\n", + "FS = 1000\n", + "FREQ_RANGE = (5, 95)\n", + "TRIM_MS = 0\n", + "CHANNELS = 128\n", + "TOKEN_GROUP = 4\n", + "PROJ_DIM = 1024\n", + "\n", + "\n", + "raw_lengths = [\n", + " seg['eeg'].shape[1]\n", + " for subj in eeg_data['dataset'].values()\n", + " for seg in subj if isinstance(seg, dict) and 'eeg' in seg\n", + "]\n", + "median_len = int(np.median(raw_lengths))\n", + "median_len -= median_len % 4\n", + "print(f\" Median EEG segment length: {median_len}\")\n", + "\n", + "\n", + "def bandpass_filter(signal, low, high, fs, order=5):\n", + " b, a = butter(order, [low / (fs/2), high / (fs/2)], btype='band')\n", + " return filtfilt(b, a, signal, axis=-1)\n", + "\n", + "def trim_and_pad(signal, target_len):\n", + " if signal.shape[1] < target_len:\n", + " pad_width = target_len - signal.shape[1]\n", + " signal = np.pad(signal, ((0, 0), (0, pad_width)), mode='constant')\n", + " return signal[:, :target_len]\n", + "\n", + "def normalize(signal):\n", + " mean = signal.mean(axis=1, keepdims=True)\n", + " std = signal.std(axis=1, keepdims=True) + 1e-8\n", + " return (signal - mean) / std\n", + "\n", + "def pad_channels(signal, target_channels=128):\n", + " c, t = signal.shape\n", + " if c < target_channels:\n", + " repeats = (target_channels + c - 1) // c\n", + " signal = np.tile(signal, (repeats, 1))[:target_channels, :]\n", + " return signal\n", + "\n", + "def temporal_tokenize(signal, group_size=4):\n", + " c, t = signal.shape\n", + " t_new = t // group_size\n", + " signal = signal[:, :t_new * group_size]\n", + " return signal.reshape(c, t_new, group_size).mean(axis=2)\n", + "\n", + "def linear_project(token_tensor, out_dim=1024):\n", + " W = np.random.randn(token_tensor.shape[1], out_dim)\n", + " return token_tensor @ W\n", + "\n", + "def plot_psd(signal, fs=1000, subject_id='', seg_idx=0):\n", + " info = mne.create_info(ch_names=[f\"ch{i}\" for i in range(signal.shape[0])], sfreq=fs, ch_types=\"eeg\")\n", + " raw = mne.io.RawArray(signal, info)\n", + " psds, freqs = mne.time_frequency.psd_array_welch(signal, sfreq=fs, fmin=1, fmax=100, n_fft=256)\n", + "\n", + " plt.figure(figsize=(10, 4))\n", + " plt.plot(freqs, psds[0], label='Channel 0')\n", + " plt.xlabel(\"Frequency (Hz)\")\n", + " plt.ylabel(\"Power Spectral Density\")\n", + " plt.title(f\"PSD (Subject {subject_id}, Segment {seg_idx})\")\n", + " plt.grid(True)\n", + " plt.legend()\n", + " plt.show()\n", + "\n", + "def plot_spectrogram(signal, fs=1000, subject_id='', seg_idx=0):\n", + " f, t, Zxx = stft(signal[0], fs=fs, nperseg=128)\n", + " power_db = 10 * np.log10(np.abs(Zxx)**2 + 1e-8)\n", + "\n", + " plt.figure(figsize=(10, 4))\n", + " plt.pcolormesh(t, f, power_db, shading='gouraud', cmap='viridis')\n", + " plt.ylabel('Frequency [Hz]')\n", + " plt.xlabel('Time [sec]')\n", + " plt.colorbar(label='Power [dB]')\n", + " plt.title(f\"Spectrogram (Subject {subject_id}, Segment {seg_idx}, Channel 0)\")\n", + " plt.tight_layout()\n", + " plt.show()\n", + "\n", + "\n", + "all_subject_embeddings = {}\n", + "\n", + "for subject_id, segments in eeg_data['dataset'].items():\n", + " print(f\" Processing Subject {subject_id} \")\n", + " subject_embeddings = []\n", + "\n", + " for seg_idx, segment in enumerate(segments):\n", + " if isinstance(segment, dict) and 'eeg' in segment:\n", + " signal = segment['eeg'].numpy()\n", + "\n", + "\n", + " filtered = bandpass_filter(signal, FREQ_RANGE[0], FREQ_RANGE[1], FS)\n", + " trimmed = trim_and_pad(filtered, target_len=median_len)\n", + " normalized = normalize(trimmed)\n", + " padded = pad_channels(normalized, target_channels=CHANNELS)\n", + " tokenized = temporal_tokenize(padded, group_size=TOKEN_GROUP)\n", + "\n", + "\n", + " token_tensor = tokenized.T\n", + " embedded = linear_project(token_tensor, out_dim=PROJ_DIM)\n", + " subject_embeddings.append(embedded)\n", + "\n", + "\n", + " if seg_idx == 0:\n", + " plot_psd(padded, fs=FS, subject_id=subject_id, seg_idx=seg_idx)\n", + " plot_spectrogram(padded, fs=FS, subject_id=subject_id, seg_idx=seg_idx)\n", + " else:\n", + " print(f\"Skipping segment {seg_idx} (no 'eeg' key)\")\n", + "\n", + " if subject_embeddings:\n", + " all_subject_embeddings[subject_id] = subject_embeddings\n", + " print(f\" Finished Subject {subject_id} with {len(subject_embeddings)} segt\")\n", + "\n", + "\n", + "torch.save(all_subject_embeddings, 'dreamdiffusion_eeg_embeddings.pth')\n", + "print(\"Embeddings saved \")\n", + "\n" + ], + "metadata": { + "id": "uLSvFU5WJru0", + "outputId": "149ad767-0313-46fb-bc25-49a1a46a1152", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 505 + } + }, + "execution_count": 1, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m7.4/7.4 MB\u001b[0m \u001b[31m88.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m363.4/363.4 MB\u001b[0m \u001b[31m4.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m13.8/13.8 MB\u001b[0m \u001b[31m33.6 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m24.6/24.6 MB\u001b[0m \u001b[31m29.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m883.7/883.7 kB\u001b[0m \u001b[31m25.5 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[2K \u001b[91m━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m\u001b[91m╸\u001b[0m\u001b[90m━━━━━━━━━━━━━\u001b[0m \u001b[32m442.3/664.8 MB\u001b[0m \u001b[31m136.3 MB/s\u001b[0m eta \u001b[36m0:00:02\u001b[0m\n", + "\u001b[?25h\u001b[31mERROR: Operation cancelled by user\u001b[0m\u001b[31m\n", + "\u001b[0m" + ] + }, + { + "output_type": "error", + "ename": "ModuleNotFoundError", + "evalue": "No module named 'mne'", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)", + "\u001b[0;32m\u001b[0m in \u001b[0;36m\u001b[0;34m()\u001b[0m\n\u001b[1;32m 3\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mtorch\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mnumpy\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 5\u001b[0;31m \u001b[0;32mimport\u001b[0m \u001b[0mmne\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 6\u001b[0m \u001b[0;32mfrom\u001b[0m \u001b[0mscipy\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msignal\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mbutter\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mfiltfilt\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mspectrogram\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mstft\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[0;32mimport\u001b[0m \u001b[0mmatplotlib\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpyplot\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mplt\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'mne'", + "", + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0;32m\nNOTE: If your import is failing due to a missing package, you can\nmanually install dependencies using either !pip or !apt.\n\nTo view examples of installing some common dependencies, click the\n\"Open Examples\" button below.\n\u001b[0;31m---------------------------------------------------------------------------\u001b[0m\n" + ], + "errorDetails": { + "actions": [ + { + "action": "open_url", + "actionText": "Open Examples", + "url": "/notebooks/snippets/importing_libraries.ipynb" + } + ] + } + } + ] + } + ] +} \ No newline at end of file