diff --git a/namedtensor/core.py b/namedtensor/core.py index dbfa793..9f41ee7 100644 --- a/namedtensor/core.py +++ b/namedtensor/core.py @@ -1,4 +1,5 @@ from .schema import _Schema +from collections import OrderedDict import operator import functools @@ -13,13 +14,14 @@ def assert_match(*tensors): for t in tensors: shape = t.vshape for i, k in t._schema.enum_all(): + name = k.name v = shape[i] if v == 1: continue - if k in sizes: - failure = failure or sizes[k] != v + if name in sizes: + failure = failure or sizes[name] != v else: - sizes[k] = v + sizes[name] = v assert not failure, "Overlapping dim names must match: " + " ".join( [str(t.shape) for t in tensors] ) @@ -39,9 +41,9 @@ def __init__(self, tensor, names, mask=0): self._tensor = tensor self._schema = _Schema.build(names, mask) if self._tensor.dim() > 0: - assert len(self._tensor.shape) == len(self._schema._names), ( + assert len(self._tensor.shape) == len(self._schema.axes), ( "Tensor has %d dim, but %d names" - % (len(self._tensor.shape), len(self._schema._names)) + % (len(self._tensor.shape), len(self._schema.axes)) ) else: assert len(names) == 0, str(tensor) @@ -54,7 +56,7 @@ def __deepcopy__(self, memo): @property def dims(self): "Return the dim names for the tensor" - return tuple(self._schema._names) + return tuple(self._schema.axes) @property def vshape(self): @@ -64,18 +66,15 @@ def vshape(self): @property def shape(self): "The ordered dict of available dimensions." - return self._schema.ordered_dict(self._tensor.size()) + #return tuple(zip(self.dims, self._tensor.shape)) + return OrderedDict(zip(self._schema._names, self._tensor.shape)) def __repr__(self): - return "NamedTensor(\n\t%s,\n\t%s)" % ( - self._tensor, - self._schema._names, - ) + return "NamedTensor(\n\t{},\n\t{})".format(self._tensor, self.dims) def size(self, dim): "Return the raw shape of the tensor" - i = self._schema.get(dim) - return self._tensor.size(i) + return self._tensor.size(self._schema.get(dim)) def assert_size(self, **kwargs): "Return the raw shape of the tensor" @@ -93,9 +92,11 @@ def values(self): return self._tensor def _new(self, tensor, drop=None, add=None, updates={}, mask=None): + # raise RuntimeError("Err drop=%s" % drop + + # " add=%s" % add + "updates=%s" % updates ) return self.__class__( tensor, - self._schema.drop(drop).update(updates)._names + self._schema.drop(drop).update(updates).axes + (() if not add else add), self._schema._masked if mask is None else mask, ) @@ -111,9 +112,7 @@ def mask_to(self, name): def stack(self, dims, name): "Stack any number of existing dimensions into a single new dimension." - for dim in dims: - self._schema.get(dim) - return self._merge(dims, name) + return self._stack(dims, name) def split(self, dim, names, **dim_sizes): "Split an of existing dimension into new dimensions." @@ -125,103 +124,103 @@ def rename(self, dim, name): def transpose(self, *dims): "Return a new DataArray object with transposed dimensions." - for dim in dims: - self._schema.get(dim) - to_dims = ( - tuple((d for d in self._schema._names if d not in dims)) + dims - ) + to_dims = tuple(d for d in self.dims if d not in dims) + dims indices = [self._schema.get(d) for d in to_dims] tensor = self._tensor.permute(*indices) return self.__class__(tensor, to_dims) - # Todo: fix arg names - def _merge(self, names, dim): - s = [] - ex = [] + def _stack(self, names, dim): + trans = [] + new_schema = [] first = True view = [] - for d in self._schema._names: + for d in self.dims: if d not in names: - s.append(d) - ex.append(d) - view.append(self.shape[d]) + trans.append(d) + new_schema.append(d) + view.append(self.size(d)) elif first: - s += names - view.append(prod([self.shape[d2] for d2 in names])) - ex.append(dim) + trans += names + view.append(prod([self.size(d2) for d2 in names])) + new_schema.append(dim) first = False - tensor = self.transpose(*s)._tensor.contiguous().view(*view) - return self.__class__(tensor, ex) + tensor = self.transpose(*trans)._tensor.contiguous().view(*view) + return self.__class__(tensor, new_schema) def _split(self, dim, names, size_dict): - query = [] - ex = [] + new_schema = [] view = [] + dim_num = self._schema.get(dim) for i, d in self._schema.enum_all(): - if d != dim: - query.append(d) - ex.append(d) - view.append(self.shape[d]) + if i != dim_num: + new_schema.append(d) + view.append(self.size(d)) else: - query += names for d2 in names: - view.append(size_dict.get(d2, -1)) - ex += names - return self.__class__(self._tensor.view(*view), ex) + d2 = d2.split(":") + view.append( + size_dict.get(d2[-1], size_dict.get(d2[0], -1)) + ) + new_schema += names + return self.__class__(self._tensor.view(*view), new_schema) def __len__(self): return len(self._tensor) def _promote(self, dims): - "Move dims to the front of the line" - term = [ - d for d in self._schema._names if d not in dims - ] + dims.split()[1:] + """ Move dims to the front of the line """ + raise RuntimeError("Err %s" % dims) + + term = [d for d in self.dims if d not in dims] + dims.split()[1:] return self.transpose(*term) def _force_order(self, names): """ Forces self to take order in names, adds 1-size dims if needed """ - ex = [] + if isinstance(names, _Schema): + names = names.axes + new_schema = [] view = [] trans = [] for d in names: - if d not in self._schema._names: - ex.append(d) + if d not in self.dims: + new_schema.append(d) view.append(1) else: - ex.append(d) - view.append(self.shape[d]) + new_schema.append(d) + view.append(self.size(d)) trans.append(d) return self.__class__( - self.transpose(*trans)._tensor.contiguous().view(*view), ex + self.transpose(*trans)._tensor.contiguous().view(*view), new_schema ) - def _broadcast_order(self, other_names): + def _broadcast_order(self, other): """ Outputs a shared order (list) that works for self and other """ - order = [] - for d in other_names: - if d not in self._schema._names: - order.append(d) - for d in self._schema._names: - order.append(d) - return order - - def _mask_broadcast_order(self, main_names): + if isinstance(other, list): + return self._schema.merge(_Schema(other)) + return self._schema.merge(other._schema) + # order = [] + # for d in other_names: + # if d not in self.dims: + # order.append(d) + # for d in self.dims: + # order.append(d) + # return order + + def _mask_broadcast_order(self, other): """ If broadcasting possible from self (mask) to main, outputs a shared order. Otherwise errors and prints dimensions that exist in mask but not in main. """ - - to_be_broadcasted = set(self._schema._names) - broadcasted_to = set(main_names) - - diff = to_be_broadcasted.difference(broadcasted_to) - diff_string = ", ".join(diff) - - assert len(diff) == 0, ( - "Attemped to broadcast mask but unable to broadcast dimensions %s" - % diff_string - ) - - return main_names + if isinstance(other, list): + return self._schema.merge(_Schema(other)) + return self._schema.merge(other._schema) + #to_be_broadcasted = set(self.dims) + #broadcasted_to = set(main_names) + #diff = to_be_broadcasted.difference(broadcasted_to) + #diff_string = ", ".join(diff) + #assert len(diff) == 0, ( + # "Attemped to broadcast mask but unable to broadcast dimensions %s" + # % diff_string + #) + #return main_names diff --git a/namedtensor/schema.py b/namedtensor/schema.py index 2a00f0d..9274abf 100644 --- a/namedtensor/schema.py +++ b/namedtensor/schema.py @@ -2,67 +2,159 @@ from .utils import make_tuple + +class _Axis(object): + "A dimension" + + def __init__(self, name): + self._set_name(name) + + def conflict(self, *axes): + for axis in axes: + if self.name == axis.name and self.abbr != axis.abbr: + return axis + # if self.abbr == axis.abbr and self.name != axis.name: + # return axis + return False + + def _set_name(self, name): + names = name.split(":") + if len(names) == 1: + self.name = names[0] + self.abbr = self.name[0] + elif len(names) == 2: + self.name = names[1] + self.abbr = names[0] + if not len(self.abbr) == 1: + raise RuntimeError( + "Error setting axis name {}\n".format(name) + + "Abbreviations must be a single character" + ) + else: + raise RuntimeError( + "Error setting axis name {}\n".format(name) + + "Valid names are of the form 'name' or 'n:name'" + ) + + def __eq__(self, other): + if isinstance(other, _Axis): + return self.name == other.name + elif isinstance(other, str): + return ( + other == str(self) or other == self.name or other == self.abbr + ) + + def __str__(self): + if self.abbr == self.name[0]: + return self.name + else: + return self.abbr + ":" + self.name + + def __repr__(self): + return self.abbr + ":" + self.name + + def __hash__(self): + return self.name.__hash__() + + class _Schema: "Dimension names and order" def __init__(self, names, mask=0): - self._names = make_tuple(names) - - s = set() - for n in self._names: - assert n is not None - assert n.isalnum(), "dim name %s must be alphanumeric" % n - assert n not in s, ( - "Tensor must have unique dims, dim '%s' is non-unique" % n - ) - s.add(n) - self._masked = mask - self._axes = OrderedDict(((d, i) for i, d in enumerate(self._names))) + self._build_axes(names) + + def _build_axes(self, names): + # print(names) + axes = [] + for name in make_tuple(names): + if not isinstance(name, _Axis): + name = _Axis(name) + if name in axes: + raise RuntimeError( + "Tensor must have unique dims, " + + "dim '{}' is not unique, dims={}".format(name, names) + + "(Note: dimension names and dimension abbreviations" + + "must both be unique)" + ) + axes.append(name) + self.axes = tuple(axes) + + @property + def _names(self): + return tuple(axis.name for axis in self.axes) + + @property + def _abbrs(self): + return tuple(axis.abbr for axis in self.axes) def _to_einops(self): return " ".join(self._names) - def ordered_dict(self, size): - return OrderedDict(((d, size[i]) for i, d in self.enum_masked())) - @staticmethod - def build(names, mask): + def build(names, mask=0): if isinstance(names, _Schema): - return _Schema(names._names, mask) - names = make_tuple(names) + return _Schema(names.axes, mask) return _Schema(names, mask) def get(self, name): - if name not in self._axes: + dim = None + for i, n in self.enum_all(): + if name == n: + if i < self._masked: + raise RuntimeError("Dimension {} is masked".format(name)) + if dim is None: + dim = i + else: + raise RuntimeError( + "Ambiguity in axis name, '{} matches '{}', ".format( + name, self.axes[dim] + ) + + "and also '{}'".format(self.axes[i]) + ) + if dim is not None: + return dim + elif name not in self.axes: raise RuntimeError( - "Dimension %s does not exist. Available dimensions are %s" - % (name, self._names) + "Dimension {} does not exist. Available dimensions are {}".format( + name, self.axes + ) + ) + else: + raise RuntimeError( # Not sure how we'd get here + "Something unexpected occured while searching for {} in {}".format( + name, self.axes + ) ) - i = self._axes[name] - if i < self._masked: - raise RuntimeError("Dimension %s is masked" % (name,)) - return i def drop(self, names): names = make_tuple(names) - return _Schema( - [n for n in self._names if n not in names], self._masked - ) + for n in names: + self.get(n) # Check for ambiguity + new_axes = [n for n in self.axes if n not in names] + return _Schema([a for a in new_axes], mask=self._masked) def update(self, update): if not update: return self - fail = True - for n in self._names: - if n in update: - fail = False - if fail: - raise RuntimeError("Tried to update unknown dim %s" % update) - return _Schema([update.get(n, n) for n in self._names], self._masked) + axes = list(self.axes) + for name in update: + axes[self.get(name)] = update[name] + return _Schema(axes, self._masked) def enum_masked(self): - return enumerate(self._names[self._masked :], self._masked) + return enumerate(self.axes[self._masked :], self._masked) def enum_all(self): - return enumerate(self._names) + return enumerate(self.axes) + + def merge(self, other): + axes = list(self.axes) + for a in other.axes: + if a not in self.axes: + axes.append(a) + elif a.conflict(*self.axes): + raise RuntimeError( + "Axis {} conflicts with axes {}".format(a, self.axes) + ) + return self.__class__(axes) diff --git a/namedtensor/torch_base.py b/namedtensor/torch_base.py index 7601bd2..b452965 100644 --- a/namedtensor/torch_base.py +++ b/namedtensor/torch_base.py @@ -3,6 +3,7 @@ from .utils import make_tuple from .nn import nn from .distributions import ndistributions +from .schema import _Axis import opt_einsum as oe @@ -31,22 +32,22 @@ def call(ntensor, *args, **kwargs): def dot(cls, dims, *tensors): names = make_tuple(dims) args = [] - ids = {} - seen_names = [] + + schema = None for t in tensors: + schema = t._schema if schema is None else schema.merge(t._schema) group = [] - for name in t._schema._names: - if name not in ids: - ids[name] = len(ids) - seen_names.append(name) - group.append(ids[name]) + for a in t.dims: + group.append(schema.get(a)) args.append(t._tensor) args.append(group) - keep = [n for n in seen_names if n not in names] + for n in names: - if n not in seen_names: + if n not in schema.axes: raise RuntimeError("No dimension %s to contract along" % n) - args.append([ids[n] for n in keep]) + + keep = schema.drop(names) + args.append([schema.get(a) for a in keep.axes]) return cls.tensor(oe.contract(*args, backend="torch"), keep) @staticmethod @@ -72,25 +73,25 @@ def chunk(tensor, number_of_chunks, dim): @staticmethod def stack(tensors, name): - old_names = tensors[0]._schema._names + old_axes = tensors[0].dims for i in range(1, len(tensors)): - if tensors[i]._schema._names != old_names: + if tensors[i].dims != old_axes: if set(tensors[i]._schema._names) != set( tensors[0]._schema._names ): raise RuntimeError( "Tensors to stack don't have matching dimension names" ) - tensors[i] = tensors[i]._force_order(tensors[0]._schema._names) + tensors[i] = tensors[i]._force_order(tensors[0].dims) to_stack = [tensor.values for tensor in tensors] - old_names = list(old_names) + old_names = list(tensors[0].dims) old_names.insert(0, name) return ntorch.tensor(torch.stack(to_stack, dim=0), old_names) @staticmethod def cat(tensors, dim, name=None): "Concate a list of named tensors along dim." - if isinstance(dim, str): + if isinstance(dim, str) or isinstance(dim, _Axis): dim = [dim] * len(tensors) if name is not None: tensors = [t.rename(d, name) for t, d in zip(tensors, dim)] @@ -189,7 +190,7 @@ def gather(input, dim, index, index_dim): @staticmethod def masked_select(input, mask, name="on"): - order = mask._mask_broadcast_order(input._schema._names) + order = mask._mask_broadcast_order(input) a1 = input._force_order(order) b1 = mask._force_order(order) return NamedTensor(a1.values.masked_select(b1.values), name) diff --git a/namedtensor/torch_helpers.py b/namedtensor/torch_helpers.py index a868c30..3a87c76 100644 --- a/namedtensor/torch_helpers.py +++ b/namedtensor/torch_helpers.py @@ -2,19 +2,72 @@ import torch from .core import NamedTensorBase, assert_match from .utils import make_tuple +import numbers class NamedTensor(NamedTensorBase): + + access_order = [] def __getitem__(self, index): - if isinstance(index, dict): + """ + Valid formats are: + - a dict of pairs name - slice/ntensor/index + - an ntensor with (matching names) + - a sequence of strings + - a sequence of slices/ntensors/indexes + """ + if isinstance(index, tuple): + cur = self + ops_dict = {} + count = 0 + restart = True + for x in index: + if isinstance(x, str): + if restart: + if ops_dict: + cur = cur[ops_dict] + self.access_order = [] + ops_dict = {} + count = 0 + restart = False + self.access_order.append(x) + elif isinstance(x, slice) and isinstance(x.start, str): + if x.step is None: + ops_dict[x.start] = x.stop + else: + ops_dict[x.start] = slice(x.stop, x.step) + restart = True + else: + ops_dict[self.access_order[count]] = x + count += 1 + restart = True + if ops_dict: + cur = cur[ops_dict] + return cur + + elif isinstance(index, str): + self.access_order = [ index, ] + return self + elif isinstance(index, slice) and isinstance(index.start, str): + return + elif isinstance(index, (slice, NamedTensor, numbers.Number)): + if len(self.access_order) != 1: + raise RuntimeError("Dimension name not set. Use v['name'].") + return self[dict([ (self.access_order[0], index), ])] + + elif isinstance(index, dict): cur = self for k, v in index.items(): if isinstance(v, slice): + cur = cur.narrow(k, v.start, v.stop - v.start) elif isinstance(v, NamedTensor): cur = cur.index_select(k, v) - else: + elif isinstance(v, numbers.Number): cur = cur.get(k, v) + else: + raise RuntimeError("Index must be number, slice, " + + "or NamedTensor, got {}".format(v)) return cur elif isinstance(index, NamedTensor): if ( @@ -69,7 +122,7 @@ def copy_(self, other): return self._setter(other, "copy_") def _setter(self, other, method, vals=[]): - order = other._mask_broadcast_order(self._schema._names) + order = other._mask_broadcast_order(self) other = other._force_order(order) args = [other.values] + vals @@ -290,7 +343,7 @@ def call(*args, **kwargs): def call(other, *args): if isinstance(other, NamedTensor): b = other - order = self._broadcast_order(b._schema._names) + order = self._broadcast_order(b) a1 = self._force_order(order) b1 = b._force_order(order) method = getattr(a1._tensor, methodname)