Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,25 @@ jobs:
with:
token: ${{ secrets.CODECOV_TOKEN }}

benchmarks:
runs-on: ubuntu-latest
env:
UV_NO_SYNC: "1"
steps:
- uses: actions/checkout@v4
- uses: astral-sh/setup-uv@v6
with:
python-version: "3.13"
enable-cache: true

- name: install
run: uv sync --no-dev --group test-codspeed

- name: Run benchmarks
uses: CodSpeedHQ/action@v3
with:
run: uv run pytest -W ignore --codspeed -v --color=yes

deploy:
name: Deploy
needs: test
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ repos:
- repo: https://github.com/astral-sh/ruff-pre-commit
rev: v0.12.2
hooks:
- id: ruff
- id: ruff-check
args: [--fix, --unsafe-fixes]
- id: ruff-format

Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
[![Python Version](https://img.shields.io/pypi/pyversions/spatial-graph.svg?color=green)](https://python.org)
[![CI](https://github.com/funkelab/spatial_graph/actions/workflows/ci.yaml/badge.svg)](https://github.com/funkelab/spatial_graph/actions/workflows/ci.yaml)
[![codecov](https://codecov.io/gh/funkelab/spatial_graph/branch/main/graph/badge.svg)](https://codecov.io/gh/funkelab/spatial_graph)
[![CodSpeed](https://img.shields.io/endpoint?url=https://codspeed.io/badge.json)](https://codspeed.io/funkelab/spatial_graph)

`spatial_graph` provides a data structure for directed and undirected graphs,
where each node has an nD position (in time or space).
Expand Down
2 changes: 2 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,13 @@ dependencies = ["witty>=v0.2.1", "CT3>=3.3.3", "numpy", "setuptools>=75.8.0"]

[dependency-groups]
test = ["pytest>=8.3.5", "pytest-cov>=6.1.1"]
test-codspeed = [{ include-group = "test" }, "pytest-codspeed >=3.2.0"]
dev = [
{ include-group = "test" },
"ipython>=8.18.1",
"mypy>=1.15.0",
"pre-commit>=4.2.0",
"pytest-benchmark>=5.1.0", # specifically excluded from test group for ci
"ruff>=0.11.10",
]
docs = [
Expand Down
84 changes: 84 additions & 0 deletions tests/test_bench.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import sys

import numpy as np
import pytest

from spatial_graph import SpatialGraph

# either run this file directly or with pytest --codspeed
if all(x not in {"--codspeed", "tests/test_bench.py"} for x in sys.argv):
pytest.skip(
"use 'pytest tests/test_bench.py' to run benchmark", allow_module_level=True
)


def _make_graph(
ndims=3,
node_dtype="uint64",
node_attr_dtypes=None,
edge_attr_dtypes=None,
directed=False,
n_nodes=100_000,
):
"""Helper to create a SpatialGraph instance with default parameters."""
if node_attr_dtypes is None:
node_attr_dtypes = {"position": "double[3]"}
if edge_attr_dtypes is None:
edge_attr_dtypes = {"score": "float32"}

graph = SpatialGraph(
ndims=ndims,
node_dtype=node_dtype,
node_attr_dtypes=node_attr_dtypes,
edge_attr_dtypes=edge_attr_dtypes,
position_attr="position",
directed=directed,
)
nodes = np.arange(n_nodes, dtype="uint64")
positions = np.random.random((n_nodes, ndims))
graph.add_nodes(nodes, position=positions)

return graph


@pytest.mark.parametrize("num_queries", [100])
@pytest.mark.parametrize("k", [1000, 10000])
@pytest.mark.parametrize("n_nodes", [100_000, 1_000_000])
def test_bench_query_nearest_nodes(n_nodes: int, k: int, num_queries: int, benchmark):
"""Benchmark query_nearest_nodes."""
graph = _make_graph(n_nodes=n_nodes)
query_points = np.random.random((num_queries, 3))

def _run():
for i in range(num_queries):
# Query nearest nodes
closest, distances = graph.query_nearest_nodes(
query_points[i], k=k, return_distances=True
)
positions = graph.node_attrs[closest].position
return closest, distances, positions

closest, distances, positions = benchmark(_run)

# Verify results
assert len(distances) == len(closest)
assert positions.shape[1] == 3


@pytest.mark.parametrize("n_nodes", [100_000, 1_000_000])
def test_roi_query_performance(n_nodes, benchmark):
"""Benchmark ROI (region of interest) queries."""
large_graph = _make_graph(n_nodes=n_nodes)
# Define a ROI that should contain a reasonable number of nodes
roi = np.array([[0.25, 0.25, 0.25], [0.75, 0.75, 0.75]])

nodes_in_roi = benchmark(lambda: large_graph.query_nodes_in_roi(roi))

# Verify results
assert len(nodes_in_roi) > 0
assert len(nodes_in_roi) < n_nodes # Should be subset

# Verify nodes are actually in ROI
positions = large_graph.node_attrs[nodes_in_roi].position
assert np.all(positions >= roi[0])
assert np.all(positions <= roi[1])
Loading