@@ -175,7 +175,7 @@ def construct(
175175
176176 if attention_mask is not None :
177177 # Apply the attention mask is (precomputed for all layers in RobertaModel forward() function)
178- attention_scores = attention_scores + attention_mask .astype (self . dense_dtype )
178+ attention_scores = attention_scores + attention_mask .astype (attention_scores . dtype )
179179
180180 # Normalize the attention scores to probabilities.
181181 # Use the trick of the CogView paper to stablize training
@@ -227,11 +227,8 @@ def __init__(self, config):
227227 self .has_relative_attention_bias = config .has_relative_attention_bias
228228 self .has_spatial_attention_bias = config .has_spatial_attention_bias
229229 self .patch_size = config .patch_size
230- self .use_float16 = config .use_float16
231- self .dense_dtype = mstype .float32
232- if self .use_float16 is True :
233- self .dense_dtype = mstype .float16
234- self .min = finfo (self .dense_dtype )
230+ self .float32_min = finfo (mstype .float32 )
231+ self .float16_min = finfo (mstype .float16 )
235232 self .out_channels = 1
236233 self .use_visual_backbone = True
237234
@@ -342,7 +339,13 @@ def get_extended_attention_mask(self, attention_mask: Tensor, input_shape, dtype
342339 # Since we are adding it to the raw scores before the softmax, this is
343340 # effectively the same as removing these entirely. # fp16 compatibility
344341 extended_attention_mask = extended_attention_mask .astype (dtype )
345- extended_attention_mask = (1.0 - extended_attention_mask ) * self .min
342+
343+ if dtype == mstype .float32 :
344+ minimum = self .float32_min
345+ elif dtype == mstype .float16 :
346+ minimum = self .float16_min
347+
348+ extended_attention_mask = (1.0 - extended_attention_mask ) * minimum
346349 return extended_attention_mask
347350
348351 def get_head_mask (self , head_mask , num_hidden_layers : int , is_attention_chunked : bool = False ):
@@ -518,7 +521,7 @@ def construct(
518521
519522
520523@register_backbone
521- def layoutlmv3 (use_float16 : bool = True , ** kwargs ):
522- pretrained_config = LayoutLMv3PretrainedConfig (use_float16 )
524+ def layoutlmv3 (** kwargs ):
525+ pretrained_config = LayoutLMv3PretrainedConfig ()
523526 model = LayoutLMv3Model (pretrained_config )
524527 return model
0 commit comments