Skip to content

Add support for tensor abbreviations and better addressing #104

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 3 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
155 changes: 77 additions & 78 deletions namedtensor/core.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from .schema import _Schema
from collections import OrderedDict
import operator
import functools

Expand All @@ -13,13 +14,14 @@ def assert_match(*tensors):
for t in tensors:
shape = t.vshape
for i, k in t._schema.enum_all():
name = k.name
v = shape[i]
if v == 1:
continue
if k in sizes:
failure = failure or sizes[k] != v
if name in sizes:
failure = failure or sizes[name] != v
else:
sizes[k] = v
sizes[name] = v
assert not failure, "Overlapping dim names must match: " + " ".join(
[str(t.shape) for t in tensors]
)
Expand All @@ -39,9 +41,9 @@ def __init__(self, tensor, names, mask=0):
self._tensor = tensor
self._schema = _Schema.build(names, mask)
if self._tensor.dim() > 0:
assert len(self._tensor.shape) == len(self._schema._names), (
assert len(self._tensor.shape) == len(self._schema.axes), (
"Tensor has %d dim, but %d names"
% (len(self._tensor.shape), len(self._schema._names))
% (len(self._tensor.shape), len(self._schema.axes))
)
else:
assert len(names) == 0, str(tensor)
Expand All @@ -54,7 +56,7 @@ def __deepcopy__(self, memo):
@property
def dims(self):
"Return the dim names for the tensor"
return tuple(self._schema._names)
return tuple(self._schema.axes)

@property
def vshape(self):
Expand All @@ -64,18 +66,15 @@ def vshape(self):
@property
def shape(self):
"The ordered dict of available dimensions."
return self._schema.ordered_dict(self._tensor.size())
#return tuple(zip(self.dims, self._tensor.shape))
return OrderedDict(zip(self._schema._names, self._tensor.shape))

def __repr__(self):
return "NamedTensor(\n\t%s,\n\t%s)" % (
self._tensor,
self._schema._names,
)
return "NamedTensor(\n\t{},\n\t{})".format(self._tensor, self.dims)

def size(self, dim):
"Return the raw shape of the tensor"
i = self._schema.get(dim)
return self._tensor.size(i)
return self._tensor.size(self._schema.get(dim))

def assert_size(self, **kwargs):
"Return the raw shape of the tensor"
Expand All @@ -93,9 +92,11 @@ def values(self):
return self._tensor

def _new(self, tensor, drop=None, add=None, updates={}, mask=None):
# raise RuntimeError("Err drop=%s" % drop +
# " add=%s" % add + "updates=%s" % updates )
return self.__class__(
tensor,
self._schema.drop(drop).update(updates)._names
self._schema.drop(drop).update(updates).axes
+ (() if not add else add),
self._schema._masked if mask is None else mask,
)
Expand All @@ -111,9 +112,7 @@ def mask_to(self, name):

def stack(self, dims, name):
"Stack any number of existing dimensions into a single new dimension."
for dim in dims:
self._schema.get(dim)
return self._merge(dims, name)
return self._stack(dims, name)

def split(self, dim, names, **dim_sizes):
"Split an of existing dimension into new dimensions."
Expand All @@ -125,103 +124,103 @@ def rename(self, dim, name):

def transpose(self, *dims):
"Return a new DataArray object with transposed dimensions."
for dim in dims:
self._schema.get(dim)
to_dims = (
tuple((d for d in self._schema._names if d not in dims)) + dims
)
to_dims = tuple(d for d in self.dims if d not in dims) + dims
indices = [self._schema.get(d) for d in to_dims]
tensor = self._tensor.permute(*indices)
return self.__class__(tensor, to_dims)

# Todo: fix arg names
def _merge(self, names, dim):
s = []
ex = []
def _stack(self, names, dim):
trans = []
new_schema = []
first = True
view = []
for d in self._schema._names:
for d in self.dims:
if d not in names:
s.append(d)
ex.append(d)
view.append(self.shape[d])
trans.append(d)
new_schema.append(d)
view.append(self.size(d))
elif first:
s += names
view.append(prod([self.shape[d2] for d2 in names]))
ex.append(dim)
trans += names
view.append(prod([self.size(d2) for d2 in names]))
new_schema.append(dim)
first = False
tensor = self.transpose(*s)._tensor.contiguous().view(*view)
return self.__class__(tensor, ex)
tensor = self.transpose(*trans)._tensor.contiguous().view(*view)
return self.__class__(tensor, new_schema)

def _split(self, dim, names, size_dict):
query = []
ex = []
new_schema = []
view = []
dim_num = self._schema.get(dim)
for i, d in self._schema.enum_all():
if d != dim:
query.append(d)
ex.append(d)
view.append(self.shape[d])
if i != dim_num:
new_schema.append(d)
view.append(self.size(d))
else:
query += names
for d2 in names:
view.append(size_dict.get(d2, -1))
ex += names
return self.__class__(self._tensor.view(*view), ex)
d2 = d2.split(":")
view.append(
size_dict.get(d2[-1], size_dict.get(d2[0], -1))
)
new_schema += names
return self.__class__(self._tensor.view(*view), new_schema)

def __len__(self):
return len(self._tensor)

def _promote(self, dims):
"Move dims to the front of the line"
term = [
d for d in self._schema._names if d not in dims
] + dims.split()[1:]
""" Move dims to the front of the line """
raise RuntimeError("Err %s" % dims)

term = [d for d in self.dims if d not in dims] + dims.split()[1:]

return self.transpose(*term)

def _force_order(self, names):
""" Forces self to take order in names, adds 1-size dims if needed """
ex = []
if isinstance(names, _Schema):
names = names.axes
new_schema = []
view = []
trans = []
for d in names:
if d not in self._schema._names:
ex.append(d)
if d not in self.dims:
new_schema.append(d)
view.append(1)
else:
ex.append(d)
view.append(self.shape[d])
new_schema.append(d)
view.append(self.size(d))
trans.append(d)
return self.__class__(
self.transpose(*trans)._tensor.contiguous().view(*view), ex
self.transpose(*trans)._tensor.contiguous().view(*view), new_schema
)

def _broadcast_order(self, other_names):
def _broadcast_order(self, other):
""" Outputs a shared order (list) that works for self and other """
order = []
for d in other_names:
if d not in self._schema._names:
order.append(d)
for d in self._schema._names:
order.append(d)
return order

def _mask_broadcast_order(self, main_names):
if isinstance(other, list):
return self._schema.merge(_Schema(other))
return self._schema.merge(other._schema)
# order = []
# for d in other_names:
# if d not in self.dims:
# order.append(d)
# for d in self.dims:
# order.append(d)
# return order

def _mask_broadcast_order(self, other):
"""
If broadcasting possible from self (mask) to main, outputs a shared order.
Otherwise errors and prints dimensions that exist in mask but not in main.
"""

to_be_broadcasted = set(self._schema._names)
broadcasted_to = set(main_names)

diff = to_be_broadcasted.difference(broadcasted_to)
diff_string = ", ".join(diff)

assert len(diff) == 0, (
"Attemped to broadcast mask but unable to broadcast dimensions %s"
% diff_string
)

return main_names
if isinstance(other, list):
return self._schema.merge(_Schema(other))
return self._schema.merge(other._schema)
#to_be_broadcasted = set(self.dims)
#broadcasted_to = set(main_names)
#diff = to_be_broadcasted.difference(broadcasted_to)
#diff_string = ", ".join(diff)
#assert len(diff) == 0, (
# "Attemped to broadcast mask but unable to broadcast dimensions %s"
# % diff_string
#)
#return main_names
Loading