From 691ec4cd30f5dddfee01eac0fe0a89e9d0430a8a Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Thu, 1 May 2025 09:57:30 +0100 Subject: [PATCH 01/15] Write constraints setter as inspired by sigma setter --- src/causalprog/causal_problem.py | 65 ++++++++++++++++++++++++++++++++ 1 file changed, 65 insertions(+) diff --git a/src/causalprog/causal_problem.py b/src/causalprog/causal_problem.py index ec2dc63..29cf417 100644 --- a/src/causalprog/causal_problem.py +++ b/src/causalprog/causal_problem.py @@ -108,6 +108,11 @@ def __init__( ) self._sigma_mapping = {} + self._constraints = raises( + NotImplementedError(f"Constraints not set for {self.label}") + ) + self._constraints_mapping = {} + def _parameter_vector_to_dict( self, parameter_vector: jax.Array ) -> dict[str, float]: @@ -197,6 +202,66 @@ def set_causal_estimand( for arg in args_not_used: self._sigma_mapping[arg] = self.graph.get_node(arg) + def set_constraints( + self, + constraints: CausalEstimand, + rvs_to_nodes: dict[str, str] | None = None, + graph_argument: str | None = None, + ) -> None: + """ + Set the Constraints for this problem. + + `constraints` should be a callable object that defines the Data Constraints of + interest, in terms of the random variables of to the problem. The + random variables are in turn represented by `Node`s, with this association being + recorded in the `rv_to_nodes` dictionary. + + The `constraints` method of the instance will be usable once this method + completes. + + Args: + constraints (Constraints): Callable object that evaluates the constraints + of interest for this `CausalProblem`, in terms of the random variables, + which are the arguments to this callable. ``constraints``s with + additional arguments are not currently supported. + rvs_to_nodes (dict[str, str]): Mapping of random variable (argument) names + of `sigma` to the labels of the corresponding `Node`s representing the + random variables. Argument names that match their corresponding `Node` + label can be omitted. + graph_argument (str): Argument to `sigma` that should be replaced with + `self.graph`. This argument is only temporary, as we are currently + limited to the syntax `expectation(Graph, Node)` rather than just + `expectation(Node)`. It will be removed in the future when methods like + `expectation` can be called solely on `Node` objects. + + """ + self._constraints = constraints + self._constraints_mapping = {} + + if rvs_to_nodes is None: + rvs_to_nodes = {} + constraints_args = signature(constraints).parameters + + for rv_name, node_label in rvs_to_nodes.items(): + if rv_name not in constraints_args: + msg = f"{rv_name} is not a parameter to causal estimand provided." + raise ValueError(msg) + self._constraints_mapping[rv_name] = self.graph.get_node(node_label) + + # Any unaccounted-for RV arguments to sigma are assumed to match + # the label of the corresponding node. + args_not_used = set(constraints_args.keys()) - set( + self._constraints_mapping.keys() + ) + + ## Temporary hack to ensure that we can use expectation(graph, X) syntax. + if graph_argument: + self._constraints_mapping[graph_argument] = self.graph + args_not_used -= {graph_argument} + ## END HACK + for arg in args_not_used: + self._constraints_mapping[arg] = self.graph.get_node(arg) + def causal_estimand(self, p: jax.Array) -> float: """ Evaluate the Causal Estimand at parameter vector `p`. From c689141f1aeebe455e176fdf6c3a02cfd3b96ddd Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Thu, 1 May 2025 10:09:40 +0100 Subject: [PATCH 02/15] Refactor common code for setting up CE and Constraints --- src/causalprog/causal_problem.py | 119 ++++++++++++++++--------------- 1 file changed, 61 insertions(+), 58 deletions(-) diff --git a/src/causalprog/causal_problem.py b/src/causalprog/causal_problem.py index 29cf417..57baf71 100644 --- a/src/causalprog/causal_problem.py +++ b/src/causalprog/causal_problem.py @@ -2,7 +2,7 @@ from collections.abc import Callable from inspect import signature -from typing import TypeAlias +from typing import Literal, TypeAlias import jax import jax.numpy as jnp @@ -103,15 +103,13 @@ def __init__( self._graph = graph - self._sigma = raises( - NotImplementedError(f"Causal estimand not set for {self.label}.") + # Callables cannot be evaluated until they are explicitly set + self.set_causal_estimand( + raises(NotImplementedError(f"Causal estimand not set for {self.label}.")) ) - self._sigma_mapping = {} - - self._constraints = raises( - NotImplementedError(f"Constraints not set for {self.label}") + self.set_constraints( + raises(NotImplementedError(f"Constraints not set for {self.label}.")) ) - self._constraints_mapping = {} def _parameter_vector_to_dict( self, parameter_vector: jax.Array @@ -126,6 +124,52 @@ def _parameter_vector_to_dict( pn = self.graph.parameter_nodes return {pn[i].label: value for i, value in enumerate(parameter_vector)} + def _set_callable( + self, + which: Literal["sigma", "constraints"], + *, + fn: CausalEstimand | Constraints, + rvs_to_nodes: dict[str, str] | None = None, + graph_argument: str | None = None, + ) -> None: + """ + Abstracted method for setting the Causal Estimand and/or Constraints functions. + + The functionality for setting up these two methods of an instance are identical, + save for the attributes which need to be updated. As such, we can refactor the + common functionality into a single, hidden, method and provide a friendlier + access point for users to employ. + """ + fn_attr = f"_{which}" + map_attr = f"_{which}_mapping" + debug_name = "constraints" if which == "constraints" else "causal estimand" + + setattr(self, fn_attr, fn) + setattr(self, map_attr, {}) + + if rvs_to_nodes is None: + rvs_to_nodes = {} + fn_args = signature(fn).parameters + + for rv_name, node_label in rvs_to_nodes.items(): + if rv_name not in fn_args: + msg = f"{rv_name} is not an argument to provided {debug_name} function." + raise ValueError(msg) + getattr(self, map_attr)[rv_name] = self.graph.get_node(node_label) + + # Any unaccounted-for RV arguments to sigma are assumed to match + # the label of the corresponding node. + args_not_used = set(fn_args) - set(getattr(self, map_attr)) + + ## Temporary hack to ensure that we can use expectation(graph, X) syntax. + if graph_argument: + getattr(self, map_attr)[graph_argument] = self.graph + args_not_used -= {graph_argument} + ## END HACK + + for arg in args_not_used: + getattr(self, map_attr)[arg] = self.graph.get_node(arg) + def _set_parameters_via_vector(self, parameter_vector: jax.Array | None) -> None: """ Shorthand to set parameter node values from a parameter vector. @@ -177,30 +221,9 @@ def set_causal_estimand( `expectation` can be called solely on `Node` objects. """ - self._sigma = sigma - self._sigma_mapping = {} - - if rvs_to_nodes is None: - rvs_to_nodes = {} - sigma_args = signature(sigma).parameters - - for rv_name, node_label in rvs_to_nodes.items(): - if rv_name not in sigma_args: - msg = f"{rv_name} is not a parameter to causal estimand provided." - raise ValueError(msg) - self._sigma_mapping[rv_name] = self.graph.get_node(node_label) - - # Any unaccounted-for RV arguments to sigma are assumed to match - # the label of the corresponding node. - args_not_used = set(sigma_args.keys()) - set(self._sigma_mapping.keys()) - - ## Temporary hack to ensure that we can use expectation(graph, X) syntax. - if graph_argument: - self._sigma_mapping[graph_argument] = self.graph - args_not_used -= {graph_argument} - ## END HACK - for arg in args_not_used: - self._sigma_mapping[arg] = self.graph.get_node(arg) + self._set_callable( + "sigma", fn=sigma, rvs_to_nodes=rvs_to_nodes, graph_argument=graph_argument + ) def set_constraints( self, @@ -211,7 +234,7 @@ def set_constraints( """ Set the Constraints for this problem. - `constraints` should be a callable object that defines the Data Constraints of + ``constraints`` should be a callable object that defines the Data Constraints of interest, in terms of the random variables of to the problem. The random variables are in turn represented by `Node`s, with this association being recorded in the `rv_to_nodes` dictionary. @@ -235,33 +258,13 @@ def set_constraints( `expectation` can be called solely on `Node` objects. """ - self._constraints = constraints - self._constraints_mapping = {} - - if rvs_to_nodes is None: - rvs_to_nodes = {} - constraints_args = signature(constraints).parameters - - for rv_name, node_label in rvs_to_nodes.items(): - if rv_name not in constraints_args: - msg = f"{rv_name} is not a parameter to causal estimand provided." - raise ValueError(msg) - self._constraints_mapping[rv_name] = self.graph.get_node(node_label) - - # Any unaccounted-for RV arguments to sigma are assumed to match - # the label of the corresponding node. - args_not_used = set(constraints_args.keys()) - set( - self._constraints_mapping.keys() + self._set_callable( + "constraints", + fn=constraints, + rvs_to_nodes=rvs_to_nodes, + graph_argument=graph_argument, ) - ## Temporary hack to ensure that we can use expectation(graph, X) syntax. - if graph_argument: - self._constraints_mapping[graph_argument] = self.graph - args_not_used -= {graph_argument} - ## END HACK - for arg in args_not_used: - self._constraints_mapping[arg] = self.graph.get_node(arg) - def causal_estimand(self, p: jax.Array) -> float: """ Evaluate the Causal Estimand at parameter vector `p`. From 5d77f5bf6531e203d60c6aa309afb1d8d4403854 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Thu, 1 May 2025 10:38:58 +0100 Subject: [PATCH 03/15] Extend tests to cover both CE and constraints setters/evaluations --- src/causalprog/causal_problem.py | 13 ++ tests/test_causal_problem/test_callables.py | 187 ++++++++++++++++++++ tests/test_causal_problem/test_sigma.py | 131 -------------- 3 files changed, 200 insertions(+), 131 deletions(-) create mode 100644 tests/test_causal_problem/test_callables.py delete mode 100644 tests/test_causal_problem/test_sigma.py diff --git a/src/causalprog/causal_problem.py b/src/causalprog/causal_problem.py index 57baf71..3f6dc78 100644 --- a/src/causalprog/causal_problem.py +++ b/src/causalprog/causal_problem.py @@ -277,3 +277,16 @@ def causal_estimand(self, p: jax.Array) -> float: self._set_parameters_via_vector(p) # Call stored function with transformed arguments. return self._sigma(**self._sigma_mapping) + + def constraints(self, p: jax.Array) -> jax.Array: + """ + Evaluate the Constraints at parameter vector `p`. + + Args: + p (jax.Array): Vector of parameter values to evaluate at. + + """ + # Set parameter nodes to their new values. + self._set_parameters_via_vector(p) + # Call stored function with transformed arguments. + return self._constraints(**self._constraints_mapping) diff --git a/tests/test_causal_problem/test_callables.py b/tests/test_causal_problem/test_callables.py new file mode 100644 index 0000000..12c4440 --- /dev/null +++ b/tests/test_causal_problem/test_callables.py @@ -0,0 +1,187 @@ +import re +from collections.abc import Callable +from typing import Literal + +import jax +import jax.numpy as jnp +import pytest + +from causalprog.algorithms import expectation, standard_deviation +from causalprog.causal_problem import CausalProblem +from causalprog.graph import Graph, Node + + +@pytest.fixture +def n_samples_for_estimands() -> int: + return 1000 + + +@pytest.fixture +def expectation_fixture( + n_samples_for_estimands: int, rng_key: jax.Array +) -> Callable[[Graph, Node], float]: + return lambda g, x: expectation( + g, x.label, samples=n_samples_for_estimands, rng_key=rng_key + ) + + +@pytest.fixture +def std_fixture( + n_samples_for_estimands: int, rng_key: jax.Array +) -> Callable[[Graph, Node], float]: + return ( + lambda g, x: standard_deviation( + g, x.label, samples=n_samples_for_estimands, rng_key=rng_key + ) + ** 2 + ) + + +@pytest.fixture +def vector_fixture( + n_samples_for_estimands: int, rng_key: jax.Array +) -> Callable[[Graph, Node, Node], jax.Array]: + """vector_fixture(g, x1, x2) = [mean of x1, std of x2].""" + return lambda g, x1, x2: jnp.array( + [ + expectation(g, x1.label, samples=n_samples_for_estimands, rng_key=rng_key), + standard_deviation( + g, x2.label, samples=n_samples_for_estimands, rng_key=rng_key + ) + ** 2, + ] + ) + + +@pytest.fixture(params=["causal_estimand", "constraints"]) +def which(request: pytest.FixtureRequest) -> Literal["causal_estimand", "constraints"]: + """For tests applicable to both the causal_estimand and constraints methods.""" + return request.param + + +@pytest.mark.parametrize( + ("initial_param_values", "args_to_setter", "expected", "atol"), + [ + pytest.param( + {"mu_x": 1.0, "nu_y": 1.0}, + { + "fn": "expectation_fixture", + "rvs_to_nodes": {"x": "mu_x"}, + "graph_argument": "g", + }, + 1.0, + 1.0e-12, + id="mu_x", + ), + pytest.param( + {"mu_x": 1.0, "nu_y": 1.0}, + { + "fn": "expectation_fixture", + "rvs_to_nodes": {"x": "nu_y"}, + "graph_argument": "g", + }, + 1.0, + 1.0e-12, + id="nu_y", + ), + pytest.param( + {"mu_x": 0.0, "nu_y": 1.0}, + { + "fn": "expectation_fixture", + "rvs_to_nodes": {}, + "graph_argument": "g", + }, + 0.0, + # Empirical calculation with 1000 samples with fixture RNG key + # should give 1.8808e-2 as the empirical expectation. + 2.0e-2, + id="E[x], infer association", + ), + pytest.param( + {"mu_x": 0.0, "nu_y": 1.0}, + { + "fn": "std_fixture", + "rvs_to_nodes": {"x": "y"}, + "graph_argument": "g", + }, + # x has fixed std 1, and nu_y will be set to 1. + 1.0**2 + 1.0**2, + # Empirical calculation with 1000 samples with fixture RNG key + # should give 1.8506 as the empirical std of y. + 2.0e-1, + id="Var[y]", + ), + pytest.param( + {"mu_x": 0.0, "nu_y": 1.0}, + { + "fn": "vector_fixture", + "rvs_to_nodes": {"x1": "x", "x2": "y"}, + "graph_argument": "g", + }, + # As per the previous test cases + jnp.array([0.0, 1.0**2 + 1.0**2]), + # As per the above cases, both components should be within + # 2.0e-1 of the analytical value. + jnp.array([2.0e-2, 2.0e-1]), + id="E[x], Var[y]", + ), + ], +) +def test_sigma_interactions( + graph: Graph, + which: Literal["causal_estimand", "constraints"], + initial_param_values: dict[str, float], + args_to_setter: dict[str, Callable[..., float] | dict[str, str] | str], + expected: float | jax.Array, + atol: float, + request: pytest.FixtureRequest, +) -> None: + """ + Test the set_{causal_estimand, constraints} and .{casual_estimand, constraints} + evaluation method. + + Test works by: + - Set the parameter values using the initial_param_values. + - Call the setter method using the given arguments. + - Evaluate the method that should have been set at the current parameter_vector, + which should evaluate the corresponding function at the current values of the + parameter vector, which will be the initial values just set. + - Check the result (lies within a given tolerance). + + In theory, there is no difference between the causal estimand and constraints when + it comes to this test - the constraints may be vector-valued but there is nothing + preventing the ``causal_estimand`` (programmatically) from being vector-valued + either. + """ + # Parametrised fixtures edit-in-place objects + args_to_setter = dict(args_to_setter) + if isinstance(args_to_setter["fn"], str): + args_to_setter["fn"] = request.getfixturevalue(args_to_setter["fn"]) + + if which == "constraints": + args_to_setter["constraints"] = args_to_setter.pop("fn") + else: + args_to_setter["sigma"] = args_to_setter.pop("fn") + + expected = jnp.array(expected, ndmin=1) + + # Test properly begins. + cp = CausalProblem(graph) + + method = getattr(cp, which) + setter_method = getattr(cp, f"set_{which}") + + # Before setting the causal estimand, it should throw an error if called. + with pytest.raises( + NotImplementedError, + match=re.escape( + f"{which.replace('_', ' ').capitalize()} not set for CausalProblem." + ), + ): + method(cp.parameter_vector) + + cp.set_parameter_values(**initial_param_values) + setter_method(**args_to_setter) + result = jnp.array(method(cp.parameter_vector), ndmin=1) + + assert jnp.allclose(result, expected, atol=atol) diff --git a/tests/test_causal_problem/test_sigma.py b/tests/test_causal_problem/test_sigma.py deleted file mode 100644 index 1c4e632..0000000 --- a/tests/test_causal_problem/test_sigma.py +++ /dev/null @@ -1,131 +0,0 @@ -import re -from collections.abc import Callable - -import jax -import pytest - -from causalprog.algorithms import expectation, standard_deviation -from causalprog.causal_problem import CausalProblem -from causalprog.graph import Graph, Node - - -@pytest.fixture -def n_samples_for_estimands() -> int: - return 1000 - - -@pytest.fixture -def expectation_fixture( - n_samples_for_estimands: int, rng_key: jax.Array -) -> Callable[[Graph, Node], float]: - return lambda g, x: expectation( - g, x.label, samples=n_samples_for_estimands, rng_key=rng_key - ) - - -@pytest.fixture -def std_fixture( - n_samples_for_estimands: int, rng_key: jax.Array -) -> Callable[[Graph, Node], float]: - return ( - lambda g, x: standard_deviation( - g, x.label, samples=n_samples_for_estimands, rng_key=rng_key - ) - ** 2 - ) - - -@pytest.mark.parametrize( - ("initial_param_values", "args_to_setter", "expected", "atol"), - [ - pytest.param( - {"mu_x": 1.0, "nu_y": 1.0}, - { - "sigma": "expectation_fixture", - "rvs_to_nodes": {"x": "mu_x"}, - "graph_argument": "g", - }, - 1.0, - 1.0e-12, - id="Return mu_x", - ), - pytest.param( - {"mu_x": 1.0, "nu_y": 1.0}, - { - "sigma": "expectation_fixture", - "rvs_to_nodes": {"x": "nu_y"}, - "graph_argument": "g", - }, - 1.0, - 1.0e-12, - id="Return nu_y", - ), - pytest.param( - {"mu_x": 0.0, "nu_y": 1.0}, - { - "sigma": "expectation_fixture", - "rvs_to_nodes": {}, - "graph_argument": "g", - }, - 0.0, - # Empirical calculation with 1000 samples with fixture RNG key - # should give 1.8808 as the empirical expectation. - 2.0e-2, - id="Return E[x], infer association", - ), - pytest.param( - {"mu_x": 0.0, "nu_y": 1.0}, - { - "sigma": "std_fixture", - "rvs_to_nodes": {"x": "y"}, - "graph_argument": "g", - }, - # x has fixed std 1, and nu_y will be set to 1. - 1.0**2 + 1.0**2, - # Empirical calculation with 1000 samples with fixture RNG key - # should give 1.8506 as the empirical std of y. - 2.0e-1, - id="Return Var[y]", - ), - ], -) -def test_sigma_interactions( - graph: Graph, - initial_param_values: dict[str, float], - args_to_setter: dict[str, Callable[..., float] | dict[str, str] | str], - expected: dict[str, float], - atol: float, - request: pytest.FixtureRequest, -) -> None: - """ - Test the set_causal_estimand and casual_estimand evaluation method. - - Test works by: - - Set the parameter values using the initial_param_values. - - Set the causal_estimand using the setter and given arguments. - - Call .causal_estimand(parameter_vector), which should evaluate the causal estimand - at the current values of the parameter vector, which will be the initial values - just set. - - Check the result (lies within a given tolerance). - """ - if isinstance(args_to_setter["sigma"], str): - args_to_setter["sigma"] = request.getfixturevalue(args_to_setter["sigma"]) - - cp = CausalProblem(graph) - - # Before setting the causal estimand, it should throw an error if called. - with pytest.raises( - NotImplementedError, - match=re.escape("Causal estimand not set for CausalProblem."), - ): - cp.causal_estimand(cp.parameter_vector) - - cp.set_parameter_values(**initial_param_values) - cp.set_causal_estimand(**args_to_setter) - - result = cp.causal_estimand(cp.parameter_vector) - - assert result == pytest.approx( - expected, - abs=atol, - ) From f361eb3e5c133363541c1d88cf7514dda938db81 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Thu, 1 May 2025 10:46:52 +0100 Subject: [PATCH 04/15] Refactor evaluation of CE & constraints --- src/causalprog/causal_problem.py | 28 ++++++++++++++++++++-------- 1 file changed, 20 insertions(+), 8 deletions(-) diff --git a/src/causalprog/causal_problem.py b/src/causalprog/causal_problem.py index 3f6dc78..b1204ab 100644 --- a/src/causalprog/causal_problem.py +++ b/src/causalprog/causal_problem.py @@ -124,6 +124,24 @@ def _parameter_vector_to_dict( pn = self.graph.parameter_nodes return {pn[i].label: value for i, value in enumerate(parameter_vector)} + def _eval_callable( + self, which: Literal["sigma", "constraints"], at: jax.Array + ) -> jax.Array: + """ + Evaluate a callable method of this instance. + + This is an abstraction method for when the causal estimand or constraints + functions need to be evaluated. In each case, the process is the same: + + - Update the values of the parameter nodes. + - Call the underlying function composed with its mapping of RVs to Nodes. + + The method is abstracted here so that any changes to the process are reflected + in both methods automatically. + """ + self._set_parameters_via_vector(at) + return getattr(self, f"_{which}")(**getattr(self, f"_{which}_mapping")) + def _set_callable( self, which: Literal["sigma", "constraints"], @@ -273,10 +291,7 @@ def causal_estimand(self, p: jax.Array) -> float: p (jax.Array): Vector of parameter values to evaluate at. """ - # Set parameter nodes to their new values. - self._set_parameters_via_vector(p) - # Call stored function with transformed arguments. - return self._sigma(**self._sigma_mapping) + return self._eval_callable("sigma", p) def constraints(self, p: jax.Array) -> jax.Array: """ @@ -286,7 +301,4 @@ def constraints(self, p: jax.Array) -> jax.Array: p (jax.Array): Vector of parameter values to evaluate at. """ - # Set parameter nodes to their new values. - self._set_parameters_via_vector(p) - # Call stored function with transformed arguments. - return self._constraints(**self._constraints_mapping) + return self._eval_callable("constraints", p) From 958cb0293a7ed429997d0eb8048f05a28789f680 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Thu, 1 May 2025 10:47:22 +0100 Subject: [PATCH 05/15] Rename test to be general --- tests/test_causal_problem/test_callables.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_causal_problem/test_callables.py b/tests/test_causal_problem/test_callables.py index 12c4440..7aecf29 100644 --- a/tests/test_causal_problem/test_callables.py +++ b/tests/test_causal_problem/test_callables.py @@ -127,7 +127,7 @@ def which(request: pytest.FixtureRequest) -> Literal["causal_estimand", "constra ), ], ) -def test_sigma_interactions( +def test_callables( graph: Graph, which: Literal["causal_estimand", "constraints"], initial_param_values: dict[str, float], From 3c0cf2780590a3778262e5c04e20aa88fb342978 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Thu, 1 May 2025 11:25:45 +0100 Subject: [PATCH 06/15] Skeleton integration test, doesn't quite work though --- .../test_two_normal_example.py | 78 +++++++++++++++++++ 1 file changed, 78 insertions(+) create mode 100644 tests/test_integration/test_two_normal_example.py diff --git a/tests/test_integration/test_two_normal_example.py b/tests/test_integration/test_two_normal_example.py new file mode 100644 index 0000000..a3c7c1f --- /dev/null +++ b/tests/test_integration/test_two_normal_example.py @@ -0,0 +1,78 @@ +"""""" + +import jax +import jax.numpy as jnp +from jax.scipy.optimize import minimize as jax_minimize + +from causalprog.algorithms import expectation +from causalprog.causal_problem import CausalProblem +from causalprog.distribution.normal import NormalFamily +from causalprog.graph import DistributionNode, Graph, Node, ParameterNode + + +def test_two_normal_example( + rng_key: jax.Array, + n_samples: int = 1000, + nu_x: float = 1.0, + nu_y: float = 1.0, + data: float = 2.0, + eps: float = 1.0, + initial_guess: dict[str, float] = {"mu_x": 1.1}, +) -> None: + """""" + data = jnp.array(data, ndmin=1) + + mu_x = ParameterNode("mu_x") + x = DistributionNode( + NormalFamily(), + label="x", + parameters={"mean": "mu_x"}, + constant_parameters={"cov": nu_x**2}, + ) + y = DistributionNode( + NormalFamily(), + label="y", + parameters={"mean": "x"}, + constant_parameters={"cov": nu_y**2}, + ) + + graph = Graph(label="G") + graph.add_edge(mu_x, x) + graph.add_edge(x, y) + + def sigma(g: Graph, rv: Node): + return expectation(g, rv.label, samples=n_samples, rng_key=rng_key) + + def constraints(g: Graph, rv: Node): + return expectation(g, rv.label, samples=n_samples, rng_key=rng_key) + + cp = CausalProblem(graph=graph, label="CP") + cp.set_causal_estimand(sigma, rvs_to_nodes={"rv": "y"}, graph_argument="g") + cp.set_constraints(constraints, rvs_to_nodes={"rv": "x"}, graph_argument="g") + + min_value = jax_minimize( + cp.causal_estimand, cp.parameter_vector, options={"maxiter": 5}, method="BFGS" + ) + + # # scipy doesn't like jax arrays, so we also have to be even more inefficient here. + # # However, SCIPY does do constrained optimisation which jax currently doesn't? + # # ``jaxopt`` is recommended by the jax devs, worth a look? + + # fn = lambda p: np.array(cp.causal_estimand(p)) + + # # This would ideally be done within the CP class, via a method. + # # But for now, we do it explicitly. + # data_constraint = NonlinearConstraint( + # lambda x: jnp.abs(cp.constraints(x) - data).__array__(), lb=-jnp.inf, ub=eps + # ) + # # Initial guess would also need to be set a-prori or an argument to hypothetical + # # method + # cp.set_parameter_values(**initial_guess) + + # # Should be able to minimise now? + # min_result = minimize( + # fn, + # cp.parameter_vector.__array__(), + # # constraints=(data_constraint,), + # options={"disp": True}, + # ) From 311489667f2604c823a91e09b95208466090f27f Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Fri, 2 May 2025 10:35:56 +0100 Subject: [PATCH 07/15] Purge currently-failing IT to get working feature in --- .../test_two_normal_example.py | 78 ------------------- 1 file changed, 78 deletions(-) delete mode 100644 tests/test_integration/test_two_normal_example.py diff --git a/tests/test_integration/test_two_normal_example.py b/tests/test_integration/test_two_normal_example.py deleted file mode 100644 index a3c7c1f..0000000 --- a/tests/test_integration/test_two_normal_example.py +++ /dev/null @@ -1,78 +0,0 @@ -"""""" - -import jax -import jax.numpy as jnp -from jax.scipy.optimize import minimize as jax_minimize - -from causalprog.algorithms import expectation -from causalprog.causal_problem import CausalProblem -from causalprog.distribution.normal import NormalFamily -from causalprog.graph import DistributionNode, Graph, Node, ParameterNode - - -def test_two_normal_example( - rng_key: jax.Array, - n_samples: int = 1000, - nu_x: float = 1.0, - nu_y: float = 1.0, - data: float = 2.0, - eps: float = 1.0, - initial_guess: dict[str, float] = {"mu_x": 1.1}, -) -> None: - """""" - data = jnp.array(data, ndmin=1) - - mu_x = ParameterNode("mu_x") - x = DistributionNode( - NormalFamily(), - label="x", - parameters={"mean": "mu_x"}, - constant_parameters={"cov": nu_x**2}, - ) - y = DistributionNode( - NormalFamily(), - label="y", - parameters={"mean": "x"}, - constant_parameters={"cov": nu_y**2}, - ) - - graph = Graph(label="G") - graph.add_edge(mu_x, x) - graph.add_edge(x, y) - - def sigma(g: Graph, rv: Node): - return expectation(g, rv.label, samples=n_samples, rng_key=rng_key) - - def constraints(g: Graph, rv: Node): - return expectation(g, rv.label, samples=n_samples, rng_key=rng_key) - - cp = CausalProblem(graph=graph, label="CP") - cp.set_causal_estimand(sigma, rvs_to_nodes={"rv": "y"}, graph_argument="g") - cp.set_constraints(constraints, rvs_to_nodes={"rv": "x"}, graph_argument="g") - - min_value = jax_minimize( - cp.causal_estimand, cp.parameter_vector, options={"maxiter": 5}, method="BFGS" - ) - - # # scipy doesn't like jax arrays, so we also have to be even more inefficient here. - # # However, SCIPY does do constrained optimisation which jax currently doesn't? - # # ``jaxopt`` is recommended by the jax devs, worth a look? - - # fn = lambda p: np.array(cp.causal_estimand(p)) - - # # This would ideally be done within the CP class, via a method. - # # But for now, we do it explicitly. - # data_constraint = NonlinearConstraint( - # lambda x: jnp.abs(cp.constraints(x) - data).__array__(), lb=-jnp.inf, ub=eps - # ) - # # Initial guess would also need to be set a-prori or an argument to hypothetical - # # method - # cp.set_parameter_values(**initial_guess) - - # # Should be able to minimise now? - # min_result = minimize( - # fn, - # cp.parameter_vector.__array__(), - # # constraints=(data_constraint,), - # options={"disp": True}, - # ) From 090712e5ff1d20072f064b6079ccb0bec8ba307e Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Fri, 2 May 2025 10:41:36 +0100 Subject: [PATCH 08/15] Revert "Purge currently-failing IT to get working feature in" Reinstate the in-development integration test. This reverts commit 311489667f2604c823a91e09b95208466090f27f. --- .../test_two_normal_example.py | 78 +++++++++++++++++++ 1 file changed, 78 insertions(+) create mode 100644 tests/test_integration/test_two_normal_example.py diff --git a/tests/test_integration/test_two_normal_example.py b/tests/test_integration/test_two_normal_example.py new file mode 100644 index 0000000..a3c7c1f --- /dev/null +++ b/tests/test_integration/test_two_normal_example.py @@ -0,0 +1,78 @@ +"""""" + +import jax +import jax.numpy as jnp +from jax.scipy.optimize import minimize as jax_minimize + +from causalprog.algorithms import expectation +from causalprog.causal_problem import CausalProblem +from causalprog.distribution.normal import NormalFamily +from causalprog.graph import DistributionNode, Graph, Node, ParameterNode + + +def test_two_normal_example( + rng_key: jax.Array, + n_samples: int = 1000, + nu_x: float = 1.0, + nu_y: float = 1.0, + data: float = 2.0, + eps: float = 1.0, + initial_guess: dict[str, float] = {"mu_x": 1.1}, +) -> None: + """""" + data = jnp.array(data, ndmin=1) + + mu_x = ParameterNode("mu_x") + x = DistributionNode( + NormalFamily(), + label="x", + parameters={"mean": "mu_x"}, + constant_parameters={"cov": nu_x**2}, + ) + y = DistributionNode( + NormalFamily(), + label="y", + parameters={"mean": "x"}, + constant_parameters={"cov": nu_y**2}, + ) + + graph = Graph(label="G") + graph.add_edge(mu_x, x) + graph.add_edge(x, y) + + def sigma(g: Graph, rv: Node): + return expectation(g, rv.label, samples=n_samples, rng_key=rng_key) + + def constraints(g: Graph, rv: Node): + return expectation(g, rv.label, samples=n_samples, rng_key=rng_key) + + cp = CausalProblem(graph=graph, label="CP") + cp.set_causal_estimand(sigma, rvs_to_nodes={"rv": "y"}, graph_argument="g") + cp.set_constraints(constraints, rvs_to_nodes={"rv": "x"}, graph_argument="g") + + min_value = jax_minimize( + cp.causal_estimand, cp.parameter_vector, options={"maxiter": 5}, method="BFGS" + ) + + # # scipy doesn't like jax arrays, so we also have to be even more inefficient here. + # # However, SCIPY does do constrained optimisation which jax currently doesn't? + # # ``jaxopt`` is recommended by the jax devs, worth a look? + + # fn = lambda p: np.array(cp.causal_estimand(p)) + + # # This would ideally be done within the CP class, via a method. + # # But for now, we do it explicitly. + # data_constraint = NonlinearConstraint( + # lambda x: jnp.abs(cp.constraints(x) - data).__array__(), lb=-jnp.inf, ub=eps + # ) + # # Initial guess would also need to be set a-prori or an argument to hypothetical + # # method + # cp.set_parameter_values(**initial_guess) + + # # Should be able to minimise now? + # min_result = minimize( + # fn, + # cp.parameter_vector.__array__(), + # # constraints=(data_constraint,), + # options={"disp": True}, + # ) From d15b0073af764ea6fb96da941bf915cf0594b1d3 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Fri, 2 May 2025 14:37:39 +0100 Subject: [PATCH 09/15] Test now runs, and genuinely fails! --- .../test_two_normal_example.py | 108 +++++++++++------- 1 file changed, 67 insertions(+), 41 deletions(-) diff --git a/tests/test_integration/test_two_normal_example.py b/tests/test_integration/test_two_normal_example.py index a3c7c1f..5e714ca 100644 --- a/tests/test_integration/test_two_normal_example.py +++ b/tests/test_integration/test_two_normal_example.py @@ -2,77 +2,103 @@ import jax import jax.numpy as jnp -from jax.scipy.optimize import minimize as jax_minimize +import numpy as np +from scipy.optimize import NonlinearConstraint, minimize from causalprog.algorithms import expectation from causalprog.causal_problem import CausalProblem from causalprog.distribution.normal import NormalFamily -from causalprog.graph import DistributionNode, Graph, Node, ParameterNode +from causalprog.graph import DistributionNode, Graph, ParameterNode def test_two_normal_example( rng_key: jax.Array, - n_samples: int = 1000, + n_samples: int = 100, nu_x: float = 1.0, nu_y: float = 1.0, - data: float = 2.0, - eps: float = 1.0, - initial_guess: dict[str, float] = {"mu_x": 1.1}, + epsilon: float = 1.0, + data: tuple[float, ...] = (2.0,), + x0: tuple[float, ...] = (1.25,), ) -> None: """""" - data = jnp.array(data, ndmin=1) + data = np.array(data, ndmin=1) + x0 = np.array(x0, ndmin=1) + true_analytic_value = data - epsilon - mu_x = ParameterNode("mu_x") + mu = ParameterNode("mu") x = DistributionNode( NormalFamily(), label="x", - parameters={"mean": "mu_x"}, + parameters={"mean": "mu"}, constant_parameters={"cov": nu_x**2}, + is_outcome=True, ) y = DistributionNode( NormalFamily(), label="y", parameters={"mean": "x"}, constant_parameters={"cov": nu_y**2}, + is_outcome=True, ) graph = Graph(label="G") - graph.add_edge(mu_x, x) + graph.add_edge(mu, x) graph.add_edge(x, y) - def sigma(g: Graph, rv: Node): - return expectation(g, rv.label, samples=n_samples, rng_key=rng_key) + def expectation_with_n_samples(): + def _inner(g: Graph, rv: DistributionNode) -> float: + return expectation(g, rv.label, samples=n_samples, rng_key=rng_key) - def constraints(g: Graph, rv: Node): - return expectation(g, rv.label, samples=n_samples, rng_key=rng_key) + return _inner - cp = CausalProblem(graph=graph, label="CP") - cp.set_causal_estimand(sigma, rvs_to_nodes={"rv": "y"}, graph_argument="g") - cp.set_constraints(constraints, rvs_to_nodes={"rv": "x"}, graph_argument="g") + # Solve everything analytically first. That is, use the analytic formula for + # the CE, and for the Constraints, and solve the resulting problem. + def analytic_ce(p): + return p - min_value = jax_minimize( - cp.causal_estimand, cp.parameter_vector, options={"maxiter": 5}, method="BFGS" + def analytic_con(p): + return np.abs(p - data) + + analytic_constraint = NonlinearConstraint(analytic_con, lb=-np.inf, ub=epsilon) + analytic_result = minimize(analytic_ce, x0, constraints=[analytic_constraint]) + assert np.isclose(analytic_result.x, true_analytic_value) + + # Setup the CausalProblem instance. + + cp = CausalProblem(graph, label="CP") + cp.set_causal_estimand( + expectation_with_n_samples(), + rvs_to_nodes={"rv": "y"}, + graph_argument="g", + ) + cp.set_constraints( + expectation_with_n_samples(), + rvs_to_nodes={"rv": "x"}, + graph_argument="g", + ) + + def ce(p): + return cp.causal_estimand(p) + + def con(p): + return np.abs(cp.constraints(p) - data) + + # Prior to solving, check that evaluating the CE and constraints bears some + # resemblance to their analytic counterparts. + range_check = np.linspace(0.0, 5.0, num=50) + for value in range_check: + v = np.atleast_1d(value) + assert np.isclose(ce(v), analytic_ce(v), atol=5 / n_samples) + assert np.isclose(con(v), analytic_con(v), atol=5 / n_samples) + + # Alright, now try solving the actual problem + nlc = NonlinearConstraint(con, lb=-np.inf, ub=epsilon) + result = minimize( + ce, + x0, + constraints=[nlc], + options={"disp": True}, + jac=lambda p: np.atleast_1d(1.0), ) - # # scipy doesn't like jax arrays, so we also have to be even more inefficient here. - # # However, SCIPY does do constrained optimisation which jax currently doesn't? - # # ``jaxopt`` is recommended by the jax devs, worth a look? - - # fn = lambda p: np.array(cp.causal_estimand(p)) - - # # This would ideally be done within the CP class, via a method. - # # But for now, we do it explicitly. - # data_constraint = NonlinearConstraint( - # lambda x: jnp.abs(cp.constraints(x) - data).__array__(), lb=-jnp.inf, ub=eps - # ) - # # Initial guess would also need to be set a-prori or an argument to hypothetical - # # method - # cp.set_parameter_values(**initial_guess) - - # # Should be able to minimise now? - # min_result = minimize( - # fn, - # cp.parameter_vector.__array__(), - # # constraints=(data_constraint,), - # options={"disp": True}, - # ) + assert np.isclose(result.x, analytic_result.x) From 2e2f8e6dc99c003589820b1f6cb456ecce1e4a1b Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Tue, 6 May 2025 09:12:34 +0100 Subject: [PATCH 10/15] Remove unused import --- tests/test_integration/test_two_normal_example.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tests/test_integration/test_two_normal_example.py b/tests/test_integration/test_two_normal_example.py index 5e714ca..8c016c7 100644 --- a/tests/test_integration/test_two_normal_example.py +++ b/tests/test_integration/test_two_normal_example.py @@ -1,7 +1,6 @@ """""" import jax -import jax.numpy as jnp import numpy as np from scipy.optimize import NonlinearConstraint, minimize @@ -23,7 +22,7 @@ def test_two_normal_example( """""" data = np.array(data, ndmin=1) x0 = np.array(x0, ndmin=1) - true_analytic_value = data - epsilon + true_analytic_value = np.array(data) - epsilon mu = ParameterNode("mu") x = DistributionNode( @@ -98,7 +97,7 @@ def con(p): x0, constraints=[nlc], options={"disp": True}, - jac=lambda p: np.atleast_1d(1.0), + jac=lambda *p: np.atleast_1d(1.0), ) assert np.isclose(result.x, analytic_result.x) From 6990016a270af5cdc066d9bf3f1017ec3a392e0b Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Wed, 7 May 2025 08:04:25 +0100 Subject: [PATCH 11/15] Confirm that stochasticity is behind convergence fails --- .../test_two_normal_example.py | 79 ++++++++++++++++--- 1 file changed, 66 insertions(+), 13 deletions(-) diff --git a/tests/test_integration/test_two_normal_example.py b/tests/test_integration/test_two_normal_example.py index 8c016c7..5fd4fdc 100644 --- a/tests/test_integration/test_two_normal_example.py +++ b/tests/test_integration/test_two_normal_example.py @@ -1,6 +1,10 @@ """""" +from collections.abc import Callable +from pathlib import Path + import jax +import matplotlib.pyplot as plt import numpy as np from scipy.optimize import NonlinearConstraint, minimize @@ -10,14 +14,16 @@ from causalprog.graph import DistributionNode, Graph, ParameterNode -def test_two_normal_example( +def test_two_normal_example( # noqa: PLR0915 rng_key: jax.Array, - n_samples: int = 100, + n_samples: int = 1000, nu_x: float = 1.0, nu_y: float = 1.0, epsilon: float = 1.0, data: tuple[float, ...] = (2.0,), - x0: tuple[float, ...] = (1.25,), + x0: tuple[float, ...] = (1.0,), + *, + plotting: bool = False, ) -> None: """""" data = np.array(data, ndmin=1) @@ -44,7 +50,7 @@ def test_two_normal_example( graph.add_edge(mu, x) graph.add_edge(x, y) - def expectation_with_n_samples(): + def expectation_with_n_samples() -> Callable[[Graph, DistributionNode], float]: def _inner(g: Graph, rv: DistributionNode) -> float: return expectation(g, rv.label, samples=n_samples, rng_key=rng_key) @@ -59,7 +65,9 @@ def analytic_con(p): return np.abs(p - data) analytic_constraint = NonlinearConstraint(analytic_con, lb=-np.inf, ub=epsilon) - analytic_result = minimize(analytic_ce, x0, constraints=[analytic_constraint]) + analytic_result = minimize( + analytic_ce, x0, constraints=[analytic_constraint], options={"disp": True} + ) assert np.isclose(analytic_result.x, true_analytic_value) # Setup the CausalProblem instance. @@ -68,11 +76,13 @@ def analytic_con(p): cp.set_causal_estimand( expectation_with_n_samples(), rvs_to_nodes={"rv": "y"}, + # rvs_to_nodes={"rv": "mu"}, graph_argument="g", ) cp.set_constraints( expectation_with_n_samples(), rvs_to_nodes={"rv": "x"}, + # rvs_to_nodes={"rv": "mu"}, graph_argument="g", ) @@ -82,14 +92,6 @@ def ce(p): def con(p): return np.abs(cp.constraints(p) - data) - # Prior to solving, check that evaluating the CE and constraints bears some - # resemblance to their analytic counterparts. - range_check = np.linspace(0.0, 5.0, num=50) - for value in range_check: - v = np.atleast_1d(value) - assert np.isclose(ce(v), analytic_ce(v), atol=5 / n_samples) - assert np.isclose(con(v), analytic_con(v), atol=5 / n_samples) - # Alright, now try solving the actual problem nlc = NonlinearConstraint(con, lb=-np.inf, ub=epsilon) result = minimize( @@ -100,4 +102,55 @@ def con(p): jac=lambda *p: np.atleast_1d(1.0), ) + if plotting: + # Debug part of the test to check what the functions we are evaluating look like + param_values = np.linspace(0.0, 3.0, num=500, endpoint=True) + f_evals = np.zeros_like(param_values) + c_evals = np.zeros_like(param_values) + for i, val in enumerate(param_values): + f_evals[i] = ce(np.atleast_1d(val)) + c_evals[i] = con(np.atleast_1d(val))[0] + + f_diff: np.ndarray = f_evals - analytic_ce(param_values) + c_diff: np.ndarray = c_evals - analytic_con(param_values) + + n_rows = 2 + n_cols = 2 + fig, ax = plt.subplots(n_rows, n_cols) + for i in range(n_rows): + for j in range(n_cols): + ax[i, j].set_xlabel(r"$\mu$") + + ax[0, 0].plot(param_values, f_evals, color="blue", label="E[Y]") + ax[0, 0].set_ylabel("Function") + ax[0, 0].plot(result.x, ce(result.x), color="red", marker="o") + ax[0, 0].plot(analytic_result.x, ce(analytic_result.x), color="red", marker="x") + + ax[0, 1].plot(param_values, c_evals) + ax[0, 1].set_ylabel("Constraint") + ax[0, 1].plot(result.x, con(result.x), color="red", marker="o") + ax[0, 1].plot( + param_values, + np.ones_like(param_values) * epsilon, + color="red", + linestyle="dashed", + ) + ax[0, 1].plot( + analytic_result.x, con(analytic_result.x), color="red", marker="x" + ) + + ax[1, 0].plot(param_values, f_diff) + ax[1, 0].set_ylabel("Function difference") + + ax[1, 1].plot(param_values, c_diff) + ax[1, 1].set_ylabel("Constraint difference") + + fig.tight_layout() + + save_loc = (Path(__file__).parent / ".." / ".." / ".vscode").resolve() + fig.savefig(save_loc / "_two_normal_plot.png") + + print("Min / max function diff:", f_diff.min(), f_diff.max()) # noqa: T201 + print("Min / max constraint diff:", c_diff.min(), c_diff.max()) # noqa: T201 + assert np.isclose(result.x, analytic_result.x) From a19aa09714440c3b45c36e35d756292e2c529d47 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Thu, 8 May 2025 09:36:18 +0100 Subject: [PATCH 12/15] Hack sample trick --- .../test_two_normal_example.py | 46 +++++++++++++++++-- 1 file changed, 41 insertions(+), 5 deletions(-) diff --git a/tests/test_integration/test_two_normal_example.py b/tests/test_integration/test_two_normal_example.py index 5fd4fdc..29efe5c 100644 --- a/tests/test_integration/test_two_normal_example.py +++ b/tests/test_integration/test_two_normal_example.py @@ -4,6 +4,7 @@ from pathlib import Path import jax +import jax.numpy as jnp import matplotlib.pyplot as plt import numpy as np from scipy.optimize import NonlinearConstraint, minimize @@ -14,6 +15,40 @@ from causalprog.graph import DistributionNode, Graph, ParameterNode +class TestingNormalNode(DistributionNode): + def __init__( + self, + distribution, + label, + *, + parameters=None, + constant_parameters=None, + is_outcome=False, + ): + super().__init__( + distribution, + label, + parameters=parameters, + constant_parameters=constant_parameters, + is_outcome=is_outcome, + ) + + def sample(self, sampled_dependencies, samples, rng_key): + """Sample a value from the node.""" + new_key = jax.random.split(rng_key, 1)[0] + params = dict( + **{ + param_name: sampled_dependencies[param_dependency] + for param_name, param_dependency in self._parameters.items() + }, + **self._constant_parameters, + ) + mean = params["mean"] + std = jnp.sqrt(params["cov"]) + s = mean + std * jax.random.normal(new_key, shape=(samples,), dtype=float) + return s + + def test_two_normal_example( # noqa: PLR0915 rng_key: jax.Array, n_samples: int = 1000, @@ -21,9 +56,9 @@ def test_two_normal_example( # noqa: PLR0915 nu_y: float = 1.0, epsilon: float = 1.0, data: tuple[float, ...] = (2.0,), - x0: tuple[float, ...] = (1.0,), + x0: tuple[float, ...] = (1.1,), *, - plotting: bool = False, + plotting: bool = True, ) -> None: """""" data = np.array(data, ndmin=1) @@ -31,14 +66,14 @@ def test_two_normal_example( # noqa: PLR0915 true_analytic_value = np.array(data) - epsilon mu = ParameterNode("mu") - x = DistributionNode( + x = TestingNormalNode( NormalFamily(), label="x", parameters={"mean": "mu"}, constant_parameters={"cov": nu_x**2}, is_outcome=True, ) - y = DistributionNode( + y = TestingNormalNode( NormalFamily(), label="y", parameters={"mean": "x"}, @@ -93,7 +128,8 @@ def con(p): return np.abs(cp.constraints(p) - data) # Alright, now try solving the actual problem - nlc = NonlinearConstraint(con, lb=-np.inf, ub=epsilon) + jac = lambda x: -1 if x < 2.0 else 1 if x > 2.0 else 0.0 + nlc = NonlinearConstraint(con, lb=-np.inf, ub=epsilon, jac=jac) result = minimize( ce, x0, From b9e647f9c54b0453fd86cc4df95fa97483ec15ff Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Thu, 8 May 2025 11:04:31 +0100 Subject: [PATCH 13/15] Set tolerance level for numerical solver --- .../test_two_normal_example.py | 130 ++++-------------- 1 file changed, 29 insertions(+), 101 deletions(-) diff --git a/tests/test_integration/test_two_normal_example.py b/tests/test_integration/test_two_normal_example.py index 29efe5c..065e1d1 100644 --- a/tests/test_integration/test_two_normal_example.py +++ b/tests/test_integration/test_two_normal_example.py @@ -1,12 +1,10 @@ """""" from collections.abc import Callable -from pathlib import Path import jax -import jax.numpy as jnp -import matplotlib.pyplot as plt import numpy as np +import pytest from scipy.optimize import NonlinearConstraint, minimize from causalprog.algorithms import expectation @@ -15,65 +13,38 @@ from causalprog.graph import DistributionNode, Graph, ParameterNode -class TestingNormalNode(DistributionNode): - def __init__( - self, - distribution, - label, - *, - parameters=None, - constant_parameters=None, - is_outcome=False, - ): - super().__init__( - distribution, - label, - parameters=parameters, - constant_parameters=constant_parameters, - is_outcome=is_outcome, - ) - - def sample(self, sampled_dependencies, samples, rng_key): - """Sample a value from the node.""" - new_key = jax.random.split(rng_key, 1)[0] - params = dict( - **{ - param_name: sampled_dependencies[param_dependency] - for param_name, param_dependency in self._parameters.items() - }, - **self._constant_parameters, - ) - mean = params["mean"] - std = jnp.sqrt(params["cov"]) - s = mean + std * jax.random.normal(new_key, shape=(samples,), dtype=float) - return s - - -def test_two_normal_example( # noqa: PLR0915 +@pytest.mark.parametrize( + ("n_samples"), + [ + pytest.param(1e3, id="1e3 samples"), + pytest.param(1e6, id="1e6 samples"), + pytest.param(1e8, id="1e8 samples"), + ], +) +def test_two_normal_example( + n_samples: int, rng_key: jax.Array, - n_samples: int = 1000, nu_x: float = 1.0, nu_y: float = 1.0, epsilon: float = 1.0, data: tuple[float, ...] = (2.0,), x0: tuple[float, ...] = (1.1,), - *, - plotting: bool = True, ) -> None: """""" + n_samples = int(n_samples) data = np.array(data, ndmin=1) x0 = np.array(x0, ndmin=1) true_analytic_value = np.array(data) - epsilon mu = ParameterNode("mu") - x = TestingNormalNode( + x = DistributionNode( NormalFamily(), label="x", parameters={"mean": "mu"}, constant_parameters={"cov": nu_x**2}, is_outcome=True, ) - y = TestingNormalNode( + y = DistributionNode( NormalFamily(), label="y", parameters={"mean": "x"}, @@ -93,6 +64,8 @@ def _inner(g: Graph, rv: DistributionNode) -> float: # Solve everything analytically first. That is, use the analytic formula for # the CE, and for the Constraints, and solve the resulting problem. + # This will flag if we have setup our problem incorrectly, or changed something + # which affects the problem further down the line. def analytic_ce(p): return p @@ -111,82 +84,37 @@ def analytic_con(p): cp.set_causal_estimand( expectation_with_n_samples(), rvs_to_nodes={"rv": "y"}, - # rvs_to_nodes={"rv": "mu"}, graph_argument="g", ) cp.set_constraints( expectation_with_n_samples(), rvs_to_nodes={"rv": "x"}, - # rvs_to_nodes={"rv": "mu"}, graph_argument="g", ) def ce(p): return cp.causal_estimand(p) + def ce_jacobian(*p): + return 1.0 + def con(p): return np.abs(cp.constraints(p) - data) - # Alright, now try solving the actual problem - jac = lambda x: -1 if x < 2.0 else 1 if x > 2.0 else 0.0 - nlc = NonlinearConstraint(con, lb=-np.inf, ub=epsilon, jac=jac) + def con_jacobian(p): + return -1.0 * (p < data) + (p > data) + + nlc = NonlinearConstraint(con, lb=-np.inf, ub=epsilon, jac=con_jacobian) result = minimize( ce, x0, constraints=[nlc], options={"disp": True}, - jac=lambda *p: np.atleast_1d(1.0), + jac=ce_jacobian, ) - if plotting: - # Debug part of the test to check what the functions we are evaluating look like - param_values = np.linspace(0.0, 3.0, num=500, endpoint=True) - f_evals = np.zeros_like(param_values) - c_evals = np.zeros_like(param_values) - for i, val in enumerate(param_values): - f_evals[i] = ce(np.atleast_1d(val)) - c_evals[i] = con(np.atleast_1d(val))[0] - - f_diff: np.ndarray = f_evals - analytic_ce(param_values) - c_diff: np.ndarray = c_evals - analytic_con(param_values) - - n_rows = 2 - n_cols = 2 - fig, ax = plt.subplots(n_rows, n_cols) - for i in range(n_rows): - for j in range(n_cols): - ax[i, j].set_xlabel(r"$\mu$") - - ax[0, 0].plot(param_values, f_evals, color="blue", label="E[Y]") - ax[0, 0].set_ylabel("Function") - ax[0, 0].plot(result.x, ce(result.x), color="red", marker="o") - ax[0, 0].plot(analytic_result.x, ce(analytic_result.x), color="red", marker="x") - - ax[0, 1].plot(param_values, c_evals) - ax[0, 1].set_ylabel("Constraint") - ax[0, 1].plot(result.x, con(result.x), color="red", marker="o") - ax[0, 1].plot( - param_values, - np.ones_like(param_values) * epsilon, - color="red", - linestyle="dashed", - ) - ax[0, 1].plot( - analytic_result.x, con(analytic_result.x), color="red", marker="x" - ) - - ax[1, 0].plot(param_values, f_diff) - ax[1, 0].set_ylabel("Function difference") - - ax[1, 1].plot(param_values, c_diff) - ax[1, 1].set_ylabel("Constraint difference") - - fig.tight_layout() - - save_loc = (Path(__file__).parent / ".." / ".." / ".vscode").resolve() - fig.savefig(save_loc / "_two_normal_plot.png") - - print("Min / max function diff:", f_diff.min(), f_diff.max()) # noqa: T201 - print("Min / max constraint diff:", c_diff.min(), c_diff.max()) # noqa: T201 - - assert np.isclose(result.x, analytic_result.x) + # When providing both Jacobians, error seems to scale with + # inverse square-root of number of samples. + # Use np.floor to provide more leeway in solution. + atol = 10 ** (-np.floor(np.log10(n_samples) / 2.0)) + assert np.isclose(result.x, analytic_result.x, atol=atol) From bf45b8f2d59de536be486eab543792edb861946d Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Thu, 8 May 2025 11:41:16 +0100 Subject: [PATCH 14/15] Solve both max and min --- .../test_two_normal_example.py | 119 +++++++++++++----- 1 file changed, 89 insertions(+), 30 deletions(-) diff --git a/tests/test_integration/test_two_normal_example.py b/tests/test_integration/test_two_normal_example.py index 065e1d1..be09029 100644 --- a/tests/test_integration/test_two_normal_example.py +++ b/tests/test_integration/test_two_normal_example.py @@ -1,6 +1,34 @@ -"""""" +r"""Integration test for the two-normal example, . + +In this test, we have the following setup: + +- A single parameter $\mu$. +- A random variable $X$ distributed as $X \sim \mathcal{N}(\mu, \nu_X^2)$ for a fixed + $\nu_X$. +- A random variable $Y$ distributed as $Y \vert X \sim \mathcal{N}(X, \nu_Y^2)$ for a + fixed $\nu_Y$. +- The causal estimand of interest $\sigma = \mathbb{E}[Y]$. +- The constraints function is $\vert \mathbb{E}[X] - \phi \vert$ for an empirical + observation of the expectation of $X$, $\phi$. + +For a fixed tolerance $\epsilon > 0$, we are thus we are looking to solve the +following optimisation problem (1): + +$$ \text{min / max}_\mu \mathbb{E}[Y], $$ +$$ \text{subject to } \vert \mathbb{E}[X] - \phi \vert \leq \epsilon. $$ + +It should be noted that analytically, $\mathbb{E}[X] = \mathbb{E}[Y] = \mu$. +This means that we are effectively solving (2): + +$$ \text{min / max}_\mu \mu, $$ +$$ \text{subject to } \vert \mu - \phi \vert \leq \epsilon. $$ + +which we can immediately spot has solutions $\mu = \phi \pm \epsilon$ (larger value in +the case of maximisation). +""" from collections.abc import Callable +from typing import Literal import jax import numpy as np @@ -14,15 +42,19 @@ @pytest.mark.parametrize( - ("n_samples"), + ("n_samples", "max_or_min"), [ - pytest.param(1e3, id="1e3 samples"), - pytest.param(1e6, id="1e6 samples"), - pytest.param(1e8, id="1e8 samples"), + pytest.param(1e3, "min", id="[Min] 1e3 samples"), + pytest.param(1e6, "min", id="[Min] 1e6 samples"), + pytest.param(1e8, "min", id="[Min] 1e8 samples"), + pytest.param(1e3, "max", id="[Max] 1e3 samples"), + pytest.param(1e6, "max", id="[Max] 1e6 samples"), + pytest.param(1e8, "max", id="[Max] 1e8 samples"), ], ) def test_two_normal_example( n_samples: int, + max_or_min: Literal["max", "min"], rng_key: jax.Array, nu_x: float = 1.0, nu_y: float = 1.0, @@ -30,11 +62,38 @@ def test_two_normal_example( data: tuple[float, ...] = (2.0,), x0: tuple[float, ...] = (1.1,), ) -> None: - """""" + r"""Integration test for the two normal example. + + 0) Record the analytic answer we expect, `true_analytic_value`. + 1) Compute the result of solving (1), `analytic_result` via the optimiser, to ensure + that we have setup and understood the problem (and analytical answer) correctly. + This also makes us robust against behaviour changes in the `causal_estimand` and + `constraints` methods of the `CausalProblem` class. + 2) Check that `analytic_result` is close to `true_analytic_value`. + 3) Compute the result of (2) via the optimiser, `result`. + 4) Check that `result` is close to both `true_analytic_result` and `analytic_result` + (see below for tolerances used). + + Empirical experiments suggest that the absolute difference between the `result` and + `analytic_result` scales in proportion to the inverse square of the number of + samples used; + + $$ \mathrm{atol} \propto \mathrm{samples}^{-0.5}, $$ + + so to be generous, we use `atol = 10 ** (-np.floor(np.log10(n_samples) / 2.)))`. + + Finally, it should be noted that in order to obtain a good answer (in any case), we + need to provide the Jacobians of the causal estimand and constraints functions to + the solver. Without these, the results are poor (if the optimiser converges at all). + """ + minimise_options = {"disp": False, "maxiter": 20} + # Maximisation is minimisation of the negation of the objective function. + prefactor = 1.0 if max_or_min == "min" else -1.0 + n_samples = int(n_samples) data = np.array(data, ndmin=1) x0 = np.array(x0, ndmin=1) - true_analytic_value = np.array(data) - epsilon + true_analytic_value = np.array(data) - prefactor * epsilon mu = ParameterNode("mu") x = DistributionNode( @@ -62,22 +121,6 @@ def _inner(g: Graph, rv: DistributionNode) -> float: return _inner - # Solve everything analytically first. That is, use the analytic formula for - # the CE, and for the Constraints, and solve the resulting problem. - # This will flag if we have setup our problem incorrectly, or changed something - # which affects the problem further down the line. - def analytic_ce(p): - return p - - def analytic_con(p): - return np.abs(p - data) - - analytic_constraint = NonlinearConstraint(analytic_con, lb=-np.inf, ub=epsilon) - analytic_result = minimize( - analytic_ce, x0, constraints=[analytic_constraint], options={"disp": True} - ) - assert np.isclose(analytic_result.x, true_analytic_value) - # Setup the CausalProblem instance. cp = CausalProblem(graph, label="CP") @@ -93,10 +136,11 @@ def analytic_con(p): ) def ce(p): - return cp.causal_estimand(p) + return prefactor * cp.causal_estimand(p) def ce_jacobian(*p): - return 1.0 + # Gradient is prefactor * 1.0 since we're effectively minimising y = x + return prefactor def con(p): return np.abs(cp.constraints(p) - data) @@ -104,17 +148,32 @@ def con(p): def con_jacobian(p): return -1.0 * (p < data) + (p > data) + # 1) Analytic solve + def analytic_ce(p): + return prefactor * p + + def analytic_con(p): + return np.abs(p - data) + + analytic_constraint = NonlinearConstraint(analytic_con, lb=-np.inf, ub=epsilon) + analytic_result = minimize( + analytic_ce, x0, constraints=[analytic_constraint], options=minimise_options + ) + + # 2) Check analytic solve. + assert np.isclose(analytic_result.x, true_analytic_value) + + # 3) Solve (1) via the CausalProblem class methods. nlc = NonlinearConstraint(con, lb=-np.inf, ub=epsilon, jac=con_jacobian) result = minimize( ce, x0, constraints=[nlc], - options={"disp": True}, + options=minimise_options, jac=ce_jacobian, ) - # When providing both Jacobians, error seems to scale with - # inverse square-root of number of samples. - # Use np.floor to provide more leeway in solution. - atol = 10 ** (-np.floor(np.log10(n_samples) / 2.0)) + # 4) Check proximity to correct solution. + atol = 10.0 ** (-np.floor(np.log10(n_samples) / 2.0)) assert np.isclose(result.x, analytic_result.x, atol=atol) + assert np.isclose(result.x, true_analytic_value, atol=atol) From ab4a9aed4c8adb11eae9ed2a7bf6fa77925f7f85 Mon Sep 17 00:00:00 2001 From: willGraham01 <1willgraham@gmail.com> Date: Thu, 8 May 2025 11:46:25 +0100 Subject: [PATCH 15/15] Disambiguify comment --- tests/test_integration/test_two_normal_example.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/test_integration/test_two_normal_example.py b/tests/test_integration/test_two_normal_example.py index be09029..c0ec529 100644 --- a/tests/test_integration/test_two_normal_example.py +++ b/tests/test_integration/test_two_normal_example.py @@ -139,7 +139,7 @@ def ce(p): return prefactor * cp.causal_estimand(p) def ce_jacobian(*p): - # Gradient is prefactor * 1.0 since we're effectively minimising y = x + # Gradient is prefactor * 1.0 since we're effectively minimising f(p) = p return prefactor def con(p):