-
Notifications
You must be signed in to change notification settings - Fork 27
1. Bug fix. 2. add fast long retention implement #25
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
…2. Add fix length seq arguement when the inputs is (addtional_token, pask_kv) 3. add fast retention implement when the sequence length >> D**2
@@ -471,7 +687,6 @@ def forward(self, x): | |||
x = self.dropout_module(x) | |||
return x | |||
|
|||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 empty lines between classes
@@ -508,7 +723,6 @@ def forward(self, x): | |||
x = self.dropout_module(x) | |||
return x | |||
|
|||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 empty lines between classes
@@ -444,13 +666,7 @@ def __init__( | |||
self.dropout_module = torch.nn.Dropout(dropout) | |||
self.fc1 = nn.Linear(self.embed_dim, ffn_dim) | |||
self.fc2 = nn.Linear(ffn_dim, self.embed_dim) | |||
if subln: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I would like to keep the use_rms_norm. Also, I would prefer if-else instead of tertiary here. If you want tertiary, could you make it sth like:
norm_class = RMSNorm if use_rms_norm else LayerNorm
self.fnn_layernorm = norm_class(ffn_dim, eps=layernorm_eps) if subln else None
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The embed_dim
should be replace by ffn_dim
, I think
if subln:
if use_rms_norm:
self.ffn_layernorm = RMSNorm(self.embed_dim, eps=layernorm_eps)
else:
self.ffn_layernorm = LayerNorm(self.embed_dim, eps=layernorm_eps)
else:
self.ffn_layernorm = None
to
if subln:
if use_rms_norm:
self.ffn_layernorm = RMSNorm(ffn_dim, eps=layernorm_eps)
else:
self.ffn_layernorm = LayerNorm(ffn_dim, eps=layernorm_eps)
else:
self.ffn_layernorm = None
# multi-head | ||
q, k, v = split_heads((q, k, v), B, T, self.num_heads) | ||
k *= self.scaling # for scaled dot product |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
what's the reasoning for this change?
retnet/modeling_retnet.py
Outdated
@@ -315,7 +540,7 @@ def chunkwise_retention(self, q, k, v, decay_mask): | |||
past_key_value: | |||
- "prev_key_value" # bsz * num_head * v_dim * qk_dim | |||
- "scale" # (1 or bsz) * num_head * 1 * 1 | |||
decay_mask, # 1 * num_head * chunk_size * chunk_size | |||
decay_mask, # 1 * num_head * chunk_size * chunk_size |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
let's keep the spaces consistent.
class RetNetDecoderLayer(nn.Module): | ||
|
||
def __init__(self, config: RetNetConfig, depth: int, tensor_parallel: bool = False): | ||
super().__init__() | ||
self.config = config | ||
self.embed_dim = config.decoder_embed_dim | ||
self.dropout_module = torch.nn.Dropout(config.dropout) | ||
self.drop_path = DropPath(np.linspace(0, config.drop_path_rate, config.decoder_layers)[depth]) if config.drop_path_rate > 0 else None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I prefer previous code. This one-liner is too long and breaks the 100 character limit.
@@ -644,17 +847,17 @@ def _init_weights(self, module): | |||
Following original retnet, weights are already initialized in their own | |||
ways within their own init. | |||
""" | |||
pass | |||
#pass |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Reason for adding this again?
@@ -841,13 +1046,15 @@ def forward( | |||
hidden_states = F.pad(hidden_states, (0, 0, 0, padding_len)) | |||
else: | |||
slen = seq_length | |||
if fixed_seq_len:slen=fixed_seq_len |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
no one-liner for if statement.
# relative position | ||
if retention_rel_pos is None: | ||
retention_rel_pos = self.retnet_rel_pos(slen, | ||
forward_impl=forward_impl, | ||
recurrent_chunk_size=recurrent_chunk_size, | ||
retention_mask=retention_mask, | ||
get_decay_scale=not self.training) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why do we want decay scale during training?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Below is an example for one parallel output
and one recurrent output
with torch.inference_mode(): #<--almost equal to `torch.no_grad()`
model.eval() # <-- this disable dropout and batchnorm or other layers that behave differently during inference
out = model(old_inputs,
forward_impl='parallel' #<-- this line indicates parallel mode
use_cached=True,#<-- must have use_cached = True
**args,)
past_kv = out.past_key_values
model.train() # if want train later token, one need reactivate it here.
out = model(new_inputs,
forward_impl='recurrent' #<-- this line indicates recurrent mode
use_cached=True,#<-- must have use_cached = True for further recurrent mode
past_key_values=past_kv, # this line must
**args
)
If we don't have the model.eval()
the recurrent mode fail to generate by take the scale=None
.
However, the model.eval()
will change the behavior of some layer.
There is no other important reason here, just for me convenience.
Basically, the goal is to generate a cache
first and reuse it many times
cache --> task_1
cache --> task_2
@@ -1300,4 +1507,4 @@ def forward( | |||
past_key_values=outputs.past_key_values, | |||
hidden_states=outputs.hidden_states, | |||
attentions=outputs.attentions, | |||
) | |||
) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The file should ideally have a trailing whiteline (Some PEP standard)
Other than some formatting and refactoring issues, I love the fast-retention implementation! I was hoping to get into that. Thanks for your work! |
Will this be merged? |
There are some code styling issues and some things I don't understand fully. I think it's great to have its own branch for now. |
Cached the fixed retnet_rel_pos ( thus does not need generate runtimely)
add fast retention implement when the sequence length >> D**2.
See
https://github.com/veya2ztn/fast_retention
5.1 I set
use_glu
defaut to false, thus consistancy to old code.5.2 The layer norm setting in FFN seem wrong, the
self.embed_dim should
beffn_dim
Anyway, I roll back to
self.ffn_layernorm = LayerNorm(ffn_dim, eps=layernorm_eps) if subln else None