diff --git a/causing/model.py b/causing/model.py index e0a6189..89e512c 100644 --- a/causing/model.py +++ b/causing/model.py @@ -47,7 +47,6 @@ def __post_init__(self): self.graph.add_node(var) self.trans_graph = networkx.transitive_closure(self.graph, reflexive=True) - @np.errstate(all="raise") def compute( self, xdat: np.array, @@ -59,7 +58,7 @@ def compute( fixed_to_yind: int = None, fixed_vals: list = None, # override default parameter values - parameters: dict[str, float] = {}, + parameters: dict[str, float] | None = None, ) -> np.array: """Compute y values for given x values @@ -69,16 +68,19 @@ def compute( assert xdat.ndim == 2, f"xdat must be m*tau (is {xdat.ndim}-dimensional)" assert xdat.shape[0] == self.mdim, f"xdat must be m*tau (is {xdat.shape})" tau = xdat.shape[1] + if parameters is None: + parameters = {} parameters = self.parameters | parameters - yhat = np.array([[float("nan")] * tau] * len(self.yvars)) - for i, eq in enumerate(self._model_lam): - if fixed_yind == i: - yhat[i, :] = fixed_yval - else: - eq_inputs = np.array( - [[*xval, *yval] for xval, yval in zip(xdat.T, yhat.T)] - ) + with np.errstate(all="raise"): + yhat = np.array([[float("nan")] * tau] * len(self.yvars)) + for i, eq in enumerate(self._model_lam): + if fixed_yind == i: + yhat[i, :] = fixed_yval + else: + eq_inputs = np.array( + [[*xval, *yval] for xval, yval in zip(xdat.T, yhat.T)] + ) if fixed_to_yind == i: eq_inputs[:, fixed_from_ind] = fixed_vals