Skip to content

Commit

Permalink
fix according to reviewer comments
Browse files Browse the repository at this point in the history
  • Loading branch information
lvhan028 committed Sep 24, 2024
1 parent d5c5f39 commit db15e41
Show file tree
Hide file tree
Showing 7 changed files with 99 additions and 210 deletions.
8 changes: 2 additions & 6 deletions docs/en/advance/long_context.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,9 @@ This test takes approximately 364 seconds per round when conducted on A100-80G G
The following codes demonstrate how to use LMDeploy to calculate perplexity.

```python
from transformers import AutoTokenizer
from lmdeploy import TurbomindEngineConfig, pipeline
import numpy as np

# load model and tokenizer
# build pipeline
model_repoid_or_path = 'internlm/internlm2_5-7b-chat-1m'
backend_config = TurbomindEngineConfig(
rope_scaling_factor=2.5,
Expand All @@ -109,11 +107,9 @@ backend_config = TurbomindEngineConfig(
cache_max_entry_count=0.7,
tp=4)
pipe = pipeline(model_repoid_or_path, backend_config=backend_config)
tokenizer = AutoTokenizer.from_pretrained(model_repoid_or_path, trust_remote_code=True)

# get perplexity
text = 'Use a long prompt to replace this sentence'
input_ids = tokenizer.encode(text)
ppl = pipe.get_ppl(input_ids)[0]
ppl = pipe.get_ppl(text)
print(ppl)
```
16 changes: 8 additions & 8 deletions docs/en/llm/pipeline.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,21 +119,21 @@ for item in pipe.stream_infer(prompts, gen_config=gen_config):
- **An example to cauculate logits & ppl:**

```python
from transformers import AutoTokenizer
from lmdeploy import pipeline

model_repoid_or_path='internlm/internlm2_5-7b-chat'
pipe = pipeline(model_repoid_or_path)
tokenizer = AutoTokenizer.from_pretrained(model_repoid_or_path, trust_remote_code=True)

# logits
messages = [
{"role": "user", "content": "Hello, how are you?"},
prompts = [
"Hello, I am an AI assistant named InternLM. I am developed by Shanghai AI Laboratory",
"How to use LMDeploy to deploy a LLM model?"
]
input_ids = tokenizer.apply_chat_template(messages)
logits = pipe.get_logits(input_ids)

# logits
logits = pipe.get_logits(prompts)

# ppl
ppl = pipe.get_ppl(input_ids)
ppl = pipe.get_ppl(prompts)
```

- **Below is an example for pytorch backend. Please install triton first.**
Expand Down
8 changes: 2 additions & 6 deletions docs/zh_cn/advance/long_context.md
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,9 @@ passkey_retrieval(session_len, 5)
下面展示使用 LMDeploy 计算困惑度的用法

```python
from transformers import AutoTokenizer
from lmdeploy import TurbomindEngineConfig, pipeline
import numpy as np

# load model and tokenizer
# build pipeline
model_repoid_or_path = 'internlm/internlm2_5-7b-chat-1m'
backend_config = TurbomindEngineConfig(
rope_scaling_factor=2.5,
Expand All @@ -109,11 +107,9 @@ backend_config = TurbomindEngineConfig(
cache_max_entry_count=0.7,
tp=4)
pipe = pipeline(model_repoid_or_path, backend_config=backend_config)
tokenizer = AutoTokenizer.from_pretrained(model_repoid_or_path, trust_remote_code=True)

# get perplexity
text = 'Use a long prompt to replace this sentence'
input_ids = tokenizer.encode(text)
loss = pipe.get_ppl(input_ids)[0]
loss = pipe.get_ppl(text)
print(ppl)
```
16 changes: 8 additions & 8 deletions docs/zh_cn/llm/pipeline.md
Original file line number Diff line number Diff line change
Expand Up @@ -119,21 +119,21 @@ for item in pipe.stream_infer(prompts, gen_config=gen_config):
- **计算 logits & ppl:**

```python
from transformers import AutoTokenizer
from lmdeploy import pipeline

model_repoid_or_path='internlm/internlm2_5-7b-chat'
pipe = pipeline(model_repoid_or_path)
tokenizer = AutoTokenizer.from_pretrained(model_repoid_or_path, trust_remote_code=True)

# logits
messages = [
{"role": "user", "content": "Hello, how are you?"},
prompts = [
"Hello, I am an AI assistant named InternLM. I am developed by Shanghai AI Laboratory",
"How to use LMDeploy to deploy a LLM model?"
]
input_ids = tokenizer.apply_chat_template(messages)
logits = pipe.get_logits(input_ids)

# logits
logits = pipe.get_logits(prompts)

# ppl
ppl = pipe.get_ppl(input_ids)
ppl = pipe.get_ppl(prompts)
```

