Skip to content

Commit 96f66d2

Browse files
committed
address #306
1 parent 12249dc commit 96f66d2

File tree

2 files changed

+6
-6
lines changed

2 files changed

+6
-6
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
setup(
77
name = 'vit-pytorch',
88
packages = find_packages(exclude=['examples']),
9-
version = '1.6.6',
9+
version = '1.6.7',
1010
license='MIT',
1111
description = 'Vision Transformer (ViT) - Pytorch',
1212
long_description=long_description,

vit_pytorch/na_vit.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -198,7 +198,7 @@ def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, ml
198198
self.calc_token_dropout = token_dropout_prob
199199

200200
elif isinstance(token_dropout_prob, (float, int)):
201-
assert 0. < token_dropout_prob < 1.
201+
assert 0. <= token_dropout_prob < 1.
202202
token_dropout_prob = float(token_dropout_prob)
203203
self.calc_token_dropout = lambda height, width: token_dropout_prob
204204

@@ -249,7 +249,7 @@ def forward(
249249
group_images = False,
250250
group_max_seq_len = 2048
251251
):
252-
p, c, device, has_token_dropout = self.patch_size, self.channels, self.device, exists(self.calc_token_dropout)
252+
p, c, device, has_token_dropout = self.patch_size, self.channels, self.device, exists(self.calc_token_dropout) and self.training
253253

254254
arange = partial(torch.arange, device = device)
255255
pad_sequence = partial(orig_pad_sequence, batch_first = True)
@@ -260,7 +260,7 @@ def forward(
260260
batched_images = group_images_by_max_seq_len(
261261
batched_images,
262262
patch_size = self.patch_size,
263-
calc_token_dropout = self.calc_token_dropout,
263+
calc_token_dropout = self.calc_token_dropout if self.training else None,
264264
max_seq_len = group_max_seq_len
265265
)
266266

@@ -314,8 +314,8 @@ def forward(
314314
# derive key padding mask
315315

316316
lengths = torch.tensor([seq.shape[-2] for seq in batched_sequences], device = device, dtype = torch.long)
317-
max_length = arange(lengths.amax().item())
318-
key_pad_mask = rearrange(lengths, 'b -> b 1') <= rearrange(max_length, 'n -> 1 n')
317+
seq_arange = arange(lengths.amax().item())
318+
key_pad_mask = rearrange(seq_arange, 'n -> 1 n') < rearrange(lengths, 'b -> b 1')
319319

320320
# derive attention mask, and combine with key padding mask from above
321321

0 commit comments

Comments
 (0)