Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make two node example graph #16

Merged
merged 27 commits into from
Mar 20, 2025
Merged
Changes from all commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
03a751f
start working on node that store distributions
mscroggs Mar 13, 2025
ddb7d52
refactor graph, add add_node and add_edge methods
mscroggs Mar 14, 2025
e8a4592
get multiple samples at onde
mscroggs Mar 14, 2025
d7c5a60
parametrise test
mscroggs Mar 14, 2025
bd50473
make two distribution example
mscroggs Mar 14, 2025
02695eb
Merge branch 'main' into mscroggs/normal-example
mscroggs Mar 14, 2025
e78a06c
Merge branch 'main' into mscroggs/normal-example
mscroggs Mar 19, 2025
bbe9284
| None
mscroggs Mar 19, 2025
0a9a7b2
remove irrelevant nodes from test
mscroggs Mar 19, 2025
711ca46
add stdev to test
mscroggs Mar 19, 2025
e460f4f
Update src/causalprog/graph/graph.py
mscroggs Mar 19, 2025
ef667db
Update src/causalprog/graph/graph.py
mscroggs Mar 19, 2025
87fa819
make roots_down_to_outcome a method of the graph
mscroggs Mar 19, 2025
df339e7
Merge branch 'mscroggs/normal-example' of github.com:UCL/causalprog i…
mscroggs Mar 19, 2025
4b457fc
Update src/causalprog/graph/graph.py
mscroggs Mar 19, 2025
b222449
default false
mscroggs Mar 19, 2025
ac5dbdc
Merge branch 'mscroggs/normal-example' of github.com:UCL/causalprog i…
mscroggs Mar 19, 2025
fa7bfe7
don't allow temporary None labels
mscroggs Mar 19, 2025
ed93108
simpler test
mscroggs Mar 19, 2025
ed02c31
improve tests, and simplify iteration code
mscroggs Mar 19, 2025
cc2f1ec
reduce number of tests
mscroggs Mar 19, 2025
8c01b7d
number of samples must be int, don't use 0 for mean so that relative …
mscroggs Mar 20, 2025
6797f4e
Update src/causalprog/graph/graph.py
mscroggs Mar 20, 2025
874d100
Update src/causalprog/graph/graph.py
mscroggs Mar 20, 2025
df45fea
Update tests/test_graph.py
mscroggs Mar 20, 2025
dbc481e
Update src/causalprog/_abc/labelled.py
mscroggs Mar 20, 2025
d5cc428
indentation
mscroggs Mar 20, 2025
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -15,7 +15,11 @@ classifiers = [
"Programming Language :: Python :: 3.13",
"Typing :: Typed",
]
dependencies = ["jax", "networkx"]
dependencies = [
"jax",
"networkx",
"numpy",
]
description = "A Python package for causal modelling and inference with stochastic causal programming"
dynamic = ["version"]
keywords = []
@@ -35,6 +39,7 @@ optional-dependencies = {dev = [
"mkdocstrings-python",
], test = [
"distrax",
"numpy",
"numpyro",
"pytest",
"pytest-cov",
2 changes: 1 addition & 1 deletion src/causalprog/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""causalprog package."""

from . import graph
from . import algorithms, distribution, graph, utils
from ._version import __version__
3 changes: 3 additions & 0 deletions src/causalprog/algorithms/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
"""Algorithms."""

from .expectation import expectation, standard_deviation
41 changes: 41 additions & 0 deletions src/causalprog/algorithms/expectation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
"""Algorithms for estimating the expectation and standard deviation."""

import numpy as np
import numpy.typing as npt

from causalprog.graph import Graph


def sample(
graph: Graph,
outcome_node_label: str | None = None,
samples: int = 1000,
) -> npt.NDArray[float]:
"""Sample data from a graph."""
if outcome_node_label is None:
outcome_node_label = graph.outcome.label

nodes = graph.roots_down_to_outcome(outcome_node_label)

values: dict[str, npt.NDArray[float]] = {}
for node in nodes:
values[node.label] = node.sample(values, samples)
return values[outcome_node_label]


def expectation(
graph: Graph,
outcome_node_label: str | None = None,
samples: int = 1000,
) -> float:
"""Estimate the expectation of a graph."""
return sample(graph, outcome_node_label, samples).mean()


def standard_deviation(
graph: Graph,
outcome_node_label: str | None = None,
samples: int = 1000,
) -> float:
"""Estimate the standard deviation of a graph."""
return np.std(sample(graph, outcome_node_label, samples))
2 changes: 1 addition & 1 deletion src/causalprog/graph/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Creation and storage of graphs."""

from .graph import Graph
from .node import DistributionNode, RootDistributionNode
from .node import DistributionNode, Node
96 changes: 85 additions & 11 deletions src/causalprog/graph/graph.py
Original file line number Diff line number Diff line change
@@ -10,24 +10,98 @@
class Graph(Labelled):
"""A directed acyclic graph that represents a causality tree."""

def __init__(self, graph: nx.Graph, label: str) -> None:
"""Initialise a graph from a NetworkX graph."""
_nodes_by_label: dict[str, Node]

def __init__(self, label: str) -> None:
"""Create end empty graph."""
super().__init__(label=label)
self._graph = nx.DiGraph()
self._nodes_by_label = {}

def get_node(self, label: str) -> Node:
"""Get a node from its label."""
node = self._nodes_by_label.get(label, None)
if not node:
msg = f'Node not found with label "{label}"'
raise KeyError(msg)
return node

def add_node(self, node: Node) -> None:
"""Add a node to the graph."""
if node.label in self._nodes_by_label:
msg = f"Duplicate node label: {node.label}"
raise ValueError(msg)
self._nodes_by_label[node.label] = node
self._graph.add_node(node)

def add_edge(self, first_node: Node | str, second_node: Node | str) -> None:
"""
Add an edge to the graph.
for node in graph.nodes:
if not isinstance(node, Node):
msg = f"Invalid node: {node}"
raise TypeError(msg)
Adding an edge between nodes not currently in the graph,
will cause said nodes to be added to the graph along with
the edge.
"""
if isinstance(first_node, str):
first_node = self.get_node(first_node)
if isinstance(second_node, str):
second_node = self.get_node(second_node)
if first_node.label not in self._nodes_by_label:
self.add_node(first_node)
if second_node.label not in self._nodes_by_label:
self.add_node(second_node)
for node_to_check in (first_node, second_node):
if node_to_check != self._nodes_by_label[node_to_check.label]:
msg = "Invalid node: {node_to_check}"
raise ValueError(msg)
self._graph.add_edge(first_node, second_node)

self._graph = graph.copy()
self._nodes = list(graph.nodes())
self._depth_first_nodes = list(nx.algorithms.dfs_postorder_nodes(graph))
@property
def predecessors(self) -> dict[Node, Node]:
"""Get predecessors of every node."""
return nx.algorithms.dfs_predecessors(self._graph)

outcomes = [node for node in self._nodes if node.is_outcome]
@property
def successors(self) -> dict[Node, list[Node]]:
"""Get successors of every node."""
return nx.algorithms.dfs_successors(self._graph)

@property
def outcome(self) -> Node:
"""The outcome node of the graph."""
outcomes = [node for node in self.nodes if node.is_outcome]
if len(outcomes) == 0:
msg = "Cannot create graph with no outcome nodes"
raise ValueError(msg)
if len(outcomes) > 1:
msg = "Cannot yet create graph with multiple outcome nodes"
raise ValueError(msg)
self._outcome = outcomes[0]
return outcomes[0]

@property
def nodes(self) -> list[Node]:
"""The nodes of the graph."""
return list(self._graph.nodes())

@property
def ordered_nodes(self) -> list[Node]:
"""Nodes ordered so that each node appears after its dependencies."""
if not nx.is_directed_acyclic_graph(self._graph):
msg = "Graph is not acyclic."
raise RuntimeError(msg)
return list(nx.topological_sort(self._graph))

def roots_down_to_outcome(
self,
outcome_node_label: str,
) -> list[Node]:
"""
Get ordered list of nodes that outcome depends on.
Nodes are ordered so that each node appears after its dependencies.
"""
outcome = self.get_node(outcome_node_label)
ancestors = nx.ancestors(self._graph, outcome)
return [
node for node in self.ordered_nodes if node == outcome or node in ancestors
]
127 changes: 60 additions & 67 deletions src/causalprog/graph/node.py
Original file line number Diff line number Diff line change
@@ -2,97 +2,90 @@

from __future__ import annotations

from abc import abstractmethod
from typing import Protocol, runtime_checkable
import typing
from abc import ABC, abstractmethod

from causalprog._abc.labelled import Labelled
import numpy as np

if typing.TYPE_CHECKING:
import numpy.typing as npt

class DistributionFamily:
"""Placeholder class."""
from causalprog._abc.labelled import Labelled


class Distribution:
class Distribution(ABC):
"""Placeholder class."""


@runtime_checkable
class Node(Protocol):
"""An abstract node in a graph."""

@property
@abstractmethod
def label(self) -> str:
"""The label of the node."""

@property
@abstractmethod
def is_root(self) -> bool:
"""Identify if the node is a root."""

@property
@abstractmethod
def is_outcome(self) -> bool:
"""Identify if the node is an outcome."""


class RootDistributionNode(Labelled):
"""A root node containing a distribution family."""
def sample(
self, sampled_dependencies: dict[str, npt.NDArray[float]], samples: int
) -> npt.NDArray[float]:
"""Sample."""


class NormalDistribution(Distribution):
"""Normal distribution."""

def __init__(self, mean: str | float = 0.0, std_dev: str | float = 1.0) -> None:
"""Initialise."""
self.mean = mean
self.std_dev = std_dev

def sample(
self, sampled_dependencies: dict[str, npt.NDArray[float]], samples: int
) -> npt.NDArray[float]:
"""Sample a normal distribution with mean 1."""
values = np.random.normal(0.0, 1.0, samples) # noqa: NPY002
if isinstance(self.std_dev, str):
values *= sampled_dependencies[self.std_dev]
else:
values *= self.std_dev
if isinstance(self.mean, str):
values += sampled_dependencies[self.mean]
else:
values += self.mean
return values


class Node(Labelled):
"""An abstract node in a graph."""

def __init__(
self,
family: DistributionFamily,
label: str,
*,
is_outcome: bool = False,
) -> None:
"""Initialise the node."""
def __init__(self, label: str, *, is_outcome: bool = False) -> None:
"""Initialise."""
super().__init__(label=label)
self._is_outcome = is_outcome

self._dfamily = family
self._outcome = is_outcome

def __repr__(self) -> str:
"""Representation."""
return f'RootDistributionNode("{self._label}")'

@property
def is_root(self) -> bool:
"""Identify if the node is a root."""
return True
@abstractmethod
def sample(
self, sampled_dependencies: dict[str, npt.NDArray[float]], samples: int
) -> float:
"""Sample a value from the node."""

@property
def is_outcome(self) -> bool:
"""Identify if the node is an outcome."""
return self._outcome
return self._is_outcome


class DistributionNode(Labelled):
"""A node containing a distribution family that depends on its parents."""
class DistributionNode(Node):
"""A node containing a distribution."""

def __init__(
self,
family: DistributionFamily,
distribution: Distribution,
label: str,
*,
is_outcome: bool = False,
) -> None:
"""Initialise the node."""
super().__init__(label=label)
"""Initialise."""
self._dist = distribution
super().__init__(label, is_outcome=is_outcome)

self._dfamily = family
self._outcome = is_outcome
def sample(
self, sampled_dependencies: dict[str, npt.NDArray[float]], samples: int
) -> float:
"""Sample a value from the node."""
return self._dist.sample(sampled_dependencies, samples)

def __repr__(self) -> str:
"""Representation."""
return f'DistributionNode("{self._label}")'

@property
def is_root(self) -> bool:
"""Identify if the node is a root."""
return False

@property
def is_outcome(self) -> bool:
"""Identify if the node is an outcome."""
return self._outcome
return f'DistributionNode("{self.label}")'
193 changes: 167 additions & 26 deletions tests/test_graph.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,184 @@
"""Tests for graph module."""

import networkx as nx
import re

import numpy as np
import pytest

import causalprog


def test_label2() -> None:
"""Test nodes."""
family = causalprog.graph.node.DistributionFamily()
node = causalprog.graph.RootDistributionNode(family, "N0")
node2 = causalprog.graph.RootDistributionNode(family, "N1")
node3 = causalprog.graph.RootDistributionNode(family, "Y")
node4 = causalprog.graph.DistributionNode(family, "N4")
def test_label():
d = causalprog.graph.node.NormalDistribution()
node = causalprog.graph.DistributionNode(d, "X")
node2 = causalprog.graph.DistributionNode(d, "Y")
node_copy = node

assert node.label == node_copy.label
assert node.label == node_copy.label == "X"
assert node.label != node2.label
assert node.label != node3.label
assert node.label != node4.label
assert node2.label == "Y"

assert isinstance(node, causalprog.graph.node.Node)
assert isinstance(node2, causalprog.graph.node.Node)
assert isinstance(node3, causalprog.graph.node.Node)
assert isinstance(node4, causalprog.graph.node.Node)


def test_simple_graph() -> None:
"""Test a simple graph."""
family = causalprog.graph.node.DistributionFamily()
n_x = causalprog.graph.RootDistributionNode(family, "N_X")
n_m = causalprog.graph.RootDistributionNode(family, "N_M")
u_y = causalprog.graph.RootDistributionNode(family, "U_Y")
x = causalprog.graph.DistributionNode(family, "X")
m = causalprog.graph.DistributionNode(family, "M")
y = causalprog.graph.DistributionNode(family, "Y", is_outcome=True)
def test_duplicate_label():
d = causalprog.graph.node.NormalDistribution()

graph = causalprog.graph.Graph("G0")
graph.add_node(causalprog.graph.DistributionNode(d, "X"))
with pytest.raises(ValueError, match=re.escape("Duplicate node label: X")):
graph.add_node(causalprog.graph.DistributionNode(d, "X"))


@pytest.mark.parametrize(
"use_labels",
[pytest.param(True, id="Via labels"), pytest.param(False, id="Via variables")],
)
def test_build_graph(*, use_labels: bool) -> None:
root_label = "root"
outcome_label = "outcome_label"
d = causalprog.graph.node.NormalDistribution()

root_node = causalprog.graph.DistributionNode(d, root_label)
outcome_node = causalprog.graph.DistributionNode(d, outcome_label, is_outcome=True)

graph = causalprog.graph.Graph("G0")
graph.add_node(root_node)
graph.add_node(outcome_node)

if use_labels:
graph.add_edge(root_label, outcome_label)
else:
graph.add_edge(root_node, outcome_node)

assert graph.roots_down_to_outcome(outcome_label) == [root_node, outcome_node]


def test_roots_down_to_outcome() -> None:
d = causalprog.graph.node.NormalDistribution()

graph = causalprog.graph.Graph("G0")

u = causalprog.graph.DistributionNode(d, "U")
v = causalprog.graph.DistributionNode(d, "V")
w = causalprog.graph.DistributionNode(d, "W")
x = causalprog.graph.DistributionNode(d, "X")
y = causalprog.graph.DistributionNode(d, "Y")
z = causalprog.graph.DistributionNode(d, "Z")

graph.add_node(u)
graph.add_node(v)
graph.add_node(w)
graph.add_node(x)
graph.add_node(y)
graph.add_node(z)

graph.add_edge("V", "W")
graph.add_edge("V", "X")
graph.add_edge("V", "Y")
graph.add_edge("X", "Z")
graph.add_edge("Y", "Z")
graph.add_edge("U", "Z")

assert graph.roots_down_to_outcome("V") == [v]
assert graph.roots_down_to_outcome("W") == [v, w]
nodes = graph.roots_down_to_outcome("Z")
assert len(nodes) == 5 # noqa: PLR2004
assert (
nodes.index(v)
< min(nodes.index(x), nodes.index(y))
< max(nodes.index(x), nodes.index(y))
< nodes.index(z)
)
assert nodes.index(u) < nodes.index(z)


def test_cycle() -> None:
d = causalprog.graph.node.NormalDistribution()

node0 = causalprog.graph.DistributionNode(d, "X")
node1 = causalprog.graph.DistributionNode(d, "Y")
node2 = causalprog.graph.DistributionNode(d, "Z")

graph = causalprog.graph.Graph("G0")
graph.add_edge(node0, node1)
graph.add_edge(node1, node2)
graph.add_edge(node2, node0)

with pytest.raises(RuntimeError, match="Graph is not acyclic."):
graph.roots_down_to_outcome("X")


@pytest.mark.parametrize(
("mean", "stdev", "samples", "rtol"),
[
pytest.param(1.0, 1.0, 10, 1, id="std normal, 10 samples"),
pytest.param(2.0, 0.8, 1000, 1e-1, id="non-standard normal, 100 samples"),
pytest.param(1.0, 1.0, 100000, 1e-2, id="std normal, 10^5 samples"),
pytest.param(1.0, 1.0, 10000000, 1e-3, id="std normal, 10^7 samples"),
],
)
def test_single_normal_node(samples, rtol, mean, stdev):
normal = causalprog.graph.node.NormalDistribution(mean, stdev)
node = causalprog.graph.DistributionNode(normal, "X", is_outcome=True)

graph = causalprog.graph.Graph("G0")
graph.add_node(node)

assert np.isclose(
causalprog.algorithms.expectation(graph, samples=samples), mean, rtol=rtol
)
assert np.isclose(
causalprog.algorithms.standard_deviation(graph, samples=samples),
stdev,
rtol=rtol,
)


nx_graph = nx.Graph()
nx_graph.add_edges_from([[n_x, x], [n_m, m], [u_y, y], [x, m], [m, y]])
@pytest.mark.parametrize(
("mean", "stdev", "stdev2", "samples", "rtol"),
[
pytest.param(
1.0,
1.0,
0.8,
100,
1,
id="N(mean=N(mean=0, stdev=1), stdev=0.8), 100 samples",
),
pytest.param(
3.0,
0.5,
1.0,
10000,
1e-1,
id="N(mean=N(mean=3, stdev=0.5), stdev=1), 10^4 samples",
),
pytest.param(
2.0,
0.7,
0.8,
1000000,
1e-2,
id="N(mean=N(mean=2, stdev=0.7), stdev=0.8), 10^6 samples",
),
],
)
def test_two_node_graph(samples, rtol, mean, stdev, stdev2):
normal = causalprog.graph.node.NormalDistribution(mean, stdev)
normal2 = causalprog.graph.node.NormalDistribution("UX", stdev2)

graph = causalprog.graph.Graph(nx_graph, "G0")
graph = causalprog.graph.Graph("G0")
graph.add_node(causalprog.graph.DistributionNode(normal, "UX"))
graph.add_node(causalprog.graph.DistributionNode(normal2, "X", is_outcome=True))
graph.add_edge("UX", "X")

assert graph.label == "G0"
assert np.isclose(
causalprog.algorithms.expectation(graph, samples=samples), mean, rtol=rtol
)
assert np.isclose(
causalprog.algorithms.standard_deviation(graph, samples=samples),
np.sqrt(stdev**2 + stdev2**2),
rtol=rtol,
)