Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Derived functions #11

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
Add functions like tensordot and matmul derived from einsum
davidweichiang committed Jan 4, 2022

Verified

This commit was created on GitHub.com and signed with GitHub’s verified signature.
commit 08b8096d3878ce3c3d9ad0c04f67d64e9d1698b2
5 changes: 5 additions & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
@@ -88,6 +88,11 @@ scratch every time einsum is called.

In addition to `einsum`, the module also exposes a differentiable `log_einsum` and a non-differentiable `log_viterbi_einsum`.

Derived Functions
-----------------

For convenience, the module also implements functions `tensordot`, `matmul`, `inner`, `dot`, `mm`, `bmm`, `mv`, and `outer` in terms of einsum. All of these functions take a `block_size` argument and an `einsum` argument, which defaults to `torch_semiring_einsum.einsum` but can be set to other einsum replacements like `log_einsum`, etc.

API Documentation
-----------------

79 changes: 79 additions & 0 deletions tests/test_derived.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
import unittest
import torch
import torch_semiring_einsum as semiring
import numpy

class TestDerived(unittest.TestCase):
def setUp(self):
self.device = torch.device('cpu')
self.generator = torch.manual_seed(123)

def test_tensordot(self):
A, B, C, D, E, F = 2, 3, 5, 7, 11, 13
for i, (x_size, y_size, inner_dims) in enumerate([
((A, B, C, D), (C, D, E, F), 2),
((A, B, C, D), (D, E, F), 1),
((A, B, C, D), (E, F), 0),
]):
with self.subTest(i):
x = torch.empty(x_size, device=self.device)
x.uniform_(-10., 10., generator=self.generator)
y = torch.empty(y_size, device=self.device)
y.uniform_(-10., 10., generator=self.generator)

semiring_out = semiring.tensordot(x, y, inner_dims, block_size=1)
torch_out = torch.tensordot(x, y, inner_dims)
numpy.testing.assert_allclose(semiring_out, torch_out, rtol=1e-3)

semiring_out = semiring.tensordot(x, y, inner_dims, block_size=1, einsum=semiring.log_einsum)
torch_out = torch.tensordot(x.exp(), y.exp(), inner_dims).log()
numpy.testing.assert_allclose(semiring_out, torch_out, rtol=1e-2)

def test_matmul(self):
J, K, M, N, P = 2, 3, 5, 7, 11

for i, (x_size, y_size) in enumerate([
# ((J, 1, N, M), (K, M, P)), # from torch.matmul docs
((J, K, N, M), (K, M, P)),
((J, K, N, M), (M,)),
((M), (K, M, P)),
]):
with self.subTest(i):
x = torch.empty(x_size)
x.uniform_(-10., 10., generator=self.generator)
y = torch.empty(y_size)
y.uniform_(-10., 10., generator=self.generator)

semiring_out = semiring.matmul(x, y, block_size=1)
torch_out = torch.matmul(x, y)
numpy.testing.assert_allclose(semiring_out, torch_out, rtol=1e-3)

semiring_out = semiring.matmul(x, y, block_size=1, einsum=semiring.log_einsum)
torch_out = torch.matmul(x.exp(), y.exp()).log()
numpy.testing.assert_allclose(semiring_out, torch_out, rtol=1e-2)

def test_inner(self):
A, B, C, D, E = 2, 3, 5, 7, 11

for i, (x_size, y_size) in enumerate([
((A, B), (C, D, B)),
((A, B), ()),
((), (A, B)),
]):
with self.subTest(i):
x = torch.empty(x_size)
x.uniform_(-10., 10., generator=self.generator)
y = torch.empty(y_size)
y.uniform_(-10., 10., generator=self.generator)

semiring_out = semiring.inner(x, y, block_size=1)
torch_out = torch.inner(x, y)
numpy.testing.assert_allclose(semiring_out, torch_out, rtol=1e-3)

semiring_out = semiring.inner(x, y, block_size=1, einsum=semiring.log_einsum)
torch_out = torch.inner(x.exp(), y.exp()).log()
numpy.testing.assert_allclose(semiring_out, torch_out, rtol=1e-2)


