From d9dc15c8b45899ed0b1a4c2c10c268f1e80b7b29 Mon Sep 17 00:00:00 2001 From: EndingCredits Date: Sun, 30 Jun 2019 12:18:55 +0100 Subject: [PATCH 1/3] added tensor abbreviations --- namedtensor/core.py | 31 +++------ namedtensor/schema.py | 146 +++++++++++++++++++++++++++++++----------- 2 files changed, 118 insertions(+), 59 deletions(-) diff --git a/namedtensor/core.py b/namedtensor/core.py index dbfa793..3816f18 100644 --- a/namedtensor/core.py +++ b/namedtensor/core.py @@ -10,19 +10,13 @@ def prod(factors): def assert_match(*tensors): sizes = {} failure = False - for t in tensors: - shape = t.vshape - for i, k in t._schema.enum_all(): - v = shape[i] - if v == 1: - continue - if k in sizes: - failure = failure or sizes[k] != v - else: - sizes[k] = v - assert not failure, "Overlapping dim names must match: " + " ".join( - [str(t.shape) for t in tensors] - ) + axes = [] + for tensor in tensors: + axes = axes + list(tensor._schema._axes) + for ax in axes: + assert not ax.conflict(*axes), "Overlapping dim names must match: " + " ".join( + [str(t.shape) for t in tensors] + ) class NamedTensorBase: @@ -37,14 +31,7 @@ class NamedTensorBase: def __init__(self, tensor, names, mask=0): self._tensor = tensor - self._schema = _Schema.build(names, mask) - if self._tensor.dim() > 0: - assert len(self._tensor.shape) == len(self._schema._names), ( - "Tensor has %d dim, but %d names" - % (len(self._tensor.shape), len(self._schema._names)) - ) - else: - assert len(names) == 0, str(tensor) + self._schema = _Schema.build(names, self._tensor.shape, mask) def __deepcopy__(self, memo): new_ntensor = self._new(self._tensor.__deepcopy__(memo)) @@ -64,7 +51,7 @@ def vshape(self): @property def shape(self): "The ordered dict of available dimensions." - return self._schema.ordered_dict(self._tensor.size()) + return self._schema.ordered_dict() def __repr__(self): return "NamedTensor(\n\t%s,\n\t%s)" % ( diff --git a/namedtensor/schema.py b/namedtensor/schema.py index 2a00f0d..e64890f 100644 --- a/namedtensor/schema.py +++ b/namedtensor/schema.py @@ -1,68 +1,140 @@ from collections import OrderedDict from .utils import make_tuple +class _Axis(object): + "A dimension" + def __init__(self, name, size=None): + self._set_name(name) + self.size = size + + def conflict(self, *axes): + sizes = [] + if self.size != 1 and not self.size is None: + sizes = [self.size, ] + for axis in axes: + if self.name == axis.name and self.abbr != axis.abbr: + return True + if self.abbr == axis.abbr and self.name != axis.name: + return True + if self.abbr == axis.abbr and self.name == axis.name: + if axis.size != 1 and not axis.size is None: + sizes.append(axis.size) + return not all(x == sizes[0] for x in sizes) + + def _set_name(self, name): + names = name.split(':') + if len(names) == 1: + self.name = names[0] + self.abbr = self.name[0] + elif len(names) == 2: + self.name = names[1] + self.abbr = names[0] + if not len(self.abbr) == 1: + raise RuntimeError("Error setting axis name {}\n".format(name) + + "Abbreviations must be a single character") + else: + raise RuntimeError("Error setting axis name {}\n".format(name) + + "Valid names are of the form 'name' or 'n:name'") + + def __eq__(self, other): + if isinstance(other, _Axis): + return self.name == other.name or self.abbr == other.abbr + elif isinstance(other, str): + return str(self) == other or self.name==other or self.abbr==other + + def __str__(self): + return self.abbr + ":" + self.name + + def __repr__(self): + return str(self) + + + class _Schema: "Dimension names and order" - def __init__(self, names, mask=0): - self._names = make_tuple(names) + def __init__(self, names, sizes=None, mask=0): + self._masked = mask + names= make_tuple(names) + if sizes is not None and len(sizes) != len(names): + raise RuntimeError("Error setting schema shape, " + + "'{}' does not match {}".format(sizes, names)) - s = set() - for n in self._names: - assert n is not None - assert n.isalnum(), "dim name %s must be alphanumeric" % n - assert n not in s, ( - "Tensor must have unique dims, dim '%s' is non-unique" % n - ) - s.add(n) + axes = [] + for i, name in enumerate(names): + if not isinstance(name, _Axis): + name = _Axis(name, sizes[i]) + elif sizes is not None and name.size != sizes[i]: + raise RuntimeError("Error setting schema dimension, " + + "dimension '{}' has size {} but attempting to set with {}".\ + format(name, name.size, sizes[i])) + if name in axes: + raise RuntimeError("Tensor must have unique dims, " + + "dim '{}' is not unique, dims={}".format(name, names) + + "(Note: dimension names and dimension abbreviations" + + "must both be unique)") + axes.append(name) + self._axes = tuple(axes) - self._masked = mask - self._axes = OrderedDict(((d, i) for i, d in enumerate(self._names))) + @property + def _names(self): + return tuple(str(axis) for axis in self._axes) + + @property + def _abbrs(self): + return tuple(axis.abbr for axis in self._axes) def _to_einops(self): - return " ".join(self._names) + return " ".join(self._abbrs) - def ordered_dict(self, size): - return OrderedDict(((d, size[i]) for i, d in self.enum_masked())) + def ordered_dict(self): + return OrderedDict((str(a), a.size) for a in self._axes) @staticmethod - def build(names, mask): + def build(names, sizes, mask=0): if isinstance(names, _Schema): - return _Schema(names._names, mask) - names = make_tuple(names) - return _Schema(names, mask) + return _Schema(names._axes, sizes, mask) + return _Schema(names, sizes, mask) def get(self, name): + for i, n in self.enum_all(): + if name == n: + if i < self._masked: + raise RuntimeError("Dimension {} is masked".format(name,)) + return i if name not in self._axes: raise RuntimeError( - "Dimension %s does not exist. Available dimensions are %s" - % (name, self._names) + "Dimension {} does not exist. Available dimensions are {}".\ + format(name, self._names) ) - i = self._axes[name] - if i < self._masked: - raise RuntimeError("Dimension %s is masked" % (name,)) - return i - + else: + raise RuntimeError( # Not sure how we'd get here + "Something unexpected occured while searching for {} in {}".\ + format(name, self._names) + ) + + def drop(self, names): - names = make_tuple(names) + names = [_Axis(name) for name in make_tuple(names)] return _Schema( - [n for n in self._names if n not in names], self._masked + [n for n in self._axes if n not in names], mask=self._masked ) def update(self, update): if not update: return self - fail = True - for n in self._names: - if n in update: - fail = False - if fail: - raise RuntimeError("Tried to update unknown dim %s" % update) - return _Schema([update.get(n, n) for n in self._names], self._masked) + raise RuntimeError("Update err %s" % update) + # fail = True + # for n in self._names: + # if n in update: + # fail = False + # if fail: + # raise RuntimeError("Tried to update unknown dim %s" % update) + # return _Schema([update.get(n, n) for n in self._names], self._masked) def enum_masked(self): - return enumerate(self._names[self._masked :], self._masked) + return enumerate(self._axes[self._masked :], self._masked) def enum_all(self): - return enumerate(self._names) + return enumerate(self._axes) From db9e73ed4a61ca9bc5429db07c4102a044033a7f Mon Sep 17 00:00:00 2001 From: EndingCredits Date: Tue, 2 Jul 2019 22:49:10 +0100 Subject: [PATCH 2/3] revised abbreviations --- namedtensor/core.py | 118 +++++++++++++++++------------------ namedtensor/schema.py | 102 +++++++++++++++++------------- namedtensor/torch_base.py | 33 +++++----- namedtensor/torch_helpers.py | 4 +- 4 files changed, 134 insertions(+), 123 deletions(-) diff --git a/namedtensor/core.py b/namedtensor/core.py index 3816f18..76d3ea4 100644 --- a/namedtensor/core.py +++ b/namedtensor/core.py @@ -12,7 +12,7 @@ def assert_match(*tensors): failure = False axes = [] for tensor in tensors: - axes = axes + list(tensor._schema._axes) + axes = axes + list(tensor.dims) for ax in axes: assert not ax.conflict(*axes), "Overlapping dim names must match: " + " ".join( [str(t.shape) for t in tensors] @@ -31,7 +31,7 @@ class NamedTensorBase: def __init__(self, tensor, names, mask=0): self._tensor = tensor - self._schema = _Schema.build(names, self._tensor.shape, mask) + self._schema = _Schema.build(names, mask) def __deepcopy__(self, memo): new_ntensor = self._new(self._tensor.__deepcopy__(memo)) @@ -41,7 +41,7 @@ def __deepcopy__(self, memo): @property def dims(self): "Return the dim names for the tensor" - return tuple(self._schema._names) + return tuple(self._schema.axes) @property def vshape(self): @@ -51,18 +51,17 @@ def vshape(self): @property def shape(self): "The ordered dict of available dimensions." - return self._schema.ordered_dict() + return tuple(zip(self.dims, self._tensor.shape)) def __repr__(self): - return "NamedTensor(\n\t%s,\n\t%s)" % ( + return "NamedTensor(\n\t{},\n\t{})".format( self._tensor, - self._schema._names, + self.dims, ) def size(self, dim): "Return the raw shape of the tensor" - i = self._schema.get(dim) - return self._tensor.size(i) + return self._tensor.size(self._schema.get(dim)) def assert_size(self, **kwargs): "Return the raw shape of the tensor" @@ -80,9 +79,11 @@ def values(self): return self._tensor def _new(self, tensor, drop=None, add=None, updates={}, mask=None): + #raise RuntimeError("Err drop=%s" % drop + + # " add=%s" % add + "updates=%s" % updates ) return self.__class__( tensor, - self._schema.drop(drop).update(updates)._names + self._schema.drop(drop).update(updates).axes + (() if not add else add), self._schema._masked if mask is None else mask, ) @@ -98,9 +99,7 @@ def mask_to(self, name): def stack(self, dims, name): "Stack any number of existing dimensions into a single new dimension." - for dim in dims: - self._schema.get(dim) - return self._merge(dims, name) + return self._stack(dims, name) def split(self, dim, names, **dim_sizes): "Split an of existing dimension into new dimensions." @@ -112,95 +111,92 @@ def rename(self, dim, name): def transpose(self, *dims): "Return a new DataArray object with transposed dimensions." - for dim in dims: - self._schema.get(dim) - to_dims = ( - tuple((d for d in self._schema._names if d not in dims)) + dims - ) - indices = [self._schema.get(d) for d in to_dims] + to_dims = ( tuple(d for d in self.dims if d not in dims) + dims ) + indices = [ self._schema.get(d) for d in to_dims ] tensor = self._tensor.permute(*indices) return self.__class__(tensor, to_dims) - # Todo: fix arg names - def _merge(self, names, dim): - s = [] - ex = [] + def _stack(self, names, dim): + trans = [] + new_schema = [] first = True view = [] - for d in self._schema._names: + for d in self.dims: if d not in names: - s.append(d) - ex.append(d) - view.append(self.shape[d]) + trans.append(d) + new_schema.append(d) + view.append(self.size(d)) elif first: - s += names - view.append(prod([self.shape[d2] for d2 in names])) - ex.append(dim) + trans += names + view.append(prod([self.size(d2) for d2 in names])) + new_schema.append(dim) first = False - tensor = self.transpose(*s)._tensor.contiguous().view(*view) - return self.__class__(tensor, ex) + tensor = self.transpose(*trans)._tensor.contiguous().view(*view) + return self.__class__(tensor, new_schema) - def _split(self, dim, names, size_dict): - query = [] - ex = [] + def _splitdim(self, dim, names, size_dict): + new_schema = [] view = [] for i, d in self._schema.enum_all(): if d != dim: - query.append(d) - ex.append(d) - view.append(self.shape[d]) + new_schema.append(d) + view.append(self.size(d)) else: - query += names for d2 in names: - view.append(size_dict.get(d2, -1)) - ex += names - return self.__class__(self._tensor.view(*view), ex) + d2 = d2.split(":") + view.append(size_dict.get(d2[-1], size_dict.get(d2[0],-1))) + new_schema += names + return self.__class__(self._tensor.view(*view), new_schema) + def __len__(self): return len(self._tensor) def _promote(self, dims): - "Move dims to the front of the line" + """ Move dims to the front of the line """ + raise RuntimeError("Err %s" % dims) + term = [ - d for d in self._schema._names if d not in dims + d for d in self.dims if d not in dims ] + dims.split()[1:] return self.transpose(*term) - def _force_order(self, names): + def _force_order(self, schema): """ Forces self to take order in names, adds 1-size dims if needed """ - ex = [] + new_schema = [] view = [] trans = [] - for d in names: - if d not in self._schema._names: - ex.append(d) + for d in schema.axes: + if d not in self.dims: + new_schema.append(d) view.append(1) else: - ex.append(d) - view.append(self.shape[d]) + new_schema.append(d) + view.append(self.size(d)) trans.append(d) return self.__class__( - self.transpose(*trans)._tensor.contiguous().view(*view), ex + self.transpose(*trans)._tensor.contiguous().view(*view), new_schema ) - def _broadcast_order(self, other_names): + def _broadcast_order(self, other): """ Outputs a shared order (list) that works for self and other """ - order = [] - for d in other_names: - if d not in self._schema._names: - order.append(d) - for d in self._schema._names: - order.append(d) - return order + return self._schema.merge(other._schema) + #order = [] + #for d in other_names: + # if d not in self.dims: + # order.append(d) + #for d in self.dims: + # order.append(d) + #return order def _mask_broadcast_order(self, main_names): """ If broadcasting possible from self (mask) to main, outputs a shared order. Otherwise errors and prints dimensions that exist in mask but not in main. """ - - to_be_broadcasted = set(self._schema._names) + raise RuntimeError("Err %s" % main_names) + to_be_broadcasted = set(self.dims) broadcasted_to = set(main_names) diff = to_be_broadcasted.difference(broadcasted_to) diff --git a/namedtensor/schema.py b/namedtensor/schema.py index e64890f..d3d0c3d 100644 --- a/namedtensor/schema.py +++ b/namedtensor/schema.py @@ -3,23 +3,16 @@ class _Axis(object): "A dimension" - def __init__(self, name, size=None): + def __init__(self, name): self._set_name(name) - self.size = size def conflict(self, *axes): - sizes = [] - if self.size != 1 and not self.size is None: - sizes = [self.size, ] for axis in axes: if self.name == axis.name and self.abbr != axis.abbr: - return True - if self.abbr == axis.abbr and self.name != axis.name: - return True - if self.abbr == axis.abbr and self.name == axis.name: - if axis.size != 1 and not axis.size is None: - sizes.append(axis.size) - return not all(x == sizes[0] for x in sizes) + return axis + #if self.abbr == axis.abbr and self.name != axis.name: + # return axis + return False def _set_name(self, name): names = name.split(':') @@ -38,15 +31,19 @@ def _set_name(self, name): def __eq__(self, other): if isinstance(other, _Axis): - return self.name == other.name or self.abbr == other.abbr - elif isinstance(other, str): - return str(self) == other or self.name==other or self.abbr==other + return self.name == other.name + elif isinstance(other, str): + return other == str(self) or \ + other == self.name or other == self.abbr def __str__(self): - return self.abbr + ":" + self.name + if self.abbr == self.name[0]: + return self.name + else: + return self.abbr + ":" + self.name def __repr__(self): - return str(self) + return self.abbr + ":" + self.name @@ -54,56 +51,58 @@ def __repr__(self): class _Schema: "Dimension names and order" - def __init__(self, names, sizes=None, mask=0): + def __init__(self, names, mask=0): self._masked = mask - names= make_tuple(names) - if sizes is not None and len(sizes) != len(names): - raise RuntimeError("Error setting schema shape, " + - "'{}' does not match {}".format(sizes, names)) + self._build_axes(names) + def _build_axes(self, names): + #print(names) axes = [] - for i, name in enumerate(names): + for name in names: if not isinstance(name, _Axis): - name = _Axis(name, sizes[i]) - elif sizes is not None and name.size != sizes[i]: - raise RuntimeError("Error setting schema dimension, " + - "dimension '{}' has size {} but attempting to set with {}".\ - format(name, name.size, sizes[i])) + name = _Axis(name) if name in axes: raise RuntimeError("Tensor must have unique dims, " + "dim '{}' is not unique, dims={}".format(name, names) + "(Note: dimension names and dimension abbreviations" + "must both be unique)") axes.append(name) - self._axes = tuple(axes) + self.axes = tuple(axes) @property def _names(self): - return tuple(str(axis) for axis in self._axes) + return tuple(str(axis) for axis in self.axes) @property def _abbrs(self): - return tuple(axis.abbr for axis in self._axes) + return tuple(axis.abbr for axis in self.axes) def _to_einops(self): return " ".join(self._abbrs) - def ordered_dict(self): - return OrderedDict((str(a), a.size) for a in self._axes) - @staticmethod - def build(names, sizes, mask=0): + def build(names, mask=0): if isinstance(names, _Schema): - return _Schema(names._axes, sizes, mask) - return _Schema(names, sizes, mask) + return _Schema(names.axes, mask) + return _Schema(names, mask) def get(self, name): + dim = None for i, n in self.enum_all(): if name == n: if i < self._masked: raise RuntimeError("Dimension {} is masked".format(name,)) - return i - if name not in self._axes: + if dim is None: + dim = i + else: + raise RuntimeError( + "Ambiguity in axis name, '{}'' matches '{}', ".\ + format(name, self.axes[dim]) + + "and also '{}'".format(self.axes[i]) + ) + if dim is not None: + return dim + elif name not in self.axes: raise RuntimeError( "Dimension {} does not exist. Available dimensions are {}".\ format(name, self._names) @@ -116,10 +115,11 @@ def get(self, name): def drop(self, names): - names = [_Axis(name) for name in make_tuple(names)] + #names = [_Axis(name) for name in make_tuple(names)] + new_axes = [ n for n in self.axes if n not in make_tuple(names)] return _Schema( - [n for n in self._axes if n not in names], mask=self._masked - ) + [ str(a) for a in new_axes], + mask=self._masked ) def update(self, update): if not update: @@ -134,7 +134,21 @@ def update(self, update): # return _Schema([update.get(n, n) for n in self._names], self._masked) def enum_masked(self): - return enumerate(self._axes[self._masked :], self._masked) + return enumerate(self.axes[self._masked :], self._masked) def enum_all(self): - return enumerate(self._axes) + return enumerate(self.axes) + + def merge(self, other): + axes = list(self.axes) + for a in other.axes: + if a not in self.axes: + axes.append(a) + elif a.conflict(*self.axes): + raise RuntimeError( + "Axis {} conflicts with axes {}".\ + format(a, self.axes) + ) + return self.__class__(axes) + + diff --git a/namedtensor/torch_base.py b/namedtensor/torch_base.py index 7601bd2..dd615a3 100644 --- a/namedtensor/torch_base.py +++ b/namedtensor/torch_base.py @@ -31,22 +31,23 @@ def call(ntensor, *args, **kwargs): def dot(cls, dims, *tensors): names = make_tuple(dims) args = [] - ids = {} - seen_names = [] + + schema = None for t in tensors: + schema = t._schema if schema is None else schema.merge(t._schema) group = [] - for name in t._schema._names: - if name not in ids: - ids[name] = len(ids) - seen_names.append(name) - group.append(ids[name]) + for a in t.dims: + group.append(schema.get(a)) args.append(t._tensor) args.append(group) - keep = [n for n in seen_names if n not in names] + for n in names: - if n not in seen_names: + if n not in schema.axes: raise RuntimeError("No dimension %s to contract along" % n) - args.append([ids[n] for n in keep]) + + keep = [ a for a in schema.axes if a not in names] + args.append([ schema.get(a) for a in keep ]) + print(args) return cls.tensor(oe.contract(*args, backend="torch"), keep) @staticmethod @@ -72,18 +73,18 @@ def chunk(tensor, number_of_chunks, dim): @staticmethod def stack(tensors, name): - old_names = tensors[0]._schema._names + old_axes = tensors[0]._schema._axes for i in range(1, len(tensors)): - if tensors[i]._schema._names != old_names: - if set(tensors[i]._schema._names) != set( - tensors[0]._schema._names + if tensors[i]._schema._axes != old_axes: + if set(tensors[i]._schema._axes) != set( + tensors[0]._schema._axes ): raise RuntimeError( "Tensors to stack don't have matching dimension names" ) - tensors[i] = tensors[i]._force_order(tensors[0]._schema._names) + tensors[i] = tensors[i]._force_order(tensors[0]._schema._axes) to_stack = [tensor.values for tensor in tensors] - old_names = list(old_names) + old_names = list(tensors[0]._schema._names) old_names.insert(0, name) return ntorch.tensor(torch.stack(to_stack, dim=0), old_names) diff --git a/namedtensor/torch_helpers.py b/namedtensor/torch_helpers.py index a868c30..c82cd78 100644 --- a/namedtensor/torch_helpers.py +++ b/namedtensor/torch_helpers.py @@ -69,7 +69,7 @@ def copy_(self, other): return self._setter(other, "copy_") def _setter(self, other, method, vals=[]): - order = other._mask_broadcast_order(self._schema._names) + order = other._mask_broadcast_order(self) other = other._force_order(order) args = [other.values] + vals @@ -290,7 +290,7 @@ def call(*args, **kwargs): def call(other, *args): if isinstance(other, NamedTensor): b = other - order = self._broadcast_order(b._schema._names) + order = self._broadcast_order(b) a1 = self._force_order(order) b1 = b._force_order(order) method = getattr(a1._tensor, methodname) From 217fb8630b4f97c71b1ebfdc733b925c0a7f025b Mon Sep 17 00:00:00 2001 From: EndingCredits Date: Thu, 4 Jul 2019 21:52:56 +0100 Subject: [PATCH 3/3] fixed code and added acessors --- namedtensor/core.py | 100 +++++++++++++++++++-------------- namedtensor/schema.py | 104 ++++++++++++++++++----------------- namedtensor/torch_base.py | 26 ++++----- namedtensor/torch_helpers.py | 57 ++++++++++++++++++- 4 files changed, 181 insertions(+), 106 deletions(-) diff --git a/namedtensor/core.py b/namedtensor/core.py index 76d3ea4..9f41ee7 100644 --- a/namedtensor/core.py +++ b/namedtensor/core.py @@ -1,4 +1,5 @@ from .schema import _Schema +from collections import OrderedDict import operator import functools @@ -10,13 +11,20 @@ def prod(factors): def assert_match(*tensors): sizes = {} failure = False - axes = [] - for tensor in tensors: - axes = axes + list(tensor.dims) - for ax in axes: - assert not ax.conflict(*axes), "Overlapping dim names must match: " + " ".join( - [str(t.shape) for t in tensors] - ) + for t in tensors: + shape = t.vshape + for i, k in t._schema.enum_all(): + name = k.name + v = shape[i] + if v == 1: + continue + if name in sizes: + failure = failure or sizes[name] != v + else: + sizes[name] = v + assert not failure, "Overlapping dim names must match: " + " ".join( + [str(t.shape) for t in tensors] + ) class NamedTensorBase: @@ -32,6 +40,13 @@ class NamedTensorBase: def __init__(self, tensor, names, mask=0): self._tensor = tensor self._schema = _Schema.build(names, mask) + if self._tensor.dim() > 0: + assert len(self._tensor.shape) == len(self._schema.axes), ( + "Tensor has %d dim, but %d names" + % (len(self._tensor.shape), len(self._schema.axes)) + ) + else: + assert len(names) == 0, str(tensor) def __deepcopy__(self, memo): new_ntensor = self._new(self._tensor.__deepcopy__(memo)) @@ -51,13 +66,11 @@ def vshape(self): @property def shape(self): "The ordered dict of available dimensions." - return tuple(zip(self.dims, self._tensor.shape)) + #return tuple(zip(self.dims, self._tensor.shape)) + return OrderedDict(zip(self._schema._names, self._tensor.shape)) def __repr__(self): - return "NamedTensor(\n\t{},\n\t{})".format( - self._tensor, - self.dims, - ) + return "NamedTensor(\n\t{},\n\t{})".format(self._tensor, self.dims) def size(self, dim): "Return the raw shape of the tensor" @@ -79,7 +92,7 @@ def values(self): return self._tensor def _new(self, tensor, drop=None, add=None, updates={}, mask=None): - #raise RuntimeError("Err drop=%s" % drop + + # raise RuntimeError("Err drop=%s" % drop + # " add=%s" % add + "updates=%s" % updates ) return self.__class__( tensor, @@ -111,8 +124,8 @@ def rename(self, dim, name): def transpose(self, *dims): "Return a new DataArray object with transposed dimensions." - to_dims = ( tuple(d for d in self.dims if d not in dims) + dims ) - indices = [ self._schema.get(d) for d in to_dims ] + to_dims = tuple(d for d in self.dims if d not in dims) + dims + indices = [self._schema.get(d) for d in to_dims] tensor = self._tensor.permute(*indices) return self.__class__(tensor, to_dims) @@ -134,21 +147,23 @@ def _stack(self, names, dim): tensor = self.transpose(*trans)._tensor.contiguous().view(*view) return self.__class__(tensor, new_schema) - def _splitdim(self, dim, names, size_dict): + def _split(self, dim, names, size_dict): new_schema = [] view = [] + dim_num = self._schema.get(dim) for i, d in self._schema.enum_all(): - if d != dim: + if i != dim_num: new_schema.append(d) view.append(self.size(d)) else: for d2 in names: d2 = d2.split(":") - view.append(size_dict.get(d2[-1], size_dict.get(d2[0],-1))) + view.append( + size_dict.get(d2[-1], size_dict.get(d2[0], -1)) + ) new_schema += names return self.__class__(self._tensor.view(*view), new_schema) - def __len__(self): return len(self._tensor) @@ -156,18 +171,18 @@ def _promote(self, dims): """ Move dims to the front of the line """ raise RuntimeError("Err %s" % dims) - term = [ - d for d in self.dims if d not in dims - ] + dims.split()[1:] + term = [d for d in self.dims if d not in dims] + dims.split()[1:] return self.transpose(*term) - def _force_order(self, schema): + def _force_order(self, names): """ Forces self to take order in names, adds 1-size dims if needed """ + if isinstance(names, _Schema): + names = names.axes new_schema = [] view = [] trans = [] - for d in schema.axes: + for d in names: if d not in self.dims: new_schema.append(d) view.append(1) @@ -181,30 +196,31 @@ def _force_order(self, schema): def _broadcast_order(self, other): """ Outputs a shared order (list) that works for self and other """ + if isinstance(other, list): + return self._schema.merge(_Schema(other)) return self._schema.merge(other._schema) - #order = [] - #for d in other_names: + # order = [] + # for d in other_names: # if d not in self.dims: # order.append(d) - #for d in self.dims: + # for d in self.dims: # order.append(d) - #return order + # return order - def _mask_broadcast_order(self, main_names): + def _mask_broadcast_order(self, other): """ If broadcasting possible from self (mask) to main, outputs a shared order. Otherwise errors and prints dimensions that exist in mask but not in main. """ - raise RuntimeError("Err %s" % main_names) - to_be_broadcasted = set(self.dims) - broadcasted_to = set(main_names) - - diff = to_be_broadcasted.difference(broadcasted_to) - diff_string = ", ".join(diff) - - assert len(diff) == 0, ( - "Attemped to broadcast mask but unable to broadcast dimensions %s" - % diff_string - ) - - return main_names + if isinstance(other, list): + return self._schema.merge(_Schema(other)) + return self._schema.merge(other._schema) + #to_be_broadcasted = set(self.dims) + #broadcasted_to = set(main_names) + #diff = to_be_broadcasted.difference(broadcasted_to) + #diff_string = ", ".join(diff) + #assert len(diff) == 0, ( + # "Attemped to broadcast mask but unable to broadcast dimensions %s" + # % diff_string + #) + #return main_names diff --git a/namedtensor/schema.py b/namedtensor/schema.py index d3d0c3d..9274abf 100644 --- a/namedtensor/schema.py +++ b/namedtensor/schema.py @@ -1,21 +1,24 @@ from collections import OrderedDict from .utils import make_tuple + + class _Axis(object): "A dimension" - def __init__(self, name): + + def __init__(self, name): self._set_name(name) def conflict(self, *axes): for axis in axes: if self.name == axis.name and self.abbr != axis.abbr: return axis - #if self.abbr == axis.abbr and self.name != axis.name: + # if self.abbr == axis.abbr and self.name != axis.name: # return axis return False def _set_name(self, name): - names = name.split(':') + names = name.split(":") if len(names) == 1: self.name = names[0] self.abbr = self.name[0] @@ -23,18 +26,23 @@ def _set_name(self, name): self.name = names[1] self.abbr = names[0] if not len(self.abbr) == 1: - raise RuntimeError("Error setting axis name {}\n".format(name) + - "Abbreviations must be a single character") + raise RuntimeError( + "Error setting axis name {}\n".format(name) + + "Abbreviations must be a single character" + ) else: - raise RuntimeError("Error setting axis name {}\n".format(name) + - "Valid names are of the form 'name' or 'n:name'") + raise RuntimeError( + "Error setting axis name {}\n".format(name) + + "Valid names are of the form 'name' or 'n:name'" + ) def __eq__(self, other): if isinstance(other, _Axis): return self.name == other.name - elif isinstance(other, str): - return other == str(self) or \ - other == self.name or other == self.abbr + elif isinstance(other, str): + return ( + other == str(self) or other == self.name or other == self.abbr + ) def __str__(self): if self.abbr == self.name[0]: @@ -45,7 +53,8 @@ def __str__(self): def __repr__(self): return self.abbr + ":" + self.name - + def __hash__(self): + return self.name.__hash__() class _Schema: @@ -56,29 +65,31 @@ def __init__(self, names, mask=0): self._build_axes(names) def _build_axes(self, names): - #print(names) + # print(names) axes = [] - for name in names: + for name in make_tuple(names): if not isinstance(name, _Axis): name = _Axis(name) if name in axes: - raise RuntimeError("Tensor must have unique dims, " + - "dim '{}' is not unique, dims={}".format(name, names) + - "(Note: dimension names and dimension abbreviations" + - "must both be unique)") + raise RuntimeError( + "Tensor must have unique dims, " + + "dim '{}' is not unique, dims={}".format(name, names) + + "(Note: dimension names and dimension abbreviations" + + "must both be unique)" + ) axes.append(name) self.axes = tuple(axes) @property def _names(self): - return tuple(str(axis) for axis in self.axes) + return tuple(axis.name for axis in self.axes) @property def _abbrs(self): return tuple(axis.abbr for axis in self.axes) def _to_einops(self): - return " ".join(self._abbrs) + return " ".join(self._names) @staticmethod def build(names, mask=0): @@ -91,47 +102,45 @@ def get(self, name): for i, n in self.enum_all(): if name == n: if i < self._masked: - raise RuntimeError("Dimension {} is masked".format(name,)) + raise RuntimeError("Dimension {} is masked".format(name)) if dim is None: dim = i else: raise RuntimeError( - "Ambiguity in axis name, '{}'' matches '{}', ".\ - format(name, self.axes[dim]) + - "and also '{}'".format(self.axes[i]) + "Ambiguity in axis name, '{} matches '{}', ".format( + name, self.axes[dim] + ) + + "and also '{}'".format(self.axes[i]) ) if dim is not None: - return dim + return dim elif name not in self.axes: raise RuntimeError( - "Dimension {} does not exist. Available dimensions are {}".\ - format(name, self._names) + "Dimension {} does not exist. Available dimensions are {}".format( + name, self.axes + ) ) else: - raise RuntimeError( # Not sure how we'd get here - "Something unexpected occured while searching for {} in {}".\ - format(name, self._names) + raise RuntimeError( # Not sure how we'd get here + "Something unexpected occured while searching for {} in {}".format( + name, self.axes + ) ) - - + def drop(self, names): - #names = [_Axis(name) for name in make_tuple(names)] - new_axes = [ n for n in self.axes if n not in make_tuple(names)] - return _Schema( - [ str(a) for a in new_axes], - mask=self._masked ) + names = make_tuple(names) + for n in names: + self.get(n) # Check for ambiguity + new_axes = [n for n in self.axes if n not in names] + return _Schema([a for a in new_axes], mask=self._masked) def update(self, update): if not update: return self - raise RuntimeError("Update err %s" % update) - # fail = True - # for n in self._names: - # if n in update: - # fail = False - # if fail: - # raise RuntimeError("Tried to update unknown dim %s" % update) - # return _Schema([update.get(n, n) for n in self._names], self._masked) + axes = list(self.axes) + for name in update: + axes[self.get(name)] = update[name] + return _Schema(axes, self._masked) def enum_masked(self): return enumerate(self.axes[self._masked :], self._masked) @@ -145,10 +154,7 @@ def merge(self, other): if a not in self.axes: axes.append(a) elif a.conflict(*self.axes): - raise RuntimeError( - "Axis {} conflicts with axes {}".\ - format(a, self.axes) + raise RuntimeError( + "Axis {} conflicts with axes {}".format(a, self.axes) ) return self.__class__(axes) - - diff --git a/namedtensor/torch_base.py b/namedtensor/torch_base.py index dd615a3..b452965 100644 --- a/namedtensor/torch_base.py +++ b/namedtensor/torch_base.py @@ -3,6 +3,7 @@ from .utils import make_tuple from .nn import nn from .distributions import ndistributions +from .schema import _Axis import opt_einsum as oe @@ -40,14 +41,13 @@ def dot(cls, dims, *tensors): group.append(schema.get(a)) args.append(t._tensor) args.append(group) - + for n in names: if n not in schema.axes: raise RuntimeError("No dimension %s to contract along" % n) - - keep = [ a for a in schema.axes if a not in names] - args.append([ schema.get(a) for a in keep ]) - print(args) + + keep = schema.drop(names) + args.append([schema.get(a) for a in keep.axes]) return cls.tensor(oe.contract(*args, backend="torch"), keep) @staticmethod @@ -73,25 +73,25 @@ def chunk(tensor, number_of_chunks, dim): @staticmethod def stack(tensors, name): - old_axes = tensors[0]._schema._axes + old_axes = tensors[0].dims for i in range(1, len(tensors)): - if tensors[i]._schema._axes != old_axes: - if set(tensors[i]._schema._axes) != set( - tensors[0]._schema._axes + if tensors[i].dims != old_axes: + if set(tensors[i]._schema._names) != set( + tensors[0]._schema._names ): raise RuntimeError( "Tensors to stack don't have matching dimension names" ) - tensors[i] = tensors[i]._force_order(tensors[0]._schema._axes) + tensors[i] = tensors[i]._force_order(tensors[0].dims) to_stack = [tensor.values for tensor in tensors] - old_names = list(tensors[0]._schema._names) + old_names = list(tensors[0].dims) old_names.insert(0, name) return ntorch.tensor(torch.stack(to_stack, dim=0), old_names) @staticmethod def cat(tensors, dim, name=None): "Concate a list of named tensors along dim." - if isinstance(dim, str): + if isinstance(dim, str) or isinstance(dim, _Axis): dim = [dim] * len(tensors) if name is not None: tensors = [t.rename(d, name) for t, d in zip(tensors, dim)] @@ -190,7 +190,7 @@ def gather(input, dim, index, index_dim): @staticmethod def masked_select(input, mask, name="on"): - order = mask._mask_broadcast_order(input._schema._names) + order = mask._mask_broadcast_order(input) a1 = input._force_order(order) b1 = mask._force_order(order) return NamedTensor(a1.values.masked_select(b1.values), name) diff --git a/namedtensor/torch_helpers.py b/namedtensor/torch_helpers.py index c82cd78..3a87c76 100644 --- a/namedtensor/torch_helpers.py +++ b/namedtensor/torch_helpers.py @@ -2,19 +2,72 @@ import torch from .core import NamedTensorBase, assert_match from .utils import make_tuple +import numbers class NamedTensor(NamedTensorBase): + + access_order = [] def __getitem__(self, index): - if isinstance(index, dict): + """ + Valid formats are: + - a dict of pairs name - slice/ntensor/index + - an ntensor with (matching names) + - a sequence of strings + - a sequence of slices/ntensors/indexes + """ + if isinstance(index, tuple): + cur = self + ops_dict = {} + count = 0 + restart = True + for x in index: + if isinstance(x, str): + if restart: + if ops_dict: + cur = cur[ops_dict] + self.access_order = [] + ops_dict = {} + count = 0 + restart = False + self.access_order.append(x) + elif isinstance(x, slice) and isinstance(x.start, str): + if x.step is None: + ops_dict[x.start] = x.stop + else: + ops_dict[x.start] = slice(x.stop, x.step) + restart = True + else: + ops_dict[self.access_order[count]] = x + count += 1 + restart = True + if ops_dict: + cur = cur[ops_dict] + return cur + + elif isinstance(index, str): + self.access_order = [ index, ] + return self + elif isinstance(index, slice) and isinstance(index.start, str): + return + elif isinstance(index, (slice, NamedTensor, numbers.Number)): + if len(self.access_order) != 1: + raise RuntimeError("Dimension name not set. Use v['name'].") + return self[dict([ (self.access_order[0], index), ])] + + elif isinstance(index, dict): cur = self for k, v in index.items(): if isinstance(v, slice): + cur = cur.narrow(k, v.start, v.stop - v.start) elif isinstance(v, NamedTensor): cur = cur.index_select(k, v) - else: + elif isinstance(v, numbers.Number): cur = cur.get(k, v) + else: + raise RuntimeError("Index must be number, slice, " + + "or NamedTensor, got {}".format(v)) return cur elif isinstance(index, NamedTensor): if (