Skip to content

Commit

Permalink
Monitor plot (#66)
Browse files Browse the repository at this point in the history
* Update base_monitor.py

* added plotting viewed compartments
  • Loading branch information
willgebhardt authored Jul 23, 2024
1 parent 1ddd86d commit e240644
Showing 1 changed file with 69 additions and 16 deletions.
85 changes: 69 additions & 16 deletions ngclearn/components/base_monitor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,10 @@

from ngclearn import Component, Compartment
from ngclearn import numpy as np
from ngcsimlib.utils import add_component_resolver, add_resolver_meta, get_current_path
from ngcsimlib.utils import add_component_resolver, add_resolver_meta, \
get_current_path
from ngcsimlib.logger import warn, critical
import matplotlib.pyplot as plt


class Base_Monitor(Component):
Expand All @@ -21,7 +23,8 @@ class Base_Monitor(Component):
Using custom window length:
myMonitor.watch(myComponent.myCompartment, customWindowLength)
To get values out of the monitor either path to the stored value directly, or pass in a compartment directly. All
To get values out of the monitor either path to the stored value
directly, or pass in a compartment directly. All
paths are the same as their local path variable.
Using a compartment:
Expand All @@ -30,7 +33,8 @@ class Base_Monitor(Component):
Using a path:
myMonitor.get_store(myComponent.myCompartment.path).value
There can only be one monitor in existence at a time due to the way it interacts with resolvers and the compilers
There can only be one monitor in existence at a time due to the way it
interacts with resolvers and the compilers
for ngclearn.
Args:
Expand All @@ -53,10 +57,10 @@ def build_advance(compartments):
"""
critical(
"build_advance() is not defined on this monitor, use either the monitor found in ngclearn.components or "
"build_advance() is not defined on this monitor, use either the "
"monitor found in ngclearn.components or "
"ngclearn.components.lava (If using lava)")


@staticmethod
def build_reset(compartments):
"""
Expand All @@ -66,6 +70,7 @@ def build_reset(compartments):
Returns: The method to reset the stored values.
"""

@staticmethod
def _reset(**kwargs):
return_vals = []
Expand Down Expand Up @@ -95,7 +100,8 @@ def __lshift__(self, other):

def watch(self, compartment, window_length):
"""
Sets the monitor to watch a specific compartment, for a specified window length.
Sets the monitor to watch a specific compartment, for a specified
window length.
Args:
compartment: the compartment object to monitor
Expand Down Expand Up @@ -150,7 +156,7 @@ def halt_all(self):
"""
for compartment in self._sources:
self.halt(compartment)

def _update_resolver(self):
output_compartments = []
compartments = []
Expand All @@ -162,13 +168,18 @@ def _update_resolver(self):
parameters = []

add_component_resolver(self.__class__.__name__, "advance_state",
(self.build_advance(compartments), output_compartments))
(self.build_advance(compartments),
output_compartments))
add_resolver_meta(self.__class__.__name__, "advance_state",
(args, parameters, compartments + [o for o in output_compartments], False))
(args, parameters,
compartments + [o for o in output_compartments],
False))

add_component_resolver(self.__class__.__name__, "reset", (self.build_reset(compartments), output_compartments))
add_component_resolver(self.__class__.__name__, "reset", (
self.build_reset(compartments), output_compartments))
add_resolver_meta(self.__class__.__name__, "reset",
(args, parameters, [o for o in output_compartments], False))
(args, parameters, [o for o in output_compartments],
False))

def _add_path(self, path):
_path = path.split("/")[1:]
Expand Down Expand Up @@ -210,7 +221,8 @@ def save(self, directory, **kwargs):
for key in self.compartments:
n = key.split("/")[-1]
_dict["sources"][key] = self.__dict__[n].value.shape
_dict["stores"][key + "*store"] = self.__dict__[n + "*store"].value.shape
_dict["stores"][key + "*store"] = self.__dict__[
n + "*store"].value.shape

with open(file_name, "w") as f:
json.dump(_dict, f)
Expand All @@ -221,9 +233,9 @@ def load(self, directory, **kwargs):
vals = json.load(f)

for comp_path, shape in vals["stores"].items():

compartment_path = comp_path.split("/")[-1]
new_path = get_current_path() + "/" + "/".join(compartment_path.split("*")[-3:-1])
new_path = get_current_path() + "/" + "/".join(
compartment_path.split("*")[-3:-1])

cs, end = self._add_path(new_path)

Expand All @@ -233,8 +245,6 @@ def load(self, directory, **kwargs):
cs[end] = new_comp
setattr(self, compartment_path, new_comp)



for comp_path, shape in vals['sources'].items():
compartment_path = comp_path.split("/")[-1]
new_comp = Compartment(np.zeros(shape))
Expand All @@ -244,3 +254,46 @@ def load(self, directory, **kwargs):
self.compartments.append(new_comp.path)

self._update_resolver()

def make_plot(self, compartment, ax=None, ylabel=None, xlabel=None, title=None, n=None, plot_func=None):
vals = self.view(compartment)

if n is None:
n = vals.shape[2]
if title is None:
title = compartment.name.split("/")[0] + " " + compartment.display_name

if ylabel is None:
_ylabel = compartment.units
elif ylabel:
_ylabel = ylabel
else:
_ylabel = None

if xlabel is None:
_xlabel = "Time Steps"
elif xlabel:
_xlabel = xlabel
else:
_xlabel = None

if ax is None:
_ax = plt
_ax.title(title)
if _ylabel:
_ax.ylabel(_ylabel)
if _xlabel:
_ax.xlabel(_xlabel)
else:
_ax = ax
_ax.set_title(title)
if _ylabel:
_ax.set_ylabel(_ylabel)
if _xlabel:
_ax.set_xlabel(_xlabel)

if plot_func is None:
for k in range(n):
_ax.plot(vals[:, 0, k])
else:
plot_func(vals, ax=_ax)

0 comments on commit e240644

Please sign in to comment.