Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

implement an argument to directly set ff_inner_dim #52

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

chris-ha458
Copy link

In NVIDIA nvidia/GPT-2B-001, a very PaLM like model is implemented.

However, instead of a ffn multiplier like ffn_mult the ffn_hidden_size (comparable to ffn_inner_dim of this codebase) is directly set as 5440.

This translates to a ffn_mult of 2.65625. However, trying this in this codebase does not work.

The error

TypeError: empty() received an invalid combination of arguments - got (tuple, dtype=NoneType, device=NoneType), but expected one of:
 * (tuple of ints size, *, tuple of names names, torch.memory_format memory_format, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)
 * (tuple of ints size, *, torch.memory_format memory_format, Tensor out, torch.dtype dtype, torch.layout layout, torch.device device, bool pin_memory, bool requires_grad)

So I implemented a way to directly set the ffn_inner_dim
please take a look!

@chris-ha458
Copy link
Author

Also, the formatter I use, changes the layout a lot so I had to manually modify the code.
What formatter does this repo use?

@GadiZimerman
Copy link

@CodiumAI-Agent /review

@CodiumAI-Agent
Copy link

PR Analysis

  • 🎯 Main theme: Adding an argument to directly set ff_inner_dim
  • 📌 Type of PR: Enhancement
  • 🧪 Relevant tests added: No
  • Focused PR: Yes, the PR is focused as it has a clear and coherent title and description, and all PR code diff changes are properly derived from the title and description.
  • 🔒 Security concerns: No, the changes made in this PR do not introduce any obvious security concerns.

PR Feedback

  • 💡 General PR suggestions: The PR is generally well-written and the changes are clear. However, it would be beneficial to include tests to ensure the new functionality works as expected. Additionally, it would be helpful to update the function's docstring to include the new parameter.

How to use

Tag me in a comment '@CodiumAI-Agent' and add one of the following commands:
/review - Request a review of the latest update to the PR.
/describe - Modify the PR title and description based on the contents of the PR.
/improve - Suggest improvements to the code in the PR. These will be provided as pull request comments, ready to commit.
/ask - Pose a question about the PR.

@@ -134,7 +135,8 @@ def __init__(
self.norm = LayerNorm(dim)

attn_inner_dim = dim_head * heads
ff_inner_dim = dim * ff_mult
# silently ignores ff_mult if ff_inner_dim is provided in the arguments
ff_inner_dim = dim * ff_mult if not ff_inner_dim else self.ff_inner_dim

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding a check to ensure that ff_inner_dim is a positive integer if it is not None. This will prevent potential errors or unexpected behavior. [important]

@@ -134,7 +135,8 @@ def __init__(
self.norm = LayerNorm(dim)

attn_inner_dim = dim_head * heads
ff_inner_dim = dim * ff_mult
# silently ignores ff_mult if ff_inner_dim is provided in the arguments
ff_inner_dim = dim * ff_mult if not ff_inner_dim else self.ff_inner_dim
self.fused_dims = (attn_inner_dim, dim_head, dim_head, (ff_inner_dim * 2))

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be beneficial to add a comment explaining why ff_inner_dim is multiplied by 2 in self.fused_dims. This would improve code readability and maintainability. [medium]

@@ -511,4 +515,4 @@ def forward(
return ret

logits = rearrange(logits, 'b n c -> b c n')
return F.cross_entropy(logits, labels, ignore_index = self.cross_entropy_ignore_index)

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Consider adding a newline at the end of the file. This is a common convention that helps with file processing in various systems. [medium]

@lucidrains lucidrains force-pushed the main branch 2 times, most recently from 89ab8ba to f721db2 Compare January 6, 2025 16:39
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants