-
Notifications
You must be signed in to change notification settings - Fork 2.2k
Description
Refactor value transform logic in logprob to avoid dummy Ops and complex rewrites
The current implementation of value transforms in pymc/logprob/transform_value.py is overly convoluted. It relies on dummy Ops (TransformedValue, TransformedValueRV) and a complex TransformValuesRewrite that is injected into conditional_logp via extra_rewrites.
This complexity is unnecessary because we can achieve the same result by simple graph manipulation in two stages:
-
Stage 1: Derive Logp with back-transformed values.
Instead of using a rewriter to inject transforms during theconditional_logpwalk, we can pre-process the value variables. For each Random Variablervwith an associatedtransform, we can use its back-transformed versionval_constrained = transform.backward(val_unconstrained, *rv.owner.inputs)as the value variable forconditional_logp.conditional_logpwill then:- Calculate the log-probability of
rvevaluated atval_constrained. - Correcty replace any occurrences of
rvin other log-probability terms withval_constrained.
- Calculate the log-probability of
-
Stage 2: Apply Jacobian correction.
Afterconditional_logpreturns the log-probability terms, we can add the Jacobian correctiontransform.log_jac_det(val_unconstrained, *rv.owner.inputs)to the term associated withrv. We must ensure that any Random Variables remaining in the Jacobian expression are also replaced by their respective value variables (usingreplace_rvs_by_values).
Benefits
- Eliminates
TransformedValue,TransformedValueRV, andTransformValuesRewrite. - Removes the need for a specialized
_logprobimplementation forTransformedValueRV. - Makes the
logprobderivation process more transparent and easier to debug. - Simplifies the internal IR by removing dummy Ops that "should not be present in the final graph".
Proposed Implementation Sketch
def transformed_conditional_logp(
rvs,
rvs_to_values,
rvs_to_transforms,
jacobian=True,
**kwargs
):
# 1. Prepare value variables (constrained for those with transforms)
logp_rv_values = {}
for rv, val in rvs_to_values.items():
transform = rvs_to_transforms.get(rv)
if transform:
logp_rv_values[rv] = transform.backward(val, *rv.owner.inputs)
else:
logp_rv_values[rv] = val
# 2. Derive logp terms
logp_terms = conditional_logp(logp_rv_values, **kwargs)
# 3. Apply Jacobian and map back to unconstrained value variables
final_terms = []
for rv in rvs:
val_unconstrained = rvs_to_values[rv]
val_constrained = logp_rv_values[rv]
logp = logp_terms[val_constrained]
transform = rvs_to_transforms.get(rv)
if transform and jacobian:
jac = transform.log_jac_det(val_unconstrained, *rv.owner.inputs)
# Replace RVs in Jacobian
[jac] = replace_rvs_by_values([jac], rvs_to_values=rvs_to_values, rvs_to_transforms=rvs_to_transforms)
# Handle potential dimension mismatch (logic from current transformed_value_logprob)
if jac.ndim < logp.ndim:
logp = logp.sum(axis=np.arange(jac.ndim - logp.ndim, 0))
logp += jac
final_terms.append(logp)
return final_termsThis refactoring would significantly clean up pymc/logprob/transform_value.py and improve the overall maintainability of the logprob module.