Skip to content

Commit

Permalink
format
Browse files Browse the repository at this point in the history
  • Loading branch information
xiangyuT committed Jul 4, 2024
1 parent cf1e9e2 commit fdc87df
Showing 1 changed file with 1 addition and 41 deletions.
42 changes: 1 addition & 41 deletions python/llm/src/ipex_llm/transformers/pipeline_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,13 +468,12 @@ def cat_kv_cache(self, model_type, kv_cache_1, kv_cache_2):
kv_cache_2.value_cache[layer_idx]], dim=0)

return kv_cache_1


def update_kv_cache(self, kv_cache, cur_id):
layer_start = self.model.layer_start
layer_end = self.model.layer_end
num_layers = self.model.num_layers

if self.model.config.model_type == "chatglm" and self.model.config.num_layers == 40:
# for glm-4-9b-chat
if self.past_key_values_dict.get(cur_id, None) is None:
Expand All @@ -492,7 +491,6 @@ def update_kv_cache(self, kv_cache, cur_id):
kv_cache = tuple((value_placeholder, value_placeholder)) + \
tuple(None for _ in range(layer_start)) + \
(kv_cache)[layer_start:]

# past_key_values_placeholder = tuple(
# (value_placeholder, value_placeholder) for _ in range(layer_start)
# ) + (kv_cache)[layer_start:]
Expand All @@ -502,7 +500,6 @@ def update_kv_cache(self, kv_cache, cur_id):

return kv_cache


@torch.no_grad()
def model_step(self, input, cur_batch):
if cur_batch is None or cur_batch.stopped or input is None:
Expand Down Expand Up @@ -554,24 +551,6 @@ def model_step(self, input, cur_batch):

if cur_batch.prefilled_index == cur_batch.batch_size:
tmp_past_key_values = self.update_kv_cache(tmp_past_key_values, cur_id)
# # TODO: remove reduntent code here
# if self.model.config.model_type == "chatglm" and self.model.config.num_layers == 40:
# # for glm-4-9b-chat
# if self.past_key_values_dict.get(cur_id, None) is None:
# value_placeholder = torch.empty_like((output.past_key_values)[-1][0])
# past_key_values_placeholder = tuple(
# (value_placeholder, value_placeholder) for _ in range(layer_start)
# ) + (output.past_key_values)[: layer_end - layer_start] + tuple(
# (value_placeholder, value_placeholder) for _ in range(layer_end, num_layers)
# )
# _past_key_values = past_key_values_placeholder
# else:
# _past_key_values = output.past_key_values
# elif self.model.config.model_type in ["baichuan", "chatglm"] and self.rank > 0:
# value_placeholder = torch.empty_like((tmp_past_key_values)[-1][0])
# tmp_past_key_values = tuple((value_placeholder, value_placeholder)) + \
# tuple(None for _ in range(layer_start)) + \
# (tmp_past_key_values)[layer_start:]

self.past_key_values_dict[cur_id] = tmp_past_key_values

Expand All @@ -585,25 +564,6 @@ def model_step(self, input, cur_batch):
_pre_output = torch.cat((_pre_output, tmp_output), dim=0)
self.partial_output_dict[cur_id] = _pre_output
else:
# if self.model.config.model_type == "chatglm" and self.model.config.num_layers == 40:
# # for glm-4-9b-chat
# if self.past_key_values_dict.get(cur_id, None) is None:
# value_placeholder = torch.empty_like((output.past_key_values)[-1][0])
# past_key_values_placeholder = tuple(
# (value_placeholder, value_placeholder) for _ in range(layer_start)
# ) + (output.past_key_values)[: layer_end - layer_start] + tuple(
# (value_placeholder, value_placeholder) for _ in range(layer_end, num_layers)
# )
# _past_key_values = past_key_values_placeholder
# else:
# _past_key_values = output.past_key_values
# elif self.model.config.model_type in ["baichuan", "chatglm"] and self.rank > 0:
# # for baichuan2 and chatglm3
# value_placeholder = torch.empty_like((output.past_key_values)[-1][0])
# past_key_values_placeholder = tuple(
# (value_placeholder, value_placeholder) for _ in range(layer_start)
# ) + (output.past_key_values)[layer_start:]
# _past_key_values = past_key_values_placeholder
_past_key_values = self.update_kv_cache(_past_key_values, cur_id)
self.past_key_values_dict[cur_id] = _past_key_values
torch.xpu.synchronize()
Expand Down

0 comments on commit fdc87df

Please sign in to comment.