Skip to content

Add alignment CRF test. Fix missing fill_() #109

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: master
Choose a base branch
from
24 changes: 24 additions & 0 deletions tests/test_alignment_crf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import torch
import torch_struct
import pytest


@pytest.mark.skipif(not torch.cuda.is_available(), reason='needs CUDA')
def test_alignment_crf_shapes():
batch, N, M = 2, 4, 5
log_potentials = torch.rand(batch, N, M, 3).cuda()

dist = torch_struct.AlignmentCRF(log_potentials)
assert (batch, N, M, 3) == dist.argmax.shape
assert (batch, N, M, 3) == dist.marginals.shape
assert (batch,) == dist.partition.shape

# Fail due to AttributeError: 'BandedMatrix' object has no attribute
# 'unsqueeze'
assert (batch,) == dist.entropy.shape
# assert (9, batch, N, M, 3) == dist.sample([9]).shape

# Fails due to: RuntimeError: Expected condition, x and y to be on
# the same device, but condition is on cpu and x and y are on
# cuda:0 and cuda:0 respectively
# assert (8, batch,) == dist.topk(8).shape
14 changes: 14 additions & 0 deletions tests/test_matmul.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import torch
from hypothesis import given
from hypothesis.strategies import integers
import genbmm

bint = integers(min_value=1, max_value=4)
mint = integers(min_value=6, max_value=8)
nint = integers(min_value=3, max_value=5)
kint = integers(min_value=9, max_value=11)


@given(bint, mint, nint, kint)
def test_matmul(batch, m, n, k):
a, b = torch.rand((m, n)), torch.rand((n, k))
3 changes: 2 additions & 1 deletion tests/test_semirings.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from hypothesis import given
from hypothesis import given, settings
from hypothesis.strategies import integers


Expand All @@ -17,6 +17,7 @@


@given(lint, lint, lint)
@settings(deadline=None) # Avoid spurious warnings when first run
def test_max(a, b, c):
torch.manual_seed(0)
t1 = torch.rand(a, 1, c).requires_grad_(True)
Expand Down
11 changes: 7 additions & 4 deletions torch_struct/alignment.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
import torch
from .helpers import _Struct
import math
import warnings

try:
import genbmm

except ImportError:
pass
warnings.warn('Could not import genbmm. '
'However, genbmm is only used for CUDA operations.')

from .semirings import LogSemiring
from .semirings.fast_semirings import broadcast
Expand Down Expand Up @@ -97,9 +100,9 @@ def _dp_scan(self, log_potentials, lengths=None, force_grad=False):
# Create finalizing paths.
point = (l + M) // 2

charta[1][:, b, point:, 1, ind, :, :, Mid] = semiring.one_(
charta[1][:, b, point:, 1, ind, :, :, Mid]
)
init = torch.zeros(charta[1].shape, device=charta[1].device).bool()
init[:, b, point:, 1, ind, :, :, Mid].fill_(True)
charta[1] = semiring.fill(charta[1], init, semiring.one)

for b in range(lengths.shape[0]):
point = (lengths[b] + M) // 2
Expand Down
3 changes: 3 additions & 0 deletions torch_struct/semirings/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,9 @@ def forward(ctx, input, dim):
def backward(ctx, grad_output):

logits, part, dim = ctx.saved_tensors
# Replace infinite logits with max float, otherwise softmax gives NaNs
# Perhaps this could be done earlier (during forward pass)?
logits[logits == float('inf')] = torch.finfo(logits.dtype).max
grad_input = None
if ctx.needs_input_grad[0]:

Expand Down
3 changes: 1 addition & 2 deletions torch_struct/semirings/semirings.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,6 @@ class LogSemiring(_BaseLog):

Gradients give marginals.
"""

@classmethod
def matmul(cls, a, b):
if has_genbmm and isinstance(a, genbmm.BandedMatrix):
Expand Down Expand Up @@ -192,7 +191,7 @@ def convert(cls, orig_potentials):
dtype=orig_potentials.dtype,
device=orig_potentials.device,
)
potentials = cls.fill(potentials, torch.tensor(True), cls.zero)
potentials = cls.fill(potentials, torch.tensor(True, device=potentials.device), cls.zero.to(potentials.device))
potentials[0] = orig_potentials
return potentials

Expand Down