Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions src/tracksdata/graph/_base_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1734,3 +1734,9 @@ def remove_metadata(self, key: str) -> None:
graph.remove_metadata("shape")
```
"""

@abc.abstractmethod
def edge_list(self) -> list[list[int, int]]:
"""
Get the edge list of the graph.
"""
2 changes: 1 addition & 1 deletion src/tracksdata/graph/_graph_view.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
from tracksdata.utils._signal import is_signal_on


class GraphView(RustWorkXGraph, MappedGraphMixin):
class GraphView(MappedGraphMixin, RustWorkXGraph):
"""
A filtered view of a graph that maintains bidirectional mapping to the root graph.

Expand Down
6 changes: 6 additions & 0 deletions src/tracksdata/graph/_mapped_graph_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,3 +245,9 @@ def _get_local_ids(self) -> list[int]:
List of all local node IDs
"""
return list(self._local_to_external.keys())

def edge_list(self) -> list[list[int, int]]:
"""
Get the edge list of the graph.
"""
return self._vectorized_map_to_external(self.rx_graph.edge_list()).tolist()
8 changes: 7 additions & 1 deletion src/tracksdata/graph/_rustworkx_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1386,8 +1386,14 @@ def update_metadata(self, **kwargs) -> None:
def remove_metadata(self, key: str) -> None:
self._graph.attrs.pop(key, None)

def edge_list(self) -> list[list[int, int]]:
"""
Get the edge list of the graph.
"""
return self.rx_graph.edge_list()


class IndexedRXGraph(RustWorkXGraph, MappedGraphMixin):
class IndexedRXGraph(MappedGraphMixin, RustWorkXGraph):
"""
A graph with arbitrary node indices.

Expand Down
10 changes: 10 additions & 0 deletions src/tracksdata/graph/_sql_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1749,3 +1749,13 @@ def remove_metadata(self, key: str) -> None:
with Session(self._engine) as session:
session.query(self.Metadata).filter(self.Metadata.key == key).delete()
session.commit()

def edge_list(self) -> list[list[int, int]]:
"""
Get the edge list of the graph.
"""
with Session(self._engine) as session:
return [
[source_id, target_id]
for source_id, target_id in session.query(self.Edge.source_id, self.Edge.target_id).all()
]
17 changes: 17 additions & 0 deletions src/tracksdata/graph/_test/test_subgraph.py
Original file line number Diff line number Diff line change
Expand Up @@ -1254,3 +1254,20 @@ def test_picking_graph_mappings(graph_backend: BaseGraph, use_subgraph: bool) ->

if isinstance(unpickled_graph, MappedGraphMixin):
_evaluate_bidict_maps(unpickled_graph._local_to_external, unpickled_graph._external_to_local)


@parametrize_subgraph_tests
def test_edge_list(graph_backend: BaseGraph, use_subgraph: bool) -> None:
"""Test edge_list functionality on both original graphs and subgraphs."""
graph_with_data = create_test_graph(graph_backend, use_subgraph)
edge_list = set(map(tuple, graph_with_data.edge_list()))
expected_edge_list = set(
map(
tuple,
graph_with_data.edge_attrs()
.select([DEFAULT_ATTR_KEYS.EDGE_SOURCE, DEFAULT_ATTR_KEYS.EDGE_TARGET])
.to_numpy()
.tolist(),
)
)
assert edge_list == expected_edge_list
Loading