From a707a04399bad7622e2e8131e4b798ac4037a6dc Mon Sep 17 00:00:00 2001 From: Jordao Bragantini Date: Fri, 12 Dec 2025 13:46:30 -0800 Subject: [PATCH] adding edge_list function --- src/tracksdata/graph/_base_graph.py | 6 ++++++ src/tracksdata/graph/_graph_view.py | 2 +- src/tracksdata/graph/_mapped_graph_mixin.py | 6 ++++++ src/tracksdata/graph/_rustworkx_graph.py | 8 +++++++- src/tracksdata/graph/_sql_graph.py | 10 ++++++++++ src/tracksdata/graph/_test/test_subgraph.py | 17 +++++++++++++++++ 6 files changed, 47 insertions(+), 2 deletions(-) diff --git a/src/tracksdata/graph/_base_graph.py b/src/tracksdata/graph/_base_graph.py index c43a839d..31f2bec7 100644 --- a/src/tracksdata/graph/_base_graph.py +++ b/src/tracksdata/graph/_base_graph.py @@ -1734,3 +1734,9 @@ def remove_metadata(self, key: str) -> None: graph.remove_metadata("shape") ``` """ + + @abc.abstractmethod + def edge_list(self) -> list[list[int, int]]: + """ + Get the edge list of the graph. + """ diff --git a/src/tracksdata/graph/_graph_view.py b/src/tracksdata/graph/_graph_view.py index 402b9cd8..e5d9aaf9 100644 --- a/src/tracksdata/graph/_graph_view.py +++ b/src/tracksdata/graph/_graph_view.py @@ -15,7 +15,7 @@ from tracksdata.utils._signal import is_signal_on -class GraphView(RustWorkXGraph, MappedGraphMixin): +class GraphView(MappedGraphMixin, RustWorkXGraph): """ A filtered view of a graph that maintains bidirectional mapping to the root graph. diff --git a/src/tracksdata/graph/_mapped_graph_mixin.py b/src/tracksdata/graph/_mapped_graph_mixin.py index ecd65cc9..da2fdd16 100644 --- a/src/tracksdata/graph/_mapped_graph_mixin.py +++ b/src/tracksdata/graph/_mapped_graph_mixin.py @@ -245,3 +245,9 @@ def _get_local_ids(self) -> list[int]: List of all local node IDs """ return list(self._local_to_external.keys()) + + def edge_list(self) -> list[list[int, int]]: + """ + Get the edge list of the graph. + """ + return self._vectorized_map_to_external(self.rx_graph.edge_list()).tolist() diff --git a/src/tracksdata/graph/_rustworkx_graph.py b/src/tracksdata/graph/_rustworkx_graph.py index 5d0a2a80..23615faf 100644 --- a/src/tracksdata/graph/_rustworkx_graph.py +++ b/src/tracksdata/graph/_rustworkx_graph.py @@ -1386,8 +1386,14 @@ def update_metadata(self, **kwargs) -> None: def remove_metadata(self, key: str) -> None: self._graph.attrs.pop(key, None) + def edge_list(self) -> list[list[int, int]]: + """ + Get the edge list of the graph. + """ + return self.rx_graph.edge_list() + -class IndexedRXGraph(RustWorkXGraph, MappedGraphMixin): +class IndexedRXGraph(MappedGraphMixin, RustWorkXGraph): """ A graph with arbitrary node indices. diff --git a/src/tracksdata/graph/_sql_graph.py b/src/tracksdata/graph/_sql_graph.py index 31ece3d0..ade87d7c 100644 --- a/src/tracksdata/graph/_sql_graph.py +++ b/src/tracksdata/graph/_sql_graph.py @@ -1749,3 +1749,13 @@ def remove_metadata(self, key: str) -> None: with Session(self._engine) as session: session.query(self.Metadata).filter(self.Metadata.key == key).delete() session.commit() + + def edge_list(self) -> list[list[int, int]]: + """ + Get the edge list of the graph. + """ + with Session(self._engine) as session: + return [ + [source_id, target_id] + for source_id, target_id in session.query(self.Edge.source_id, self.Edge.target_id).all() + ] diff --git a/src/tracksdata/graph/_test/test_subgraph.py b/src/tracksdata/graph/_test/test_subgraph.py index 66df7dfc..efef8baa 100644 --- a/src/tracksdata/graph/_test/test_subgraph.py +++ b/src/tracksdata/graph/_test/test_subgraph.py @@ -1254,3 +1254,20 @@ def test_picking_graph_mappings(graph_backend: BaseGraph, use_subgraph: bool) -> if isinstance(unpickled_graph, MappedGraphMixin): _evaluate_bidict_maps(unpickled_graph._local_to_external, unpickled_graph._external_to_local) + + +@parametrize_subgraph_tests +def test_edge_list(graph_backend: BaseGraph, use_subgraph: bool) -> None: + """Test edge_list functionality on both original graphs and subgraphs.""" + graph_with_data = create_test_graph(graph_backend, use_subgraph) + edge_list = set(map(tuple, graph_with_data.edge_list())) + expected_edge_list = set( + map( + tuple, + graph_with_data.edge_attrs() + .select([DEFAULT_ATTR_KEYS.EDGE_SOURCE, DEFAULT_ATTR_KEYS.EDGE_TARGET]) + .to_numpy() + .tolist(), + ) + ) + assert edge_list == expected_edge_list