From 20c50d02d511d5a15b87dd253905dbee8f22c3e1 Mon Sep 17 00:00:00 2001 From: MrZhengXin <510546841@qq.com> Date: Wed, 25 Dec 2024 16:05:11 +0800 Subject: [PATCH] add `logit_bias` parameter --- src/llamafactory/api/chat.py | 2 ++ src/llamafactory/api/protocol.py | 1 + src/llamafactory/chat/hf_engine.py | 8 ++++++++ src/llamafactory/chat/vllm_engine.py | 8 ++++++++ 4 files changed, 19 insertions(+) diff --git a/src/llamafactory/api/chat.py b/src/llamafactory/api/chat.py index c467a3e6b2..b60cc23b29 100644 --- a/src/llamafactory/api/chat.py +++ b/src/llamafactory/api/chat.py @@ -155,6 +155,7 @@ async def create_chat_completion_response( max_new_tokens=request.max_tokens, num_return_sequences=request.n, stop=request.stop, + logit_bias=request.logit_bias, ) prompt_length, response_length = 0, 0 @@ -214,6 +215,7 @@ async def create_stream_chat_completion_response( top_p=request.top_p, max_new_tokens=request.max_tokens, stop=request.stop, + logit_bias=request.logit_bias, ): if len(new_token) != 0: yield _create_stream_chat_completion_chunk( diff --git a/src/llamafactory/api/protocol.py b/src/llamafactory/api/protocol.py index c6fe6f757b..a9d05f079b 100644 --- a/src/llamafactory/api/protocol.py +++ b/src/llamafactory/api/protocol.py @@ -103,6 +103,7 @@ class ChatCompletionRequest(BaseModel): max_tokens: Optional[int] = None stop: Optional[Union[str, List[str]]] = None stream: bool = False + logit_bias: Optional[Union[Dict[int, float], Dict[str, float]]] = None class ChatCompletionResponseChoice(BaseModel): diff --git a/src/llamafactory/chat/hf_engine.py b/src/llamafactory/chat/hf_engine.py index d001386bc6..9fa77def43 100644 --- a/src/llamafactory/chat/hf_engine.py +++ b/src/llamafactory/chat/hf_engine.py @@ -117,9 +117,16 @@ def _process_args( max_length: Optional[int] = input_kwargs.pop("max_length", None) max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None) stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None) + logit_bias: Optional[Union[Dict[int, float], Dict[str, float]]] = input_kwargs.pop("logit_bias", None) if stop is not None: logger.warning_rank0("Stop parameter is not supported by the huggingface engine yet.") + + if logit_bias is not None: + logit_bias = { + int(token): bias + for token, bias in logit_bias.items() + } generating_args = generating_args.copy() generating_args.update( @@ -135,6 +142,7 @@ def _process_args( length_penalty=length_penalty if length_penalty is not None else generating_args["length_penalty"], eos_token_id=[tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids, pad_token_id=tokenizer.pad_token_id, + logit_bias=logit_bias, ) ) diff --git a/src/llamafactory/chat/vllm_engine.py b/src/llamafactory/chat/vllm_engine.py index fd54b5a90d..00c7623305 100644 --- a/src/llamafactory/chat/vllm_engine.py +++ b/src/llamafactory/chat/vllm_engine.py @@ -140,6 +140,7 @@ async def _generate( max_length: Optional[int] = input_kwargs.pop("max_length", None) max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None) stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None) + logit_bias: Optional[Union[Dict[int, float], Dict[str, float]]] = input_kwargs.pop("logit_bias", None) if length_penalty is not None: logger.warning_rank0("Length penalty is not supported by the vllm engine yet.") @@ -158,6 +159,12 @@ async def _generate( if max_new_tokens: max_tokens = max_new_tokens + if logit_bias: + logit_bias = { + int(token): bias + for token, bias in logit_bias.items() + } + sampling_params = SamplingParams( n=num_return_sequences, repetition_penalty=( @@ -171,6 +178,7 @@ async def _generate( stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids, max_tokens=max_tokens, skip_special_tokens=self.generating_args["skip_special_tokens"], + logit_bias=logit_bias, ) if images is not None: # add image features