Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Creating threads to update visualization asynchronously #2656

Open
wants to merge 25 commits into
base: main
Choose a base branch
from
Open
Changes from 24 commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
4722447
Optimize UI responsiveness by Offloading model execution to separate …
HMNS19 Jan 19, 2025
ac55736
adding thread without implementing render_interval
HMNS19 Jan 22, 2025
4182285
visualisation thread
HMNS19 Jan 24, 2025
2561df8
render _interval functioning
HMNS19 Jan 25, 2025
00c3c93
Update solara_viz.py
HMNS19 Jan 25, 2025
a68b4fd
Update solara_viz.py
HMNS19 Jan 25, 2025
8e72a05
terminating threads
HMNS19 Jan 25, 2025
a08bdcb
removing asynchornous task to avoid cancelledError exception
HMNS19 Jan 25, 2025
fd17584
bidirectional synchronization
HMNS19 Jan 25, 2025
d1c7965
takes care of valueError and CancelledError
HMNS19 Jan 26, 2025
8b8cd19
final changes
HMNS19 Jan 26, 2025
1c276aa
Display message for adjusting play interval when using threads
HMNS19 Jan 26, 2025
fec0861
updating SimulatorConrtoller
HMNS19 Jan 27, 2025
775eaa9
Update solara_viz.py
HMNS19 Jan 27, 2025
1ef0cb9
Merge branch 'projectmesa:main' into issue-2604
HMNS19 Jan 27, 2025
5e5b2cf
solving the 'non rendering of plots while using threads' bug
HMNS19 Jan 28, 2025
b1ffeff
Fix code indentation
HMNS19 Jan 28, 2025
0fcf2b2
Fix code indentation
HMNS19 Jan 28, 2025
791d70b
threading for simulator
HMNS19 Jan 31, 2025
6cd0f22
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 31, 2025
3ba3593
Change from use_thread to use_task in SimulatorController for better …
HMNS19 Feb 1, 2025
2d4e7d7
Merge branch 'main' into issue-2604
HMNS19 Feb 6, 2025
e5579ad
ensuring event objects dont reset during re-renders and removing asyn…
HMNS19 Feb 15, 2025
da8d67b
Merge branch 'main' into issue-2604
HMNS19 Feb 15, 2025
d771f7f
removing lambda
HMNS19 Feb 18, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
129 changes: 110 additions & 19 deletions mesa/visualization/solara_viz.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@

import asyncio
import inspect
import threading
import time
from collections.abc import Callable
from typing import TYPE_CHECKING, Literal

Expand Down Expand Up @@ -57,6 +59,7 @@ def SolaraViz(
simulator: Simulator | None = None,
model_params=None,
name: str | None = None,
use_threads: bool = False,
):
"""Solara visualization component.

Expand All @@ -76,6 +79,8 @@ def SolaraViz(
This controls the speed of the model's automatic stepping. Defaults to 100 ms.
render_interval (int, optional): Controls how often plots are updated during a simulation,
allowing users to skip intermediate steps and update graphs less frequently.
use_threads: Flag for indicating whether to utilize multi-threading for model execution.
When checked, the model will utilize multiple threads,adjust based on system capabilities.
simulator: A simulator that controls the model (optional)
model_params (dict, optional): Parameters for (re-)instantiating a model.
Can include user-adjustable parameters and fixed parameters. Defaults to None.
Expand Down Expand Up @@ -114,6 +119,7 @@ def SolaraViz(
reactive_model_parameters = solara.use_reactive({})
reactive_play_interval = solara.use_reactive(play_interval)
reactive_render_interval = solara.use_reactive(render_interval)
reactive_use_threads = solara.use_reactive(use_threads)
with solara.AppBar():
solara.AppBarTitle(name if name else model.value.__class__.__name__)
solara.lab.ThemeToggle()
Expand All @@ -136,12 +142,21 @@ def SolaraViz(
max=100,
step=2,
)
if reactive_use_threads.value:
solara.Text("Increase play interval to avoid skipping plots")

solara.Checkbox(
label="Use Threads",
value=reactive_use_threads,
on_value=lambda v: reactive_use_threads.set(v),
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@HMNS19 Thanks for the PR!

I will leave the more focused review to @quaquel and @Corvince, however if possible I like to avoid lambda functions in Mesa as they are generally viewed as less readable, harder to read in the traceback and create issues with serialization. As this is part of the visual I appreciate the last 2 are less of a concern

However a simple function is also fairly easy to add

def set_reactive_use_threads(value):
   reactive_use_threads.set(value)

solara.Checkbox(
   label="Use Threads",
   value=reactive_use_threads,
   on_value=set_reactive_use_threads,
)

See PEP 8

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the feedback! I’ve replaced the lambda with a named function as you suggested. Let me know if there’s anything else to improve. Also curious—would lambda functions in the play_interval and render_interval sliders have similar concerns?

)
if not isinstance(simulator, Simulator):
ModelController(
model,
model_parameters=reactive_model_parameters,
play_interval=reactive_play_interval,
render_interval=reactive_render_interval,
use_threads=reactive_use_threads,
)
else:
SimulatorController(
Expand All @@ -150,6 +165,7 @@ def SolaraViz(
model_parameters=reactive_model_parameters,
play_interval=reactive_play_interval,
render_interval=reactive_render_interval,
use_threads=reactive_use_threads,
)
with solara.Card("Model Parameters"):
ModelCreator(
Expand Down Expand Up @@ -211,6 +227,7 @@ def ModelController(
model_parameters: dict | solara.Reactive[dict] = None,
play_interval: int | solara.Reactive[int] = 100,
render_interval: int | solara.Reactive[int] = 1,
use_threads: bool | solara.Reactive[bool] = False,
):
"""Create controls for model execution (step, play, pause, reset).

