From 10df75a7b3392946b323dbb9ed058748fb6e6987 Mon Sep 17 00:00:00 2001
From: Alex Rush <sasha.rush@gmail.com>
Date: Mon, 27 Jan 2020 22:24:24 -0500
Subject: [PATCH] .

---
 torch_struct/__init__.py        |  2 ++
 torch_struct/factorial_hmm.py   | 56 +++++++++++++++++++++++++++++++++
 torch_struct/test_algorithms.py | 11 +++++++
 3 files changed, 69 insertions(+)
 create mode 100644 torch_struct/factorial_hmm.py

diff --git a/torch_struct/__init__.py b/torch_struct/__init__.py
index 2e58fbbc..d7c43b48 100644
--- a/torch_struct/__init__.py
+++ b/torch_struct/__init__.py
@@ -14,6 +14,7 @@
 from .cky_crf import CKY_CRF
 from .deptree import DepTree
 from .linearchain import LinearChain
+from .factorial_hmm import FactorialHMM
 from .semimarkov import SemiMarkov
 from .alignment import Alignment
 from .rl import SelfCritical
@@ -43,6 +44,7 @@
     CKY_CRF,
     DepTree,
     LinearChain,
+    FactorialHMM,
     SemiMarkov,
     LogSemiring,
     StdSemiring,
diff --git a/torch_struct/factorial_hmm.py b/torch_struct/factorial_hmm.py
new file mode 100644
index 00000000..46ecab62
--- /dev/null
+++ b/torch_struct/factorial_hmm.py
@@ -0,0 +1,56 @@
+import torch
+from .helpers import _Struct, Chart
+import math
+
+
+class FactorialHMM(_Struct):
+    def _dp(self, scores, lengths=None, force_grad=False):
+        transition, emission = scores
+        semiring = self.semiring
+        transition.requires_grad_(True)
+        emission.requires_grad_(True)
+        batch, L, K, K2 = transition.shape
+        batch, N, K, K, K = emission.shape
+        assert L == 3
+        assert K == K2
+
+        transition = semiring.convert(transition)
+        emission = semiring.convert(emission)
+
+
+        ssize = semiring.size()
+
+        state_out = Chart((batch, N, L, K), transition, semiring)
+        state_in = Chart((batch, N, L, K), transition, semiring)
+        emit = Chart((batch, N, K, K, K), transition, semiring)
+
+        emit[0, :] = emission[:, :, 0]
+
+        def make_out(val, i):
+            state_out[i, 0] = semiring.sum(semiring.sum(val, 4), 2)
+            state_out[i, 1] = semiring.sum(semiring.sum(val, 4), 3)
+            state_out[i, 2] = semiring.sum(semiring.sum(val, 3), 2)
+
+        make_out(emit[0, :], 0)
+
+        for i in range(1, N):
+            # print(transition.shape, state_out[i-1, :].unsqueeze(-2).shape)
+            state_in = semiring.dot(state_out[i-1, :].unsqueeze(-2), transition)
+            # print(state_in[..., None, :, None].shape,  emission[:, :, i].shape)
+            emit[i, :] = semiring.times(state_in[..., 0, :, None, None],
+                                        state_in[..., 1, None, :, None],
+                                        state_in[..., 2, None, None, :],
+                                        emission[:, :, i])
+            make_out(emit[i, :], i)
+
+        log_Z = semiring.sum(emit[N-1, :])
+        return log_Z, [scores], None
+
+    @staticmethod
+    def _rand():
+        batch = torch.randint(2, 5, (1,))
+        K = torch.randint(2, 5, (1,))
+        N = torch.randint(2, 5, (1,))
+        transition = torch.rand(batch, 3, K, K)
+        emission = torch.rand(batch, N, K, K, K)
+        return (transition, emission), (batch.item(), N.item())
diff --git a/torch_struct/test_algorithms.py b/torch_struct/test_algorithms.py
index 980ae3c2..e88fc8a1 100644
--- a/torch_struct/test_algorithms.py
+++ b/torch_struct/test_algorithms.py
@@ -2,6 +2,7 @@
 from .cky_crf import CKY_CRF
 from .deptree import DepTree, deptree_nonproj, deptree_part
 from .linearchain import LinearChain
+from .factorial_hmm import FactorialHMM
 from .semimarkov import SemiMarkov
 from .alignment import Alignment
 from .semirings import (
@@ -325,6 +326,15 @@ def test_params(data, seed):
         c = vals.grad.detach()
         assert torch.isclose(b, c).all()
 
+def test_factorial_hmm():
+    model = FactorialHMM
+    semiring = StdSemiring
+    struct = model(semiring)
+    vals, (batch, N) = model._rand()
+    alpha = struct.sum(vals)
+    print(alpha)
+    assert False
+
 
 @given(data())
 @settings(max_examples=50, deadline=None)
@@ -416,6 +426,7 @@ def test_hmm():
     LinearChain().sum(out)
 
 
+
 @given(data())
 def test_sparse_max(data):
     model = data.draw(sampled_from([LinearChain]))