Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 38 additions & 11 deletions manim/scene/scene.py
Original file line number Diff line number Diff line change
Expand Up @@ -218,6 +218,37 @@ def remove(self, *mobjects_to_remove: OpenGLMobject) -> Self:
self.mobjects = list_difference_update(self.mobjects, mob.get_family())
return self

def _replace(
self,
parent_mobjects: list[OpenGLMobject],
mobject: OpenGLMobject,
*replacements: OpenGLMobject,
) -> list[OpenGLMobject]:
"""Replace one Mobject in a parent Mobject with one or more other Mobjects,
preserving draw order.

Parameters
----------
pareparent_mobjectsnt
The parent mobjects list where the mobject is going to be replaced
mobject
The mobject to be replaced. Must be present in the scene.
replacements
One or more Mobjects which must not already be in the scene.
"""
if mobject in parent_mobjects:
index = parent_mobjects.index(mobject)
parent_mobjects = [
*parent_mobjects[:index],
*[
replacement
for replacement in replacements
if replacement not in parent_mobjects
],
*parent_mobjects[index + 1 :],
]
return parent_mobjects

def replace(self, mobject: OpenGLMobject, *replacements: OpenGLMobject):
"""Replace one Mobject in the scene with one or more other Mobjects,
preserving draw order.
Expand All @@ -234,17 +265,13 @@ def replace(self, mobject: OpenGLMobject, *replacements: OpenGLMobject):
One or more Mobjects which must not already be in the scene.

"""
if mobject in self.mobjects:
index = self.mobjects.index(mobject)
self.mobjects = [
*self.mobjects[:index],
*[
replacement
for replacement in replacements
if replacement not in self.mobjects
],
*self.mobjects[index + 1 :],
]
for ancestor in mobject.get_ancestors():
ancestor.submobjects = self._replace(
ancestor.submobjects, mobject, *replacements
)

self.mobjects = self._replace(self.mobjects, mobject, *replacements)

return self

def add_updater(self, func: Callable[[float], None]) -> None:
Expand Down
22 changes: 21 additions & 1 deletion tests/module/animation/test_transform.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from __future__ import annotations

from manim import Circle, Manager, ReplacementTransform, Scene, Square, VGroup
from manim import Circle, Manager, ReplacementTransform, Scene, Square, Triangle, VGroup


def test_no_duplicate_references():
Expand All @@ -27,3 +27,23 @@ def test_duplicate_references_in_group():
submobs = vg.submobjects
assert len(submobs) == 1
assert submobs[0] is sq


def test_duplicate_references_in_multiple_groups():
manager = Manager(Scene)
scene = manager.scene
c = Circle()
sq = Square()
tr = Triangle()
vg_1 = VGroup(c, sq)
vg_2 = VGroup(c, tr)
scene.add(vg_1, vg_2)

scene.play(ReplacementTransform(c, sq))
submobs_1 = vg_1.submobjects
submobs_2 = vg_2.submobjects
assert len(submobs_1) == 1
assert submobs_1[0] is sq
assert len(submobs_2) == 2
assert submobs_2[0] is sq
assert submobs_2[1] is tr