Skip to content

Commit 9e562ad

Browse files
Merge pull request #3 from GT-LIT-Lab/documentation-v2
add documentation for simple models
2 parents 5a13784 + 7543e9c commit 9e562ad

7 files changed

Lines changed: 920 additions & 20 deletions

docs/assemblies_tutorial.rst

Lines changed: 231 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,231 @@
1+
Understanding Assemblies in LITcoder
2+
=====================================
3+
4+
An **Assembly** is the core data structure in LITcoder that organizes and manages brain imaging data, stimuli, and metadata for encoding model training. It's the foundation that everything else builds upon.
5+
6+
What is an Assembly?
7+
-------------------
8+
9+
An assembly is a structured container that holds all the data needed to train encoding models:
10+
11+
- **Brain Data**: Recordings aligned with stimuli
12+
- **Stimuli**: Text or audio stimuli presented during the experiment
13+
- **Timing Information**: Precise timing of when each stimulus was presented
14+
- **Split Indices**: Maps each word/stimulus to its corresponding TR (time repetition)
15+
- **Metadata**: Story names, subject information, and experimental parameters
16+
17+
Think of an assembly as a well-organized database that contains everything needed to train a brain encoding model.
18+
19+
Assembly Structure
20+
-----------------
21+
22+
An assembly contains several key components:
23+
24+
**Stories**: List of story/run names
25+
Each story represents a continuous experimental session (e.g., listening to a story)
26+
27+
**Story Data**: Dictionary mapping story names to their data
28+
Contains brain data, stimuli, timing, and metadata for each story
29+
30+
**Timing Information**:
31+
- `tr_times`: When each TR (time repetition) occurred
32+
- `data_times`: Precise timing for each data point (word-level)
33+
- `split_indices`: Maps each word to its corresponding TR
34+
35+
**Brain Data**:
36+
- Preprocessed fMRI data aligned with stimuli
37+
- Shape: (n_timepoints, n_voxels/vertices)
38+
39+
Working with Assemblies
40+
-----------------------
41+
42+
Let's explore how to work with assemblies using the LeBel assembly:
43+
44+
.. code-block:: python
45+
46+
from encoding.assembly.assembly_loader import load_assembly
47+
48+
# Load the pre-packaged LeBel assembly
49+
assembly = load_assembly("assembly_lebel_uts03.pkl")
50+
51+
# Basic information
52+
print(f"Assembly shape: {assembly.shape}")
53+
print(f"Stories: {assembly.stories}")
54+
print(f"Validation method: {assembly.get_validation_method()}")
55+
56+
Key Assembly Methods
57+
-------------------
58+
59+
Here are the most important methods for working with assemblies:
60+
61+
**Data Access**:
62+
- `get_stimuli()`: Get text stimuli for each story
63+
- `get_brain_data()`: Get brain data for each story
64+
- `get_split_indices()`: Get word-to-TR mapping
65+
- `get_tr_times()`: Get TR timing information
66+
- `get_data_times()`: Get precise word-level timing
67+
68+
**Story-Specific Data**:
69+
- `get_temporal_baseline(story_name)`: Get temporal baseline features
70+
- `get_audio_path()`: Get audio file paths (for speech models)
71+
- `get_words()`: Get individual words for each story
72+
- `get_word_rates()`: Get pre-computed word rates
73+
74+
**Metadata**:
75+
- `get_validation_method()`: Get validation strategy ("inner" or "outer")
76+
- `stories`: List of story names
77+
- `story_data`: Dictionary of story-specific data
78+
79+
Exploring Assembly Contents
80+
---------------------------
81+
82+
Let's examine what's inside an assembly:
83+
84+
.. code-block:: python
85+
86+
# Load assembly
87+
assembly = load_assembly("assembly_lebel_uts03.pkl")
88+
89+
# Basic information
90+
print("=== Assembly Overview ===")
91+
print(f"Total presentations: {assembly.shape[0]}")
92+
print(f"Number of voxels/vertices: {assembly.shape[1]}")
93+
print(f"Stories: {assembly.stories}")
94+
print(f"Validation method: {assembly.get_validation_method()}")
95+
96+
# Explore each story
97+
print("\n=== Story Details ===")
98+
for story in assembly.stories:
99+
story_data = assembly.story_data[story]
100+
print(f"\nStory: {story}")
101+
print(f" Brain data shape: {story_data.brain_data.shape}")
102+
print(f" Number of stimuli: {len(story_data.stimuli)}")
103+
print(f" Split indices: {len(story_data.split_indices)} words")
104+
print(f" TR times: {len(story_data.tr_times)} TRs")
105+
print(f" Data times: {len(story_data.data_times)} words")
106+
107+
# Show first few stimuli
108+
print(f" First 3 stimuli: {story_data.stimuli[:3]}")
109+
110+
# Show split indices (these map words to TRs)
111+
print(f" First 10 split indices: {story_data.split_indices[:10]}")
112+
print(f" Last 10 split indices: {story_data.split_indices[-10:]}")
113+
114+
Understanding the Data Flow
115+
---------------------------
116+
117+
Here's how data flows through an assembly:
118+
119+
1. **Stimuli Extraction**: Text is processed into features (embeddings, word rates, etc.)
120+
2. **Timing Alignment**: Features are aligned with brain data using timing information
121+
3. **Downsampling**: High-resolution features are downsampled to match brain data TR
122+
4. **FIR Delays**: Temporal delays are applied to account for hemodynamic response
123+
5. **Train/Test Split**: Data is split for proper evaluation
124+
125+
Assembly Attributes
126+
-------------------
127+
128+
An assembly has several key attributes:
129+
130+
**Shape**: (n_presentations, n_voxels/vertices)
131+
Total number of timepoints and brain regions
132+
133+
**Stories**: List of story names
134+
Each story represents a continuous experimental session
135+
136+
**Story Data**: Dictionary of story-specific data
137+
Contains all the data for each story
138+
139+
**Coordinates**: Metadata about presentations
140+
Story IDs, stimulus IDs, etc.
141+
142+
**Validation Method**: "inner" or "outer"
143+
How the assembly handles train/test splits
144+
145+
Working with Story Data
146+
-----------------------
147+
148+
Each story in an assembly contains:
149+
150+
.. code-block:: python
151+
152+
# Get data for a specific story
153+
story_name = assembly.stories[0]
154+
story_data = assembly.story_data[story_name]
155+
156+
print(f"Story: {story_name}")
157+
print(f" Brain data: {story_data.brain_data.shape}")
158+
print(f" Stimuli: {len(story_data.stimuli)}")
159+
print(f" Split indices: {len(story_data.split_indices)}")
160+
print(f" TR times: {len(story_data.tr_times)}")
161+
print(f" Data times: {len(story_data.data_times)}")
162+
163+
# Access specific data
164+
brain_data = story_data.brain_data
165+
stimuli = story_data.stimuli
166+
split_indices = story_data.split_indices
167+
tr_times = story_data.tr_times
168+
data_times = story_data.data_times
169+
170+
Using Assemblies in Training
171+
----------------------------
172+
173+
Here's how assemblies are used in the training pipeline:
174+
175+
.. code-block:: python
176+
177+
from encoding.assembly.assembly_loader import load_assembly
178+
from encoding.features.factory import FeatureExtractorFactory
179+
from encoding.downsample.downsampling import Downsampler
180+
from encoding.models.nested_cv import NestedCVModel
181+
from encoding.trainer import AbstractTrainer
182+
183+
# 1. Load assembly
184+
assembly = load_assembly("assembly_lebel_uts03.pkl")
185+
186+
# 2. Create feature extractor
187+
extractor = FeatureExtractorFactory.create_extractor(
188+
modality="wordrate",
189+
model_name="wordrate",
190+
config={},
191+
cache_dir="cache",
192+
)
193+
194+
# 3. Set up other components
195+
downsampler = Downsampler()
196+
model = NestedCVModel(model_name="ridge_regression")
197+
198+
# 4. Configure training parameters
199+
fir_delays = [1, 2, 3, 4]
200+
trimming_config = {
201+
"train_features_start": 10,
202+
"train_features_end": -5,
203+
"train_targets_start": 0,
204+
"train_targets_end": None,
205+
"test_features_start": 50,
206+
"test_features_end": -5,
207+
"test_targets_start": 40,
208+
"test_targets_end": None,
209+
}
210+
211+
# 5. Create trainer
212+
trainer = AbstractTrainer(
213+
assembly=assembly,
214+
feature_extractors=[extractor],
215+
downsampler=downsampler,
216+
model=model,
217+
fir_delays=fir_delays,
218+
trimming_config=trimming_config,
219+
use_train_test_split=True,
220+
logger_backend="wandb",
221+
wandb_project_name="lebel-tutorial",
222+
dataset_type="lebel",
223+
results_dir="results",
224+
)
225+
226+
# 6. Train the model
227+
metrics = trainer.train()
228+
print(f"Median correlation: {metrics.get('median_score', float('nan')):.4f}")
229+
230+
231+
This understanding of assemblies is crucial for effectively using LITcoder. The assembly serves as the foundation for all encoding model training, providing the structured interface between your experimental data and the machine learning pipeline.

