Skip to content

Commit 14cad7a

Browse files
author
Eron Gjoni
committed
Added mistral support.
Made Dgenerate somewhate better conform to the pecularities of model.generate. Fixed a bug that was handicapping A_dose_theta in llama2 models moved away from the legacy cache format
1 parent d1ebe0a commit 14cad7a

File tree

12 files changed

+418
-100
lines changed

12 files changed

+418
-100
lines changed

base_ref.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,13 @@
11
#This file is just a sober baseline
2-
32
import time
43
import bitsandbytes
54
import sys
65
import torch
76
from transformers import AutoTokenizer, TextStreamer, GenerationConfig
87
from transformers import AutoModelForCausalLM
98

10-
model_id = "NousResearch/Llama-2-7b-chat-hf"
9+
model_id = "cognitivecomputations/dolphin-2.2.1-mistral-7b"
1110
tokenizer = AutoTokenizer.from_pretrained(model_id)
12-
tokenizer.pad_token_id = tokenizer.eos_token_id
1311
model = AutoModelForCausalLM.from_pretrained(
1412
model_id,
1513
device_map="auto",
@@ -39,14 +37,13 @@
3937
input_ids=tokenized_start.to('cuda'),
4038
generation_config=GenerationConfig(
4139
use_cache=True,
42-
min_new_tokens=20,
40+
min_new_tokens=2,
4341
max_new_tokens=500,
4442
temperature=1,
4543
do_sample=False,
46-
pad_token_id=tokenizer.pad_token_id,
4744
eos_token_id=tokenizer.eos_token_id,
4845
return_dict_in_generate=True,
49-
output_hidden_states=True,
46+
output_hidden_states=False,
5047
output_scores = True
5148
),
5249
streamer=streamer,

drugs/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,10 @@
1010
AutoModelForTokenClassification,
1111
LlamaForCausalLM,
1212
LlamaForSequenceClassification,
13-
LlamaModel
13+
LlamaModel,
14+
MistralForCausalLM,
15+
MistralForSequenceClassification,
16+
MistralModel,
1417
)
1518

