Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
177 changes: 169 additions & 8 deletions manim/mobject/value_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,10 @@

from __future__ import annotations

__all__ = ["ValueTracker", "ComplexValueTracker"]
__all__ = ["ValueTracker", "ComplexValueTracker", "ThreeDValueTracker"]

from typing import TYPE_CHECKING, Any
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, cast

import numpy as np

Expand Down Expand Up @@ -85,8 +86,16 @@ def get_value(self) -> float:
value: float = self.points[0, 0]
return value

def set_value(self, value: float) -> Self:
"""Sets a new scalar value to the ValueTracker."""
def set_value(self, value: float | int | str) -> Self:
value = float(value)
if not np.isreal(value):
raise TypeError(
f"ValueTracker only accepts real numbers — use ComplexValueTracker for having 2 ValueTrackers simultaneously, got {value}"
)
if not np.isfinite(value):
raise ValueError(
f"Value must be finite — no nan or inf allowed, got {value}"
)
self.points[0, 0] = value
return self

Expand Down Expand Up @@ -235,8 +244,160 @@ def get_value(self) -> complex: # type: ignore [override]
"""Get the current value of this ComplexValueTracker as a complex number."""
return complex(*self.points[0, :2])

def set_value(self, value: complex | float) -> Self:
"""Sets a new complex value to the ComplexValueTracker."""
z = complex(value)
self.points[0, :2] = (z.real, z.imag)
def set_value(
self,
value: complex | float | int | str | Sequence[float | int] | np.ndarray = 0
+ 0j,
mode: str = "rectangular", # "rectangular" or "polar"
angle_unit: str = "radians", # "radians" or "degrees" — only used when mode="polar"
) -> Self:
"""
Sets a new complex value to the ComplexValueTracker.

Parameters
----------
value : complex | float | int | str | Sequence[float | int] | np.ndarray
The value to set. It can be:
- a complex number: 2+3j
- a float or int: 5.0 or 5
- a valid numeric string: "23" or "2+3j"
- a sequence of exactly 2 real numbers: (2, 3), [2, 3], np.array([2, 3])
- if mode="rectangular": interpreted as (x, y)
- if mode="polar": interpreted as (r, theta)
- theta can be in radians or degrees, specified by angle_unit
mode : str
"rectangular" (default) or "polar".
Only relevant when value is a sequence.
angle_unit : str
"radians" (default) or "degrees".
Only relevant when mode="polar".
If "degrees", theta is converted to radians internally.

Examples
--------
set_value(2+3j) # rectangular complex
set_value((2, 3)) # rectangular sequence
set_value((1, 90), mode="polar", angle_unit="degrees") # polar, degrees
set_value((1, np.pi/2), mode="polar") # polar, radians
"""
# validate mode
if mode not in ("rectangular", "polar"):
raise ValueError(f"mode must be 'rectangular' or 'polar', got '{mode}'")

# validate angle_unit
if angle_unit not in ("radians", "degrees"):
raise ValueError(
f"angle_unit must be 'radians' or 'degrees', got '{angle_unit}'"
)

if isinstance(value, (list, tuple, np.ndarray)):
# length check
if len(value) != 2:
raise ValueError(f"Expected exactly 2 numbers, got {len(value)}")
# check for type of number provided and finiteness check
if not all(np.isreal(v) and np.isfinite(v) for v in value):
raise TypeError(
"Elements must be real and finite numbers — no NAN(Not a Number) or infinity is allowed"
)
a, b = value

if mode == "polar":
r, theta = a, b
if r < 0:
raise ValueError(
f"Radius r must be non-negative in polar form, got {r}"
)
# convert degrees to radians if needed
if angle_unit == "degrees":
theta = np.deg2rad(theta)
x = r * np.cos(theta)
y = r * np.sin(theta)
else: # rectangular
x, y = a, b

else:
value = cast(complex | float | int | str, value)
z = complex(value) # handles complex, float, int, valid strings
# check real and imag parts individually for finiteness
if not np.isfinite(z.real):
raise ValueError(f"Real part must be finite, got {z.real}")
if not np.isfinite(z.imag):
raise ValueError(f"Imaginary part must be finite, got {z.imag}")
x, y = z.real, z.imag

self.points[0, :2] = (x, y)
return self


class ThreeDValueTracker(ValueTracker):
"""
A ValueTracker that tracks 3 numeric values simultaneously, equivalent to using 3 ValueTrackers at once.
Useful when working in 3D Scenes.
Accepts list, tuple, ndarray, int/float or numpy integer/floating as input.

Arrays of length < 3 are zero-padded on the right. Example:
If only 1 number is provided, it is stored as [float(number), 0., 0.].
If only 2 numbers are provided, it is stored as [float(number1), float(number2), 0.].

Arrays of length > 3 raise a ValueError.

Example of values
--------
tracker = ThreeDValueTracker([1, 2, 3]) # OK
tracker = ThreeDValueTracker([1, 2]) # stored as [1., 2., 0.]
tracker = ThreeDValueTracker(5) # stored as [5., 0., 0.]

class testThreeDValueTracker(ThreeDScene):
def construct(self):
self.set_camera_orientation(**self.default_angled_camera_orientation_kwargs)
axes = ThreeDAxes().add_coordinates()
for axis,color in zip(axes.get_axes(),[RED, GREEN, BLUE]):
axis.set_color(color)
self.add(axes)
position = ThreeDValueTracker([-3,0,0])
s = Sphere(radius = 0.1).set_color(GOLD)
s.move_to(axes.c2p(position.get_value()))
self.add(s)
self.begin_ambient_camera_rotation(rate=1.5)
self.wait(2)
s.add_updater(lambda m: m.move_to(axes.c2p(position.get_value())))
self.play(position.animate(run_time = 2).set_value([0,3,4]))
self.wait()
self.play(position.animate(run_time = 2).set_value([-2,0,-4]))
self.wait()
self.play(position.animate(run_time = 2).set_value([2,0,0]))
self.wait()
"""

def _validate(self, value: list | tuple | np.ndarray | int | float) -> np.ndarray:
"""
Converts input to a float numpy array of shape (3,).

Accepts:
- int, float, np.integer, np.floating → [value, 0., 0.]
- list, tuple, ndarray of length <= 3 → zero padded if length < 3

Raises:
ValueError: if input is non-numeric or length > 3
"""
if isinstance(value, (int, float, np.integer, np.floating)):
return np.array([float(value), 0.0, 0.0])
try:
value = np.asarray(value, dtype=float).flatten()
except (TypeError, ValueError) as err:
raise ValueError(
"Value must be numeric — list, tuple, ndarray, int, or float"
) from err
if len(value) > 3:
raise ValueError(f"Expected length at most 3, got length {len(value)}")
value = np.pad(value, (0, 3 - len(value))) if len(value) < 3 else value
return value

def get_value(self) -> np.ndarray:
"""Returns a copy of the current value."""
return self.points[0, :3].copy()

def set_value(self, value: list | tuple | np.ndarray | int | float) -> Self:
"""Sets a new 3D vector value to the tracker."""
self.points[0, :3] = self._validate(value)
return self