Expand All @@ -219,37 +236,70 @@ def ModelController(
model_parameters: Reactive parameters for (re-)instantiating a model.
play_interval: Interval for playing the model steps in milliseconds.
render_interval: Controls how often the plots are updated during simulation steps.Higher value reduce update frequency.
use_threads: Flag for indicating whether to utilize multi-threading for model execution.
"""
playing = solara.use_reactive(False)
running = solara.use_reactive(True)

if model_parameters is None:
model_parameters = {}
model_parameters = solara.use_reactive(model_parameters)

async def step():
while playing.value and running.value:
await asyncio.sleep(play_interval.value / 1000)
do_step()
visualization_pause_event = solara.use_memo(lambda: threading.Event(), [])

def step():
try:
while running.value and playing.value:
time.sleep(play_interval.value / 1000)
do_step()
if use_threads.value:
visualization_pause_event.set()
except Exception as e:
print(f"Error in step: {e}")
return

def visualization_task():
if use_threads.value:
try:
while playing.value and running.value:
visualization_pause_event.wait()
visualization_pause_event.clear()
force_update()
except Exception as e:
print(f"Error in visualization_task: {e}")

solara.lab.use_task(
step, dependencies=[playing.value, running.value], prefer_threaded=False
step, dependencies=[playing.value, running.value], prefer_threaded=True
)

solara.use_thread(
visualization_task,
dependencies=[playing.value, running.value],
)

@function_logger(__name__)
def do_step():
"""Advance the model by the number of steps specified by the render_interval slider."""
for _ in range(render_interval.value):
model.value.step()

running.value = model.value.running
if playing.value:
for _ in range(render_interval.value):
model.value.step()
running.value = model.value.running
if not playing.value:
break
if not use_threads.value:
force_update()

force_update()
else:
for _ in range(render_interval.value):
model.value.step()
running.value = model.value.running
force_update()

@function_logger(__name__)
def do_reset():
"""Reset the model to its initial state."""
playing.value = False
running.value = True
visualization_pause_event.clear()
_mesa_logger.log(
10,
f"creating new {model.value.__class__} instance with {model_parameters.value}",
Expand Down Expand Up @@ -285,6 +335,7 @@ def SimulatorController(
model_parameters: dict | solara.Reactive[dict] = None,
play_interval: int | solara.Reactive[int] = 100,
render_interval: int | solara.Reactive[int] = 1,
use_threads: bool | solara.Reactive[bool] = False,
):
"""Create controls for model execution (step, play, pause, reset).

Expand All @@ -294,6 +345,7 @@ def SimulatorController(
model_parameters: Reactive parameters for (re-)instantiating a model.
play_interval: Interval for playing the model steps in milliseconds.
render_interval: Controls how often the plots are updated during simulation steps.Higher values reduce update frequency.
use_threads: Flag for indicating whether to utilize multi-threading for model execution.

Notes:
The `step button` increments the step by the value specified in the `render_interval` slider.
Expand All @@ -304,27 +356,66 @@ def SimulatorController(
if model_parameters is None:
model_parameters = {}
model_parameters = solara.use_reactive(model_parameters)

async def step():
while playing.value and running.value:
await asyncio.sleep(play_interval.value / 1000)
do_step()
visualization_pause_event = solara.use_memo(lambda: threading.Event(), [])
pause_step_event = solara.use_memo(lambda: threading.Event(), [])

def step():
try:
while running.value and playing.value:
time.sleep(play_interval.value / 1000)
if use_threads.value:
pause_step_event.wait()
pause_step_event.clear()
do_step()
if use_threads.value:
visualization_pause_event.set()
except Exception as e:
print(f"Error in step: {e}")

def visualization_task():
if use_threads.value:
try:
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
pause_step_event.set()
while playing.value and running.value:
visualization_pause_event.wait()
visualization_pause_event.clear()
force_update()
pause_step_event.set()
except Exception as e:
print(f"Error in visualization_task: {e}")
return

solara.lab.use_task(
step, dependencies=[playing.value, running.value], prefer_threaded=False
)
solara.lab.use_task(visualization_task, dependencies=[playing.value])

def do_step():
"""Advance the model by the number of steps specified by the render_interval slider."""
simulator.run_for(render_interval.value)
running.value = model.value.running
force_update()
if playing.value:
for _ in range(render_interval.value):
simulator.run_for(1)
running.value = model.value.running
if not playing.value:
break
if not use_threads.value:
force_update()

else:
for _ in range(render_interval.value):
simulator.run_for(1)
running.value = model.value.running
force_update()

def do_reset():
"""Reset the model to its initial state."""
playing.value = False
running.value = True
simulator.reset()
visualization_pause_event.clear()
pause_step_event.clear()
model.value = model.value = model.value.__class__(
simulator=simulator, **model_parameters.value
)
Expand Down
Loading