diff --git a/manim/scene/scene.py b/manim/scene/scene.py index 44282df581..c8c6bc9362 100644 --- a/manim/scene/scene.py +++ b/manim/scene/scene.py @@ -218,6 +218,37 @@ def remove(self, *mobjects_to_remove: OpenGLMobject) -> Self: self.mobjects = list_difference_update(self.mobjects, mob.get_family()) return self + def _replace( + self, + parent_mobjects: list[OpenGLMobject], + mobject: OpenGLMobject, + *replacements: OpenGLMobject, + ) -> list[OpenGLMobject]: + """Replace one Mobject in a parent Mobject with one or more other Mobjects, + preserving draw order. + + Parameters + ---------- + pareparent_mobjectsnt + The parent mobjects list where the mobject is going to be replaced + mobject + The mobject to be replaced. Must be present in the scene. + replacements + One or more Mobjects which must not already be in the scene. + """ + if mobject in parent_mobjects: + index = parent_mobjects.index(mobject) + parent_mobjects = [ + *parent_mobjects[:index], + *[ + replacement + for replacement in replacements + if replacement not in parent_mobjects + ], + *parent_mobjects[index + 1 :], + ] + return parent_mobjects + def replace(self, mobject: OpenGLMobject, *replacements: OpenGLMobject): """Replace one Mobject in the scene with one or more other Mobjects, preserving draw order. @@ -234,17 +265,13 @@ def replace(self, mobject: OpenGLMobject, *replacements: OpenGLMobject): One or more Mobjects which must not already be in the scene. """ - if mobject in self.mobjects: - index = self.mobjects.index(mobject) - self.mobjects = [ - *self.mobjects[:index], - *[ - replacement - for replacement in replacements - if replacement not in self.mobjects - ], - *self.mobjects[index + 1 :], - ] + for ancestor in mobject.get_ancestors(): + ancestor.submobjects = self._replace( + ancestor.submobjects, mobject, *replacements + ) + + self.mobjects = self._replace(self.mobjects, mobject, *replacements) + return self def add_updater(self, func: Callable[[float], None]) -> None: diff --git a/tests/module/animation/test_transform.py b/tests/module/animation/test_transform.py index 1e452857fd..6600c3632e 100644 --- a/tests/module/animation/test_transform.py +++ b/tests/module/animation/test_transform.py @@ -1,6 +1,6 @@ from __future__ import annotations -from manim import Circle, Manager, ReplacementTransform, Scene, Square, VGroup +from manim import Circle, Manager, ReplacementTransform, Scene, Square, Triangle, VGroup def test_no_duplicate_references(): @@ -27,3 +27,23 @@ def test_duplicate_references_in_group(): submobs = vg.submobjects assert len(submobs) == 1 assert submobs[0] is sq + + +def test_duplicate_references_in_multiple_groups(): + manager = Manager(Scene) + scene = manager.scene + c = Circle() + sq = Square() + tr = Triangle() + vg_1 = VGroup(c, sq) + vg_2 = VGroup(c, tr) + scene.add(vg_1, vg_2) + + scene.play(ReplacementTransform(c, sq)) + submobs_1 = vg_1.submobjects + submobs_2 = vg_2.submobjects + assert len(submobs_1) == 1 + assert submobs_1[0] is sq + assert len(submobs_2) == 2 + assert submobs_2[0] is sq + assert submobs_2[1] is tr