From e24064454397cc711937cc938b11184b0cb9ddc0 Mon Sep 17 00:00:00 2001 From: Will Gebhardt Date: Tue, 23 Jul 2024 15:04:01 -0400 Subject: [PATCH] Monitor plot (#66) * Update base_monitor.py * added plotting viewed compartments --- ngclearn/components/base_monitor.py | 85 +++++++++++++++++++++++------ 1 file changed, 69 insertions(+), 16 deletions(-) diff --git a/ngclearn/components/base_monitor.py b/ngclearn/components/base_monitor.py index 2f371d420..6efb367cc 100644 --- a/ngclearn/components/base_monitor.py +++ b/ngclearn/components/base_monitor.py @@ -2,8 +2,10 @@ from ngclearn import Component, Compartment from ngclearn import numpy as np -from ngcsimlib.utils import add_component_resolver, add_resolver_meta, get_current_path +from ngcsimlib.utils import add_component_resolver, add_resolver_meta, \ + get_current_path from ngcsimlib.logger import warn, critical +import matplotlib.pyplot as plt class Base_Monitor(Component): @@ -21,7 +23,8 @@ class Base_Monitor(Component): Using custom window length: myMonitor.watch(myComponent.myCompartment, customWindowLength) - To get values out of the monitor either path to the stored value directly, or pass in a compartment directly. All + To get values out of the monitor either path to the stored value + directly, or pass in a compartment directly. All paths are the same as their local path variable. Using a compartment: @@ -30,7 +33,8 @@ class Base_Monitor(Component): Using a path: myMonitor.get_store(myComponent.myCompartment.path).value - There can only be one monitor in existence at a time due to the way it interacts with resolvers and the compilers + There can only be one monitor in existence at a time due to the way it + interacts with resolvers and the compilers for ngclearn. Args: @@ -53,10 +57,10 @@ def build_advance(compartments): """ critical( - "build_advance() is not defined on this monitor, use either the monitor found in ngclearn.components or " + "build_advance() is not defined on this monitor, use either the " + "monitor found in ngclearn.components or " "ngclearn.components.lava (If using lava)") - @staticmethod def build_reset(compartments): """ @@ -66,6 +70,7 @@ def build_reset(compartments): Returns: The method to reset the stored values. """ + @staticmethod def _reset(**kwargs): return_vals = [] @@ -95,7 +100,8 @@ def __lshift__(self, other): def watch(self, compartment, window_length): """ - Sets the monitor to watch a specific compartment, for a specified window length. + Sets the monitor to watch a specific compartment, for a specified + window length. Args: compartment: the compartment object to monitor @@ -150,7 +156,7 @@ def halt_all(self): """ for compartment in self._sources: self.halt(compartment) - + def _update_resolver(self): output_compartments = [] compartments = [] @@ -162,13 +168,18 @@ def _update_resolver(self): parameters = [] add_component_resolver(self.__class__.__name__, "advance_state", - (self.build_advance(compartments), output_compartments)) + (self.build_advance(compartments), + output_compartments)) add_resolver_meta(self.__class__.__name__, "advance_state", - (args, parameters, compartments + [o for o in output_compartments], False)) + (args, parameters, + compartments + [o for o in output_compartments], + False)) - add_component_resolver(self.__class__.__name__, "reset", (self.build_reset(compartments), output_compartments)) + add_component_resolver(self.__class__.__name__, "reset", ( + self.build_reset(compartments), output_compartments)) add_resolver_meta(self.__class__.__name__, "reset", - (args, parameters, [o for o in output_compartments], False)) + (args, parameters, [o for o in output_compartments], + False)) def _add_path(self, path): _path = path.split("/")[1:] @@ -210,7 +221,8 @@ def save(self, directory, **kwargs): for key in self.compartments: n = key.split("/")[-1] _dict["sources"][key] = self.__dict__[n].value.shape - _dict["stores"][key + "*store"] = self.__dict__[n + "*store"].value.shape + _dict["stores"][key + "*store"] = self.__dict__[ + n + "*store"].value.shape with open(file_name, "w") as f: json.dump(_dict, f) @@ -221,9 +233,9 @@ def load(self, directory, **kwargs): vals = json.load(f) for comp_path, shape in vals["stores"].items(): - compartment_path = comp_path.split("/")[-1] - new_path = get_current_path() + "/" + "/".join(compartment_path.split("*")[-3:-1]) + new_path = get_current_path() + "/" + "/".join( + compartment_path.split("*")[-3:-1]) cs, end = self._add_path(new_path) @@ -233,8 +245,6 @@ def load(self, directory, **kwargs): cs[end] = new_comp setattr(self, compartment_path, new_comp) - - for comp_path, shape in vals['sources'].items(): compartment_path = comp_path.split("/")[-1] new_comp = Compartment(np.zeros(shape)) @@ -244,3 +254,46 @@ def load(self, directory, **kwargs): self.compartments.append(new_comp.path) self._update_resolver() + + def make_plot(self, compartment, ax=None, ylabel=None, xlabel=None, title=None, n=None, plot_func=None): + vals = self.view(compartment) + + if n is None: + n = vals.shape[2] + if title is None: + title = compartment.name.split("/")[0] + " " + compartment.display_name + + if ylabel is None: + _ylabel = compartment.units + elif ylabel: + _ylabel = ylabel + else: + _ylabel = None + + if xlabel is None: + _xlabel = "Time Steps" + elif xlabel: + _xlabel = xlabel + else: + _xlabel = None + + if ax is None: + _ax = plt + _ax.title(title) + if _ylabel: + _ax.ylabel(_ylabel) + if _xlabel: + _ax.xlabel(_xlabel) + else: + _ax = ax + _ax.set_title(title) + if _ylabel: + _ax.set_ylabel(_ylabel) + if _xlabel: + _ax.set_xlabel(_xlabel) + + if plot_func is None: + for k in range(n): + _ax.plot(vals[:, 0, k]) + else: + plot_func(vals, ax=_ax)