Skip to content

1. Bug fix. 2. add fast long retention implement #25

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from

Conversation

veya2ztn
Copy link

  1. Fix bug for the mode with inputs_embedding rather than inputs_ids
if inputs_embeds is None:
    inputs_embeds = self.forward_embedding(input_ids, forward_impl, inputs_embeds,past_key_values)
else:
    if forward_impl == 'recurrent':
        inputs_embeds = inputs_embeds[:, -1:]
  1. Add fix length seq arguement when the inputs is (addtional_token, pask_kv)
if fixed_seq_len:slen=fixed_seq_len
  1. Cached the fixed retnet_rel_pos ( thus does not need generate runtimely)

  2. add fast retention implement when the sequence length >> D**2.
    See https://github.com/veya2ztn/fast_retention

5.1 I set use_glu defaut to false, thus consistancy to old code.
5.2 The layer norm setting in FFN seem wrong, the self.embed_dim should be ffn_dim

if subln:
    if use_rms_norm:
        self.ffn_layernorm = RMSNorm(self.embed_dim, eps=layernorm_eps)
    else:
        self.ffn_layernorm = LayerNorm(self.embed_dim, eps=layernorm_eps)
else:
    self.ffn_layernorm = None

Anyway, I roll back to self.ffn_layernorm = LayerNorm(ffn_dim, eps=layernorm_eps) if subln else None

…2. Add fix length seq arguement when the inputs is (addtional_token, pask_kv) 3. add fast retention implement when the sequence length >> D**2
@@ -471,7 +687,6 @@ def forward(self, x):
x = self.dropout_module(x)
return x


Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 empty lines between classes

@@ -508,7 +723,6 @@ def forward(self, x):
x = self.dropout_module(x)
return x


Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 empty lines between classes

@@ -444,13 +666,7 @@ def __init__(
self.dropout_module = torch.nn.Dropout(dropout)
self.fc1 = nn.Linear(self.embed_dim, ffn_dim)
self.fc2 = nn.Linear(ffn_dim, self.embed_dim)
if subln:
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would like to keep the use_rms_norm. Also, I would prefer if-else instead of tertiary here. If you want tertiary, could you make it sth like:

norm_class = RMSNorm if use_rms_norm else LayerNorm
self.fnn_layernorm = norm_class(ffn_dim, eps=layernorm_eps) if subln else None

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The embed_dim should be replace by ffn_dim, I think

if subln:
    if use_rms_norm:
        self.ffn_layernorm = RMSNorm(self.embed_dim, eps=layernorm_eps)
    else:
        self.ffn_layernorm = LayerNorm(self.embed_dim, eps=layernorm_eps)
else:
    self.ffn_layernorm = None

to

if subln:
    if use_rms_norm:
        self.ffn_layernorm = RMSNorm(ffn_dim, eps=layernorm_eps)
    else:
        self.ffn_layernorm = LayerNorm(ffn_dim, eps=layernorm_eps)
else:
    self.ffn_layernorm = None

# multi-head
q, k, v = split_heads((q, k, v), B, T, self.num_heads)
k *= self.scaling # for scaled dot product
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what's the reasoning for this change?

@@ -315,7 +540,7 @@ def chunkwise_retention(self, q, k, v, decay_mask):
past_key_value:
- "prev_key_value" # bsz * num_head * v_dim * qk_dim
- "scale" # (1 or bsz) * num_head * 1 * 1
decay_mask, # 1 * num_head * chunk_size * chunk_size
decay_mask, # 1 * num_head * chunk_size * chunk_size
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

let's keep the spaces consistent.

class RetNetDecoderLayer(nn.Module):

def __init__(self, config: RetNetConfig, depth: int, tensor_parallel: bool = False):
super().__init__()
self.config = config
self.embed_dim = config.decoder_embed_dim
self.dropout_module = torch.nn.Dropout(config.dropout)
self.drop_path = DropPath(np.linspace(0, config.drop_path_rate, config.decoder_layers)[depth]) if config.drop_path_rate > 0 else None
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I prefer previous code. This one-liner is too long and breaks the 100 character limit.

@@ -644,17 +847,17 @@ def _init_weights(self, module):
Following original retnet, weights are already initialized in their own
ways within their own init.
"""
pass
#pass
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reason for adding this again?

@@ -841,13 +1046,15 @@ def forward(
hidden_states = F.pad(hidden_states, (0, 0, 0, padding_len))
else:
slen = seq_length
if fixed_seq_len:slen=fixed_seq_len
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no one-liner for if statement.

# relative position
if retention_rel_pos is None:
retention_rel_pos = self.retnet_rel_pos(slen,
forward_impl=forward_impl,
recurrent_chunk_size=recurrent_chunk_size,
retention_mask=retention_mask,
get_decay_scale=not self.training)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we want decay scale during training?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Below is an example for one parallel output and one recurrent output

with torch.inference_mode(): #<--almost equal to `torch.no_grad()`
   model.eval() # <-- this disable dropout and batchnorm or other layers that behave differently during inference
   out = model(old_inputs,
      forward_impl='parallel' #<--  this line indicates parallel mode 
      use_cached=True,#<--  must have use_cached = True  
      **args,)
   past_kv = out.past_key_values
   model.train() # if want train later token, one need reactivate it here.
   out = model(new_inputs,
      forward_impl='recurrent' #<--  this line indicates recurrent mode 
      use_cached=True,#<--  must have use_cached = True for further recurrent mode
      past_key_values=past_kv, # this line must
      **args
   )

If we don't have the model.eval() the recurrent mode fail to generate by take the scale=None.
However, the model.eval() will change the behavior of some layer.

There is no other important reason here, just for me convenience.

Basically, the goal is to generate a cache first and reuse it many times
cache --> task_1
cache --> task_2

@@ -1300,4 +1507,4 @@ def forward(
past_key_values=outputs.past_key_values,
hidden_states=outputs.hidden_states,
attentions=outputs.attentions,
)
)
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The file should ideally have a trailing whiteline (Some PEP standard)

@syncdoth
Copy link
Owner

Other than some formatting and refactoring issues, I love the fast-retention implementation! I was hoping to get into that. Thanks for your work!

@pkpro
Copy link

pkpro commented Oct 31, 2023

Will this be merged?

@syncdoth
Copy link
Owner

syncdoth commented Nov 3, 2023

There are some code styling issues and some things I don't understand fully. I think it's great to have its own branch for now.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants