Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 32 additions & 16 deletions pymc/backends/arviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,15 +13,16 @@
# limitations under the License.
"""PyMC-ArviZ conversion code."""

from __future__ import annotations

import logging
import warnings

from collections.abc import Iterable, Mapping, Sequence
from typing import (
TYPE_CHECKING,
Any,
Optional,
Union,
TypeAlias,
cast,
)

Expand All @@ -38,13 +39,16 @@

import pymc

from pymc.model import Model, modelcontext
from pymc.progress_bar import CustomProgress, default_progress_theme
from pymc.pytensorf import PointFunc, extract_obs_data
from pymc.util import get_default_varnames
from pymc.model import modelcontext
from pymc.util import StrongCoords

if TYPE_CHECKING:
from pymc.backends.base import MultiTrace
from pymc.model import Model

from pymc.progress_bar import CustomProgress, default_progress_theme
from pymc.pytensorf import PointFunc, extract_obs_data
from pymc.util import get_default_varnames

___all__ = [""]

Expand All @@ -56,6 +60,7 @@

# random variable object ...
Var = Any
DimsDict: TypeAlias = Mapping[str, Sequence[str]]


def dict_to_dataset_drop_incompatible_coords(vars_dict, *args, dims, coords, **kwargs):
Expand Down Expand Up @@ -85,7 +90,7 @@ def dict_to_dataset_drop_incompatible_coords(vars_dict, *args, dims, coords, **k
return dict_to_dataset(vars_dict, *args, dims=dims, coords=safe_coords, **kwargs)


def find_observations(model: "Model") -> dict[str, Var]:
def find_observations(model: Model) -> dict[str, Var]:
"""If there are observations available, return them as a dictionary."""
observations = {}
for obs in model.observed_RVs:
Expand All @@ -102,7 +107,7 @@ def find_observations(model: "Model") -> dict[str, Var]:
return observations


def find_constants(model: "Model") -> dict[str, Var]:
def find_constants(model: Model) -> dict[str, Var]:
"""If there are constants available, return them as a dictionary."""
model_vars = model.basic_RVs + model.deterministics + model.potentials
value_vars = set(model.rvs_to_values.values())
Expand All @@ -123,7 +128,9 @@ def find_constants(model: "Model") -> dict[str, Var]:
return constant_data


def coords_and_dims_for_inferencedata(model: Model) -> tuple[dict[str, Any], dict[str, Any]]:
def coords_and_dims_for_inferencedata(
model: Model,
) -> tuple[StrongCoords, DimsDict]:
"""Parse PyMC model coords and dims format to one accepted by InferenceData."""
coords = {
cname: np.array(cvals) if isinstance(cvals, tuple) else cvals
Expand Down Expand Up @@ -265,7 +272,7 @@ def __init__(

self.observations = find_observations(self.model)

def split_trace(self) -> tuple[Union[None, "MultiTrace"], Union[None, "MultiTrace"]]:
def split_trace(self) -> tuple[None | MultiTrace, None | MultiTrace]:
"""Split MultiTrace object into posterior and warmup.

Returns
Expand Down Expand Up @@ -491,7 +498,7 @@ def to_inference_data(self):


def to_inference_data(
trace: Optional["MultiTrace"] = None,
trace: MultiTrace | None = None,
*,
prior: Mapping[str, Any] | None = None,
posterior_predictive: Mapping[str, Any] | None = None,
Expand All @@ -500,7 +507,7 @@ def to_inference_data(
coords: CoordSpec | None = None,
dims: DimSpec | None = None,
sample_dims: list | None = None,
model: Optional["Model"] = None,
model: Model | None = None,
save_warmup: bool | None = None,
include_transformed: bool = False,
) -> InferenceData:
Expand Down Expand Up @@ -568,8 +575,8 @@ def to_inference_data(
### perhaps we should have an inplace argument?
def predictions_to_inference_data(
predictions,
posterior_trace: Optional["MultiTrace"] = None,
model: Optional["Model"] = None,
posterior_trace: MultiTrace | None = None,
model: Model | None = None,
coords: CoordSpec | None = None,
dims: DimSpec | None = None,
sample_dims: list | None = None,
Expand Down Expand Up @@ -705,11 +712,20 @@ def apply_function_over_dataset(
)
)

trimmed_dims = {}
for var_name, var_dims in dims.items():
arr = out_trace.get(var_name)
if arr is None:
continue
# Remove sample dims
ndims_without_sample = arr.ndim - len(sample_dims)
trimmed_dims[var_name] = list(var_dims[:ndims_without_sample])

return dict_to_dataset(
out_trace,
library=pymc,
dims=dims,
dims=trimmed_dims,
coords=coords,
default_dims=list(sample_dims),
skip_event_dims=True,
skip_event_dims=False,
)
3 changes: 1 addition & 2 deletions pymc/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@
convert_size,
find_size,
rv_size_is_none,
shape_from_dims,
)
from pymc.logprob.abstract import MeasurableOp, _icdf, _logccdf, _logcdf, _logprob
from pymc.logprob.basic import logp
Expand Down Expand Up @@ -533,7 +532,7 @@ def __new__(
# finally, observed, to determine the shape of the variable.
if kwargs.get("size") is None and kwargs.get("shape") is None:
if dims is not None:
kwargs["shape"] = shape_from_dims(dims, model)
kwargs["shape"] = model.symbolic_shape_from_dims(dims)
elif observed is not None:
kwargs["shape"] = tuple(observed.shape)

Expand Down
61 changes: 17 additions & 44 deletions pymc/distributions/shape_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,14 @@

"""Common shape operations to broadcast samples from probability distributions for stochastic nodes in PyMC."""

from __future__ import annotations

import warnings

from collections.abc import Sequence
from functools import singledispatch
from types import EllipsisType
from typing import Any, TypeAlias, cast
from typing import Any, cast

import numpy as np

Expand All @@ -33,18 +35,25 @@
from pytensor.tensor.type_other import NoneTypeT
from pytensor.tensor.variable import TensorVariable

from pymc.model import modelcontext
from pymc.pytensorf import convert_observed_data
from pymc.exceptions import ShapeError
from pymc.pytensorf import PotentialShapeType, convert_observed_data
from pymc.util import (
Dims,
DimsWithEllipsis,
Shape,
Size,
StrongDims,
StrongDimsWithEllipsis,
StrongShape,
StrongSize,
)

__all__ = [
"change_dist_size",
"rv_size_is_none",
"to_tuple",
]

from pymc.exceptions import ShapeError
from pymc.pytensorf import PotentialShapeType


def to_tuple(shape):
"""Convert ints, arrays, and Nones to tuples.
Expand Down Expand Up @@ -85,19 +94,6 @@ def _check_shape_type(shape):
return tuple(out)


# User-provided can be lazily specified as scalars
Shape: TypeAlias = int | TensorVariable | Sequence[int | Variable]
Dims: TypeAlias = str | Sequence[str | None]
DimsWithEllipsis: TypeAlias = str | EllipsisType | Sequence[str | None | EllipsisType]
Size: TypeAlias = int | TensorVariable | Sequence[int | Variable]

# After conversion to vectors
StrongShape: TypeAlias = TensorVariable | tuple[int | Variable, ...]
StrongDims: TypeAlias = Sequence[str]
StrongDimsWithEllipsis: TypeAlias = Sequence[str | EllipsisType]
StrongSize: TypeAlias = TensorVariable | tuple[int | Variable, ...]


def convert_dims(dims: Dims | None) -> StrongDims | None:
"""Process a user-provided dims variable into None or a valid dims tuple."""
if dims is None:
Expand Down Expand Up @@ -164,31 +160,6 @@ def convert_size(size: Size) -> StrongSize | None:
)


def shape_from_dims(dims: StrongDims, model) -> StrongShape:
"""Determine shape from a `dims` tuple.

Parameters
----------
dims : array-like
A vector of dimension names or None.
model : pm.Model
The current model on stack.

Returns
-------
dims : tuple of (str or None)
Names or None for all RV dimensions.
"""
# Dims must be known already
unknowndim_dims = set(dims) - set(model.dim_lengths)
if unknowndim_dims:
raise KeyError(
f"Dimensions {unknowndim_dims} are unknown to the model and cannot be used to specify a `shape`."
)

return tuple(model.dim_lengths[dname] for dname in dims)


def find_size(
shape: StrongShape | None,
size: StrongSize | None,
Expand Down Expand Up @@ -403,6 +374,8 @@ def get_support_shape(
assert isinstance(dims, tuple)
if len(dims) < ndim_supp:
raise ValueError(f"Number of dims is too small for ndim_supp of {ndim_supp}")
from pymc.model.core import modelcontext

model = modelcontext(None)
inferred_support_shape = [
model.dim_lengths[dims[i]] - support_shape_offset[i] for i in range(-ndim_supp, 0)
Expand Down
16 changes: 10 additions & 6 deletions pymc/model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,9 @@
)
from pymc.util import (
UNSET,
Coords,
CoordValue,
StrongCoords,
WithMemoization,
_UnsetType,
get_transformed_name,
Expand Down Expand Up @@ -453,7 +456,7 @@ def _validate_name(name):
def __init__(
self,
name="",
coords=None,
coords: Coords | None = None,
check_bounds=True,
*,
model: _UnsetType | None | Model = UNSET,
Expand Down Expand Up @@ -488,7 +491,7 @@ def __init__(
self.deterministics = treelist()
self.potentials = treelist()
self.data_vars = treelist()
self._coords = {}
self._coords: StrongCoords = {}
self._dim_lengths = {}
self.add_coords(coords)

Expand Down Expand Up @@ -907,7 +910,7 @@ def unobserved_RVs(self):
return self.free_RVs + self.deterministics

@property
def coords(self) -> dict[str, tuple | None]:
def coords(self) -> StrongCoords:
"""Coordinate values for model dimensions."""
return self._coords

Expand All @@ -919,7 +922,7 @@ def dim_lengths(self) -> dict[str, TensorVariable]:
"""
return self._dim_lengths

def shape_from_dims(self, dims):
def symbolic_shape_from_dims(self, dims):
shape = []
if len(set(dims)) != len(dims):
raise ValueError("Can not contain the same dimension name twice.")
Expand All @@ -931,13 +934,14 @@ def shape_from_dims(self, dims):
"argument of the model or through a pm.Data "
"variable."
)
shape.extend(np.shape(self.coords[dim]))
length = self.dim_lengths[dim]
shape.append(length)
return tuple(shape)

def add_coord(
self,
name: str,
values: Sequence | np.ndarray | None = None,
values: CoordValue = None,
*,
length: int | Variable | None = None,
):
Expand Down
8 changes: 7 additions & 1 deletion pymc/printing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,12 @@
# limitations under the License.


from __future__ import annotations

import re

from functools import partial
from typing import TYPE_CHECKING

from pytensor.compile import SharedVariable
from pytensor.graph.basic import Constant
Expand All @@ -26,7 +29,8 @@
from pytensor.tensor.random.type import RandomType
from pytensor.tensor.type_other import NoneTypeT

from pymc.model import Model
if TYPE_CHECKING:
from pymc.model import Model

__all__ = [
"str_for_dist",
Expand Down Expand Up @@ -302,6 +306,8 @@ def _default_repr_pretty(obj: TensorVariable | Model, p, cycle):
# register our custom pretty printer in ipython shells
import IPython

from pymc.model.core import Model

IPython.lib.pretty.for_type(TensorVariable, _default_repr_pretty)
IPython.lib.pretty.for_type(Model, _default_repr_pretty)
except (ModuleNotFoundError, AttributeError):
Expand Down
2 changes: 1 addition & 1 deletion pymc/step_methods/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ class DataClassState:
def equal_dataclass_values(v1, v2):
if v1.__class__ != v2.__class__:
return False
if isinstance(v1, (list, tuple)): # noqa: UP038
if isinstance(v1, list | tuple):
return len(v1) == len(v2) and all(
equal_dataclass_values(v1i, v2i) for v1i, v2i in zip(v1, v2, strict=True)
)
Expand Down
Loading