From 297ac3e4c65f034de6ff3fa85008d871d6d786b2 Mon Sep 17 00:00:00 2001 From: Chris Ha Date: Mon, 26 Jun 2023 20:28:24 +0900 Subject: [PATCH] implement direct setting of ff_inner_dim --- palm_rlhf_pytorch/palm.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/palm_rlhf_pytorch/palm.py b/palm_rlhf_pytorch/palm.py index 3f3601e..3302d9c 100644 --- a/palm_rlhf_pytorch/palm.py +++ b/palm_rlhf_pytorch/palm.py @@ -124,6 +124,7 @@ def __init__( qk_rmsnorm = False, qk_scale = 8, ff_mult = 4, + ff_inner_dim = None, attn_dropout = 0., ff_dropout = 0., use_xpos = True, @@ -134,7 +135,8 @@ def __init__( self.norm = LayerNorm(dim) attn_inner_dim = dim_head * heads - ff_inner_dim = dim * ff_mult + # silently ignores ff_mult if ff_inner_dim is provided in the arguments + ff_inner_dim = dim * ff_mult if not ff_inner_dim else self.ff_inner_dim self.fused_dims = (attn_inner_dim, dim_head, dim_head, (ff_inner_dim * 2)) self.qk_rmsnorm = qk_rmsnorm @@ -270,6 +272,7 @@ def __init__( dim_head = 64, heads = 8, ff_mult = 4, + ff_inner_dim = None, attn_dropout = 0., ff_dropout = 0., qk_rmsnorm = False, @@ -297,6 +300,7 @@ def __init__( heads = heads, qk_rmsnorm = qk_rmsnorm, ff_mult = ff_mult, + ff_inner_dim = ff_inner_dim, attn_dropout = attn_dropout, ff_dropout = ff_dropout, xpos_scale_base = rotary_xpos_scale_base, @@ -511,4 +515,4 @@ def forward( return ret logits = rearrange(logits, 'b n c -> b c n') - return F.cross_entropy(logits, labels, ignore_index = self.cross_entropy_ignore_index) \ No newline at end of file + return F.cross_entropy(logits, labels, ignore_index = self.cross_entropy_ignore_index)