Skip to content

Commit 600d4e9

Browse files
committed
minor fix to #14
1 parent 09b4273 commit 600d4e9

File tree

1 file changed

+7
-5
lines changed

1 file changed

+7
-5
lines changed

models/hierarchical.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -77,10 +77,12 @@ def forward(self, input_ids):
7777
inputs_embeds = self.word_embeddings(input_ids)
7878
embeds_shape = list(inputs_embeds.size())
7979

80-
inputs_embeds = inputs_embeds.view([embeds_shape[0], -1, embeds_shape[-1]])
81-
inputs_embeds = inputs_embeds.permute(1, 0, 2)
82-
inputs_embeds = self.transformer_encoder(inputs_embeds)
83-
inputs_embeds = inputs_embeds.permute(1, 0, 2).contiguous()
84-
inputs_embeds = self.lin_proj(inputs_embeds.view(embeds_shape[:-2] + [-1]))
80+
inputs_embeds = inputs_embeds.view([-1] + embeds_shape[-2:])
81+
inputs_embeds = inputs_embeds.permute(1, 0, 2)
82+
inputs_embeds = self.transformer_encoder(inputs_embeds)
83+
inputs_embeds = inputs_embeds.permute(1, 0, 2)
84+
inputs_embeds = inputs_embeds.contiguous().view(embeds_shape[0:2]+[-1])
85+
86+
inputs_embeds = self.lin_proj(inputs_embeds)
8587

8688
return inputs_embeds

0 commit comments

Comments
 (0)