From 559abea9af21ba9c83bcd9472c95f739a3985d40 Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Fri, 12 Dec 2025 10:11:17 -0800 Subject: [PATCH 1/2] fixing non-default attr T usage --- src/tracksdata/graph/_rustworkx_graph.py | 10 +++++----- src/tracksdata/graph/_sql_graph.py | 12 ++++++------ src/tracksdata/graph/_test/test_graph_backends.py | 7 +++++++ src/tracksdata/graph/filters/_spatial_filter.py | 2 +- src/tracksdata/metrics/_visualize.py | 4 ++-- src/tracksdata/nodes/_random.py | 3 ++- 6 files changed, 23 insertions(+), 15 deletions(-) diff --git a/src/tracksdata/graph/_rustworkx_graph.py b/src/tracksdata/graph/_rustworkx_graph.py index 5d0a2a80..d60df86a 100644 --- a/src/tracksdata/graph/_rustworkx_graph.py +++ b/src/tracksdata/graph/_rustworkx_graph.py @@ -417,11 +417,11 @@ def add_node( if validate_keys: self._validate_attributes(attrs, self.node_attr_keys(), "node") - if "t" not in attrs: - raise ValueError(f"Node attributes must have a 't' key. Got {attrs.keys()}") + if DEFAULT_ATTR_KEYS.T not in attrs: + raise ValueError(f"Node attributes must have a '{DEFAULT_ATTR_KEYS.T}' key. Got {attrs.keys()}") node_id = self.rx_graph.add_node(attrs) - self._time_to_nodes.setdefault(attrs["t"], []).append(node_id) + self._time_to_nodes.setdefault(attrs[DEFAULT_ATTR_KEYS.T], []).append(node_id) self.node_added.emit_fast(node_id) return node_id @@ -449,7 +449,7 @@ def bulk_add_nodes(self, nodes: list[dict[str, Any]], indices: list[int] | None node_indices = list(self.rx_graph.add_nodes_from(nodes)) for node, index in zip(nodes, node_indices, strict=True): - self._time_to_nodes.setdefault(node["t"], []).append(index) + self._time_to_nodes.setdefault(node[DEFAULT_ATTR_KEYS.T], []).append(index) # checking if it has connections to reduce overhead if is_signal_on(self.node_added): @@ -481,7 +481,7 @@ def remove_node(self, node_id: int) -> None: self.node_removed.emit_fast(node_id) # Get the time value before removing the node - t = self.rx_graph[node_id]["t"] + t = self.rx_graph[node_id][DEFAULT_ATTR_KEYS.T] # Remove the node from the graph (this also removes all connected edges) self.rx_graph.remove_node(node_id) diff --git a/src/tracksdata/graph/_sql_graph.py b/src/tracksdata/graph/_sql_graph.py index 31ece3d0..370e5c6c 100644 --- a/src/tracksdata/graph/_sql_graph.py +++ b/src/tracksdata/graph/_sql_graph.py @@ -637,10 +637,10 @@ def add_node( if validate_keys: self._validate_attributes(attrs, self.node_attr_keys(), "node") - if "t" not in attrs: - raise ValueError(f"Node attributes must have a 't' key. Got {attrs.keys()}") + if DEFAULT_ATTR_KEYS.T not in attrs: + raise ValueError(f"Node attributes must have a '{DEFAULT_ATTR_KEYS.T}' key. Got {attrs.keys()}") - time = attrs["t"] + time = attrs[DEFAULT_ATTR_KEYS.T] if index is None: default_node_id = (time * self.node_id_time_multiplier) - 1 @@ -712,7 +712,7 @@ def bulk_add_nodes( node_ids = [] for i, node in enumerate(nodes): - time = node["t"] + time = node[DEFAULT_ATTR_KEYS.T] if indices is None: default_node_id = (time * self.node_id_time_multiplier) - 1 @@ -1500,8 +1500,8 @@ def update_node_attrs( attrs: dict[str, Any], node_ids: Sequence[int] | None = None, ) -> None: - if "t" in attrs: - raise ValueError("Node attribute 't' cannot be updated.") + if DEFAULT_ATTR_KEYS.T in attrs: + raise ValueError(f"Node attribute '{DEFAULT_ATTR_KEYS.T}' cannot be updated.") self._update_table(self.Node, node_ids, DEFAULT_ATTR_KEYS.NODE_ID, attrs) diff --git a/src/tracksdata/graph/_test/test_graph_backends.py b/src/tracksdata/graph/_test/test_graph_backends.py index 89237645..5025d571 100644 --- a/src/tracksdata/graph/_test/test_graph_backends.py +++ b/src/tracksdata/graph/_test/test_graph_backends.py @@ -1641,6 +1641,13 @@ def test_summary(graph_backend: BaseGraph) -> None: assert "Number of edges" in summary +def test_changing_default_attr_keys(graph_backend: BaseGraph) -> None: + DEFAULT_ATTR_KEYS.T = "frame" + graph_backend.add_node({"frame": 0}) + node_attrs = graph_backend.node_attrs() + assert "frame" in node_attrs.columns + + def test_spatial_filter_basic(graph_backend: BaseGraph) -> None: graph_backend.add_node_attr_key("x", 0.0) graph_backend.add_node_attr_key("y", 0.0) diff --git a/src/tracksdata/graph/filters/_spatial_filter.py b/src/tracksdata/graph/filters/_spatial_filter.py index b34c0633..91ed1d12 100644 --- a/src/tracksdata/graph/filters/_spatial_filter.py +++ b/src/tracksdata/graph/filters/_spatial_filter.py @@ -145,7 +145,7 @@ def __init__( attr_keys: list[str] | None = None, ) -> None: if attr_keys is None: - attr_keys = ["t", "z", "y", "x"] + attr_keys = [DEFAULT_ATTR_KEYS.T, "z", "y", "x"] valid_keys = set(graph.node_attr_keys()) attr_keys = list(filter(lambda x: x in valid_keys, attr_keys)) diff --git a/src/tracksdata/metrics/_visualize.py b/src/tracksdata/metrics/_visualize.py index 5349c8a3..eb16c279 100644 --- a/src/tracksdata/metrics/_visualize.py +++ b/src/tracksdata/metrics/_visualize.py @@ -136,9 +136,9 @@ def visualize_matches( viewer = napari.Viewer() if "z" in input_graph.node_attr_keys(): - pos = ["t", "z", "y", "x"] + pos = [DEFAULT_ATTR_KEYS.T, "z", "y", "x"] else: - pos = ["t", "y", "x"] + pos = [DEFAULT_ATTR_KEYS.T, "y", "x"] node_attrs = input_graph.node_attrs() ref_node_attrs = ref_graph.node_attrs() diff --git a/src/tracksdata/nodes/_random.py b/src/tracksdata/nodes/_random.py index 01465ee7..6fdc5fbc 100644 --- a/src/tracksdata/nodes/_random.py +++ b/src/tracksdata/nodes/_random.py @@ -3,6 +3,7 @@ import numpy as np +from tracksdata.constants import DEFAULT_ATTR_KEYS from tracksdata.graph._base_graph import BaseGraph from tracksdata.nodes._base_nodes import BaseNodesOperator from tracksdata.utils._multiprocessing import multiprocessing_apply @@ -169,4 +170,4 @@ def _nodes_per_time( size=(n_nodes_at_t, len(self.spatial_cols)), ).tolist() - return [{"t": t, **dict(zip(self.spatial_cols, c, strict=True)), **kwargs} for c in coords] + return [{DEFAULT_ATTR_KEYS.T: t, **dict(zip(self.spatial_cols, c, strict=True)), **kwargs} for c in coords] From 68dade03ff7fa7de59f9c328c4cbd8f10d3462ce Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Fri, 12 Dec 2025 11:34:20 -0800 Subject: [PATCH 2/2] wip fixing sqlite class --- src/tracksdata/graph/_sql_graph.py | 52 +++++++++++++------ .../graph/_test/test_graph_backends.py | 15 +++++- 2 files changed, 48 insertions(+), 19 deletions(-) diff --git a/src/tracksdata/graph/_sql_graph.py b/src/tracksdata/graph/_sql_graph.py index 370e5c6c..590bf464 100644 --- a/src/tracksdata/graph/_sql_graph.py +++ b/src/tracksdata/graph/_sql_graph.py @@ -507,29 +507,46 @@ class Base(DeclarativeBase): self.Base = Base return - class Node(Base): - __tablename__ = "Node" - - # Use node_id as sole primary key for simpler updates - node_id = sa.Column(sa.BigInteger, primary_key=True, unique=True) - - # Add t as a regular column - # NOTE might want to use as index for fast time-based queries - t = sa.Column(sa.Integer, nullable=False) + Node = type( + "Node", + (Base,), + { + "__tablename__": "Node", + # Use node_id as sole primary key for simpler updates + DEFAULT_ATTR_KEYS.NODE_ID: sa.Column(sa.BigInteger, primary_key=True, unique=True), + # Add t as a regular column + # NOTE might want to use as index for fast time-based queries + DEFAULT_ATTR_KEYS.T: sa.Column(sa.Integer, nullable=False), + }, + ) node_tb_name = Node.__tablename__ - class Edge(Base): - __tablename__ = "Edge" - edge_id = sa.Column(sa.Integer, primary_key=True, unique=True, autoincrement=True) - source_id = sa.Column(sa.BigInteger, sa.ForeignKey(f"{node_tb_name}.node_id"), index=True) - target_id = sa.Column(sa.BigInteger, sa.ForeignKey(f"{node_tb_name}.node_id"), index=True) + Edge = type( + "Edge", + (Base,), + { + "__tablename__": "Edge", + DEFAULT_ATTR_KEYS.EDGE_ID: sa.Column(sa.Integer, primary_key=True, unique=True, autoincrement=True), + DEFAULT_ATTR_KEYS.EDGE_SOURCE: sa.Column( + sa.BigInteger, sa.ForeignKey(f"{node_tb_name}.{DEFAULT_ATTR_KEYS.NODE_ID}"), index=True + ), + DEFAULT_ATTR_KEYS.EDGE_TARGET: sa.Column( + sa.BigInteger, sa.ForeignKey(f"{node_tb_name}.{DEFAULT_ATTR_KEYS.NODE_ID}"), index=True + ), + }, + ) class Overlap(Base): __tablename__ = "Overlap" + overlap_id = sa.Column(sa.Integer, primary_key=True, unique=True, autoincrement=True) - source_id = sa.Column(sa.BigInteger, sa.ForeignKey(f"{node_tb_name}.node_id"), index=True) - target_id = sa.Column(sa.BigInteger, sa.ForeignKey(f"{node_tb_name}.node_id"), index=True) + source_id = sa.Column( + sa.BigInteger, sa.ForeignKey(f"{node_tb_name}.{DEFAULT_ATTR_KEYS.NODE_ID}"), index=True + ) + target_id = sa.Column( + sa.BigInteger, sa.ForeignKey(f"{node_tb_name}.{DEFAULT_ATTR_KEYS.NODE_ID}"), index=True + ) class Metadata(Base): __tablename__ = "Metadata" @@ -570,8 +587,9 @@ def _update_max_id_per_time(self) -> None: point and updates the internal cache to ensure newly created nodes have unique IDs. """ + t_column = getattr(self.Node, DEFAULT_ATTR_KEYS.T) with Session(self._engine) as session: - stmt = sa.select(self.Node.t, sa.func.max(self.Node.node_id)).group_by(self.Node.t) + stmt = sa.select(t_column, sa.func.max(getattr(self.Node, DEFAULT_ATTR_KEYS.NODE_ID))).group_by(t_column) self._max_id_per_time = {int(time): int(max_id) for time, max_id in session.execute(stmt).all()} def filter( diff --git a/src/tracksdata/graph/_test/test_graph_backends.py b/src/tracksdata/graph/_test/test_graph_backends.py index 5025d571..11fa269a 100644 --- a/src/tracksdata/graph/_test/test_graph_backends.py +++ b/src/tracksdata/graph/_test/test_graph_backends.py @@ -1643,10 +1643,21 @@ def test_summary(graph_backend: BaseGraph) -> None: def test_changing_default_attr_keys(graph_backend: BaseGraph) -> None: DEFAULT_ATTR_KEYS.T = "frame" - graph_backend.add_node({"frame": 0}) - node_attrs = graph_backend.node_attrs() + if isinstance(graph_backend, SQLGraph): + kwargs = {"drivername": "sqlite", "database": ":memory:"} + else: + kwargs = {} + + # must be a new graph because `graph_backend` already has a `t` attribute initialized + new_graph = type(graph_backend)(**kwargs) + new_graph.add_node({"frame": 0}) + node_attrs = new_graph.node_attrs() assert "frame" in node_attrs.columns + # this must be undone otherwise other tests will fail + DEFAULT_ATTR_KEYS.T = "t" + assert DEFAULT_ATTR_KEYS.T == "t" + def test_spatial_filter_basic(graph_backend: BaseGraph) -> None: graph_backend.add_node_attr_key("x", 0.0)