Skip to content

This issue was moved to a discussion.

You can continue the conversation there. Go to discussion →

Labeled Tensors as first class objects #1421

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Closed
aseyboldt opened this issue May 26, 2025 · 0 comments
Closed

Labeled Tensors as first class objects #1421

aseyboldt opened this issue May 26, 2025 · 0 comments

Comments

@aseyboldt
Copy link
Member

Description

Just a couple of thoughts regarding #1411.

The current implementation uses strings to identify a dimension. But maybe in the static graph framework of pytensor it would make more sense to think of them as first class objects, so that they can have their own graph-like structure? Maybe they could be a different subclass of Variable as well? I think that might lead to cleaner code for users (no typos in dimension names that lead to silent broadcasting and gigantic arrays with an out-of-memory error), and it might also make derived dimensions easier to handle and reason about?

I think we could do something like

class DimensionType(Type):
    pass

# I guess all dimensions have the same type?
DimType = DimensionType()

class Dimension(Variable):
    def __init__(self, name=None, length=None):
        pass

    def length(self) -> TensorVariable:
        ...

class DimOp(Op):
    pass

class DimConstant(DimensionVariable):
    def __init__(self, name, *, length=None):
        pass

# The result of stacking two dims
class Product(DimOp):
    __props__ = ("name",)

    def make_node(self, *inputs: DimensionVariable):
        if self.name is not None:
            name = self.name
        else:
            name = f"product[{','.join(input.name for input in inputs)}]"
        output = Dimension(name=name, length=prod(input.length for input in inputs))
        return Apply(self, inputs, [output])

def dim(name, *, length=None):
    return DimConstant(name, length=length)

def stack(variable, *, dims, name=None):
    dim_op = Product(name)
    stacked_dim = dim_op(dims)
    ...

Final usage could maybe be something like this?

country = pt.dim("country")
treatment = pt.dim("treatment")

# We can talk about the stacked dim directly:
interaction = pt.stacked_dim(country, treatment)
effect = xtensor(dim=interaction)
assert effect.unstack(interaction).dims == (country, treatment)
assert pt.stack(effect, [country, treatment]).dims == interaction

# Same for indices and slices...
subset = pt.dim_slice(country, slice("A", "B"))
sub_effect = pt.xtensor(dims=subset)
effect = pt.xtensor(dims=country)
assert effect.sel(country=slice("A", "B")).dims == (subset,)
effect = effect.at.sel(country=slice("A", "B")).add(sub_effect)
@pymc-devs pymc-devs locked and limited conversation to collaborators May 27, 2025
@ricardoV94 ricardoV94 converted this issue into discussion #1423 May 27, 2025

This issue was moved to a discussion.

You can continue the conversation there. Go to discussion →

Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant