Skip to content

Refactor value transform logic in logprob to avoid dummy Ops and complex rewrites #8100

@ricardoV94

Description

@ricardoV94

Refactor value transform logic in logprob to avoid dummy Ops and complex rewrites

The current implementation of value transforms in pymc/logprob/transform_value.py is overly convoluted. It relies on dummy Ops (TransformedValue, TransformedValueRV) and a complex TransformValuesRewrite that is injected into conditional_logp via extra_rewrites.

This complexity is unnecessary because we can achieve the same result by simple graph manipulation in two stages:

  1. Stage 1: Derive Logp with back-transformed values.
    Instead of using a rewriter to inject transforms during the conditional_logp walk, we can pre-process the value variables. For each Random Variable rv with an associated transform, we can use its back-transformed version val_constrained = transform.backward(val_unconstrained, *rv.owner.inputs) as the value variable for conditional_logp.

    conditional_logp will then:

    • Calculate the log-probability of rv evaluated at val_constrained.
    • Correcty replace any occurrences of rv in other log-probability terms with val_constrained.
  2. Stage 2: Apply Jacobian correction.
    After conditional_logp returns the log-probability terms, we can add the Jacobian correction transform.log_jac_det(val_unconstrained, *rv.owner.inputs) to the term associated with rv. We must ensure that any Random Variables remaining in the Jacobian expression are also replaced by their respective value variables (using replace_rvs_by_values).

Benefits

  • Eliminates TransformedValue, TransformedValueRV, and TransformValuesRewrite.
  • Removes the need for a specialized _logprob implementation for TransformedValueRV.
  • Makes the logprob derivation process more transparent and easier to debug.
  • Simplifies the internal IR by removing dummy Ops that "should not be present in the final graph".

Proposed Implementation Sketch

def transformed_conditional_logp(
    rvs, 
    rvs_to_values, 
    rvs_to_transforms, 
    jacobian=True, 
    **kwargs
):
    # 1. Prepare value variables (constrained for those with transforms)
    logp_rv_values = {}
    for rv, val in rvs_to_values.items():
        transform = rvs_to_transforms.get(rv)
        if transform:
            logp_rv_values[rv] = transform.backward(val, *rv.owner.inputs)
        else:
            logp_rv_values[rv] = val
    
    # 2. Derive logp terms
    logp_terms = conditional_logp(logp_rv_values, **kwargs)
    
    # 3. Apply Jacobian and map back to unconstrained value variables
    final_terms = []
    for rv in rvs:
        val_unconstrained = rvs_to_values[rv]
        val_constrained = logp_rv_values[rv]
        logp = logp_terms[val_constrained]
        
        transform = rvs_to_transforms.get(rv)
        if transform and jacobian:
            jac = transform.log_jac_det(val_unconstrained, *rv.owner.inputs)
            # Replace RVs in Jacobian
            [jac] = replace_rvs_by_values([jac], rvs_to_values=rvs_to_values, rvs_to_transforms=rvs_to_transforms)
            
            # Handle potential dimension mismatch (logic from current transformed_value_logprob)
            if jac.ndim < logp.ndim:
                logp = logp.sum(axis=np.arange(jac.ndim - logp.ndim, 0))
            
            logp += jac
            
        final_terms.append(logp)
        
    return final_terms

This refactoring would significantly clean up pymc/logprob/transform_value.py and improve the overall maintainability of the logprob module.

Metadata

Metadata

Assignees

No one assigned

    Labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions