diff --git a/neurolib/control/optimal_control/oc_ww/__init__.py b/neurolib/control/optimal_control/oc_ww/__init__.py new file mode 100644 index 00000000..22b96d13 --- /dev/null +++ b/neurolib/control/optimal_control/oc_ww/__init__.py @@ -0,0 +1 @@ +from .oc_ww import OcWw diff --git a/neurolib/control/optimal_control/oc_ww/oc_ww.py b/neurolib/control/optimal_control/oc_ww/oc_ww.py new file mode 100644 index 00000000..90931119 --- /dev/null +++ b/neurolib/control/optimal_control/oc_ww/oc_ww.py @@ -0,0 +1,173 @@ +import numba + +from neurolib.control.optimal_control.oc import OC +from neurolib.models.ww.timeIntegration import ( + compute_hx, + compute_hx_min1, + compute_hx_nw, + Duh, + Dxdoth, +) + + +class OcWw(OC): + """Class for optimal control specific to neurolib's implementation of the two-population Wong-Wang model + ("WWmodel"). + + :param model: Instance of Wong-Wang model (can describe a single Wong-Wang node or a network of coupled + Wong-Wang nodes. + :type model: neurolib.models.ww.model.WWModel + """ + + def __init__( + self, + model, + target, + weights=None, + print_array=[], + cost_interval=(None, None), + cost_matrix=None, + control_matrix=None, + M=1, + M_validation=0, + validate_per_step=False, + ): + super().__init__( + model, + target, + weights=weights, + print_array=print_array, + cost_interval=cost_interval, + cost_matrix=cost_matrix, + control_matrix=control_matrix, + M=M, + M_validation=M_validation, + validate_per_step=validate_per_step, + ) + + assert self.model.name == "wongwang" + + def compute_dxdoth(self): + """Derivative of systems dynamics wrt. change of systems variables.""" + return Dxdoth(self.N, self.dim_vars) + + def get_model_params(self): + """Model params as an ordered tuple. + + :rtype: tuple + """ + return ( + self.model.params.a_exc, + self.model.params.b_exc, + self.model.params.d_exc, + self.model.params.tau_exc, + self.model.params.gamma_exc, + self.model.params.w_exc, + self.model.params.exc_current_baseline, + self.model.params.a_inh, + self.model.params.b_inh, + self.model.params.d_inh, + self.model.params.tau_inh, + self.model.params.w_inh, + self.model.params.inh_current_baseline, + self.model.params.J_NMDA, + self.model.params.J_I, + self.model.params.w_ee, + ) + + def Duh(self): + """Jacobian of systems dynamics wrt. external control input. + + :return: N x 4 x 4 x T Jacobians. + :rtype: np.ndarray + """ + + xs = self.get_xs() + xsd = self.get_xs_delay() + + return Duh( + self.model_params, + self.N, + self.dim_in, + self.dim_vars, + self.T, + self.control[:, self.state_vars_dict["r_exc"], :], + self.control[:, self.state_vars_dict["r_inh"], :], + xs[:, self.state_vars_dict["se"], :], + xs[:, self.state_vars_dict["si"], :], + self.model.params.K_gl, + self.model.params.Cmat, + self.Dmat_ndt, + xsd[:, self.state_vars_dict["se"], :], + self.state_vars_dict, + ) + + def compute_hx_list(self): + """List of Jacobians without and with time delays (e.g. in the ALN model) and list of respective time step delays as integers (0 for undelayed) + + :return: List of Jacobian matrices, list of time step delays + : rtype: List of np.ndarray, List of integers + """ + hx = self.compute_hx() + hx_min1 = self.compute_hx_min1() + return numba.typed.List([hx, hx_min1]), numba.typed.List([0, -1]) + + def compute_hx(self): + """Jacobians of WwModel wrt. all variables. + + :return: N x T x 6 x 6 Jacobians. + :rtype: np.ndarray + """ + return compute_hx( + self.model_params, + self.model.params.K_gl, + self.model.Cmat, + self.Dmat_ndt, + self.N, + self.dim_vars, + self.T, + self.get_xs(), + self.get_xs_delay(), + self.control, + self.state_vars_dict, + ) + + def compute_hx_min1(self): + """Jacobians of WWModel dse/dre and dsi/dri. + Dependency is in same time step, so shift by -1 in time is required for OC computation. + + :return: N x T x 6 x 6 Jacobians. + :rtype: np.ndarray + """ + return compute_hx_min1( + self.model_params, + self.N, + self.dim_vars, + self.T, + self.get_xs(), + self.state_vars_dict, + ) + + def compute_hx_nw(self): + """Jacobians for each time step for the network coupling. + + :return: N x N x T x (4x4) array + :rtype: np.ndarray + """ + + xs = self.get_xs() + + return compute_hx_nw( + self.model_params, + self.model.params.K_gl, + self.model.Cmat, + self.Dmat_ndt, + self.N, + self.dim_vars, + self.T, + xs[:, self.state_vars_dict["se"], :], + xs[:, self.state_vars_dict["se"], :], + self.get_xs_delay()[:, self.state_vars_dict["se"], :], + self.control[:, self.state_vars_dict["r_exc"], :], + self.state_vars_dict, + ) diff --git a/neurolib/models/ww/timeIntegration.py b/neurolib/models/ww/timeIntegration.py index a484de67..faf593c5 100644 --- a/neurolib/models/ww/timeIntegration.py +++ b/neurolib/models/ww/timeIntegration.py @@ -290,3 +290,478 @@ def r(I, a, b, d): ) # mV/ms return t, r_exc, r_inh, ses, sis, exc_ou, inh_ou + + +@numba.njit +def logistic(x, a, b, d): + """Logistic function evaluated at point 'x'. + + :type x: float + :param a: Parameter of logistic function. + :type a: float + :param b: Parameter of logistic function. + :type b: float + :param d: Parameter of logistic function. + :type d: float + + :rtype: float + """ + return (a * x - b) / (1.0 - np.exp(-d * (a * x - b))) + + +@numba.njit +def logistic_der(x, a, b, d): + """Derivative of logistic function, evaluated at point 'x'. + + :type x: float + :param a: Parameter of logistic function. + :type a: float + :param b: Parameter of logistic function. + :type b: float + :param d: Parameter of logistic function. + :type d: float + + :rtype: float + """ + exp = np.exp(-d * (a * x - b)) + return (a * (1.0 - exp) - (a * x - b) * d * a * exp) / (1.0 - exp) ** 2 + + +@numba.njit +def jacobian_ww( + model_params, + nw_se, + re, + se, + si, + ue, + ui, + V, + sv, +): + """Jacobian of the WW dynamical system. + + :param model_params: Tuple of parameters in the WC Model in order + :type model_params: tuple of float + :param nw_se: N x T input of network into each node's 'exc' + :type nw_se: np.ndarray + :param re: Value of the r_exc-variable at specific time. + :type re: float + :param se: Value of the se-variable at specific time. + :type se: float + :param si: Value of the si-variable at specific time. + :type si: float + :param ue: Value of control input to into 'exc' at specific time. + :type ue: float + :param ui: Value of control input to into 'ihn' at specific time. + :type ui: float + :param V: Number of system variables. + :type V: int + :param sv: dictionary of state vars and respective indices + :type sv: dict + + :return: 4 x 4 Jacobian matrix. + :rtype: np.ndarray + """ + ( + a_exc, + b_exc, + d_exc, + tau_exc, + gamma_exc, + w_exc, + exc_current_baseline, + a_inh, + b_inh, + d_inh, + tau_inh, + w_inh, + inh_current_baseline, + J_NMDA, + J_I, + w_ee, + ) = model_params + + jacobian = np.zeros((V, V)) + IE = w_exc * (exc_current_baseline + ue) + w_ee * J_NMDA * se - J_I * si + J_NMDA * nw_se + jacobian[sv["r_exc"], sv["se"]] = -logistic_der(IE, a_exc, b_exc, d_exc) * w_ee * J_NMDA + jacobian[sv["r_exc"], sv["si"]] = logistic_der(IE, a_exc, b_exc, d_exc) * J_I + II = w_inh * (inh_current_baseline + ui) + J_NMDA * se - si + jacobian[sv["r_inh"], sv["se"]] = -logistic_der(II, a_inh, b_inh, d_inh) * J_NMDA + jacobian[sv["r_inh"], sv["si"]] = logistic_der(II, a_inh, b_inh, d_inh) + + # jacobian[sv["se"], sv["r_exc"]] = -(1.0 - se) * gamma_exc + jacobian[sv["se"], sv["se"]] = 1.0 / tau_exc + gamma_exc * re + + # jacobian[sv["si"], sv["r_inh"]] = -1.0 + jacobian[sv["si"], sv["si"]] = 1.0 / tau_inh + return jacobian + + +@numba.njit +def compute_hx( + wc_model_params, + K_gl, + cmat, + dmat_ndt, + N, + V, + T, + dyn_vars, + dyn_vars_delay, + control, + sv, +): + """Jacobians of WWModel wrt. the all variables for each time step. + + :param model_params: Tuple of parameters in the WC Model in order + :type model_params: tuple of float + :param K_gl: Model parameter of global coupling strength. + :type K_gl: float + :param cmat: Model parameter, connectivity matrix. + :type cmat: ndarray + :param dmat_ndt: N x N delay matrix in multiples of dt. + :type dmat_ndt: np.ndarray + :param N: Number of nodes in the network. + :type N: int + :param V: Number of system variables. + :type V: int + :param T: Length of simulation (time dimension). + :type T: int + :param dyn_vars: N x V x T array containing all values of 'exc' and 'inh'. + :type dyn_vars: np.ndarray + :param dyn_vars_delay: + :type dyn_vars_delay: np.ndarray + :param control: N x 2 x T control inputs to 'exc' and 'inh'. + :type control: np.ndarray + :param sv: dictionary of state vars and respective indices + :type sv: dict + + :return: N x T x 4 x 4 Jacobians. + :rtype: np.ndarray + """ + hx = np.zeros((N, T, V, V)) + nw_e = compute_nw_input(N, T, K_gl, cmat, dmat_ndt, dyn_vars_delay[:, sv["se"], :]) + + for n in range(N): + for t in range(T): + re = dyn_vars[n, sv["r_exc"], t] + se = dyn_vars[n, sv["se"], t] + si = dyn_vars[n, sv["si"], t] + ue = control[n, sv["r_exc"], t] + ui = control[n, sv["r_inh"], t] + hx[n, t, :, :] = jacobian_ww( + wc_model_params, + nw_e[n, t], + re, + se, + si, + ue, + ui, + V, + sv, + ) + return hx + + +@numba.njit +def jacobian_ww_min1( + model_params, + se, + V, + sv, +): + """Jacobian of the WW dynamical system. + + :param model_params: Tuple of parameters in the WC Model in order + :type model_params: tuple of float + :param se: Value of the se-variable at specific time. + :type se: float + :param V: Number of system variables. + :type V: int + :param sv: dictionary of state vars and respective indices + :type sv: dict + + :return: 4 x 4 Jacobian matrix. + :rtype: np.ndarray + """ + ( + a_exc, + b_exc, + d_exc, + tau_exc, + gamma_exc, + w_exc, + exc_current_baseline, + a_inh, + b_inh, + d_inh, + tau_inh, + w_inh, + inh_current_baseline, + J_NMDA, + J_I, + w_ee, + ) = model_params + + jacobian = np.zeros((V, V)) + + jacobian[sv["se"], sv["r_exc"]] = -(1.0 - se) * gamma_exc + jacobian[sv["si"], sv["r_inh"]] = -1.0 + return jacobian + + +@numba.njit +def compute_hx_min1( + wc_model_params, + N, + V, + T, + dyn_vars, + sv, +): + """Jacobians of WWModel wrt. the all variables for each time step. + + :param model_params: Tuple of parameters in the WC Model in order + :type model_params: tuple of float + :param N: Number of nodes in the network. + :type N: int + :param V: Number of system variables. + :type V: int + :param T: Length of simulation (time dimension). + :type T: int + :param dyn_vars: N x V x T array containing all values of 'exc' and 'inh'. + :type dyn_vars: np.ndarray + :param sv: dictionary of state vars and respective indices + :type sv: dict + + :return: N x T x 4 x 4 Jacobians. + :rtype: np.ndarray + """ + hx = np.zeros((N, T, V, V)) + + for n in range(N): + for t in range(T): + se = dyn_vars[n, sv["se"], t] + hx[n, t, :, :] = jacobian_ww_min1( + wc_model_params, + se, + V, + sv, + ) + return hx + + +@numba.njit +def compute_nw_input(N, T, K_gl, cmat, dmat_ndt, se): + """Compute input by other nodes of network into each node's 'exc' population at every timestep. + + :param N: Number of nodes in the network. + :type N: int + :param T: Length of simulation (time dimension). + :type T: int + :param K_gl: Model parameter of global coupling strength. + :type K_gl: float + :param cmat: Model parameter, connectivity matrix. + :type cmat: ndarray + :param dmat_ndt: N x N delay matrix in multiples of dt. + :type dmat_ndt: np.ndarray + :param se: N x T array containing values of 'exc' of all nodes through time. + :type se: np.ndarray + :return: N x T network inputs. + :rytpe: np.ndarray + """ + nw_input = np.zeros((N, T)) + + for t in range(1, T): + for n in range(N): + for l in range(N): + nw_input[n, t] += K_gl * cmat[n, l] * (se[l, t - dmat_ndt[n, l] - 1]) + return nw_input + + +@numba.njit +def compute_hx_nw( + model_params, + K_gl, + cmat, + dmat_ndt, + N, + V, + T, + se, + si, + se_delay, + ue, + sv, +): + """Jacobians for network connectivity in all time steps. + + :param model_params: Tuple of parameters in the WC Model in order + :type model_params: tuple of float + :param K_gl: Model parameter of global coupling strength. + :type K_gl: float + :param cmat: Model parameter, connectivity matrix. + :type cmat: ndarray + :param dmat_ndt: N x N delay matrix in multiples of dt. + :type dmat_ndt: np.ndarray + :param N: Number of nodes in the network. + :type N: int + :param V: Number of system variables. + :type V: int + :param T: Length of simulation (time dimension). + :type T: int + :param se: Array of the se-variable. + :type se: np.ndarray + :param si: Array of the se-variable. + :type si: np.ndarray + :param se_delay: Value of delayed se-variable. + :type se_delay: np.ndarray + :param ue: N x T array of the total input received by 'exc' population in every node at any time. + :type ue: np.ndarray + :param sv: dictionary of state vars and respective indices + :type sv: dict + + :return: Jacobians for network connectivity in all time steps. + :rtype: np.ndarray of shape N x N x T x 4 x 4 + """ + ( + a_exc, + b_exc, + d_exc, + tau_exc, + gamma_exc, + w_exc, + exc_current_baseline, + a_inh, + b_inh, + d_inh, + tau_inh, + w_inh, + inh_current_baseline, + J_NMDA, + J_I, + w_ee, + ) = model_params + hx_nw = np.zeros((N, N, T, V, V)) + + nw_e = compute_nw_input(N, T, K_gl, cmat, dmat_ndt, se_delay) + IE = w_exc * (exc_current_baseline + ue) + w_ee * J_NMDA * se - J_I * si + J_NMDA * nw_e + + for n1 in range(N): + for n2 in range(N): + for t in range(T - 1): + hx_nw[n1, n2, t, sv["r_exc"], sv["se"]] = ( + logistic_der(IE[n1, t], a_exc, b_exc, d_exc) * J_NMDA * K_gl * cmat[n1, n2] + ) + + return -hx_nw + + +@numba.njit +def Duh( + model_params, + N, + V_in, + V_vars, + T, + ue, + ui, + se, + si, + K_gl, + cmat, + dmat_ndt, + se_delay, + sv, +): + """Jacobian of systems dynamics wrt. external inputs (control signals). + + :param model_params: Tuple of parameters in the WC Model in order + :type model_params: tuple of float + :param N: Number of nodes in the network. + :type N: int + :param V_in: Number of input variables. + :type V_in: int + :param V_vars: Number of system variables. + :type V_vars: int + :param T: Length of simulation (time dimension). + :type T: int + :param nw_e: N x T input of network into each node's 'exc' + :type nw_e: np.ndarray + :param ue: N x T array of the total input received by 'exc' population in every node at any time. + :type ue: np.ndarray + :param ui: N x T array of the total input received by 'inh' population in every node at any time. + :type ui: np.ndarray + :param se: Value of the se-variable for each node and timepoint + :type se: np.ndarray + :param si: Value of the si-variable for each node and timepoint + :type si: np.ndarray + :param K_gl: global coupling strength + :type K_gl float + :param cmat: coupling matrix + :type cmat: np.ndarray + :param dmat_ndt: delay index matrix + :type dmat_ndt: np.ndarray + :param se_delay: N x T array containing values of 'exc' of all nodes through time. + :type se_delay: np.ndarray + :param sv: dictionary of state vars and respective indices + :type sv: dict + + :rtype: np.ndarray of shape N x V x V x T + """ + + ( + a_exc, + b_exc, + d_exc, + tau_exc, + gamma_exc, + w_exc, + exc_current_baseline, + a_inh, + b_inh, + d_inh, + tau_inh, + w_inh, + inh_current_baseline, + J_NMDA, + J_I, + w_ee, + ) = model_params + + nw_e = compute_nw_input(N, T, K_gl, cmat, dmat_ndt, se_delay) + + duh = np.zeros((N, V_vars, V_in, T)) + for t in range(T): + for n in range(N): + IE = ( + w_exc * (exc_current_baseline + ue[n, t]) + + w_ee * J_NMDA * se[n, t] + - J_I * si[n, t] + + J_NMDA * nw_e[n, t] + ) + duh[n, sv["r_exc"], sv["r_exc"], t] = -logistic_der(IE, a_exc, b_exc, d_exc) * w_exc + II = w_inh * (inh_current_baseline + ui[n, t]) + J_NMDA * se[n, t] - si[n, t] + duh[n, sv["r_inh"], sv["r_inh"], t] = -logistic_der(II, a_inh, b_inh, d_inh) * w_inh + return duh + + +@numba.njit +def Dxdoth(N, V): + """Derivative of system dynamics wrt x dot + + :param N: Number of nodes in the network. + :type N: int + :param V: Number of system variables. + :type V: int + + :return: N x V x V matrix. + :rtype: np.ndarray + """ + dxdoth = np.zeros((N, V, V)) + for n in range(N): + for v in range(2, V): + dxdoth[n, v, v] = 1 + + return dxdoth diff --git a/tests/control/optimal_control/test_oc_ww.py b/tests/control/optimal_control/test_oc_ww.py new file mode 100644 index 00000000..1005e508 --- /dev/null +++ b/tests/control/optimal_control/test_oc_ww.py @@ -0,0 +1,213 @@ +import unittest +import numpy as np + +from neurolib.models.ww import WWModel +from neurolib.control.optimal_control import oc_ww + +import test_oc_utils as test_oc_utils + +p = test_oc_utils.params + + +class TestWW(unittest.TestCase): + """ + Test ww in neurolib/optimal_control/ + """ + + # tests if the control from OC computation coincides with a random input used for target forward-simulation + # single-node case + def test_1n(self): + print("Test OC in single-node system") + model = WWModel() + test_oc_utils.setinitzero_1n(model) + model.params["duration"] = p.TEST_DURATION_6 + # decrease time scale of sigmoidal function + # model.params["d_exc"] = 1.0 + # model.params["d_inh"] = 1.0 + + for input_channel in [0, 1]: + for measure_channel in range(4): + print("input_channel, measure_channel = ", input_channel, measure_channel) + + cost_mat = np.zeros((model.params.N, len(model.output_vars))) + control_mat = np.zeros((model.params.N, len(model.state_vars))) + control_mat[0, input_channel] = 1.0 # only allow inputs to input_channel + cost_mat[0, measure_channel] = 1.0 # only measure other channel + + test_oc_utils.set_input(model, p.ZERO_INPUT_1N_6) + model.params[model.input_vars[input_channel]] = p.TEST_INPUT_1N_6 + model.run() + target = test_oc_utils.gettarget_1n_ww(model) + + test_oc_utils.set_input(model, p.ZERO_INPUT_1N_6) + + model_controlled = oc_ww.OcWw(model, target) + model_controlled.maximum_control_strength = 2.0 + + model_controlled.control = np.concatenate( + [ + control_mat[0, 0] * p.INIT_INPUT_1N_6[:, np.newaxis, :], + control_mat[0, 1] * p.INIT_INPUT_1N_6[:, np.newaxis, :], + ], + axis=1, + ) + + model_controlled.update_input() + + control_coincide = False + + for i in range(p.LOOPS): + model_controlled.optimize(p.ITERATIONS) + + c_diff = np.abs(model_controlled.control[0, input_channel, :] - p.TEST_INPUT_1N_6[0, :]) + + if np.amax(c_diff) < p.LIMIT_DIFF: + control_coincide = True + break + + if model_controlled.zero_step_encountered: + break + + self.assertTrue(control_coincide) + + def test_2n(self): + print("Test OC in 2-node network") + ### communication between E and I is validated in test_onenode_oc. Test only E-E communication + ### Because of symmetry, test only inputs to 0 node, precision measuement in 1 node + + dmat = np.array([[0.0, 0.0], [0.0, 0.0]]) # no delay + cmat = np.array([[0.0, 1.0], [1.0, 0.0]]) + + model = WWModel(Cmat=cmat, Dmat=dmat) + test_oc_utils.setinitzero_2n(model) + model.params.duration = p.TEST_DURATION_10 + + # decrease time scale of sigmoidal function + # model.params["d_exc"] = 1.0 + # model.params["d_inh"] = 1.0 + + cost_mat = np.zeros((model.params.N, len(model.output_vars))) + control_mat = np.zeros((model.params.N, len(model.state_vars))) + control_mat[0, 0] = 1.0 + cost_mat[1, 0] = 1.0 + + model.params["exc_current"] = p.TEST_INPUT_2N_10 + model.params["inh_current"] = p.ZERO_INPUT_2N_10 + model.run() + + target = test_oc_utils.gettarget_2n_ww(model) + model.params["exc_current"] = p.ZERO_INPUT_2N_10 + + model_controlled = oc_ww.OcWw( + model, + target, + control_matrix=control_mat, + cost_matrix=cost_mat, + ) + model_controlled.maximum_control_strength = 2.0 + + model_controlled.control = np.concatenate( + [ + p.INIT_INPUT_2N_10[:, np.newaxis, :], + p.ZERO_INPUT_2N_10[:, np.newaxis, :], + ], + axis=1, + ) + model_controlled.update_input() + + control_coincide = False + + for i in range(p.LOOPS): + model_controlled.optimize(p.ITERATIONS) + c_diff = np.abs(model_controlled.control[0, 0, :] - p.TEST_INPUT_2N_10[0, :]) + if np.amax(c_diff) < p.LIMIT_DIFF: + control_coincide = True + break + + if model_controlled.zero_step_encountered: + break + + self.assertTrue(control_coincide) + + # tests if the control from OC computation coincides with a random input used for target forward-simulation + # delayed network case + def test_2n_delay(self): + print("Test OC in delayed 2-node network") + + cmat = np.array([[0.0, 0.0], [1.0, 0.0]]) + dmat = np.array([[0.0, 0.0], [p.TEST_DELAY, 0.0]]) + + model = WWModel(Cmat=cmat, Dmat=dmat) + test_oc_utils.setinitzero_2n(model) + model.params.duration = p.TEST_DURATION_8 + model.params.signalV = 1.0 + + cost_mat = np.zeros((model.params.N, len(model.output_vars))) + control_mat = np.zeros((model.params.N, len(model.state_vars))) + control_mat[0, 0] = 1.0 + cost_mat[1, 0] = 1.0 + + model.params["exc_current"] = p.TEST_INPUT_2N_8 + model.params["inh_current"] = p.ZERO_INPUT_2N_8 + + model.run() + + target = test_oc_utils.gettarget_2n_ww(model) + model.params["exc_current"] = p.ZERO_INPUT_2N_8 + + model_controlled = oc_ww.OcWw( + model, + target, + control_matrix=control_mat, + cost_matrix=cost_mat, + ) + model_controlled.maximum_control_strength = 2.0 + + model_controlled.control = np.concatenate( + [p.INIT_INPUT_2N_8[:, np.newaxis, :], p.ZERO_INPUT_2N_8[:, np.newaxis, :]], + axis=1, + ) + model_controlled.update_input() + + control_coincide = False + + for i in range(p.LOOPS): + model_controlled.optimize(p.ITERATIONS) + + # last entries of adjoint_state[0,0,:] are zero + self.assertTrue(np.amax(np.abs(model_controlled.adjoint_state[0, 0, -model.getMaxDelay() :])) == 0.0) + + c_diff_max = np.amax(np.abs(model_controlled.control[0, 0, :] - p.TEST_INPUT_2N_8[0, :])) + if c_diff_max < p.LIMIT_DIFF: + control_coincide = True + break + + if model_controlled.zero_step_encountered: + break + + self.assertTrue(control_coincide) + + # Arbitrary network and control setting, get_xs() returns correct array shape (despite initial values array longer than 1) + def test_get_xs(self): + print("Test state shape agrees with target shape") + + cmat = np.array([[0.0, 1.0], [1.0, 0.0]]) + dmat = np.array([[0.0, 0.0], [0.0, 0.0]]) # no delay + model = WWModel(Cmat=cmat, Dmat=dmat) + model.params.duration = p.TEST_DURATION_6 + test_oc_utils.set_input(model, p.TEST_INPUT_2N_6) + + target = np.ones((2, len(model.output_vars), p.TEST_INPUT_2N_6.shape[1])) + + model_controlled = oc_ww.OcWw( + model, + target, + ) + + model_controlled.optimize(1) + xs = model_controlled.get_xs() + self.assertTrue(xs.shape == target.shape) + + +if __name__ == "__main__": + unittest.main()