docs/index.rst

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,16 @@ Welcome to litcoder's documentation!
1010
installation
1111
quickstart
1212

13+
.. toctree::
14+
:maxdepth: 2
15+
:caption: Tutorials
16+
17+
assemblies_tutorial
18+
tutorial_wordrate
19+
tutorial_language_models
20+
tutorial_speech
21+
tutorial_embeddings
22+
1323
.. toctree::
1424
:maxdepth: 2
1525
:caption: Guides
@@ -33,4 +43,4 @@ Indices and tables
3343

3444
* :ref:`genindex`
3545
* :ref:`modindex`
36-
* :ref:`search`
46+
* :ref:`search`

docs/quickstart.rst

Lines changed: 59 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,43 +1,83 @@
11
Quickstart
22
==========
33

4-
This minimal example wires an assembly, feature extractor, downsampler, model, and the trainer.
4+
This minimal example shows how to train an encoding model using the LeBel assembly with word rate features. This is the simplest and fastest way to get started with LITcoder.
55

66
.. code-block:: python
77
8-
from encoding.assembly.assemblies import NarrativesAssembly
8+
from encoding.assembly.assembly_loader import load_assembly
99
from encoding.features.factory import FeatureExtractorFactory
1010
from encoding.downsample.downsampling import Downsampler
11-
from encoding.models.ridge_regression import RidgeRegressionModel
11+
from encoding.models.nested_cv import NestedCVModel
1212
from encoding.trainer import AbstractTrainer
1313
14-
# 1) Load data (example: Narratives-style assembly)
15-
assembly = NarrativesAssembly(assembly_path="/path/to/narratives.h5")
14+
# 1) Load prepackaged assembly
15+
assembly_path = "assembly_lebel_uts03.pkl"
16+
assembly = load_assembly(assembly_path)
1617
17-
# 2) Configure features (e.g., language model embeddings)
18-
extractor = FeatureExtractorFactory.create_language_model(
19-
model_name="gpt2", context_type="fullcontext", last_token=True
18+
# 2) Configure components (wordrate-only)
19+
extractor = FeatureExtractorFactory.create_extractor(
20+
modality="wordrate",
21+
model_name="wordrate",
22+
config={},
23+
cache_dir="cache",
2024
)
2125
22-
# 3) Downsampler
23-
downsampler = Downsampler(method="linear")
26+
downsampler = Downsampler()
27+
model = NestedCVModel(model_name="ridge_regression")
2428
25-
# 4) Model
26-
model = RidgeRegressionModel(n_alphas=20)
29+
# FIR, downsampling, and trimming match our LeBel defaults
30+
fir_delays = [1, 2, 3, 4]
31+
trimming_config = {
32+
"train_features_start": 10, "train_features_end": -5,
33+
"train_targets_start": 0, "train_targets_end": None,
34+
"test_features_start": 50, "test_features_end": -5,
35+
"test_targets_start": 40, "test_targets_end": None,
36+
}
2737
28-
# 5) Trainer
38+
downsample_config = {}
39+
40+
# 3) Train
2941
trainer = AbstractTrainer(
3042
assembly=assembly,
3143
feature_extractors=[extractor],
3244
downsampler=downsampler,
3345
model=model,
34-
fir_delays=[0, 1, 2, 3, 4],
35-
trimming_config={"features_start": 5, "targets_start": 5},
36-
use_train_test_split=False,
37-
dataset_type="narratives",
38-
logger_backend="tensorboard",
46+
fir_delays=fir_delays,
47+
trimming_config=trimming_config,
48+
use_train_test_split=True,
49+
logger_backend="wandb",
50+
wandb_project_name="lebel-wordrate",
51+
dataset_type="lebel",
3952
results_dir="results",
53+
downsample_config=downsample_config,
4054
)
4155
4256
metrics = trainer.train()
43-
print("Median correlation:", metrics["median_score"])
57+
print({
58+
"median_correlation": metrics.get("median_score", float("nan")),
59+
"n_significant": metrics.get("n_significant"),
60+
})
61+
62+
Prerequisites
63+
-------------
64+
65+
Before running this example, you need to:
66+
67+
1. **Download the LeBel assembly**:
68+
69+
.. code-block:: bash
70+
71+
gdown 1q-XLPjvhd8doGFhYBmeOkcenS9Y59x64
72+
73+
2. **Install LITcoder**:
74+
75+
.. code-block:: bash
76+
77+
git clone git@github.com:GT-LIT-Lab/litcoder_core.git
78+
cd litcoder_core
79+
conda create -n litcoder -y python=3.12.8
80+
conda activate litcoder
81+
conda install pip
82+
pip install -e .
83+

0 commit comments

Comments
 (0)