From 10df75a7b3392946b323dbb9ed058748fb6e6987 Mon Sep 17 00:00:00 2001 From: Alex Rush <sasha.rush@gmail.com> Date: Mon, 27 Jan 2020 22:24:24 -0500 Subject: [PATCH] . --- torch_struct/__init__.py | 2 ++ torch_struct/factorial_hmm.py | 56 +++++++++++++++++++++++++++++++++ torch_struct/test_algorithms.py | 11 +++++++ 3 files changed, 69 insertions(+) create mode 100644 torch_struct/factorial_hmm.py diff --git a/torch_struct/__init__.py b/torch_struct/__init__.py index 2e58fbbc..d7c43b48 100644 --- a/torch_struct/__init__.py +++ b/torch_struct/__init__.py @@ -14,6 +14,7 @@ from .cky_crf import CKY_CRF from .deptree import DepTree from .linearchain import LinearChain +from .factorial_hmm import FactorialHMM from .semimarkov import SemiMarkov from .alignment import Alignment from .rl import SelfCritical @@ -43,6 +44,7 @@ CKY_CRF, DepTree, LinearChain, + FactorialHMM, SemiMarkov, LogSemiring, StdSemiring, diff --git a/torch_struct/factorial_hmm.py b/torch_struct/factorial_hmm.py new file mode 100644 index 00000000..46ecab62 --- /dev/null +++ b/torch_struct/factorial_hmm.py @@ -0,0 +1,56 @@ +import torch +from .helpers import _Struct, Chart +import math + + +class FactorialHMM(_Struct): + def _dp(self, scores, lengths=None, force_grad=False): + transition, emission = scores + semiring = self.semiring + transition.requires_grad_(True) + emission.requires_grad_(True) + batch, L, K, K2 = transition.shape + batch, N, K, K, K = emission.shape + assert L == 3 + assert K == K2 + + transition = semiring.convert(transition) + emission = semiring.convert(emission) + + + ssize = semiring.size() + + state_out = Chart((batch, N, L, K), transition, semiring) + state_in = Chart((batch, N, L, K), transition, semiring) + emit = Chart((batch, N, K, K, K), transition, semiring) + + emit[0, :] = emission[:, :, 0] + + def make_out(val, i): + state_out[i, 0] = semiring.sum(semiring.sum(val, 4), 2) + state_out[i, 1] = semiring.sum(semiring.sum(val, 4), 3) + state_out[i, 2] = semiring.sum(semiring.sum(val, 3), 2) + + make_out(emit[0, :], 0) + + for i in range(1, N): + # print(transition.shape, state_out[i-1, :].unsqueeze(-2).shape) + state_in = semiring.dot(state_out[i-1, :].unsqueeze(-2), transition) + # print(state_in[..., None, :, None].shape, emission[:, :, i].shape) + emit[i, :] = semiring.times(state_in[..., 0, :, None, None], + state_in[..., 1, None, :, None], + state_in[..., 2, None, None, :], + emission[:, :, i]) + make_out(emit[i, :], i) + + log_Z = semiring.sum(emit[N-1, :]) + return log_Z, [scores], None + + @staticmethod + def _rand(): + batch = torch.randint(2, 5, (1,)) + K = torch.randint(2, 5, (1,)) + N = torch.randint(2, 5, (1,)) + transition = torch.rand(batch, 3, K, K) + emission = torch.rand(batch, N, K, K, K) + return (transition, emission), (batch.item(), N.item()) diff --git a/torch_struct/test_algorithms.py b/torch_struct/test_algorithms.py index 980ae3c2..e88fc8a1 100644 --- a/torch_struct/test_algorithms.py +++ b/torch_struct/test_algorithms.py @@ -2,6 +2,7 @@ from .cky_crf import CKY_CRF from .deptree import DepTree, deptree_nonproj, deptree_part from .linearchain import LinearChain +from .factorial_hmm import FactorialHMM from .semimarkov import SemiMarkov from .alignment import Alignment from .semirings import ( @@ -325,6 +326,15 @@ def test_params(data, seed): c = vals.grad.detach() assert torch.isclose(b, c).all() +def test_factorial_hmm(): + model = FactorialHMM + semiring = StdSemiring + struct = model(semiring) + vals, (batch, N) = model._rand() + alpha = struct.sum(vals) + print(alpha) + assert False + @given(data()) @settings(max_examples=50, deadline=None) @@ -416,6 +426,7 @@ def test_hmm(): LinearChain().sum(out) + @given(data()) def test_sparse_max(data): model = data.draw(sampled_from([LinearChain]))