A Python library replicating core functionality from WarpLib, a C# framework for cryo-electron tomography (cryo-ET) data processing. Built on PyTorch for GPU acceleration and automatic differentiation.
warpylib provides:
- Cubic B-spline grid interpolation — spatially varying parameter fields with smooth interpolation
- CTF modelling — contrast transfer function calculation with full WarpLib parameter compatibility
- Tilt series geometry — 3D↔2D coordinate transforms, angle handling, metadata I/O
- Subtomogram reconstruction — weighted backprojection with CTF correction
- Image operations — Fourier-space shifting, resizing, normalization, bandpass filtering
The library preserves WarpLib's coordinate conventions and unit system (Ångströms, degrees, µm) for compatibility with Warp-processed datasets. If you're looking for a library that reads Warp metadata and reproduces Warp's exact tilt series model behavior, this is it!
pip install warpylibFor development:
git clone <repo>
cd warpylib
pip install -e ".[dev]"Requirements: Python ≥ 3.8
| Package | Version | Purpose |
|---|---|---|
torch |
≥ 2.0.0 | Core computation, GPU support, autograd |
torch-cubic-spline-grids |
≥ 0.0.10 | B-spline interpolation |
torch-projectors |
— | Fourier-space backprojection |
torch-subpixel-crop |
≥ 0.1.1 | Image extraction at fractional pixel coordinates |
numpy |
≥ 1.20.0 | Array utilities |
lxml |
≥ 4.9.0 | XML metadata I/O |
starfile |
≥ 0.5.0 | RELION STAR file I/O |
mrcfile |
— | MRC file I/O |
pillow |
— | Image loading/saving |
With torch-projectors, you're responsible for sourcing the correct package version for your CUDA environment if you want to use the GPU. Unfortunately, pip can't do this automatically.
import torch
from warpylib import TiltSeries, CTF, CubicGrid, euler_to_matrix
# Load a tilt series from WarpLib XML metadata
ts = TiltSeries(path="path/to/tiltseries.xml")
ts.load_meta("path/to/tiltseries.xml")
# Load tilt images resampled to a target pixel size
tilt_data = ts.load_images(original_pixel_size=10.0, target_pixel_size=5.0)
# Get per-tilt 2D positions for a set of particle coordinates
coords_3d = torch.tensor([[512.0, 512.0, 256.0]]) # (N, 3) in Ångströms
coords_2d = ts.get_position_in_all_tilts(coords_3d) # (N, n_tilts, 3)
# Reconstruct subtomograms
subtomos = ts.reconstruct_subvolumes(
tilt_data=tilt_data,
coords=coords_2d,
pixel_size=5.0,
size=64,
)
# subtomos: (N, 64, 64, 64)The central class for cryo-ET tilt series. Stores geometry, CTF parameters, and motion corrections as spatially varying cubic B-spline grids.
from warpylib import TiltSeries
# From WarpLib XML file
ts = TiltSeries(path="tiltseries.xml")
ts.load_meta("tiltseries.xml")
# From a RELION tomo STAR file
ts = TiltSeries(n_tilts=41)
ts.initialize_from_tomo_star("tomo.star")
# Save back to XML
ts.save_meta("tiltseries_modified.xml")ts.angles # (n_tilts,) tilt angles in degrees
ts.dose # (n_tilts,) accumulated electron dose
ts.use_tilt # (n_tilts,) bool mask — which tilts to include
ts.tilt_axis_angles # (n_tilts,) in-plane rotation of the tilt axis
ts.tilt_movie_paths # list of paths to per-tilt image files
# Spatially varying grids (CubicGrid objects)
ts.grid_ctf_defocus # defocus as a function of XY position and tilt
ts.grid_movement_x # motion correction X shifts
ts.grid_movement_y # motion correction Y shifts
ts.grid_volume_warp_x # 3D volume deformation field (X component)All coordinates are in Ångströms. Volume coordinates use the WarpLib convention (X, Y, Z).
import torch
# 3D particle positions: (N, 3)
particle_positions = torch.tensor([
[1024.0, 1024.0, 512.0],
[2048.0, 768.0, 300.0],
])
# Project to all tilts: (N, n_tilts, 3) — X, Y in Å, Z is defocus in µm
coords_2d = ts.get_position_in_all_tilts(particle_positions)
# Project to a single tilt: (N, 3)
coords_tilt5 = ts.get_positions_in_one_tilt(particle_positions, tilt_id=5)Particle orientations (ZYZ Euler angles) can be transformed into the reference frame of each tilt:
# Particle orientations in the volume frame: (N, 3) — rot, tilt, psi in radians
orientations = torch.zeros(len(particle_positions), 3)
# Transform to each tilt's frame: (N, n_tilts, 3)
tilt_orientations = ts.get_angle_in_all_tilts(orientations)
# For a single tilt: (N, 3)
tilt5_orientations = ts.get_angles_in_one_tilt(orientations, tilt_id=5)# Generate CTFs for all particles at their per-tilt positions
# coords has shape (N, n_tilts, 3) from get_position_in_all_tilts
ctfs = ts.get_ctfs_for_particles(
coords=coords_2d,
pixel_size=5.0,
)
# ctfs is a CTF object with batched parameters of shape (N, n_tilts)# Load all tilt images, optionally resampling
tilt_data = ts.load_images(
original_pixel_size=10.0, # pixel size of raw images in Å/px
target_pixel_size=5.0, # desired output pixel size
)
# tilt_data: (n_tilts, H, W) float tensor
# Just load image dimensions without reading data
ts.load_image_dimensions(original_pixel_size=10.0)subtomos = ts.reconstruct_subvolumes(
tilt_data=tilt_data, # (n_tilts, H, W)
coords=coords_2d, # (N, n_tilts, 3) from get_position_in_all_tilts
pixel_size=5.0, # Å/px
size=64, # output box size (isotropic, must be even)
oversampling=2.0, # internal oversampling for aliasing control
apply_ctf=True, # multiply by CTF in Fourier space
ctf_weighted=True, # weight by CTF² for Wiener-filter-like correction
correct_attenuation=True, # apply sinc² correction for linear interpolation
tilt_ids=None, # optionally restrict to a subset of tilts
)
# subtomos: (N, 64, 64, 64)tomogram = ts.reconstruct_full(
tilt_data=tilt_data,
output_size=(1000, 1000, 200), # (X, Y, Z) voxels
pixel_size=5.0,
oversampling=1.0,
)
# tomogram: (Z, Y, X) float tensorReconstruct the CTF modulation as a 3D volume (useful for CTF refinement):
ctf_vols = ts.reconstruct_subvolume_ctfs(
coords=coords_2d,
pixel_size=5.0,
size=64,
)
# ctf_vols: (N, 64, 64, 64)# Import per-tilt shifts and angles from external alignment files
ts.import_alignments(
shifts=torch.zeros(ts.n_tilts, 2), # (n_tilts, 2) XY shifts in Å
angles=torch.zeros(ts.n_tilts), # (n_tilts,) in-plane rotation in degrees
)# Shift a tilt's grid fields and propagate changes to motion/deformation grids
ts.apply_tilt_shift_and_propagate(tilt_id=3, shift_x=10.0, shift_y=-5.0)
# Apply a 3D shift to the volume origin and all per-tilt positions
ts.apply_tomogram_shift_3d(shift=torch.tensor([10.0, 0.0, -5.0]))Models the contrast transfer function of a cryo-EM experiment. All parameters match WarpLib's conventions and units.
from warpylib import CTF
import torch
ctf = CTF()
# Microscope parameters
ctf.voltage = 300.0 # kV
ctf.cs = 2.7 # spherical aberration (mm)
ctf.pixel_size = 1.5 # Å/px
# Defocus (underfocus positive, in µm)
ctf.defocus = 2.0
ctf.defocus_delta = 0.1 # astigmatism magnitude (µm)
ctf.defocus_angle = 45.0 # astigmatism angle (degrees)
# Amplitude contrast and phase shift
ctf.amplitude = 0.07
ctf.phase_shift = 0.0 # in units of π
# B-factor (dose weighting)
ctf.bfactor = -50.0 # Ų
# Compute 2D CTF in real-frequency half (rfft) format
ctf_2d = ctf.get_2d(size=(256, 256))
# ctf_2d: (256, 129) — use torch.fft.irfftn to go back to real space
# Compute radial 1D profile
ctf_1d = ctf.get_1d(width=128)
# ctf_1d: (128,)Batched parameters: Any scalar parameter can be replaced with a tensor to compute CTFs for multiple particles simultaneously:
ctf = CTF()
ctf.voltage = 300.0
ctf.cs = 2.7
ctf.pixel_size = 1.5
# Per-particle defocus values
ctf.defocus = torch.tensor([1.5, 2.0, 2.5, 3.0]) # (4,)
ctf_2d = ctf.get_2d(size=(256, 256))
# ctf_2d: (4, 256, 129)Options for get_2d / get_1d:
ctf.get_2d(
size=(256, 256),
amp_squared=False, # if True, return |CTF|² (power spectrum)
ignore_bfactor=False, # skip dose-weighting B-factor
ignore_scale=False, # skip amplitude scale factor
ignore_below_res=0.0, # taper CTF to zero below this resolution (Å)
ignore_transition_res=0.0, # transition width for the above taper (Å)
)Cubic B-spline interpolation grids for spatially varying parameters. Matches WarpLib's CubicGrid implementation, supporting gradients for optimization.
from warpylib import CubicGrid
import torch
# Create a 10×10×10 grid with manually specified values
values = torch.zeros(10 * 10 * 10) # flat, in einspline layout
grid = CubicGrid(
dimensions=(10, 10, 10),
values=values,
margins=(0.0, 0.0, 0.0),
centered_spacing=True,
)
# Interpolate at normalized coordinates in [0, 1]³
# Input shape: (N, 3), coords in (x, y, z) order
coords = torch.rand(100, 3)
interpolated = grid.get_interpolated(coords) # (100,)
# Evaluate on a regular grid
grid_values = grid.get_interpolated_grid(
value_grid=(20, 20, 20), # output resolution
border=(0.0, 0.0, 0.0),
)
# grid_values: (8000,) — flattened 20×20×20 in einspline layoutConvenience constructors:
# Uniform-value grid
grid = CubicGrid.from_scalar(dimensions=(5, 5, 5), value=1.0)
# Linear gradient along one axis
grid = CubicGrid(
dimensions=(10, 10, 10),
gradient_direction=0, # 0=X, 1=Y, 2=Z
value_min=0.0,
value_max=1.0,
)Resampling and slicing:
# Resize to new resolution (bilinear resampling of control points)
grid_fine = grid.resize(new_size=(20, 20, 20))
# Extract 2D slices
slice_xy = grid.get_slice_xy(index=5) # (X*Y,)
slice_xz = grid.get_slice_xz(index=5)
slice_yz = grid.get_slice_yz(index=5)
# Reduce dimensionality
grid_2d = grid.collapse_z() # average over Z
grid_1d = grid.collapse_xy() # average over X and YDevice placement:
grid_gpu = grid.to(torch.device("cuda"))XML serialization (WarpLib-compatible format):
from lxml import etree
root = etree.Element("TiltSeries")
grid.save_to_xml(root)
# Load back
grid_loaded = CubicGrid.load_from_xml(root)Functions for converting between rotation matrices and Euler angles, supporting batched inputs.
from warpylib import euler_to_matrix, matrix_to_euler
from warpylib import euler_xyz_extrinsic_to_matrix, matrix_to_euler_xyz_extrinsic
from warpylib import rotate_x, rotate_y, rotate_z
import torch
import math
# ZYZ Euler convention (cryo-EM standard: rot, tilt, psi — all in radians)
angles = torch.tensor([[0.0, math.radians(30.0), math.radians(45.0)]) # (1, 3)
R = euler_to_matrix(angles) # (1, 3, 3)
angles_back = matrix_to_euler(R) # (1, 3)
# XYZ extrinsic convention
angles_xyz = torch.tensor([[0.1, 0.2, 0.3]])
R_xyz = euler_xyz_extrinsic_to_matrix(angles_xyz) # (1, 3, 3)
# Batched: convert 1000 rotation matrices at once
R_batch = torch.rand(1000, 3, 3) # not valid rotations, just for shape demo
# In practice these would be orthonormal matrices
angles_batch = matrix_to_euler(R_batch) # (1000, 3)
# Elementary rotations (angle in radians, batched)
angle = torch.tensor([math.radians(45.0)])
Rx = rotate_x(angle) # (1, 3, 3)
Ry = rotate_y(angle)
Rz = rotate_z(angle)Callable torch modules that solve for cubic B-spline coefficients and evaluate them. Useful as differentiable interpolators within neural networks or optimization loops.
from warpylib import InterpolatingBSpline1d, InterpolatingBSpline2d
import torch
# 1D: fit a spline through 20 data points and evaluate at 100 query points
spline = InterpolatingBSpline1d()
data = torch.sin(torch.linspace(0, 2 * torch.pi, 20)) # (20,)
query = torch.linspace(0, 1, 100) # normalized coords in [0, 1]
values = spline(data, query) # (100,)
# 2D: fit through a (32, 32) grid, evaluate at arbitrary (x, y) coords
spline2d = InterpolatingBSpline2d()
data2d = torch.rand(32, 32)
coords = torch.rand(500, 2) # (N, 2), normalized to [0, 1]²
values2d = spline2d(data2d, coords) # (500,)Both modules preserve gradients, so they can be used inside torch.autograd computations.
Low-level image processing operations in real and Fourier space.
import torch
import warpylib.ops as opsimage = torch.rand(256, 256)
# Real-space resizing (Fourier-domain crop/pad then IFFT)
small = ops.resize(image, new_size=(128, 128))
large = ops.resize(image, new_size=(512, 512))
# Already in Fourier space (full FFT)
image_ft = torch.fft.fftn(image)
small_ft = ops.resize_ft(image_ft, new_size=(128, 128))
# In half-space (rFFT)
image_rft = torch.fft.rfftn(image)
small_rft = ops.resize_rft(image_rft, new_size=(128, 128))
# Non-integer scaling
image_2x = ops.rescale(image, scale_factor=2.0)# Zero-mean, unit-variance — optionally restricted to a circular region
image_norm = ops.norm(image, dimensionality=2)
image_norm = ops.norm(image, dimensionality=2, diameter=200, mode="inner")
# In Fourier space
image_rft = torch.fft.rfftn(image)
image_rft_norm = ops.norm_rft(image_rft, dimensionality=2)Linear interpolation suppresses high-frequency signal by a sinc² envelope. This corrects for that:
# 2D correction map
correction = ops.get_sinc2_correction(size=(256, 256), oversampling=2.0)
image_corrected = torch.fft.irfftn(
torch.fft.rfftn(image) * ops.get_sinc2_correction_rft((256, 256), oversampling=2.0)
)
# 3D
correction_3d = ops.get_sinc2_correction(size=(64, 64, 64))# Rectangular mask (zeros outside a centred box)
masked = ops.mask_rectangular(volume, size=(50, 50, 50))
# Fit and remove a plane from a tomogram slice
plane_normal, plane_offset = ops.fit_plane(slice_2d)
flat_slice = ops.subtract_plane(slice_2d)Convenience function applying normalization, masking, and optional bandpass to a tilt stack:
tilt_data = torch.rand(41, 3712, 3712) # (n_tilts, H, W)
processed = ops.preprocess_tilt_data(
tilt_data,
normalize=True,
mask_radius=0.9, # fraction of half-width to use as circular mask
bandpass_highpass=500.0, # high-pass resolution cutoff in Å (None to skip)
bandpass_lowpass=10.0, # low-pass resolution cutoff in Å (None to skip)
pixel_size=5.0, # needed when bandpass is specified
)Method binding pattern: TiltSeries methods are defined in focused submodules (positions.py, ctf.py, etc.) and attached to the class in __init__.py. This keeps related code together while exposing a unified API.
Coordinate conventions: Volume coordinates follow WarpLib's (X, Y, Z) convention in Ångströms. PyTorch tensors internally use (Z, Y, X) indexing; the library handles conversion transparently.
Grid layout: CubicGrid values are stored in einspline layout [(z*Y+y)*X+x] to match WarpLib's serialization format.
Batching: CTF parameters, Euler angles, and coordinate transforms all support arbitrary leading batch dimensions via PyTorch broadcasting.
Gradients: CubicGrid, InterpolatingBSpline1d/2d, CTF calculations, and Fourier-space operations all preserve gradients for use in differentiable optimization.
pytest tests/MIT