Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions docs/source/guides/batch_jobs.rst
Original file line number Diff line number Diff line change
Expand Up @@ -219,8 +219,18 @@ Build reusable base images for CPU and GPU workloads:

.. code-block:: bash

docker build -f docker/analysis-base/Dockerfile.cpu -t spikelab/analysis-base:cpu .
docker build -f docker/analysis-base/Dockerfile.gpu -t spikelab/analysis-base:gpu .
bash scripts/build_base_image.sh cpu spikelab/analysis-base:cpu
bash scripts/build_base_image.sh gpu spikelab/analysis-base:gpu

The base image bakes in the SpikeLab source via ``COPY src ./src`` and
``pip install -e .``. It is a frozen snapshot — published SpikeLab releases do
not update an existing image automatically. Rebuild whenever the library
source has changed and you need that change reflected on the cluster.

When iterating on a feature branch, build under a developer-scoped tag (e.g.,
``ghcr.io/<org>/spikelab-analysis-base:${USER}-$(git rev-parse --short HEAD)``)
and pass it explicitly via ``--image`` so concurrent developers do not clobber
each other's shared ``:cpu`` / ``:gpu`` tags.

Temporary images
^^^^^^^^^^^^^^^^
Expand All @@ -232,6 +242,10 @@ Build and push a temporary image for a single run:
bash scripts/build_temp_image.sh gpu ghcr.io/<org>/spikelab-analysis-temp:<tag>
bash scripts/push_temp_image.sh ghcr.io/<org>/spikelab-analysis-temp:<tag>

This layers analysis-time files on top of an existing ``analysis-base`` image
without rebuilding it. Use this when only the analysis script changed; if
``src/spikelab/`` itself changed, rebuild the base image first (see above).

Reference this tag in the ``ContainerSpec`` when creating your ``JobSpec``.


Expand Down
27 changes: 27 additions & 0 deletions scripts/build_base_image.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
#!/usr/bin/env bash
set -euo pipefail

if [[ $# -lt 2 ]]; then
echo "Usage: $0 <cpu|gpu> <image-tag>"
echo "Example: $0 cpu ghcr.io/acme/spikelab-analysis-base:dev-abc1234"
exit 1
fi

profile="$1"
image_tag="$2"

case "$profile" in
cpu) dockerfile="docker/analysis-base/Dockerfile.cpu" ;;
gpu) dockerfile="docker/analysis-base/Dockerfile.gpu" ;;
*)
echo "Error: profile must be 'cpu' or 'gpu', got '$profile'"
exit 1
;;
esac

docker build \
-f "${dockerfile}" \
-t "${image_tag}" \
.

echo "BUILT_IMAGE=${image_tag}"
26 changes: 26 additions & 0 deletions src/spikelab/batch_jobs/INSTRUCTIONS.md
Original file line number Diff line number Diff line change
Expand Up @@ -85,12 +85,38 @@ These scripts are in the SpikeLab repository under `scripts/` and `docker/`. The
- `python scripts/generate_job_config.py --image <image-tag> --profile <cpu|gpu> --output configs/batch-temp-job.yaml`
5. Confirm image is pullable from target cluster/namespace before deploy.

### When SpikeLab source has changed (developer iteration)

The `build_temp_image.sh` workflow above layers analysis code on top of an existing `analysis-base` image. It does **not** capture changes to `src/spikelab/` itself. If the user has modified the SpikeLab library (e.g., they are on a feature branch with new methods that the submitted script depends on), the `analysis-base` image must be rebuilt first — otherwise the running container exposes a stale API and the job will fail with `AttributeError` or run against outdated behavior.

In that case, rebuild and push a **developer-scoped base image** before submitting, and pass it explicitly via `--image`:

```bash
# From SpikeLab repo root. Use ${USER:-${USERNAME}} for Linux/Mac/Windows compatibility.
USER_TAG="ghcr.io/<org>/spikelab-analysis-base:${USER:-${USERNAME}}-$(git rev-parse --short HEAD)"

bash scripts/build_base_image.sh cpu "${USER_TAG}" # or 'gpu'
bash scripts/push_temp_image.sh "${USER_TAG}"

# Submit using the freshly built image
spikelab-batch-jobs deploy-job \
--profile <profile> \
--job-config <path> \
--image "${USER_TAG}"
```

