From fdc87dfaba22c7b14177329257ff2899c5d820c5 Mon Sep 17 00:00:00 2001 From: xiangyuT Date: Thu, 4 Jul 2024 10:03:46 +0800 Subject: [PATCH] format --- .../transformers/pipeline_parallel.py | 42 +------------------ 1 file changed, 1 insertion(+), 41 deletions(-) diff --git a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py index 9b066235cd9..7ba4cea1ae4 100644 --- a/python/llm/src/ipex_llm/transformers/pipeline_parallel.py +++ b/python/llm/src/ipex_llm/transformers/pipeline_parallel.py @@ -468,13 +468,12 @@ def cat_kv_cache(self, model_type, kv_cache_1, kv_cache_2): kv_cache_2.value_cache[layer_idx]], dim=0) return kv_cache_1 - def update_kv_cache(self, kv_cache, cur_id): layer_start = self.model.layer_start layer_end = self.model.layer_end num_layers = self.model.num_layers - + if self.model.config.model_type == "chatglm" and self.model.config.num_layers == 40: # for glm-4-9b-chat if self.past_key_values_dict.get(cur_id, None) is None: @@ -492,7 +491,6 @@ def update_kv_cache(self, kv_cache, cur_id): kv_cache = tuple((value_placeholder, value_placeholder)) + \ tuple(None for _ in range(layer_start)) + \ (kv_cache)[layer_start:] - # past_key_values_placeholder = tuple( # (value_placeholder, value_placeholder) for _ in range(layer_start) # ) + (kv_cache)[layer_start:] @@ -502,7 +500,6 @@ def update_kv_cache(self, kv_cache, cur_id): return kv_cache - @torch.no_grad() def model_step(self, input, cur_batch): if cur_batch is None or cur_batch.stopped or input is None: @@ -554,24 +551,6 @@ def model_step(self, input, cur_batch): if cur_batch.prefilled_index == cur_batch.batch_size: tmp_past_key_values = self.update_kv_cache(tmp_past_key_values, cur_id) - # # TODO: remove reduntent code here - # if self.model.config.model_type == "chatglm" and self.model.config.num_layers == 40: - # # for glm-4-9b-chat - # if self.past_key_values_dict.get(cur_id, None) is None: - # value_placeholder = torch.empty_like((output.past_key_values)[-1][0]) - # past_key_values_placeholder = tuple( - # (value_placeholder, value_placeholder) for _ in range(layer_start) - # ) + (output.past_key_values)[: layer_end - layer_start] + tuple( - # (value_placeholder, value_placeholder) for _ in range(layer_end, num_layers) - # ) - # _past_key_values = past_key_values_placeholder - # else: - # _past_key_values = output.past_key_values - # elif self.model.config.model_type in ["baichuan", "chatglm"] and self.rank > 0: - # value_placeholder = torch.empty_like((tmp_past_key_values)[-1][0]) - # tmp_past_key_values = tuple((value_placeholder, value_placeholder)) + \ - # tuple(None for _ in range(layer_start)) + \ - # (tmp_past_key_values)[layer_start:] self.past_key_values_dict[cur_id] = tmp_past_key_values @@ -585,25 +564,6 @@ def model_step(self, input, cur_batch): _pre_output = torch.cat((_pre_output, tmp_output), dim=0) self.partial_output_dict[cur_id] = _pre_output else: - # if self.model.config.model_type == "chatglm" and self.model.config.num_layers == 40: - # # for glm-4-9b-chat - # if self.past_key_values_dict.get(cur_id, None) is None: - # value_placeholder = torch.empty_like((output.past_key_values)[-1][0]) - # past_key_values_placeholder = tuple( - # (value_placeholder, value_placeholder) for _ in range(layer_start) - # ) + (output.past_key_values)[: layer_end - layer_start] + tuple( - # (value_placeholder, value_placeholder) for _ in range(layer_end, num_layers) - # ) - # _past_key_values = past_key_values_placeholder - # else: - # _past_key_values = output.past_key_values - # elif self.model.config.model_type in ["baichuan", "chatglm"] and self.rank > 0: - # # for baichuan2 and chatglm3 - # value_placeholder = torch.empty_like((output.past_key_values)[-1][0]) - # past_key_values_placeholder = tuple( - # (value_placeholder, value_placeholder) for _ in range(layer_start) - # ) + (output.past_key_values)[layer_start:] - # _past_key_values = past_key_values_placeholder _past_key_values = self.update_kv_cache(_past_key_values, cur_id) self.past_key_values_dict[cur_id] = _past_key_values torch.xpu.synchronize()