File tree Expand file tree Collapse file tree 1 file changed +7
-5
lines changed Expand file tree Collapse file tree 1 file changed +7
-5
lines changed Original file line number Diff line number Diff line change @@ -77,10 +77,12 @@ def forward(self, input_ids):
77
77
inputs_embeds = self .word_embeddings (input_ids )
78
78
embeds_shape = list (inputs_embeds .size ())
79
79
80
- inputs_embeds = inputs_embeds .view ([embeds_shape [0 ], - 1 , embeds_shape [- 1 ]])
81
- inputs_embeds = inputs_embeds .permute (1 , 0 , 2 )
82
- inputs_embeds = self .transformer_encoder (inputs_embeds )
83
- inputs_embeds = inputs_embeds .permute (1 , 0 , 2 ).contiguous ()
84
- inputs_embeds = self .lin_proj (inputs_embeds .view (embeds_shape [:- 2 ] + [- 1 ]))
80
+ inputs_embeds = inputs_embeds .view ([- 1 ] + embeds_shape [- 2 :])
81
+ inputs_embeds = inputs_embeds .permute (1 , 0 , 2 )
82
+ inputs_embeds = self .transformer_encoder (inputs_embeds )
83
+ inputs_embeds = inputs_embeds .permute (1 , 0 , 2 )
84
+ inputs_embeds = inputs_embeds .contiguous ().view (embeds_shape [0 :2 ]+ [- 1 ])
85
+
86
+ inputs_embeds = self .lin_proj (inputs_embeds )
85
87
86
88
return inputs_embeds
You can’t perform that action at this time.
0 commit comments