Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions manim/mobject/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"DiGraph",
]

import contextlib
import itertools as it
from collections.abc import Hashable, Iterable, Sequence
from copy import copy
Expand All @@ -18,7 +19,7 @@
if TYPE_CHECKING:
from typing import TypeAlias

from manim.scene.scene import Scene
from manim.animation.scene_buffer import SceneBuffer
from manim.typing import Point3D, Point3DLike

NxGraph: TypeAlias = nx.classes.graph.Graph | nx.classes.digraph.DiGraph
Expand Down Expand Up @@ -908,10 +909,12 @@ def _add_vertices_animation(self, *args, anim_args=None, **kwargs):

vertex_mobjects = self._create_vertices(*args, **kwargs)

def on_finish(scene: Scene):
def on_finish(buf: SceneBuffer | None):
for v in vertex_mobjects:
scene.remove(v[-1])
self._add_created_vertex(*v)
if buf is not None:
with contextlib.suppress(Exception):
buf.replace(v[-1])

return AnimationGroup(
*(animation(v[-1], **anim_args) for v in vertex_mobjects),
Expand Down Expand Up @@ -1565,6 +1568,9 @@ def update_edges(self, graph):
def __repr__(self: Graph) -> str:
return f"Undirected graph on {len(self.vertices)} vertices and {len(self.edges)} edges"

def __str__(self: Graph) -> str:
return self.__repr__()


class DiGraph(GenericGraph):
"""A directed graph.
Expand Down Expand Up @@ -1782,3 +1788,6 @@ def update_edges(self, graph):

def __repr__(self: DiGraph) -> str:
return f"Directed graph on {len(self.vertices)} vertices and {len(self.edges)} edges"

def __str__(self: DiGraph) -> str:
return self.__repr__()
13 changes: 7 additions & 6 deletions tests/module/mobject/test_graph.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import numpy as np
import pytest

from manim import DiGraph, Graph, LabeledLine, Manager, Scene, Text, tempconfig
Expand Down Expand Up @@ -134,9 +135,9 @@ def test_custom_graph_layout_dict():
[1, 2, 3], [(1, 2), (2, 3)], layout={1: [0, 0, 0], 2: [1, 1, 0], 3: [1, -1, 0]}
)
assert str(G) == "Undirected graph on 3 vertices and 2 edges"
assert all(G.vertices[1].get_center() == [0, 0, 0])
assert all(G.vertices[2].get_center() == [1, 1, 0])
assert all(G.vertices[3].get_center() == [1, -1, 0])
np.testing.assert_allclose(G.vertices[1].get_center(), [0, 0, 0])
np.testing.assert_allclose(G.vertices[2].get_center(), [1, 1, 0])
np.testing.assert_allclose(G.vertices[3].get_center(), [1, -1, 0])


def test_graph_layouts():
Expand Down Expand Up @@ -165,9 +166,9 @@ def layout_func(graph, scale):
return {vertex: [vertex, vertex, 0] for vertex in graph}

G = Graph([1, 2, 3], [(1, 2), (2, 3)], layout=layout_func)
assert all(G.vertices[1].get_center() == [1, 1, 0])
assert all(G.vertices[2].get_center() == [2, 2, 0])
assert all(G.vertices[3].get_center() == [3, 3, 0])
np.testing.assert_allclose(G.vertices[1].get_center(), [1, 1, 0])
np.testing.assert_allclose(G.vertices[2].get_center(), [2, 2, 0])
np.testing.assert_allclose(G.vertices[3].get_center(), [3, 3, 0])


def test_custom_graph_layout_function_with_kwargs():
Expand Down