@@ -198,7 +198,7 @@ def __init__(self, *, image_size, patch_size, num_classes, dim, depth, heads, ml
198
198
self .calc_token_dropout = token_dropout_prob
199
199
200
200
elif isinstance (token_dropout_prob , (float , int )):
201
- assert 0. < token_dropout_prob < 1.
201
+ assert 0. <= token_dropout_prob < 1.
202
202
token_dropout_prob = float (token_dropout_prob )
203
203
self .calc_token_dropout = lambda height , width : token_dropout_prob
204
204
@@ -249,7 +249,7 @@ def forward(
249
249
group_images = False ,
250
250
group_max_seq_len = 2048
251
251
):
252
- p , c , device , has_token_dropout = self .patch_size , self .channels , self .device , exists (self .calc_token_dropout )
252
+ p , c , device , has_token_dropout = self .patch_size , self .channels , self .device , exists (self .calc_token_dropout ) and self . training
253
253
254
254
arange = partial (torch .arange , device = device )
255
255
pad_sequence = partial (orig_pad_sequence , batch_first = True )
@@ -260,7 +260,7 @@ def forward(
260
260
batched_images = group_images_by_max_seq_len (
261
261
batched_images ,
262
262
patch_size = self .patch_size ,
263
- calc_token_dropout = self .calc_token_dropout ,
263
+ calc_token_dropout = self .calc_token_dropout if self . training else None ,
264
264
max_seq_len = group_max_seq_len
265
265
)
266
266
@@ -314,8 +314,8 @@ def forward(
314
314
# derive key padding mask
315
315
316
316
lengths = torch .tensor ([seq .shape [- 2 ] for seq in batched_sequences ], device = device , dtype = torch .long )
317
- max_length = arange (lengths .amax ().item ())
318
- key_pad_mask = rearrange (lengths , 'b -> b 1 ' ) <= rearrange (max_length , 'n -> 1 n ' )
317
+ seq_arange = arange (lengths .amax ().item ())
318
+ key_pad_mask = rearrange (seq_arange , 'n -> 1 n ' ) < rearrange (lengths , 'b -> b 1 ' )
319
319
320
320
# derive attention mask, and combine with key padding mask from above
321
321
0 commit comments