Skip to content

Commit

Permalink
add batch based on xiangyu's commit
Browse files Browse the repository at this point in the history
  • Loading branch information
cyita committed Jan 31, 2024
1 parent c80d318 commit cddcc76
Showing 1 changed file with 35 additions and 14 deletions.
49 changes: 35 additions & 14 deletions python/llm/src/bigdl/llm/transformers/speculative.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def generate(
for var in ['max_new_tokens', 'max_step_draft', 'th_stop_draft', 'do_sample',
'top_k', 'top_p', 'temperature', 'hf_adjust',
'auto_th_stop_draft', 'auto_parameters', 'repetition_penalty',
'attention_mask']:
'attention_mask', 'pad_token_id']:
value = kwargs.pop(var, None)
if value is not None:
new_speculative_kwargs[var] = value
Expand Down Expand Up @@ -116,6 +116,12 @@ def clear_benchmarks(self):
self.n_matched = 0


def _update_attention_mask(attention_mask, len=1):
return torch.cat(
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], len))],
dim=-1)


def _prepare_past_key_values_storage_cpu(self, past_key_values,
max_new_tokens, _enable_ipex=False):
past_key_values_storage = []
Expand Down Expand Up @@ -297,7 +303,22 @@ def speculative_generate(self,
batch_size = inputs_tensor.shape[0]

# 4. Define other model kwargs
# Removed not used
model_kwargs["output_attentions"] = generation_config.output_attentions
model_kwargs["output_hidden_states"] = generation_config.output_hidden_states
# decoder-only models with inputs_embeds forwarding must use caching (otherwise we can't detect whether we are
# generating the first new token or not, and we only want to use the embeddings for the first new token)
if not self.config.is_encoder_decoder and model_input_name == "inputs_embeds":
model_kwargs["use_cache"] = True
else:
model_kwargs["use_cache"] = generation_config.use_cache

accepts_attention_mask = "attention_mask" in set(inspect.signature(self.forward).parameters.keys())
requires_attention_mask = "encoder_outputs" not in model_kwargs

if model_kwargs.get("attention_mask", None) is None and requires_attention_mask and accepts_attention_mask:
model_kwargs["attention_mask"] = self._prepare_attention_mask_for_generation(
inputs_tensor, generation_config.pad_token_id, generation_config.eos_token_id
)

# decoder-only models should use left-padding for generation
if not self.config.is_encoder_decoder:
Expand Down Expand Up @@ -368,6 +389,7 @@ def speculative_generate(self,
e2e_tic = 0.0

self.clear_benchmarks()
attention_mask = model_kwargs["attention_mask"]

# Example:
# Target model forward for the first token
Expand Down Expand Up @@ -398,9 +420,10 @@ def speculative_generate(self,
output_ids = sample(logits, do_sample=generation_config.do_sample,
top_k=generation_config.top_k, top_p=generation_config.top_p,
temperature=generation_config.temperature)
generate_ids[:, step] = output_ids
generate_ids[:, step] = output_ids.squeeze()
current_input_ids = output_ids
past_key_values = output['past_key_values']
attention_mask = _update_attention_mask(attention_mask, 1)
step += 1
if self.device.type == 'xpu':
torch.xpu.synchronize()
Expand All @@ -424,17 +447,16 @@ def speculative_generate(self,
original_draft_past_key_values = draft_past_key_values
else:
draft_past_key_values = past_key_values
draft_generate_ids[:, 0] = current_input_ids
draft_generate_ids[:, 0] = current_input_ids.squeeze()
tic = time.time()
# Draft model auto-regressively generate k tokens
# Early stop when prob less then th_stop_draft
for step_draft in range(max_step_draft):
if attention_mask is None:
draft_attention_mask = None
else:
appended_len = step_draft + step
ones_to_append = torch.ones(attention_mask.size(0), appended_len)
draft_attention_mask = torch.cat((attention_mask, ones_to_append), dim=1)
draft_attention_mask = _update_attention_mask(attention_mask,
step_draft + step - 1)
if self.config.model_type == "chatglm":
past_key_value_len = past_key_values[0][0].shape[0]
position_ids = torch.Tensor([[past_key_value_len + step_draft]]).long()
Expand Down Expand Up @@ -462,12 +484,12 @@ def speculative_generate(self,
top_k=generation_config.top_k,
top_p=generation_config.top_p,
temperature=generation_config.temperature)
draft_generate_ids[:, step_draft+1] = draft_output_ids
draft_generate_ids[:, step_draft+1] = draft_output_ids.squeeze()
draft_current_input_ids = draft_output_ids
draft_past_key_values = draft_output['past_key_values']
# check if draft prob is less then th_stop_draft
# Draft number + step >= max output token number
if draft_output_probs.item() < th_stop_draft or \
if min(draft_output_probs).item() < th_stop_draft or \
step + step_draft + 2 >= max_new_tokens:
break
if self.device.type == 'xpu':
Expand All @@ -486,9 +508,8 @@ def speculative_generate(self,
if attention_mask is None:
cur_attention_mask = None
else:
appended_len = drafted_input_ids.size(1) + step - 1
ones_to_append = torch.ones(attention_mask.size(0), appended_len)
cur_attention_mask = torch.cat((attention_mask, ones_to_append), dim=1)
cur_attention_mask = _update_attention_mask(attention_mask,
drafted_input_ids.size(1) + step - 2)
if _enable_ipex and hasattr(self, "trace_graph"):
if self.config.model_type == "baichuan":
output = self.trace_graph(input_ids=drafted_input_ids,
Expand Down Expand Up @@ -546,8 +567,8 @@ def speculative_generate(self,
# Drafts start from [1, k]
# Verified output start from [0, k - 1]
# including the one generated by the base model
max_matched = ((output_ids[:, :-1] != drafted_input_ids[:, 1:]).cumsum(-1) == 0)
max_matched = max_matched.sum(-1).item() + 1
max_matched = min(((output_ids[:, :-1] != drafted_input_ids[:, 1:]).cumsum(-1) == 0)
.sum(-1)).item() + 1
max_of_max_matched = output_ids.size(1)
# Accept number is max_matched, min is 1
self.accept_num.append(max_matched)
Expand Down

0 comments on commit cddcc76

Please sign in to comment.