diff --git a/paddleformers/transformers/cache_utils.py b/paddleformers/transformers/cache_utils.py index fec0c1e6758..67e5daddbbd 100644 --- a/paddleformers/transformers/cache_utils.py +++ b/paddleformers/transformers/cache_utils.py @@ -94,12 +94,15 @@ class DynamicLayer(CacheLayerMixin): is_sliding = False - def lazy_initialization(self, key_states: paddle.Tensor): + def lazy_initialization(self, key_states: paddle.Tensor, value_states: paddle.Tensor): self.dtype, self.place = key_states.dtype, key_states.place - B, N, _, H = key_states.shape - initial_shape = [B, N, 0, H] - self.keys = paddle.empty(initial_shape, dtype=self.dtype, device=self.place) - self.values = paddle.empty(initial_shape, dtype=self.dtype, device=self.place) + B, N, _, H_k = key_states.shape + _, _, _, H_v = value_states.shape + initial_keys_shape = [B, N, 0, H_k] + initial_values_shape = [B, N, 0, H_v] + + self.keys = paddle.empty(initial_keys_shape, dtype=self.dtype, device=self.place) + self.values = paddle.empty(initial_values_shape, dtype=self.dtype, device=self.place) self.is_initialized = True def update( @@ -121,7 +124,7 @@ def update( """ # Lazy initialization if not self.is_initialized: - self.lazy_initialization(key_states) + self.lazy_initialization(key_states, value_states) # the shape of the key and value states is [B,N,S,H]. self.keys = paddle.concat([self.keys, key_states], axis=-2) self.values = paddle.concat([self.values, value_states], axis=-2) @@ -432,7 +435,7 @@ class DynamicCache(Cache): >>> model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen2-0.5B-Instruct") >>> tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen2-0.5B-Instruct") - >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pt") + >>> inputs = tokenizer(text="My name is Qwen2", return_tensors="pd") >>> # Prepare a cache class and pass it to model's forward >>> past_key_values = DynamicCache(config=model.config) @@ -521,8 +524,8 @@ def __init__(self, sliding_window: int): self.cumulative_length = 0 self._sliding_window_tensor = paddle.to_tensor(self.sliding_window, dtype=paddle.int64) - def lazy_initialization(self, key_states: paddle.Tensor) -> None: - super().lazy_initialization(key_states) + def lazy_initialization(self, key_states: paddle.Tensor, value_states: paddle.Tensor) -> None: + super().lazy_initialization(key_states, value_states) self._sliding_window_tensor = self._sliding_window_tensor.to(self.place) def update( @@ -544,7 +547,7 @@ def update( """ # Lazy initialization if not self.is_initialized: - self.lazy_initialization(key_states) + self.lazy_initialization(key_states, value_states) self.cumulative_length += key_states.shape[-2] diff --git a/paddleformers/transformers/configuration_utils.py b/paddleformers/transformers/configuration_utils.py index 4a658e9af97..abca14b2385 100644 --- a/paddleformers/transformers/configuration_utils.py +++ b/paddleformers/transformers/configuration_utils.py @@ -606,7 +606,8 @@ def __init__(self, **kwargs): self.output_hidden_states = kwargs.pop("output_hidden_states", False) self.output_attentions = kwargs.pop("output_attentions", False) self.dtype = kwargs.pop("dtype", None) - self.use_cache = kwargs.pop("use_cache", False) + default_cache = getattr(self, "use_cache", False) + self.use_cache = kwargs.pop("use_cache", default_cache) self.tie_word_embeddings = kwargs.pop("tie_word_embeddings", True) # for transformers fuse diff --git a/paddleformers/transformers/deepseek_v3/modeling.py b/paddleformers/transformers/deepseek_v3/modeling.py index 40cd34f355f..4c58c59290e 100644 --- a/paddleformers/transformers/deepseek_v3/modeling.py +++ b/paddleformers/transformers/deepseek_v3/modeling.py @@ -554,9 +554,10 @@ def forward(self, hidden_states): class DeepseekV3Attention(nn.Layer): """Multi-headed attention from 'Attention Is All You Need' paper""" - def __init__(self, config: DeepseekV3Config): + def __init__(self, config: DeepseekV3Config, layer_idx: int): super().__init__() self.config = config + self.layer_idx = layer_idx self.attention_dropout = config.attention_dropout self.hidden_size = config.hidden_size self.num_heads = config.num_attention_heads @@ -825,7 +826,7 @@ def __init__(self, config: DeepseekV3Config, layer_idx: int): self.sequence_parallel = config.sequence_parallel self.hidden_size = config.hidden_size - self.self_attn = DeepseekV3Attention(config=config) + self.self_attn = DeepseekV3Attention(config=config, layer_idx=layer_idx) try: moe_group = fleet.get_hybrid_communicate_group().get_expert_parallel_group()