if __name__ == '__main__':
unittest.main()
1 change: 1 addition & 0 deletions torch_semiring_einsum/__init__.py
Original file line number Diff line number Diff line change
@@ -6,6 +6,7 @@
from .log_forward import log_einsum_forward
from .log_backward import log_einsum_backward
from .log_viterbi_forward import log_viterbi_einsum_forward
from .derived import *

einsum = combine(real_einsum_forward, real_einsum_backward)
r"""Differentiable version of ordinary (real) einsum.
94 changes: 94 additions & 0 deletions torch_semiring_einsum/derived.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
__all__ = ['tensordot', 'matmul', 'inner', 'dot', 'mm', 'bmm', 'mv', 'outer']

from .real_forward import real_einsum_forward
from .real_backward import real_einsum_backward
from .equation import compile_equation
from .function import combine

default_einsum = combine(real_einsum_forward, real_einsum_backward)

def index_range(start, stop):
e = []
for i in range(start, stop):
e.append(chr(ord('a')+i))
return ''.join(e)

def tensordot(a, b, ndim, *, block_size, einsum=default_einsum):
if isinstance(ndim, (tuple, list)):
raise NotImplementedError()
e = (index_range(0, a.ndim) +
',' +
index_range(a.ndim-ndim,a.ndim+b.ndim-ndim) +
'->' +
index_range(0, a.ndim-ndim) +
index_range(a.ndim, a.ndim+b.ndim-ndim))
e = compile_equation(e)
return einsum(e, a, b, block_size=block_size)

def matmul(a, b, *, block_size, einsum=default_einsum):
"""Like torch.matmul"""
if a.ndim == 0 or b.ndim == 0:
raise ValueError('matmul of 0-dimensional tensors is not allowed')

ndim = max(a.ndim, b.ndim)

oi = index_range(3, ndim+1)
if a.ndim == 1:
ai = 'b'
else:
ai = index_range(ndim+1-(a.ndim-2), ndim+1) + 'ab'
oi += 'a'

if b.ndim == 1:
bi = 'b'
else:
bi = index_range(ndim+1-(b.ndim-2), ndim+1) + 'bc'
oi += 'c'

e = compile_equation(ai+','+bi+'->'+oi)
return einsum(e, a, b, block_size=block_size)

def inner(a, b, *, block_size, einsum=default_einsum):
if a.ndim == 0:
e = ','+index_range(0, b.ndim) + '->' + index_range(0, b.ndim)
elif b.ndim == 0:
e = index_range(0, a.ndim) + ',->' + index_range(0, a.ndim)
else:
ai = index_range(1, a.ndim)
bi = index_range(a.ndim+1, a.ndim+b.ndim)
e = ai + 'a,' + bi + 'a->' + ai + bi
e = compile_equation(e)
return einsum(e, a, b, block_size=block_size)

dot_equation = compile_equation('i,i->i')
def dot(a, b, *, block_size, einsum=default_einsum):
if a.ndim != 1 or b.ndim != 1:
raise ValueError('arguments must be 1-dimensional')
return einsum(dot_equation, a, b, block_size=block_size)

mm_equation = compile_equation('ij,jk->ik')
def mm(a, b, *, block_size, einsum=default_einsum):
if a.ndim != 2 or b.ndim != 2:
raise ValueError('arguments must be 2-dimensional')
return einsum(mm_equation, a, b, block_size=block_size)

mv_equation = compile_equation('ij,j->i')
def mm(a, b, *, block_size, einsum=default_einsum):
if a.ndim != 2 or b.ndim != 1:
raise ValueError('arguments must be 2-dimensional and 1-dimensional, respectively')
return einsum(mv_equation, a, b, block_size=block_size)

bmm_equation = compile_equation('bij,bjk->bik')
def bmm(a, b, *, block_size, einsum=default_einsum):
if a.ndim != 3 or b.ndim != 3:
raise ValueError('arguments must be 3-dimensional')
return einsum(bmm_equation, a, b, block_size=block_size)

outer_equation = compile_equation('i,j->ij')
def outer(a, b, *, block_size, einsum=default_einsum):
if a.ndim != 1 or b.ndim != 1:
raise ValueError('arguments must be 1-dimensional')
return einsum(outer_equation, a, b, block_size=block_size)
ger = outer