Skip to content

Commit

Permalink
let example inference pass
Browse files Browse the repository at this point in the history
  • Loading branch information
Spycsh committed Sep 3, 2024
1 parent acd6710 commit ed19006
Showing 1 changed file with 4 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -301,8 +301,8 @@ def forward(
past_key_values_length, past_key_values_length + input_shape[1], device=inputs_embeds.device
)

if position_ids is None:
position_ids = (token_idx-1).unsqueeze(0)
if position_ids is None and token_idx:
position_ids = (token_idx - 1).unsqueeze(0)
# embed positions
if input_ids is not None:
positions = self.embed_positions(
Expand Down Expand Up @@ -620,12 +620,12 @@ def prepare_inputs_for_generation(
decoder_attention_mask = (decoder_input_ids != self.config.pad_token_id).long()

# prepare the decoder_position_ids
if token_idx <= forced_decoder_ids_length:
if token_idx and token_idx <= forced_decoder_ids_length:
decoder_position_ids = decoder_attention_mask.cumsum(-1) - 1
else:
decoder_position_ids = None

if token_idx >= forced_decoder_ids_length + 1:
if token_idx and token_idx >= forced_decoder_ids_length + 1:
decoder_input_ids = torch.index_select(decoder_input_ids, 1, token_idx - 1)

return {
Expand Down

0 comments on commit ed19006

Please sign in to comment.