-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathlimited_nondeterministic_stack.py
447 lines (402 loc) · 17.6 KB
/
limited_nondeterministic_stack.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
import attr
import torch
from torch_semiring_einsum import compile_equation
from torch_rnn_tools import UnidirectionalRNN
from lib.pytorch_tools.set_slice import set_slice
from lib.semiring import log, log_viterbi
from lib.data_structures.linked_list import LinkedList
from .nondeterministic_stack import (
NondeterministicStackRNN,
TooManyUpdates,
UpdateResult,
next_alpha_column,
next_gamma_column,
ensure_not_negative,
Operation,
PushOperation,
ReplaceOperation,
PopOperation
)
class LimitedNondeterministicStackRNN(NondeterministicStackRNN):
def __init__(self, input_size, num_states, stack_alphabet_size,
window_size, controller, normalize_operations=False,
include_states_in_reading=True, **kwargs):
super().__init__(
input_size=input_size,
num_states=num_states,
stack_alphabet_size=stack_alphabet_size,
controller=controller,
normalize_operations=normalize_operations,
include_states_in_reading=include_states_in_reading,
**kwargs
)
self.window_size = window_size
# Use a separate type for the initial state used with `initial_state()`,
# `initial_state=` and `return_state=` that contains only the chunks of
# gamma and alpha that need to be passed between batches.
@attr.s
class InitialState(UnidirectionalRNN.State):
state = attr.ib()
stack = attr.ib()
def detach(self):
return type(self)(self.state.detach(), self.stack.detach())
def batch_size(self):
return self.state.batch_size()
def slice_batch(self, s):
return type(self)(self.state.slice_batch(s), self.stack.slice_batch(s))
class InitialStackState(UnidirectionalRNN.State):
def __init__(self, gamma, alpha, semiring):
super().__init__()
self.gamma = gamma
self.alpha = alpha
self.semiring = semiring
def transform_tensors(self, func):
# This implements detach() and slice_batch().
return type(self)(
self.semiring.on_tensor(self.gamma, func),
self.semiring.on_tensor(self.alpha, func),
self.semiring
)
def batch_size(self):
return self.gamma.size(0)
def forward(self, input_sequence, block_size, initial_state=None,
return_state=False, include_first=True, return_actions=False,
return_readings=False):
semiring = log
batch_size = input_sequence.size(0)
sequence_length = input_sequence.size(1)
if initial_state is None:
initial_state = self.initial_state(batch_size, semiring)
else:
if not isinstance(initial_state, self.InitialState):
raise TypeError
# Convert initial_state from the InitialState type to State.
stack_kwargs = dict(
sequence_length=sequence_length,
semiring=semiring,
block_size=block_size
)
if initial_state.stack is None:
stack = None
stack_args = []
else:
stack = self.get_new_stack(
initial_stack_state=initial_state.stack,
batch_size=batch_size,
**stack_kwargs
)
stack_args = None
stack_kwargs = None
state = self.State(
rnn=self,
hidden_state=initial_state.state,
previous_stack=stack,
return_actions=return_actions,
# It might be helpful to grab this somehow from initial_state,
# but if include_first is False, this won't be used anyway.
previous_actions=None,
return_readings=return_readings,
previous_reading=None,
stack_args=stack_args,
stack_kwargs=stack_kwargs
)
result = super(NondeterministicStackRNN, self).forward(
input_sequence,
initial_state=state,
return_state=return_state,
include_first=include_first
)
if return_state:
output, last_state = result
stack = last_state.previous_stack
last_state = self.InitialState(
last_state.hidden_state,
self.get_forwarded_stack_state(stack, semiring)
)
result = output, last_state
return result
def get_forwarded_stack_state(self, stack, semiring):
# Convert the last state from the State type to the InitialState type.
# Save some slices of gamma and alpha to forward to the next batch.
D = self.window_size
return self.InitialStackState(
gamma=semiring.on_tensor(stack.gamma, lambda x: x[:, -(D-1):, -(D-1):]),
alpha=semiring.on_tensor(stack.alpha, lambda x: x[:, -D:]),
semiring=semiring
)
def initial_state(self, batch_size, semiring=log):
return self.InitialState(
self.controller.initial_state(batch_size),
None
)
def get_new_viterbi_stack(self, batch_size, sequence_length, semiring,
block_size, initial_stack_state=None):
tensor = next(self.parameters())
return LimitedNondeterministicStack(
batch_size=batch_size,
num_states=self.num_states,
stack_alphabet_size=self.stack_alphabet_size,
sequence_length=sequence_length,
window_size=self.window_size,
include_states_in_reading=self.include_states_in_reading,
initial_state=initial_stack_state,
semiring=semiring,
block_size=block_size,
dtype=tensor.dtype,
device=tensor.device
)
def get_viterbi_decoder(self, alpha_columns, gamma_j_nodes, alpha_j_nodes):
return LimitedViterbiDecoder(
alpha_columns,
gamma_j_nodes,
alpha_j_nodes,
self.window_size
)
class LimitedNondeterministicStack:
def __init__(self, batch_size, num_states, stack_alphabet_size,
sequence_length, window_size, include_states_in_reading,
initial_state, semiring, block_size, dtype, device):
if not (window_size >= 1):
raise ValueError('window_size must be at least 1')
# Let j_1 be the first time step for which we will be computing gamma
# and alpha, that is, gamma[i, j_1] and alpha[j_1].
# We will be computing gamma and alpha for j = j_1 through j_1 + T - 1.
# Since i >= j - D and i <= j - 1, computing gamma[i, j_1] only
# requires computing gamma for i = j_1 - D through j_1 - 1.
# Computing gamma[i -> j_1] for i = j_1 - D through j_1 - 1 only
# requires the values of gamma[i, j] for j_1 - D + 1 <= j <= j_1 - 1.
# Computing alpha[j_1] only requires alpha[j] for j = j_1 - D through
# j_1 - 1 and gamma[i -> j_1] for i = j_1 - D through j_1 - 1.
# This means that the initial stack state consists of:
# (1) a square of gamma for gamma[j_1 - D + 1:j_1, j_1 - D + 1:j_1]
# (2) a slice of alpha for alpha[j_1 - D:j_1].
super().__init__()
B = self.batch_size = batch_size
Q = self.num_states = num_states
S = self.stack_alphabet_size = stack_alphabet_size
T = self.sequence_length = sequence_length
D = self.window_size = window_size
self.include_states_in_reading = include_states_in_reading
self.semiring = semiring
self.block_size = block_size
self.device = device
# Verify that pieces of alpha and gamma passed in from a previous batch
# have the correct dimensions.
if initial_state is not None:
if semiring.get_tensor(initial_state.gamma).size() != (B, D-1, D-1, Q, S, Q, S):
raise ValueError
if semiring.get_tensor(initial_state.alpha).size() != (B, D, Q, S):
raise ValueError
# Initialize the current timestep to -1. It is incremented at the
# beginning of update(), which computes column j of gamma. A subsequent
# call to reading() should return entry j of alpha.
# j is relative to the beginning of the current chunk and is reset at
# the beginning of every chunk. For the first chunk, j = 0 corresponds
# to timestep 1.
self.j = -1
self.alpha = semiring.zeros((B, D+T, Q, S), dtype=dtype, device=device)
self.gamma = semiring.zeros((B, T+D-1, T+D-1, Q, S, Q, S), dtype=dtype, device=device)
if initial_state is None:
# If no initial state was given, automatically run the first
# timestep, which populates the first entry of alpha. The stack is
# in an invalid state otherwise.
self._first_update()
else:
# Initialize alpha and gamma with the pieces passed in from the
# previous batch.
semiring.get_tensor(self.alpha)[:, :alpha_j_index(D, 0)] = semiring.get_tensor(initial_state.alpha)
semiring.get_tensor(self.gamma)[:, :gamma_i_index(D, -1), :gamma_j_index(D, 0)] = semiring.get_tensor(initial_state.gamma)
def _first_update(self):
if self.j != -1:
raise ValueError
j = self.j = self.j + 1
D = self.window_size
semiring = self.semiring
# Note that j is relative to the beginning of the chunk. For the first
# chunk, j corresponds to timestep j (they are the same).
semiring.get_tensor(self.gamma)[:, gamma_i_index(D, -1), gamma_j_index(D, 0), 0, 0, 0, 0] = semiring.one
# alpha[-1] and alpha[0] are both initialized here.
semiring.get_tensor(self.alpha)[:, (alpha_j_index(D, -1), alpha_j_index(D, 0)), 0, 0] = semiring.one
def update(self, push, repl, pop, return_gamma_prime=False):
# push : B x Q x S x Q x S
# repl : B x Q x S x Q x S
# pop : B x Q x S x Q
if not (self.j + 1 < self.sequence_length):
raise TooManyUpdates(
f'attempting to compute timestep {self.j+1} (0-indexed), but '
f'only {self.sequence_length} timesteps were allocated with '
f'sequence_length')
semiring = self.semiring
D = self.window_size
block_size = self.block_size
# j is initialized to -1. The first call to update() will compute
# column 0 of gamma and alpha.
self.j = j = self.j + 1
# gamma_j : B x D x Q x S x Q x S
gamma_j, gamma_prime_j = next_gamma_column(
# B x D-1 x D-1 x Q x S x Q x S
semiring.on_tensor(self.gamma, lambda x: x[
:,
gamma_i_index(D, j-1-(D-1)):gamma_i_index(D, j-1),
gamma_j_index(D, j-(D-1)):gamma_j_index(D, j)
]),
push,
repl,
pop,
semiring,
block_size,
return_gamma_prime,
gamma_prime_zero_grad=True
)
# This is just a long way of updating column j of gamma in-place.
self.gamma = semiring.combine(
[self.gamma, gamma_j],
lambda args: set_slice(
args[0],
(
slice(None),
slice(gamma_i_index(D, j-D), gamma_i_index(D, j)),
gamma_j_index(D, j)
),
args[1]))
# alpha_j : B x Q x S
alpha_j = next_alpha_column(
# B x D x Q x S
semiring.on_tensor(self.alpha, lambda x: x[:, alpha_j_index(D, j-D):alpha_j_index(D, j)]),
# B x D x Q x S x Q x S
gamma_j,
semiring,
block_size
)
# This is just a long way of updating entry j of alpha in-place.
self.alpha = semiring.combine(
[self.alpha, alpha_j],
lambda args: set_slice(
args[0],
(slice(None), alpha_j_index(D, j)),
args[1]))
return UpdateResult(j, gamma_j, alpha_j, gamma_prime_j)
def reading(self):
# Return log P_j(r, y).
# alpha[0...j] has already been computed.
semiring = self.semiring
D = self.window_size
# alpha_j : B x Q x S
alpha_j = self.alpha[:, alpha_j_index(D, self.j)]
if self.include_states_in_reading:
B = alpha_j.size(0)
# result : B x (Q * S)
result = semiring.on_tensor(alpha_j, lambda x: x.view(B, -1))
else:
# result : B x S
result = semiring.sum(alpha_j, dim=1)
# Using softmax, normalize the log-weights so they sum to 1.
assert semiring is log
return torch.nn.functional.softmax(result, dim=1)
def gamma_i_index(D, i):
# Note that i is relative to the beginning of the current chunk.
return ensure_not_negative(D + i)
def gamma_j_index(D, j):
# Note that j is the timestep relative to the beginning of the current
# chunk. For the first chunk, j = 0 is timestep 1, not 0.
return ensure_not_negative(D - 1 + j)
def alpha_j_index(D, j):
return ensure_not_negative(D + j)
class LimitedViterbiDecoder:
def __init__(self, alpha_columns, gamma_j_nodes, alpha_j_nodes, window_size):
self.alpha_columns = alpha_columns
self.gamma_j_nodes = gamma_j_nodes
self.alpha_j_nodes = alpha_j_nodes
self.window_size = window_size
def decode_timestep(self, j):
"""Return the best path leading up to the reading at timestep j.
Timesteps are 0-indexed, where j = 0 is the first stack reading
computed from the base case for alpha, and j = n-1 is the last valid
timestep, corresponding to the stack reading just before the last
output. The Viterbi path leading up to timestep j is always of length
j."""
if not (0 <= j < self.sequence_length):
raise ValueError(f'timestep ({j}) must be in [0, {self.sequence_length})')
# Sum over states, then stack symbols.
alpha_j_sum_scores, alpha_j_sum_node = \
log_viterbi.sum(log_viterbi.sum(self.get_alpha_j(j), dim=1), dim=1)
batch_size = alpha_j_sum_scores.size(0)
paths = [
self.decode_alpha_j_sum(alpha_j_sum_node, b, j)
for b in range(batch_size)
]
return paths, alpha_j_sum_scores
def decode_alpha_j_sum(self, alpha_j_sum_node, b, j):
y = alpha_j_sum_node.backpointers[b]
alpha_j_sum_states_node, = alpha_j_sum_node.children
r = alpha_j_sum_states_node.backpointers[b, y]
return self.decode_alpha_j(b, j, r, y)
def decode_alpha_j(self, b, j, r, y):
if j > 0:
alpha_j_node = self.get_alpha_j_node(j)
relative_i, q, x = alpha_j_node.backpointers[b, r, y]
i = j - self.window_size + relative_i.item()
# Recurse on alpha[i] and gamma[i, j]
alpha_path = self.decode_alpha_j(b, i, q, x)
gamma_path = self.decode_gamma_j(b, i, j, q, x, r, y)
path = alpha_path
path.extend(gamma_path)
return path
elif j == 0:
return LinkedList([])
else:
raise ValueError(f'logic error: invalid value for j ({j})')
def decode_gamma_j(self, b, i, j, q, x, r, y):
gamma_j_node = self.get_gamma_j_node(j)
repl_pop_node, repl_node, push_node = gamma_j_node.children
if i < j - self.window_size:
raise ValueError
elif i < j-2:
relative_i = self.get_relative_i(i, j)
is_pop = repl_pop_node.backpointers[b, relative_i, q, x, r, y].item()
repl_node, pop_node = repl_pop_node.children
if is_pop:
return self.decode_pop(pop_node, b, i, j, q, x, r, y)
else:
return self.decode_repl(repl_node, b, i, j, q, x, r, y)
elif i == j-2:
return self.decode_repl(repl_node, b, i, j, q, x, r, y)
elif i == j-1:
return LinkedList([PushOperation(r.item(), y.item())])
else:
raise ValueError
def decode_repl(self, repl_node, b, i, j, q, x, r, y):
relative_i = self.get_relative_i(i, j)
s, z = repl_node.backpointers[b, relative_i, q, x, r, y]
path = self.decode_gamma_j(b, i, j-1, q, x, s, z)
path.append(ReplaceOperation(r.item(), y.item()))
return path
def decode_pop(self, pop_node, b, i, j, q, x, r, y):
relative_i = self.get_relative_i(i, j)
relative_k, t = pop_node.backpointers[b, relative_i, q, x, r, y]
gamma_1_node, gamma_prime_node = pop_node.children
s, z = gamma_prime_node.backpointers[b, relative_k, t, y, r]
k = j - (self.window_size - 1) + relative_k.item()
gamma_1_path = self.decode_gamma_j(b, i, k, q, x, t, y)
gamma_2_path = self.decode_gamma_j(b, k, j-1, t, y, s, z)
path = gamma_1_path
path.extend(gamma_2_path)
path.append(PopOperation(r.item()))
return path
def get_alpha_j(self, j):
# Let 0 be the first time step, where alpha[0] is the set of initial
# weights where only alpha[0][0, 0] is set to 1.
# self.alpha_columns[0] is actually alpha[1], so we need to adjust the
# index accordingly.
return self.alpha_columns[j-1]
def get_alpha_j_node(self, j):
return self.alpha_j_nodes[j-1]
def get_gamma_j_node(self, j):
# Return the node for computing all gamma entries of the form
# gamma[i, j].
return self.gamma_j_nodes[j-1]
def get_relative_i(self, i, j):
return i - (j - self.window_size)
@property
def sequence_length(self):
return len(self.alpha_columns) + 1