Notes:
- The Dockerfile uses `COPY src ./src`, so **uncommitted edits in `src/spikelab/` are also baked into the image**. This is useful for fast iteration but can be surprising — confirm `git status` reflects the state you intend to ship.
- Use a developer-scoped tag (username + short SHA) rather than the shared `:cpu`/`:gpu` tags so concurrent developers do not clobber each other's images.
- The shared `ghcr.io/braingeneers/spikelab-analysis-base:cpu` / `:gpu` tags are static snapshots — they do **not** track new SpikeLab releases automatically. Always rebuild when the library source has changed locally.

## Fixed Workflow

1. **Preflight checks**
- Run `kubectl version --client`.
- Run `kubectl config current-context`.
- Validate registry/image tag exists and is pushed.
- If `git status` shows changes to `src/spikelab/`, the cluster-side image is stale relative to local code. Rebuild and push a developer-scoped base image before submitting (see "When SpikeLab source has changed" under Container Prep) and pass the resulting tag via `--image`.
- Optionally verify S3 access if asked by the user.
2. **Validate inputs**
- Ensure `--job-config` is present.
Expand Down
1 change: 1 addition & 0 deletions src/spikelab/spike_sorting/stim_sorting/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ def sort_stim_recording(

Parameters:
stim_recording: The stimulation recording. Can be:

- ``str`` or ``Path`` to a recording file (Maxwell .h5 or
NWB). Chunked path.
- A SpikeInterface ``BaseRecording`` object. Chunked path.
Expand Down
5 changes: 4 additions & 1 deletion src/spikelab/spike_sorting/stim_sorting/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,13 @@
"""

from pathlib import Path
from typing import Optional, Tuple
from typing import TYPE_CHECKING, Optional, Tuple

import numpy as np

if TYPE_CHECKING:
from spikeinterface.core import BaseRecording


def preprocess_stim_artifacts(
recording,
Expand Down
1 change: 1 addition & 0 deletions src/spikelab/spike_sorting/stim_sorting/recentering.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,7 @@ def recenter_stim_times(
max_offset_ms (float): Radius of the search window around
each logged stim time, in milliseconds. Default 50.0.
peak_mode (str): Alignment target. One of:

* ``"abs_max"`` (default): largest ``|voltage|`` across
channels. Backward-compatible with the pre-``peak_mode``
API.
Expand Down
21 changes: 13 additions & 8 deletions src/spikelab/spikedata/plot_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2828,8 +2828,9 @@ def plot_prediction_probability_heatmap(
true label matches. Optionally subtracts the mean probability over
a set of baseline cycles to highlight changes across stim rounds.

Cell ``(i, j)`` of the heatmap = mean ``proba[i, samples in cycle j
where true == classes[i]]``.
Cell ``(i, j)`` of the heatmap is the mean of ``proba[i, s]`` taken
over samples ``s`` in cycle ``j`` whose true label equals
``classes[i]``.

Parameters:
probabilities (np.ndarray): Predicted probabilities, shape
Expand Down Expand Up @@ -2859,9 +2860,11 @@ def plot_prediction_probability_heatmap(
"P(correct)" or "ΔP vs baseline".

Returns:
result (dict): ``{"heatmap": (K, n_groups) array, "ax": ax,
"bar_ax": bar_ax or None, "groups": (n_groups,) array,
"classes": (K,) array}``.
result (dict): Mapping with keys ``"heatmap"`` (``(K, n_groups)``
array), ``"ax"`` (the heatmap axes), ``"bar_ax"`` (the bar
axes or ``None``), ``"groups"`` (``(n_groups,)`` array of
cycle indices), and ``"classes"`` (``(K,)`` array of class
labels).

Notes:
- Requires ``matplotlib``.
Expand Down Expand Up @@ -3061,9 +3064,11 @@ def plot_responsive_unit_map(
other_target_marker_size (float): Marker size for other-stim X.

Returns:
result (dict): ``{"ax": ax, "scatter": PathCollection,
"target_scatter": PathCollection,
"other_target_scatter": PathCollection or None}``.
result (dict): Mapping with keys ``"ax"`` (the plot axes),
``"scatter"`` (the units ``PathCollection``),
``"target_scatter"`` (the target marker ``PathCollection``),
and ``"other_target_scatter"`` (the other-stim
``PathCollection`` or ``None``).

Notes:
- Requires ``matplotlib``.
Expand Down
1 change: 1 addition & 0 deletions src/spikelab/spikedata/spikeslicestack.py
Original file line number Diff line number Diff line change
Expand Up @@ -488,6 +488,7 @@ def baseline_normalized_raster(
window relative to slice origin used to estimate the
per-slice baseline rate.
mode (str): Normalization mode:

- ``"subtract"`` (default) — counts above baseline expectation.
- ``"ratio"`` — counts / expected_counts (NaN where expected
is 0).
Expand Down
Loading