- **使用 pytorch 后端**
Expand Down
99 changes: 1 addition & 98 deletions lmdeploy/pytorch/engine/engine_instance.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List, Union

import torch
from torch.nn.utils.rnn import pad_sequence
from typing import List

from lmdeploy.messages import EngineOutput, GenerationConfig
from lmdeploy.utils import get_logger
Expand Down Expand Up @@ -586,97 +583,3 @@ def __add_messages(session_ids, input_ids, adapter_names,
self.end(sid)

return ret

def get_ppl(self, input_ids: Union[List[int], List[List[int]]]):
"""Get perplexity scores given a list of input tokens.
Args:
input_ids (Union[List[int], List[List[int]]]): the batch of
input token ids
"""
assert isinstance(input_ids, List) and len(input_ids) > 0
if isinstance(input_ids[0], int):
input_ids = [input_ids]
assert all(len(_) > 1 for _ in input_ids)

def get_logits(input_ids,
sequence_start,
sequence_end,
pre_iter_logits=None):
logits = self.decode(input_ids=input_ids,
sequence_start=sequence_start,
sequence_end=sequence_end)
logits = logits.float().cpu()
padding_token_id = -100
if pre_iter_logits is None:
_logits = logits
target_ids = [(x + [padding_token_id])[1:] for x in input_ids]
target_ids = [
torch.Tensor(torch.LongTensor(_target_ids))
for _target_ids in target_ids
]
else:
# concat the logit of the last token in previous prefill iter,
# and shift the logit of the last token in this iter
_logits = torch.concat((pre_iter_logits[..., -1:, :], logits),
dim=1)
_logits = _logits[..., :-1, :]
target_ids = [
torch.Tensor(torch.LongTensor(_target_ids))
for _target_ids in input_ids
]
target_ids = pad_sequence(target_ids,
batch_first=True,
padding_value=padding_token_id)
target_ids = target_ids.to(logits.device)
target_mask = target_ids != padding_token_id

# compute cross entropy loss
bsz, seq_len, vocab_size = logits.shape
flat_logits = _logits.contiguous().view(-1, vocab_size)
flat_target_ids = target_ids.contiguous().view(-1)
flat_loss_matrix = torch.nn.functional.cross_entropy(
flat_logits,
flat_target_ids,
reduction='none',
ignore_index=padding_token_id)
return logits, flat_loss_matrix.view(bsz, seq_len), target_mask

bs = len(input_ids)
max_seq_len = max([len(input_id) for input_id in input_ids])

# TODO: a better way to determine `max_input_len`
# At most allocate 2G mem for logits with shape [bs, seq, vocab_size]
vocab_size = self.engine.model_config.vocab_size
max_input_len = 2 * 1024**3 // (bs * vocab_size * 4)

all_loss_matrix = []
all_target_mask = []
# the 1st prefill iter
_input_ids = [input_id[0:max_input_len] for input_id in input_ids]
logits, loss_matrix, target_mask = get_logits(input_ids=_input_ids,
sequence_start=True,
sequence_end=False)
all_loss_matrix.append(loss_matrix)
all_target_mask.append(target_mask)

# the following prefill iters
for i in range(max_input_len, max_seq_len, max_input_len):
_input_ids = [
input_id[i:i + max_input_len] for input_id in input_ids
]
logits, loss_matrix, target_mask = get_logits(
input_ids=_input_ids,
sequence_start=False,
sequence_end=(i + max_input_len >= max_seq_len),
pre_iter_logits=logits)
all_loss_matrix.append(loss_matrix)
all_target_mask.append(target_mask)

all_loss_matrix = torch.cat(all_loss_matrix, dim=1)
all_target_mask = torch.cat(all_target_mask, dim=1)
target_count = torch.sum(all_target_mask, dim=-1)
loss_sum = torch.sum(all_loss_matrix * all_target_mask, dim=1)
loss_avg = loss_sum / target_count
loss_avg = loss_avg.cpu().numpy()
return loss_avg
86 changes: 77 additions & 9 deletions lmdeploy/serve/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import numpy as np
import torch
from torch.nn.utils.rnn import pad_sequence

from lmdeploy.utils import get_logger

Expand Down Expand Up @@ -64,7 +65,7 @@ def prepare_inputs(self, prompts: Union[PromptType, List[PromptType]]):

def get_logits(
self,
input_ids: Union[InputIdsType, List[InputIdsType]],
inputs: Union[str, List[str]],
input_embeddings: Union[InputEmbsType, List[InputEmbsType]] = None,
input_embedding_ranges: Union[InputEmbRngsType,
List[InputEmbRngsType]] = None):
Expand All @@ -74,13 +75,17 @@ def get_logits(
input_ids (Union[List[int], List[List[int]]]): the batch of
input token ids
"""
assert len(input_ids) > 0
if isinstance(input_ids[0], int):
input_ids = [input_ids]
for input_id in input_ids:
assert len(input_id) > 0
if isinstance(inputs, str):
inputs = [inputs]
assert all(len(_) > 0 for _ in inputs)

input_ids = [self.tokenizer.encode(text) for text in inputs]
bs = len(input_ids)
# TODO: a better way to determine `max_input_len`, at most allocate
# 2G mem for logits with shape [bs, max_input_len, vocab_size]
vocab_size = self.hf_tm_cfg.vocab_size
max_input_len = 2 * 1024**3 // (bs * vocab_size * 4)

max_input_len = self.backend_config.max_prefill_token_num
n_max_iter = np.ceil(
max([len(input_id)
for input_id in input_ids]) / max_input_len).astype(int)
Expand Down Expand Up @@ -183,6 +188,69 @@ def get_ppl(self, inputs: List[str]) -> List[float]:
"""
if isinstance(inputs, str):
inputs = [inputs]
input_ids = [self.tokenizer.encode(text) for text in inputs]
assert all(len(_) > 0 for _ in inputs)

generator = self.engine.create_instance()
return generator.get_ppl(input_ids)
input_ids = [self.tokenizer.encode(text) for text in inputs]

bs = len(input_ids)
max_seq_len = len(input_ids[0])

# TODO: a better way to determine `max_input_len`, at most allocate
# 2G mem for logits with shape [bs, max_input_len, vocab_size]
vocab_size = self.hf_tm_cfg.vocab_size
max_input_len = 2 * 1024**3 // (bs * vocab_size * 4)

all_loss_matrix = []
all_target_mask = []
for i in range(0, max_seq_len, max_input_len):
token_ids = [
input_id[i:i + max_input_len] for input_id in input_ids
]
steps = [i] * bs
logits = generator.decode(
token_ids,
steps=steps,
sequence_start=(i == 0),
sequence_end=(i + max_input_len >= max_seq_len))
bsz, seq_len, vocab_size = logits.shape
logits = logits.float().cpu()
padding_token_id = -100
# meaning logits[..., :, :] corresponds to labels
# token_ids[1:] + predict_token_id, which is
# input_ids[:, i+max_input_len:i+max_input_len+1]
target_ids = [
input_id[i + 1:i + 1 + max_input_len] for input_id in input_ids
]
if len(target_ids[0]) < len(token_ids[0]):
target_ids = [x + [padding_token_id] for x in target_ids]
target_ids = [
torch.Tensor(torch.LongTensor(_target_ids))
for _target_ids in target_ids
]
target_ids = pad_sequence(target_ids,
batch_first=True,
padding_value=padding_token_id)
target_ids = target_ids.to(logits.device)
target_mask = target_ids != padding_token_id

# compute cross entropy loss
flat_logits = logits.contiguous().view(-1, vocab_size)
flat_target_ids = target_ids.contiguous().view(-1)
flat_loss_matrix = torch.nn.functional.cross_entropy(
flat_logits,
flat_target_ids,
reduction='none',
ignore_index=padding_token_id)

all_loss_matrix.append(flat_loss_matrix.view(bsz, seq_len))
all_target_mask.append(target_mask)

all_loss_matrix = torch.cat(all_loss_matrix, dim=1)
all_target_mask = torch.cat(all_target_mask, dim=1)
target_count = torch.sum(all_target_mask, dim=-1)
loss_sum = torch.sum(all_loss_matrix * all_target_mask, dim=1)
loss_avg = loss_sum / target_count
loss_avg = loss_avg.cpu().numpy()

return loss_avg
Loading

0 comments on commit db15e41

Please sign in to comment.