Skip to content

Commit

Permalink
seeing immediate improvements when investigating a new iclr 2025 paper
Browse files Browse the repository at this point in the history
  • Loading branch information
lucidrains committed Oct 26, 2024
1 parent d1ca975 commit 0f18c78
Show file tree
Hide file tree
Showing 5 changed files with 55 additions and 3 deletions.
9 changes: 9 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2310,4 +2310,13 @@ ids_out, num_out, is_number_mask = model.generate(start_ids, start_nums, 17)
}
```

```bibtex
@inproceedings{Zhou2024ValueRL,
title = {Value Residual Learning For Alleviating Attention Concentration In Transformers},
author = {Zhanchao Zhou and Tianyi Wu and Zhiyun Jiang and Zhenzhong Lan},
year = {2024},
url = {https://api.semanticscholar.org/CorpusID:273532030}
}
```

*solve intelligence... then use that to solve everything else.* - Demis Hassabis
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
setup(
name = 'x-transformers',
packages = find_packages(exclude=['examples']),
version = '1.40.6',
version = '1.40.7',
license='MIT',
description = 'X-Transformers - Pytorch',
author = 'Phil Wang',
Expand Down
19 changes: 18 additions & 1 deletion tests/test_x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,6 +313,23 @@ def test_reinject_input():
)
)

x = torch.randint(0, 256, (1, 12))
x = torch.randint(0, 256, (1, 1024))

model(x) # (1, 1024, 20000)

def test_value_residual():

model = TransformerWrapper(
num_tokens = 20000,
max_seq_len = 1024,
attn_layers = Decoder(
dim = 128,
depth = 6,
heads = 8,
add_value_residual = True,
)
)

x = torch.randint(0, 20000, (2, 1024))

model(x)
1 change: 1 addition & 0 deletions x_transformers/attend.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class Intermediates:
qk_similarities: Tensor | None = None
pre_softmax_attn: Tensor | None = None
post_softmax_attn: Tensor | None = None
values: Tensor | None = None
cached_kv: Tuple[Tensor, Tensor] | None = None
layer_type: str | None = None

Expand Down
27 changes: 26 additions & 1 deletion x_transformers/x_transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1114,6 +1114,7 @@ def forward(
mem_mask = None,
return_intermediates = False,
cache: Intermediates | None = None,
value_residual = None
):
b, n, h, kv_h, head_scale, num_mem_kv, device, has_context = x.shape[0], x.shape[1], self.heads, self.kv_heads, self.head_scale, self.num_mem_kv, x.device, exists(context)

Expand Down Expand Up @@ -1243,6 +1244,12 @@ def forward(
attn_bias = rel_pos(i, j)
attn_bias = pad_at_dim(attn_bias, (num_mem_kv, 0), value = 0.) # handle memory key / values

# previous values passed in
# https://arxiv.org/abs/2410.17897v1

if exists(value_residual):
v = v + value_residual

# attention is all we need

out, intermediates = self.attend(
Expand All @@ -1252,6 +1259,10 @@ def forward(
prev_attn = prev_attn
)

# store the values for resformer from Zhou et al. https://arxiv.org/abs/2410.17897v1

intermediates.values = v

# https://arxiv.org/abs/2208.06061 proposes to add a residual for better gradients

if exists(r):
Expand Down Expand Up @@ -1354,6 +1365,7 @@ def __init__(
layerscale_init_value = 0.,
unet_skips = False,
reinject_input = False, # seen first in DEQ paper https://arxiv.org/abs/1909.01377, but later used in a number of papers trying to achieve depthwise generalization https://arxiv.org/abs/2410.03020v1
add_value_residual = False, # resformer from Zhou et al - https://arxiv.org/abs/2410.17897v1 | TODO: also add NeuTRENO from Nguyen et al. https://arxiv.org/abs/2312.00751
**kwargs
):
super().__init__()
Expand Down Expand Up @@ -1588,6 +1600,10 @@ def __init__(
self.reinject_input = reinject_input
self.reinject_input_proj = nn.Linear(dim, dim, bias = False) if reinject_input else None

# add the value from the first self attention block to all latter projected self attention values as a residual

self.add_value_residual = add_value_residual

# iterate and construct layers

for ind, (layer_type, layer_shift_tokens) in enumerate(zip(self.layer_types, shift_tokens)):
Expand Down Expand Up @@ -1787,6 +1803,8 @@ def forward(

skip_hiddens = []

first_self_attn_inter = None

# go through the attention and feedforward layers

for ind, (layer_type, skip_combine, (norm, block, residual_fn), layer_dropout) in enumerate(zip(*layer_variables)):
Expand Down Expand Up @@ -1838,13 +1856,20 @@ def forward(

block = partial(block, **block_forward_kwargs)

maybe_value_residual = None
if self.add_value_residual and exists(first_self_attn_inter):
maybe_value_residual = first_self_attn_inter.values

if layer_type == 'a':
out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, return_intermediates = True)
out, inter = block(x, mask = mask, context_mask = self_attn_kv_mask, attn_mask = attn_mask, rel_pos = self.rel_pos, rotary_pos_emb = rotary_pos_emb, prev_attn = prev_attn, cache = next(iter_attn_cache, None), mem = layer_mem, mem_mask = layer_mem_mask, attn_bias = attn_bias, value_residual = maybe_value_residual, return_intermediates = True)
elif layer_type == 'c':
out, inter = block(x, context = context, mask = mask, context_mask = context_mask, prev_attn = prev_cross_attn, cache = next(iter_attn_cache, None), return_intermediates = True)
elif layer_type == 'f':
out = block(x)

if not exists(first_self_attn_inter) and layer_type == 'a':
first_self_attn_inter = inter

if self.resi_dual:
outer_residual = outer_residual + out * self.resi_dual_scale

Expand Down

0 comments on commit 0f18c78

Please sign in to comment.