1619
"""
@@ -31,9 +34,6 @@
3134
GPTNeoXForTokenClassification,
3235
GPTNeoXModel,
3336
GPTNeoXPreTrainedModel,
34-
MistralForCausalLM,
35-
MistralForSequenceClassification,
36-
MistralModel,
3737
MptForCausalLM,
3838
MptForQuestionAnswering,
3939
MptForSequenceClassification,

drugs/dgenerate.py

Lines changed: 160 additions & 44 deletions
Large diffs are not rendered by default.

drugs/inject_mixin.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
#"mpt": "MptModel",
1818
#"gpt_neox": "GPTNeoXModel",
1919
#"gptj": "GPTJModel",
20-
#"mistral": "MistralModel",
20+
"mistral": "MistralModel",
2121
#"qwen": "QWenModel",
2222
#"stablelm_epoch": "StableLMEpochModel"
2323
}
@@ -29,7 +29,7 @@
2929
#"mpt": "MptAttention",
3030
#"gpt_neox": "GPTNeoXAttention",
3131
#"gptj": "GPTJAttention",
32-
#"mistral": "MistralAttention",
32+
"mistral": "MistralAttention",
3333
#"qwen": "QWenAttention",
3434
#"stablelm_epoch": "Attention",
3535
}
@@ -56,7 +56,7 @@ def _inject_drugged_attention(cls, model: PreTrainedModel, **kwargs) -> Optional
5656
#mpt_drugged_attention_forward,
5757
#gptj_drugged_attention_forward,
5858
llama_drugged_attention_forward,
59-
#mistral_drugged_attention_forward,
59+
mistral_drugged_attention_forward,
6060
#qwen_drugged_attention_forward,
6161
#stablelm_epoch_drugged_attention_forward,
6262
)
@@ -67,14 +67,14 @@ def _inject_drugged_attention(cls, model: PreTrainedModel, **kwargs) -> Optional
6767
#"mpt": None,
6868
#"gpt_neox": gpt_neox_drugged_attention_forward,
6969
#"gptj": gptj_drugged_attention_forward,
70-
#"mistral": mistral_drugged_attention_forward,
70+
"mistral": mistral_drugged_attention_forward,
7171
#"qwen": qwen_drugged_attention_forward,
7272
#"stablelm_epoch": stablelm_epoch_drugged_attention_forward,
7373
}
7474

7575

7676
# Not all models require updated attention forwards
77-
if ATTENTION_FORWARD_MAPPING[model_type] is None:
77+
if model_type not in ATTENTION_FORWARD_MAPPING or ATTENTION_FORWARD_MAPPING[model_type] is None:
7878
return
7979

8080
#TODO: support flash attention 2 and sdpa

drugs/models/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,13 +37,13 @@
3737
LlamaModel,
3838
llama_drugged_attention_forward
3939
)
40-
"""from .mistral import (
40+
from .mistral import (
4141
MistralForCausalLM,
4242
MistralForSequenceClassification,
4343
MistralModel,
4444
mistral_drugged_attention_forward,
4545
)
46-
from .mpt import (
46+
"""from .mpt import (
4747
MptAttention,
4848
MptForCausalLM,
4949
MptForQuestionAnswering,

drugs/models/llama/drugged.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,9 @@ def llama_drugged_attention_forward(
5353

5454
"""applying noise before rotatry embeddings because that just feels right (expecially with RoPE)
5555
and not touching the attention sink because that just feels wrong"""
56-
if self.quelude_theta > 0:
57-
query_states[:,:,6:,:] = get_perturbed_vectors(query_states[:,:,6:,:], self.quelude_theta)
56+
sink_protect = (position_ids < 6).sum().item()
57+
if self.quaalude_theta > 0:
58+
query_states[:,:,sink_protect:, :] = get_perturbed_vectors(query_states[:,:,sink_protect:,:], self.quaalude_theta)
5859

5960
kv_seq_len = key_states.shape[-2]
6061
if past_key_value is not None:
@@ -77,6 +78,7 @@ def llama_drugged_attention_forward(
7778
dkey_states = key_states
7879
dvalue_states = value_states
7980

81+
"""sliced at 6th sequence vector to not touch any attention sinks. bad juju"""
8082
if self.ketamine_theta > 0:
8183
dkey_states = torch.clone(key_states)
8284
dkey_states[:,:,6:,:] = get_perturbed_vectors(key_states[:,:,6:,:], self.ketamine_theta)
@@ -105,17 +107,16 @@ def llama_drugged_attention_forward(
105107
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
106108
attn_output = torch.matmul(attn_weights, dvalue_states)
107109

110+
if self.adderall_theta > 0:
111+
attn_output[:,:,sink_protect:, :] = get_perturbed_vectors(attn_output[:,:,sink_protect:, :], self.adderall_theta)
112+
108113
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
109114
raise ValueError(
110115
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
111116
f" {attn_output.size()}"
112117
)
113118

114119
attn_output = attn_output.transpose(1, 2).contiguous()
115-
116-
if self.adderall_theta > 0:
117-
attn_output[:,:,6:,:] = get_perturbed_vectors(attn_output[:,:,6:,:], self.adderall_theta)
118-
119120
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
120121

121122
if self.config.pretraining_tp > 1:

drugs/models/llama/modeling_llama.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,6 @@
1010
import torch
1111
from typing import Optional, Tuple, List, Union
1212

13-
#TODO: self.model._modules['layers']._modules exposes decoder layers, maybe better way to access layer idx?
14-
1513
class LlamaPreTrainedModel(InjectDrugsMixin, TLlamaPreTrainedModel):
1614
pass
1715

drugs/models/mistral/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
from .modeling_mistral import (
2+
MistralForCausalLM,
3+
MistralForSequenceClassification,
4+
MistralModel,
5+
)
6+
from .drugged import mistral_drugged_attention_forward

drugs/models/mistral/drugged.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
import math
2+
from typing import Optional, Tuple
3+
import warnings
4+
5+
import torch
6+
import torch.utils.checkpoint
7+
from torch import nn
8+
from transformers.models.mistral.modeling_mistral import repeat_kv, rotate_half, apply_rotary_pos_emb
9+
from drugs.generation.utils import get_perturbed_vectors
10+
from transformers.cache_utils import Cache
11+
12+
__all__ = ["mistral_drugged_attention_forward"]
13+
14+
15+
def mistral_drugged_attention_forward(
16+
self,
17+
hidden_states: torch.Tensor,
18+
attention_mask: Optional[torch.Tensor] = None,
19+
position_ids: Optional[torch.LongTensor] = None,
20+
past_key_value: Optional[Cache] = None,
21+
output_attentions: bool = False,
22+
use_cache: bool = False,
23+
padding_mask: Optional[torch.Tensor] = None,
24+
**kwargs,
25+
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
26+
if "padding_mask" in kwargs:
27+
warnings.warn(
28+
"Passing `padding_mask` is deprecated and will be removed in v4.37. Please make sure use `attention_mask` instead.`"
29+
)
30+
bsz, q_len, _ = hidden_states.size()
31+
32+
query_states = self.q_proj(hidden_states)
33+
key_states = self.k_proj(hidden_states)
34+
value_states = self.v_proj(hidden_states)
35+
36+
query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2)
37+
key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
38+
value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2)
39+
40+
sink_protect = (position_ids < 6).sum().item()
41+
if self.quaalude_theta > 0:
42+
query_states[:,:,sink_protect:, :] = get_perturbed_vectors(query_states[:,:,sink_protect:,:], self.quaalude_theta)
43+
44+
kv_seq_len = key_states.shape[-2]
45+
if past_key_value is not None:
46+
if self.layer_idx is None:
47+
raise ValueError(
48+
f"The cache structure has changed since version v4.36. If you are using {self.__class__.__name__} "
49+
"for auto-regressive decoding with k/v caching, please make sure to initialize the attention class "
50+
"with a layer index."
51+
)
52+
kv_seq_len += past_key_value.get_usable_length(kv_seq_len, self.layer_idx)
53+
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
54+
query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin, position_ids)
55+
56+
if past_key_value is not None:
57+
cache_kwargs = {"sin": sin, "cos": cos} # Specific to RoPE models
58+
key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs)
59+
60+
# repeat k/v heads if n_kv_heads < n_heads
61+
key_states = repeat_kv(key_states, self.num_key_value_groups)
62+
value_states = repeat_kv(value_states, self.num_key_value_groups)
63+
dkey_states = key_states
64+
dvalue_states = value_states
65+
66+
"""sliced at 6th sequence vector to not touch any attention sinks. bad juju"""
67+
if self.ketamine_theta > 0:
68+
dkey_states = torch.clone(key_states)
69+
dkey_states[:,:,6:,:] = get_perturbed_vectors(key_states[:,:,6:,:], self.ketamine_theta)
70+
71+
if self.valium_theta > 0:
72+
dvalue_states = torch.clone(value_states)
73+
dvalue_states[:,:,6:,:] = get_perturbed_vectors(value_states[:,:,6:,:], self.valium_theta)
74+
75+
attn_weights = torch.matmul(query_states, dkey_states.transpose(2, 3)) / math.sqrt(self.head_dim)
76+
77+
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
78+
raise ValueError(
79+
f"Attention weights should be of size {(bsz, self.num_heads, q_len, kv_seq_len)}, but is"
80+
f" {attn_weights.size()}"
81+
)
82+
83+
if attention_mask is not None:
84+
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
85+
raise ValueError(
86+
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
87+
)
88+
89+
attn_weights = attn_weights + attention_mask
90+
91+
# upcast attention to fp32
92+
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
93+
attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training)
94+
attn_output = torch.matmul(attn_weights, value_states)
95+
96+
if self.adderall_theta > 0:
97+
attn_output[:,:,sink_protect:, :] = get_perturbed_vectors(attn_output[:,:,sink_protect:, :], self.adderall_theta)
98+
99+
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
100+
raise ValueError(
101+
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
102+
f" {attn_output.size()}"
103+
)
104+
105+
attn_output = attn_output.transpose(1, 2).contiguous()
106+
107+
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
108+
109+
attn_output = self.o_proj(attn_output)
110+
111+
if not output_attentions:
112+
attn_weights = None
113+
114+
return attn_output, attn_weights, past_key_value
Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
from transformers import MistralForCausalLM as TMistralForCausalLM
2+
from transformers import MistralForSequenceClassification as TMistralForSequenceClassification
3+
from transformers import MistralModel as TMistralModel
4+
from transformers import MistralPreTrainedModel as TMistralPreTrainedModel
5+
6+
from drugs.inject_mixin import InjectDrugsMixin
7+
8+
9+
class MistralPreTrainedModel(InjectDrugsMixin, TMistralPreTrainedModel):
10+
pass
11+
12+
13+
class MistralModel(MistralPreTrainedModel, TMistralModel):
14+
pass
15+
16+
17+
class MistralForCausalLM(MistralPreTrainedModel, TMistralForCausalLM):
18+
pass
19+
20+
21+
class MistralForSequenceClassification(MistralPreTrainedModel, TMistralForSequenceClassification):
22+
pass

0 commit comments

Comments
 (0)