From 4dd51594ee8782e07967f1b9850ff8fdbca068f8 Mon Sep 17 00:00:00 2001 From: Shi Hui Date: Wed, 8 May 2024 13:27:49 +0800 Subject: [PATCH 01/37] add the types for code completions --- xinference/types.py | 15 ++++++++++++++- 1 file changed, 14 insertions(+), 1 deletion(-) diff --git a/xinference/types.py b/xinference/types.py index 53473f26f2..e9c71becc2 100644 --- a/xinference/types.py +++ b/xinference/types.py @@ -14,7 +14,7 @@ from typing import Any, Callable, Dict, ForwardRef, Iterable, List, Optional, Union -from typing_extensions import Literal, NotRequired, TypedDict +from typing_extensions import Literal, Mapping, NotRequired, TypedDict from ._compat import ( BaseModel, @@ -507,3 +507,16 @@ def from_dict(cls, data: Dict): image_lora_load_kwargs=data.get("image_lora_load_kwargs"), image_lora_fuse_kwargs=data.get("image_lora_fuse_kwargs"), ) + + +class CreateCodeCompletion(CreateCompletion): + mode: Literal["completion"] = "completion" + + +class CreateCodeInFill(CreateCompletion): + mode: Literal["infill"] = "infill" + + +class CreateCodeRepoLevelCompletion(CreateCompletion): + mode: Literal["repo-level-completion"] = "repo-level-completion" + files: Mapping[str, str] From 2d893b3fcbb2bb88ecea9eb036c8b5e6a342a5ba Mon Sep 17 00:00:00 2001 From: Shi Hui Date: Wed, 8 May 2024 17:00:49 +0800 Subject: [PATCH 02/37] add prompt style definition for code completions. --- xinference/model/llm/llm_family.py | 20 ++++++++++++++++++++ 1 file changed, 20 insertions(+) diff --git a/xinference/model/llm/llm_family.py b/xinference/model/llm/llm_family.py index e7b8561d3e..2ce7a12f83 100644 --- a/xinference/model/llm/llm_family.py +++ b/xinference/model/llm/llm_family.py @@ -109,6 +109,25 @@ class PromptStyleV1(BaseModel): stop_token_ids: Optional[List[int]] +class FIMSpecV1(BaseModel): + style: Literal["PSM", "PMS"] + prefix: str + middle: str + suffix: str + + +class RepoLevelCodeCompletionSpecV1(BaseModel): + repo_name: Optional[str] + file_type: Literal["filepath", "filename"] + file_separator: str + + +class CodePromptStyleV1(BaseModel): + style_name: str + fim_spec: Optional[FIMSpecV1] + repo_level_spec: Optional[RepoLevelCodeCompletionSpecV1] + + class LLMFamilyV1(BaseModel): version: Literal[1] context_length: Optional[int] = DEFAULT_CONTEXT_LENGTH @@ -120,6 +139,7 @@ class LLMFamilyV1(BaseModel): model_family: Optional[str] model_specs: List["LLMSpecV1"] prompt_style: Optional["PromptStyleV1"] + code_prompt_style: Optional[CodePromptStyleV1] class CustomLLMFamilyV1(LLMFamilyV1): From 81a77cbe67a290b9a57dde12836cb9bc4c4d8384 Mon Sep 17 00:00:00 2001 From: Shi Hui Date: Wed, 8 May 2024 18:38:48 +0800 Subject: [PATCH 03/37] add code completion mixin. --- xinference/model/llm/llm_family.json | 11 +++- xinference/model/llm/utils.py | 87 +++++++++++++++++++++++++++- 2 files changed, 96 insertions(+), 2 deletions(-) diff --git a/xinference/model/llm/llm_family.json b/xinference/model/llm/llm_family.json index 082fa33c12..57fdfdd03c 100644 --- a/xinference/model/llm/llm_family.json +++ b/xinference/model/llm/llm_family.json @@ -2291,7 +2291,16 @@ "model_id": "TheBloke/starcoder-GGML", "model_file_name_template": "starcoder.ggmlv3.{quantization}.bin" } - ] + ], + "code_prompt_style": { + "style_name": "STARCODER", + "fim_spec": { + "style": "PMS", + "prefix": "", + "middle": "", + "suffix": "" + } + } }, { "version": 1, diff --git a/xinference/model/llm/utils.py b/xinference/model/llm/utils.py index 51d8354570..9683ccb88f 100644 --- a/xinference/model/llm/utils.py +++ b/xinference/model/llm/utils.py @@ -17,7 +17,17 @@ import os import time import uuid -from typing import AsyncGenerator, Dict, Iterator, List, Optional, Tuple, cast +from typing import ( + AsyncGenerator, + Dict, + Iterator, + List, + Literal, + Mapping, + Optional, + Tuple, + cast, +) from ...types import ( SPECIAL_TOOL_PROMPT, @@ -28,6 +38,7 @@ CompletionChunk, ) from .llm_family import ( + CodePromptStyleV1, GgmlLLMSpecV1, LLMFamilyV1, LLMSpecV1, @@ -714,6 +725,80 @@ def _tool_calls_completion(cls, model_family, model_uid, c, tools): } +class CodeModelMixin: + @staticmethod + def get_code_prompt( + generate_type: Literal["completion", "fim"], + prompt: str, + code_prompt_style: CodePromptStyleV1, + suffix: Optional[str], + repo_name: Optional[str], + files: Optional[Mapping[str, str]], + ) -> str: + if generate_type == "completion": + if suffix is not None: + logger.warning( + "Suffix is only required on generate type is infill, ignored" + ) + + if files is None or len(files) == 0: + return prompt + + spec = code_prompt_style.repo_level_spec + if spec is None: + raise ValueError( + "The model does not support repository level code completion" + ) + + ret_val = "" + spec = code_prompt_style.repo_level_spec + if spec.repo_name is None: + if repo_name is not None: + logger.warning( + "repo_name is provided but it will not be used in this model, ignored" + ) + else: + if repo_name is None: + raise ValueError( + "The repo_name is required for repository level code completion for this model" + ) + ret_val = f"{spec.repo_name}{repo_name}\n" + + for filepath, content in files.items(): + repo_file = ( + filepath + if spec.file_type == "filepath" + else CodeModelMixin._path_to_name(filepath) + ) + ret_val += f"{spec.file_separator}{repo_file}\n{content}\n" + return ret_val + prompt + + elif generate_type == "fim": + spec = code_prompt_style.fim_spec + if spec is None: + raise ValueError("This model is not support FIM mode generate") + + if suffix is None: + raise ValueError("suffix is required in FIM mode") + + if files is not None and len(files) > 0: + logger.warning( + "files is only required in repository level code completion, ignored" + ) + + if spec.style == "PSM": + return f"{spec.prefix}{prompt}{spec.suffix}{suffix}{spec.middle}" + else: + return f"{spec.prefix}{prompt}{spec.middle}{suffix}{spec.suffix}" + + else: + raise ValueError(f"Unsupported generate type: {generate_type}") + + @staticmethod + def _path_to_name(filepath: str) -> str: + return os.path.split(filepath)[1] + + def get_file_location( llm_family: LLMFamilyV1, spec: LLMSpecV1, quantization: str ) -> Tuple[str, bool]: From b3c0e3abf2b1c352b4afcf976f6e1f9b0babccef Mon Sep 17 00:00:00 2001 From: Shi Hui Date: Thu, 9 May 2024 15:36:07 +0800 Subject: [PATCH 04/37] refactor the code prompt style and add the unit test for code prompt style check. --- xinference/api/restful_api.py | 99 +++++++ xinference/core/model.py | 331 ++++++++++++++--------- xinference/core/supervisor.py | 11 + xinference/model/llm/__init__.py | 25 +- xinference/model/llm/ggml/llamacpp.py | 58 +++- xinference/model/llm/llm_family.json | 127 ++++++++- xinference/model/llm/llm_family.py | 4 +- xinference/model/llm/pytorch/core.py | 63 ++++- xinference/model/llm/tests/test_utils.py | 234 +++++++++++++++- xinference/model/llm/utils.py | 53 ++-- xinference/model/llm/vllm/core.py | 61 ++++- xinference/types.py | 14 +- 12 files changed, 896 insertions(+), 184 deletions(-) diff --git a/xinference/api/restful_api.py b/xinference/api/restful_api.py index f627521dc1..c01bce70cf 100644 --- a/xinference/api/restful_api.py +++ b/xinference/api/restful_api.py @@ -62,6 +62,7 @@ ChatCompletionMessage, Completion, CreateChatCompletion, + CreateCodeCompletion, CreateCompletion, ImageList, PeftModelConfig, @@ -229,6 +230,9 @@ def serve(self, logging_conf: Optional[dict] = None): self._router.add_api_route( "/v1/models/prompts", self._get_builtin_prompts, methods=["GET"] ) + self._router.add_api_route( + "/v1/models/code_prompts", self._get_builtin_code_prompts, methods=["GET"] + ) self._router.add_api_route( "/v1/models/families", self._get_builtin_families, methods=["GET"] ) @@ -452,6 +456,18 @@ def serve(self, logging_conf: Optional[dict] = None): ), ) + self._router.add_api_route( + "v1/code/completions", + self.create_code_completion, + methods=["POST"], + response_model=Completion, + dependencies=( + [Security(self._auth_service, scopes=["models:read"])] + if self.is_authenticated() + else None + ), + ) + # for custom models self._router.add_api_route( "/v1/model_registrations/{model_type}", @@ -575,6 +591,18 @@ async def _get_builtin_prompts(self) -> JSONResponse: logger.error(e, exc_info=True) raise HTTPException(status_code=500, detail=str(e)) + async def _get_builtin_code_prompts(self) -> JSONResponse: + """ + For internal usage + :return: + """ + try: + data = await (await self._get_supervisor_ref()).get_builtin_code_prompts() + return JSONResponse(content=data) + except Exception as e: + logger.error(e, exc_info=True) + raise HTTPException(status_code=500, detail=str(e)) + async def _get_builtin_families(self) -> JSONResponse: """ For internal usage @@ -1392,6 +1420,77 @@ async def stream_results(): self.handle_request_limit_error(e) raise HTTPException(status_code=500, detail=str(e)) + async def create_code_completion(self, request: Request) -> Response: + json_data = await request.json() + if "mode" not in json_data: + raise HTTPException( + status_code=400, detail="mode is required in code completion request" + ) + + if json_data["mode"] not in ("completion", "infill", "repo-level-completion"): + raise HTTPException( + status_code=400, + detail="mode must be one of 'completion', 'infill' or 'repo-level-completion", + ) + + if json_data.get("stream", False): + json_data["stream"] = False + + body = CreateCodeCompletion.parse_obj(json_data) + exclude = { + "generate_mode", + "prompt", + "suffix", + "repo_name", + "files", + "model", + "n", + "messages", + "logit_bias", + "logit_bias_type", + "user", + } + + kwargs = body.dict(exclude_unset=True, exclude=exclude) + + # TODO: Decide if this default value override is necessary #1061 + if body.max_tokens is None: + kwargs["max_tokens"] = max_tokens_field.default + + if body.logit_bias is not None: + raise HTTPException(status_code=501, detail="Not implemented") + + model_uid = body.model + + try: + model = await (await self._get_supervisor_ref()).get_model(model_uid) + except ValueError as ve: + logger.error(str(ve), exc_info=True) + await self._report_error_event(model_uid, str(ve)) + raise HTTPException(status_code=400, detail=str(ve)) + except Exception as e: + logger.error(e, exc_info=True) + await self._report_error_event(model_uid, str(e)) + raise HTTPException(status_code=500, detail=str(e)) + + assert not body.stream + + try: + data = await model.code_generate( + body.generate_mode, + body.prompt, + body.suffix, + body.repo_name, + body.files, + kwargs, + ) + return Response(content=data, media_type="application/json") + except Exception as e: + logger.error(e, exc_info=True) + await self._report_error_event(model_uid, str(e)) + self.handle_request_limit_error(e) + raise HTTPException(status_code=500, detail=str(e)) + async def query_engines_by_model_name(self, model_name: str) -> JSONResponse: try: content = await ( diff --git a/xinference/core/model.py b/xinference/core/model.py index f31989a9c4..ed79058644 100644 --- a/xinference/core/model.py +++ b/xinference/core/model.py @@ -28,6 +28,7 @@ Generator, Iterator, List, + Mapping, Optional, Union, ) @@ -35,6 +36,8 @@ import sse_starlette.sse import xoscar as xo +from ..types import CodeGenerateMode + if TYPE_CHECKING: from .worker import WorkerActor from ..model.llm.core import LLM @@ -290,9 +293,9 @@ async def _to_json_async_gen(self, gen: types.AsyncGeneratorType): if time_to_first_token is None: time_to_first_token = (time.time() - start_time) * 1000 final_usage = v.pop("usage", None) - v = await asyncio.to_thread(json.dumps, v) + v = await asyncio.to_thread(json.dumps, v) # type: ignore[attr-defined] v = dict(data=v) # noqa: F821 - yield await asyncio.to_thread(sse_starlette.sse.ensure_bytes, v, None) + yield await asyncio.to_thread(sse_starlette.sse.ensure_bytes, v, None) # type: ignore[attr-defined] except OutOfMemoryError: logger.exception( "Model actor is out of memory, model id: %s", self.model_uid() @@ -324,7 +327,7 @@ async def _call_wrapper(self, fn: Callable, *args, **kwargs): if inspect.iscoroutinefunction(fn): ret = await fn(*args, **kwargs) else: - ret = await asyncio.to_thread(fn, *args, **kwargs) + ret = await asyncio.to_thread(fn, *args, **kwargs) # type: ignore[attr-defined] else: async with self._lock: if inspect.iscoroutinefunction(fn): @@ -343,7 +346,7 @@ async def _call_wrapper(self, fn: Callable, *args, **kwargs): gen = self._to_json_async_gen(ret) self._current_generator = weakref.ref(gen) return gen - return await asyncio.to_thread(json_dumps, ret) + return await asyncio.to_thread(json_dumps, ret) # type: ignore[attr-defined] @log_async(logger=logger) @request_limit @@ -397,143 +400,207 @@ async def chat(self, prompt: str, *args, **kwargs): @log_async(logger=logger) @request_limit - async def create_embedding(self, input: Union[str, List[str]], *args, **kwargs): - if hasattr(self._model, "create_embedding"): - return await self._call_wrapper( - self._model.create_embedding, input, *args, **kwargs - ) - - raise AttributeError( - f"Model {self._model.model_spec} is not for creating embedding." - ) - - @log_async(logger=logger) - @request_limit - async def rerank( + @xo.generator + async def code_generate( self, - documents: List[str], - query: str, - top_n: Optional[int], - max_chunks_per_doc: Optional[int], - return_documents: Optional[bool], + generate_mode: CodeGenerateMode, + prompt: str, + suffix: Optional[str], + repo_name: Optional[str], + files: Optional[Mapping[str, str]], *args, **kwargs, ): - if hasattr(self._model, "rerank"): - return await self._call_wrapper( - self._model.rerank, - documents, - query, - top_n, - max_chunks_per_doc, - return_documents, - *args, - **kwargs, + start_time = time.time() + response = None + try: + if hasattr(self._model, "code_generate"): + response = await self._call_wrapper( + self._model.code_generate, + generate_mode, + prompt, + suffix, + repo_name, + files, + *args, + **kwargs, + ) + return response + if hasattr(self._model, "async_code_generate"): + response = await self._call_wrapper( + self._model.async_code_generate, + generate_mode, + prompt, + suffix, + repo_name, + files, + *args, + **kwargs, + ) + return response + raise AttributeError( + f"Model {self._model.model_spec} is not for code generate." ) - raise AttributeError(f"Model {self._model.model_spec} is not for reranking.") + finally: + # For the non stream result. + record = None + if isinstance(response, (Generator, AsyncGenerator)): + record = response + elif isinstance(response, bytes): + record = json.loads(response) + if record and isinstance(record, dict): + usage = record["usage"] + # Some backends may not have a valid usage, we just skip them. + completion_tokens = usage["completion_tokens"] + prompt_tokens = usage["prompt_tokens"] + await self._record_completion_metrics( + time.time() - start_time, + completion_tokens, + prompt_tokens, + ) - @log_async(logger=logger, args_formatter=lambda _, kwargs: kwargs.pop("audio")) - @request_limit - async def transcriptions( - self, - audio: bytes, - language: Optional[str] = None, - prompt: Optional[str] = None, - response_format: str = "json", - temperature: float = 0, - timestamp_granularities: Optional[List[str]] = None, - ): - if hasattr(self._model, "transcriptions"): - return await self._call_wrapper( - self._model.transcriptions, - audio, - language, - prompt, - response_format, - temperature, - timestamp_granularities, - ) - raise AttributeError( - f"Model {self._model.model_spec} is not for creating transcriptions." - ) - @log_async(logger=logger, args_formatter=lambda _, kwargs: kwargs.pop("audio")) - @request_limit - async def translations( - self, - audio: bytes, - language: Optional[str] = None, - prompt: Optional[str] = None, - response_format: str = "json", - temperature: float = 0, - timestamp_granularities: Optional[List[str]] = None, - ): - if hasattr(self._model, "translations"): - return await self._call_wrapper( - self._model.translations, - audio, - language, - prompt, - response_format, - temperature, - timestamp_granularities, - ) - raise AttributeError( - f"Model {self._model.model_spec} is not for creating translations." +@log_async(logger=logger) +@request_limit +async def create_embedding(self, input: Union[str, List[str]], *args, **kwargs): + if hasattr(self._model, "create_embedding"): + return await self._call_wrapper( + self._model.create_embedding, input, *args, **kwargs ) - @log_async(logger=logger) - @request_limit - async def text_to_image( - self, - prompt: str, - n: int = 1, - size: str = "1024*1024", - response_format: str = "url", - *args, - **kwargs, - ): - if hasattr(self._model, "text_to_image"): - return await self._call_wrapper( - self._model.text_to_image, - prompt, - n, - size, - response_format, - *args, - **kwargs, - ) - raise AttributeError( - f"Model {self._model.model_spec} is not for creating image." + raise AttributeError( + f"Model {self._model.model_spec} is not for creating embedding." + ) + + +@log_async(logger=logger) +@request_limit +async def rerank( + self, + documents: List[str], + query: str, + top_n: Optional[int], + max_chunks_per_doc: Optional[int], + return_documents: Optional[bool], + *args, + **kwargs, +): + if hasattr(self._model, "rerank"): + return await self._call_wrapper( + self._model.rerank, + documents, + query, + top_n, + max_chunks_per_doc, + return_documents, + *args, + **kwargs, ) - - async def image_to_image( - self, - image: "PIL.Image", - prompt: str, - negative_prompt: str, - n: int = 1, - size: str = "1024*1024", - response_format: str = "url", - *args, - **kwargs, - ): - if hasattr(self._model, "image_to_image"): - return await self._call_wrapper( - self._model.image_to_image, - image, - prompt, - negative_prompt, - n, - size, - response_format, - *args, - **kwargs, - ) - raise AttributeError( - f"Model {self._model.model_spec} is not for creating image." + raise AttributeError(f"Model {self._model.model_spec} is not for reranking.") + + +@log_async(logger=logger, args_formatter=lambda _, kwargs: kwargs.pop("audio")) +@request_limit +async def transcriptions( + self, + audio: bytes, + language: Optional[str] = None, + prompt: Optional[str] = None, + response_format: str = "json", + temperature: float = 0, + timestamp_granularities: Optional[List[str]] = None, +): + if hasattr(self._model, "transcriptions"): + return await self._call_wrapper( + self._model.transcriptions, + audio, + language, + prompt, + response_format, + temperature, + timestamp_granularities, ) + raise AttributeError( + f"Model {self._model.model_spec} is not for creating transcriptions." + ) + + +@log_async(logger=logger, args_formatter=lambda _, kwargs: kwargs.pop("audio")) +@request_limit +async def translations( + self, + audio: bytes, + language: Optional[str] = None, + prompt: Optional[str] = None, + response_format: str = "json", + temperature: float = 0, + timestamp_granularities: Optional[List[str]] = None, +): + if hasattr(self._model, "translations"): + return await self._call_wrapper( + self._model.translations, + audio, + language, + prompt, + response_format, + temperature, + timestamp_granularities, + ) + raise AttributeError( + f"Model {self._model.model_spec} is not for creating translations." + ) + + +@log_async(logger=logger) +@request_limit +async def text_to_image( + self, + prompt: str, + n: int = 1, + size: str = "1024*1024", + response_format: str = "url", + *args, + **kwargs, +): + if hasattr(self._model, "text_to_image"): + return await self._call_wrapper( + self._model.text_to_image, + prompt, + n, + size, + response_format, + *args, + **kwargs, + ) + raise AttributeError(f"Model {self._model.model_spec} is not for creating image.") + + +async def image_to_image( + self, + image: "PIL.Image", + prompt: str, + negative_prompt: str, + n: int = 1, + size: str = "1024*1024", + response_format: str = "url", + *args, + **kwargs, +): + if hasattr(self._model, "image_to_image"): + return await self._call_wrapper( + self._model.image_to_image, + image, + prompt, + negative_prompt, + n, + size, + response_format, + *args, + **kwargs, + ) + raise AttributeError(f"Model {self._model.model_spec} is not for creating image.") + - async def record_metrics(self, name, op, kwargs): - worker_ref = await self._get_worker_ref() - await worker_ref.record_metrics(name, op, kwargs) +async def record_metrics(self, name, op, kwargs): + worker_ref = await self._get_worker_ref() + await worker_ref.record_metrics(name, op, kwargs) diff --git a/xinference/core/supervisor.py b/xinference/core/supervisor.py index 6352e0aee2..724f1ff713 100644 --- a/xinference/core/supervisor.py +++ b/xinference/core/supervisor.py @@ -273,10 +273,20 @@ async def get_builtin_prompts() -> Dict[str, Any]: data[k] = v.dict() return data + @staticmethod + async def get_builtin_code_prompts() -> Dict[str, Any]: + from ..model.llm.llm_family import BUILTIN_LLM_CODE_PROMPT_STYLE + + data = {} + for k, v in BUILTIN_LLM_CODE_PROMPT_STYLE.items(): + data[k] = v.dict() + return data + @staticmethod async def get_builtin_families() -> Dict[str, List[str]]: from ..model.llm.llm_family import ( BUILTIN_LLM_MODEL_CHAT_FAMILIES, + BUILTIN_LLM_MODEL_CODE_FAMILIES, BUILTIN_LLM_MODEL_GENERATE_FAMILIES, BUILTIN_LLM_MODEL_TOOL_CALL_FAMILIES, ) @@ -285,6 +295,7 @@ async def get_builtin_families() -> Dict[str, List[str]]: "chat": list(BUILTIN_LLM_MODEL_CHAT_FAMILIES), "generate": list(BUILTIN_LLM_MODEL_GENERATE_FAMILIES), "tools": list(BUILTIN_LLM_MODEL_TOOL_CALL_FAMILIES), + "code": list(BUILTIN_LLM_MODEL_CODE_FAMILIES), } async def get_devices_count(self) -> int: diff --git a/xinference/model/llm/__init__.py b/xinference/model/llm/__init__.py index 5541a07762..d741d6d489 100644 --- a/xinference/model/llm/__init__.py +++ b/xinference/model/llm/__init__.py @@ -25,8 +25,10 @@ get_llm_model_descriptions, ) from .llm_family import ( + BUILTIN_LLM_CODE_PROMPT_STYLE, BUILTIN_LLM_FAMILIES, BUILTIN_LLM_MODEL_CHAT_FAMILIES, + BUILTIN_LLM_MODEL_CODE_FAMILIES, BUILTIN_LLM_MODEL_GENERATE_FAMILIES, BUILTIN_LLM_MODEL_TOOL_CALL_FAMILIES, BUILTIN_LLM_PROMPT_STYLE, @@ -37,6 +39,7 @@ SGLANG_CLASSES, SUPPORTED_ENGINES, VLLM_CLASSES, + CodePromptStyleV1, CustomLLMFamilyV1, GgmlLLMSpecV1, LLMFamilyV1, @@ -97,10 +100,10 @@ def generate_engine_config_by_model_family(model_family): def _install(): from .ggml.chatglm import ChatglmCppChatModel - from .ggml.llamacpp import LlamaCppChatModel, LlamaCppModel + from .ggml.llamacpp import LlamaCppChatModel, LlamaCppCodeModel, LlamaCppModel from .pytorch.baichuan import BaichuanPytorchChatModel from .pytorch.chatglm import ChatglmPytorchChatModel - from .pytorch.core import PytorchChatModel, PytorchModel + from .pytorch.core import PytorchChatModel, PytorchCodeModel, PytorchModel from .pytorch.deepseek_vl import DeepSeekVLChatModel from .pytorch.falcon import FalconPytorchChatModel, FalconPytorchModel from .pytorch.internlm2 import Internlm2PytorchChatModel @@ -109,7 +112,7 @@ def _install(): from .pytorch.vicuna import VicunaPytorchChatModel from .pytorch.yi_vl import YiVLChatModel from .sglang.core import SGLANGChatModel, SGLANGModel - from .vllm.core import VLLMChatModel, VLLMModel + from .vllm.core import VLLMChatModel, VLLMCodeModel, VLLMModel try: from .pytorch.omnilmm import OmniLMMModel @@ -124,11 +127,12 @@ def _install(): [ ChatglmCppChatModel, LlamaCppChatModel, + LlamaCppCodeModel, LlamaCppModel, ] ) SGLANG_CLASSES.extend([SGLANGModel, SGLANGChatModel]) - VLLM_CLASSES.extend([VLLMModel, VLLMChatModel]) + VLLM_CLASSES.extend([VLLMModel, VLLMChatModel, VLLMCodeModel]) PYTORCH_CLASSES.extend( [ BaichuanPytorchChatModel, @@ -138,6 +142,7 @@ def _install(): LlamaPytorchModel, LlamaPytorchChatModel, PytorchChatModel, + PytorchCodeModel, FalconPytorchModel, Internlm2PytorchChatModel, QwenVLChatModel, @@ -177,6 +182,16 @@ def _install(): if "tools" in model_spec.model_ability: BUILTIN_LLM_MODEL_TOOL_CALL_FAMILIES.add(model_spec.model_name) + if "code" in model_spec.model_ability and isinstance( + model_spec.code_prompt_style, CodePromptStyleV1 + ): + BUILTIN_LLM_CODE_PROMPT_STYLE[ + model_spec.model_name + ] = model_spec.code_prompt_style + + if "code" in model_spec.model_ability: + BUILTIN_LLM_MODEL_CODE_FAMILIES.add(model_spec.model_name) + modelscope_json_path = os.path.join( os.path.dirname(os.path.abspath(__file__)), "llm_family_modelscope.json" ) @@ -199,6 +214,8 @@ def _install(): BUILTIN_LLM_MODEL_GENERATE_FAMILIES.add(model_spec.model_name) if "tools" in model_spec.model_ability: BUILTIN_LLM_MODEL_TOOL_CALL_FAMILIES.add(model_spec.model_name) + if "code" in model_spec.model_ability: + BUILTIN_LLM_MODEL_CODE_FAMILIES.add(model_spec.model_name) for llm_specs in [BUILTIN_LLM_FAMILIES, BUILTIN_MODELSCOPE_LLM_FAMILIES]: for llm_spec in llm_specs: diff --git a/xinference/model/llm/ggml/llamacpp.py b/xinference/model/llm/ggml/llamacpp.py index 3725c7fdbd..787bd1c3f8 100644 --- a/xinference/model/llm/ggml/llamacpp.py +++ b/xinference/model/llm/ggml/llamacpp.py @@ -14,12 +14,13 @@ import datetime import logging import os -from typing import Iterable, Iterator, List, Optional, Union +from typing import Iterable, Iterator, List, Mapping, Optional, Union from ....types import ( ChatCompletion, ChatCompletionChunk, ChatCompletionMessage, + CodeGenerateMode, Completion, CompletionChunk, CreateCompletionLlamaCpp, @@ -29,7 +30,7 @@ ) from ..core import LLM from ..llm_family import LLMFamilyV1, LLMSpecV1 -from ..utils import ChatModelMixin +from ..utils import ChatModelMixin, CodeModelMixin logger = logging.getLogger(__name__) @@ -302,3 +303,56 @@ def chat( self.model_family, self.model_uid, c, tools ) return self._to_chat_completion(c) + + +class LlamaCppCodeModel(LlamaCppModel, CodeModelMixin): + def __init__( + self, + model_uid: str, + model_family: "LLMFamilyV1", + model_spec: "LLMSpecV1", + quantization: str, + model_path: str, + llamacpp_model_config: Optional[LlamaCppModelConfig] = None, + ): + super().__init__( + model_uid, + model_family, + model_spec, + quantization, + model_path, + llamacpp_model_config, + ) + + @classmethod + def match( + cls, llm_family: LLMFamilyV1, llm_spec: LLMSpecV1, quantization: str + ) -> bool: + if llm_spec.model_format not in ["ggmlv3", "ggufv2"]: + return False + if "chatglm" in llm_family.model_name: + return False + return "code" not in llm_family.model_ability + + def code_generate( + self, + generate_model: CodeGenerateMode, + prompt: str, + suffix: Optional[str], + repo_name: Optional[str], + files: Optional[Mapping[str, str]], + generate_config: Optional[LlamaCppGenerateConfig] = None, + ): + code_prompt = self.get_code_prompt( + generate_model, + prompt, + self.model_spec.code_prompt_style, + suffix, + repo_name, + files, + ) + + if generate_config is not None and generate_config.get("stream", False): + generate_config["stream"] = False + + return self.generate(code_prompt, generate_config) diff --git a/xinference/model/llm/llm_family.json b/xinference/model/llm/llm_family.json index 57fdfdd03c..bc22aa6ce0 100644 --- a/xinference/model/llm/llm_family.json +++ b/xinference/model/llm/llm_family.json @@ -2274,7 +2274,8 @@ "en" ], "model_ability": [ - "generate" + "generate", + "code" ], "model_description": "Starcoder is an open-source Transformer based LLM that is trained on permissively licensed data from GitHub.", "model_specs": [ @@ -2295,7 +2296,7 @@ "code_prompt_style": { "style_name": "STARCODER", "fim_spec": { - "style": "PMS", + "style": "PSM", "prefix": "", "middle": "", "suffix": "" @@ -4425,6 +4426,128 @@ ] } }, + { + "version": 1, + "context_length": 16384, + "model_name": "deepseek-coder-base", + "model_lang": [ + "en", + "zh" + ], + "model_ability": [ + "generate", + "code" + ], + "model_description": "deepseek-coder-base is pre-trained on project-level code corpus by employing a window size of 16K and a extra fill-in-the-blank task, to support project-level code completion and infilling.", + "model_specs": [ + { + "model_format": "pytorch", + "model_size_in_billions": "1_3", + "quantizations": [ + "4-bit", + "8-bit", + "none" + ], + "model_id": "deepseek-ai/deepseek-coder-1.3b-base", + "model_revision": "c919139c3a9b4070729c8b2cca4847ab29ca8d94" + }, + { + "model_format": "pytorch", + "model_size_in_billions": "6_7", + "quantizations": [ + "4-bit", + "8-bit", + "none" + ], + "model_id": "deepseek-ai/deepseek-coder-6.7b-base", + "model_revision": "ce2207a8bfef3ee92bd7dd4cc31c52cfa0046912" + }, + { + "model_format": "pytorch", + "model_size_in_billions": 33, + "quantizations": [ + "4-bit", + "8-bit", + "none" + ], + "model_id": "deepseek-ai/deepseek-coder-33b-base", + "model_revision": "45c85cadf3720ef3e85a492e24fd4b8c5d21d8ac" + }, + { + "model_format": "ggufv2", + "model_size_in_billions": "1_3", + "quantizations": [ + "Q2_K", + "Q3_K_L", + "Q3_K_M", + "Q3_K_S", + "Q4_0", + "Q4_K_M", + "Q4_K_S", + "Q5_0", + "Q5_K_M", + "Q5_K_S", + "Q6_K", + "Q8_0" + ], + "model_id": "TheBloke/deepseek-coder-1.3b-base-GGUF", + "model_file_name_template": "deepseek-coder-1.3b-base.{quantization}.gguf" + }, + { + "model_format": "ggufv2", + "model_size_in_billions": "6_7", + "quantizations": [ + "Q2_K", + "Q3_K_L", + "Q3_K_M", + "Q3_K_S", + "Q4_0", + "Q4_K_M", + "Q4_K_S", + "Q5_0", + "Q5_K_M", + "Q5_K_S", + "Q6_K", + "Q8_0" + ], + "model_id": "TheBloke/deepseek-coder-6.7B-base-GGUF", + "model_file_name_template": "deepseek-coder-6.7b-base.{quantization}.gguf" + }, + { + "model_format": "ggufv2", + "model_size_in_billions": 33, + "quantizations": [ + "Q2_K", + "Q3_K_L", + "Q3_K_M", + "Q3_K_S", + "Q4_0", + "Q4_K_M", + "Q4_K_S", + "Q5_0", + "Q5_K_M", + "Q5_K_S", + "Q6_K", + "Q8_0" + ], + "model_id": "TheBloke/deepseek-coder-33B-base-GGUF", + "model_file_name_template": "deepseek-coder-33b-base.{quantization}.gguf" + } + ], + "code_prompt_style": { + "style_name": "DEEPSEEK_CODER", + "fim_spec": { + "style": "PMS", + "prefix": "<|fim▁begin|>", + "middle": "<|fim▁hole|>", + "suffix": "<|fim▁end|>" + }, + "repo_level_spec": { + "file_type": "filename", + "file_separator": "#" + } + } + }, { "version": 1, "context_length": 4096, diff --git a/xinference/model/llm/llm_family.py b/xinference/model/llm/llm_family.py index 2ce7a12f83..9adb11d3e7 100644 --- a/xinference/model/llm/llm_family.py +++ b/xinference/model/llm/llm_family.py @@ -47,9 +47,11 @@ DEFAULT_CONTEXT_LENGTH = 2048 BUILTIN_LLM_PROMPT_STYLE: Dict[str, "PromptStyleV1"] = {} +BUILTIN_LLM_CODE_PROMPT_STYLE: Dict[str, "CodePromptStyleV1"] = {} BUILTIN_LLM_MODEL_CHAT_FAMILIES: Set[str] = set() BUILTIN_LLM_MODEL_GENERATE_FAMILIES: Set[str] = set() BUILTIN_LLM_MODEL_TOOL_CALL_FAMILIES: Set[str] = set() +BUILTIN_LLM_MODEL_CODE_FAMILIES: Set[str] = set() class GgmlLLMSpecV1(BaseModel): @@ -133,7 +135,7 @@ class LLMFamilyV1(BaseModel): context_length: Optional[int] = DEFAULT_CONTEXT_LENGTH model_name: str model_lang: List[str] - model_ability: List[Literal["embed", "generate", "chat", "tools", "vision"]] + model_ability: List[Literal["embed", "generate", "chat", "tools", "vision", "code"]] model_description: Optional[str] # reason for not required str here: legacy registration model_family: Optional[str] diff --git a/xinference/model/llm/pytorch/core.py b/xinference/model/llm/pytorch/core.py index 3703b36704..b51a3f6733 100644 --- a/xinference/model/llm/pytorch/core.py +++ b/xinference/model/llm/pytorch/core.py @@ -15,7 +15,7 @@ import json import logging import os -from typing import Iterable, Iterator, List, Optional, Union +from typing import Iterable, Iterator, List, Mapping, Optional, Union from ....device_utils import ( get_device_preferred_dtype, @@ -26,6 +26,7 @@ ChatCompletion, ChatCompletionChunk, ChatCompletionMessage, + CodeGenerateMode, Completion, CompletionChunk, CreateCompletionTorch, @@ -39,7 +40,7 @@ from ...utils import select_device from ..core import LLM from ..llm_family import LLMFamilyV1, LLMSpecV1 -from ..utils import ChatModelMixin +from ..utils import ChatModelMixin, CodeModelMixin logger = logging.getLogger(__name__) @@ -511,3 +512,61 @@ def chat( self.model_family, self.model_uid, c, tools ) return self._to_chat_completion(c) + + +class PytorchCodeModel(PytorchModel, CodeModelMixin): + def __init__( + self, + model_uid: str, + model_family: "LLMFamilyV1", + model_spec: "LLMSpecV1", + quantization: str, + model_path: str, + pytorch_model_config: Optional[PytorchModelConfig] = None, + peft_model: Optional[List[LoRA]] = None, + ): + super().__init__( + model_uid, + model_family, + model_spec, + quantization, + model_path, + pytorch_model_config, + peft_model, + ) + + @classmethod + def match( + cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str + ) -> bool: + if llm_spec.model_format not in ["pytorch", "gptq", "awq"]: + return False + model_family = llm_family.model_family or llm_family.model_name + if model_family in NON_DEFAULT_MODEL_LIST: + return False + if "code" not in llm_family.model_ability: + return False + return True + + def code_generate( + self, + generate_mode: CodeGenerateMode, + prompt: str, + suffix: Optional[str], + repo_name: Optional[str], + files: Optional[Mapping[str, str]], + generate_config: Optional[PytorchGenerateConfig] = None, + ) -> Union[Completion, Iterator[CompletionChunk]]: + code_prompt = self.get_code_prompt( + generate_mode, + prompt, + self.model_spec.code_prompt_style, + suffix, + repo_name, + files, + ) + + if generate_config is not None and generate_config.get("stream", False): + generate_config["stream"] = False + + return self.generate(code_prompt, generate_config) diff --git a/xinference/model/llm/tests/test_utils.py b/xinference/model/llm/tests/test_utils.py index add8552716..3b299d2041 100644 --- a/xinference/model/llm/tests/test_utils.py +++ b/xinference/model/llm/tests/test_utils.py @@ -13,8 +13,9 @@ # limitations under the License. from ....types import ChatCompletionMessage -from ..llm_family import PromptStyleV1 -from ..utils import ChatModelMixin +from ..llm_family import CodePromptStyleV1, FIMSpecV1, PromptStyleV1 +from ..llm_family import RepoLevelCodeCompletionSpecV1 as RepoLevelSpecV1 +from ..utils import ChatModelMixin, CodeModelMixin def test_prompt_style_add_colon_single(): @@ -463,3 +464,232 @@ def test_is_valid_model_name(): assert not is_valid_model_name("foo/bar") assert not is_valid_model_name(" ") assert not is_valid_model_name("") + + +def test_code_prompt_style_starcoder(): + code_prompt_style = CodePromptStyleV1( + style_name="STARCODER", + fim_spec=FIMSpecV1( + style="PSM", + prefix="", + middle="", + suffix="", + ), + ) + prompt = "def print_hello_world():" + expected = prompt + assert expected == CodeModelMixin.get_code_prompt( + "completion", prompt, code_prompt_style + ) + + prompt = "def print_hello_world():\n " + suffix = "\n print('Hello world!')" + expected = "def print_hello_world():\n \n print('Hello world!')" + assert expected == CodeModelMixin.get_code_prompt( + "infill", prompt, code_prompt_style, suffix + ) + + +def test_code_prompt_style_deepseek_coder(): + code_prompt_style = CodePromptStyleV1( + style_name="DEEPSEEK_CODER", + fim_spec=FIMSpecV1( + style="PMS", + prefix="<|fim▁begin|>", + middle="<|fim▁hole|>", + suffix="<|fim▁end|>", + ), + repo_level_spec=RepoLevelSpecV1(file_type="filename", file_separator="#"), + ) + + prompt = "#write a quick sort algorithm" + expected = prompt + + assert expected == CodeModelMixin.get_code_prompt( + "completion", prompt, code_prompt_style + ) + + prompt = """def quick_sort(arr): + if len(arr) <= 1: + return arr + pivot = arr[0] + left = [] + right = [] +""" + suffix = """ + if arr[i] < pivot: + left.append(arr[i]) + else: + right.append(arr[i]) + return quick_sort(left) + [pivot] + quick_sort(right)""" + + expected = """<|fim▁begin|>def quick_sort(arr): + if len(arr) <= 1: + return arr + pivot = arr[0] + left = [] + right = [] +<|fim▁hole|> + if arr[i] < pivot: + left.append(arr[i]) + else: + right.append(arr[i]) + return quick_sort(left) + [pivot] + quick_sort(right)<|fim▁end|>""" + + assert expected == CodeModelMixin.get_code_prompt( + "infill", prompt, code_prompt_style, suffix + ) + + files = { + "utils.py": """import torch +from sklearn import datasets +from sklearn.model_selection import train_test_split +from sklearn.preprocessing import StandardScaler +from sklearn.metrics import accuracy_score + +def load_data(): + iris = datasets.load_iris() + X = iris.data + y = iris.target + + # Standardize the data + scaler = StandardScaler() + X = scaler.fit_transform(X) + + X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42) + + # Convert numpy data to PyTorch tensors + X_train = torch.tensor(X_train, dtype=torch.float32) + X_test = torch.tensor(X_test, dtype=torch.float32) + y_train = torch.tensor(y_train, dtype=torch.int64) + y_test = torch.tensor(y_test, dtype=torch.int64) + + return X_train, X_test, y_train, y_test + +def evaluate_predictions(y_test, y_pred): + return accuracy_score(y_test, y_pred)""", + "model.py": """import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader, TensorDataset + +class IrisClassifier(nn.Module): + def __init__(self): + super(IrisClassifier, self).__init__() + self.fc = nn.Sequential( + nn.Linear(4, 16), + nn.ReLU(), + nn.Linear(16, 3) + ) + + def forward(self, x): + return self.fc(x) + + def train_model(self, X_train, y_train, epochs, lr, batch_size): + criterion = nn.CrossEntropyLoss() + optimizer = optim.Adam(self.parameters(), lr=lr) + + # Create DataLoader for batches + dataset = TensorDataset(X_train, y_train) + dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) + + for epoch in range(epochs): + for batch_X, batch_y in dataloader: + optimizer.zero_grad() + outputs = self(batch_X) + loss = criterion(outputs, batch_y) + loss.backward() + optimizer.step() + + def predict(self, X_test): + with torch.no_grad(): + outputs = self(X_test) + _, predicted = outputs.max(1) + return predicted.numpy()""", + "main.py": """from utils import load_data, evaluate_predictions +from model import IrisClassifier as Classifier + +def main(): + # Model training and evaluation +""", + } + + prompt = "" + + expected = """#utils.py +import torch +from sklearn import datasets +from sklearn.model_selection import train_test_split +from sklearn.preprocessing import StandardScaler +from sklearn.metrics import accuracy_score + +def load_data(): + iris = datasets.load_iris() + X = iris.data + y = iris.target + + # Standardize the data + scaler = StandardScaler() + X = scaler.fit_transform(X) + + X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42) + + # Convert numpy data to PyTorch tensors + X_train = torch.tensor(X_train, dtype=torch.float32) + X_test = torch.tensor(X_test, dtype=torch.float32) + y_train = torch.tensor(y_train, dtype=torch.int64) + y_test = torch.tensor(y_test, dtype=torch.int64) + + return X_train, X_test, y_train, y_test + +def evaluate_predictions(y_test, y_pred): + return accuracy_score(y_test, y_pred) +#model.py +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader, TensorDataset + +class IrisClassifier(nn.Module): + def __init__(self): + super(IrisClassifier, self).__init__() + self.fc = nn.Sequential( + nn.Linear(4, 16), + nn.ReLU(), + nn.Linear(16, 3) + ) + + def forward(self, x): + return self.fc(x) + + def train_model(self, X_train, y_train, epochs, lr, batch_size): + criterion = nn.CrossEntropyLoss() + optimizer = optim.Adam(self.parameters(), lr=lr) + + # Create DataLoader for batches + dataset = TensorDataset(X_train, y_train) + dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) + + for epoch in range(epochs): + for batch_X, batch_y in dataloader: + optimizer.zero_grad() + outputs = self(batch_X) + loss = criterion(outputs, batch_y) + loss.backward() + optimizer.step() + + def predict(self, X_test): + with torch.no_grad(): + outputs = self(X_test) + _, predicted = outputs.max(1) + return predicted.numpy() +#main.py +from utils import load_data, evaluate_predictions +from model import IrisClassifier as Classifier + +def main(): + # Model training and evaluation +""" + assert expected == CodeModelMixin.get_code_prompt( + "completion", prompt, code_prompt_style, None, None, files + ) diff --git a/xinference/model/llm/utils.py b/xinference/model/llm/utils.py index 9683ccb88f..fac6cd1d57 100644 --- a/xinference/model/llm/utils.py +++ b/xinference/model/llm/utils.py @@ -17,23 +17,14 @@ import os import time import uuid -from typing import ( - AsyncGenerator, - Dict, - Iterator, - List, - Literal, - Mapping, - Optional, - Tuple, - cast, -) +from typing import AsyncGenerator, Dict, Iterator, List, Mapping, Optional, Tuple, cast from ...types import ( SPECIAL_TOOL_PROMPT, ChatCompletion, ChatCompletionChunk, ChatCompletionMessage, + CodeGenerateMode, Completion, CompletionChunk, ) @@ -728,14 +719,14 @@ def _tool_calls_completion(cls, model_family, model_uid, c, tools): class CodeModelMixin: @staticmethod def get_code_prompt( - generate_type: Literal["completion", "fim"], + generate_mode: CodeGenerateMode, prompt: str, code_prompt_style: CodePromptStyleV1, - suffix: Optional[str], - repo_name: Optional[str], - files: Optional[Mapping[str, str]], + suffix: Optional[str] = None, + repo_name: Optional[str] = None, + files: Optional[Mapping[str, str]] = None, ) -> str: - if generate_type == "completion": + if generate_mode == "completion": if suffix is not None: logger.warning( "Suffix is only required on generate type is infill, ignored" @@ -746,23 +737,23 @@ def get_code_prompt( spec = code_prompt_style.repo_level_spec if spec is None: - raise ValueError( - "The model does not support repository level code completion" + logger.warning( + "The model does not support repository level code completion, 'repo_name' and 'files' are ignored" ) + return prompt - ret_val = "" - spec = code_prompt_style.repo_level_spec + chunks = [] if spec.repo_name is None: if repo_name is not None: logger.warning( - "repo_name is provided but it will not be used in this model, ignored" + "'repo_name' is provided but it will not be used in this model, ignored" ) else: if repo_name is None: raise ValueError( - "The repo_name is required for repository level code completion for this model" + "The 'repo_name' is required for repository level code completion for this model" ) - ret_val = f"{spec.repo_name}{repo_name}\n" + chunks.append(f"{spec.repo_name}{repo_name}") for filepath, content in files.items(): repo_file = ( @@ -770,16 +761,20 @@ def get_code_prompt( if spec.file_type == "filepath" else CodeModelMixin._path_to_name(filepath) ) - ret_val += f"{spec.file_separator}{repo_file}\n{content}\n" - return ret_val + prompt + chunks.append(f"{spec.file_separator}{repo_file}\n{content}") + + if len(prompt.strip()) > 0: + chunks.append(prompt) + + return "\n".join(chunks) - elif generate_type == "fim": + elif generate_mode == "infill": spec = code_prompt_style.fim_spec if spec is None: - raise ValueError("This model is not support FIM mode generate") + raise ValueError("This model is not support infill mode generate") if suffix is None: - raise ValueError("suffix is required in FIM mode") + raise ValueError("suffix is required in infill mode") if files is not None and len(files) > 0: logger.warning( @@ -792,7 +787,7 @@ def get_code_prompt( return f"{spec.prefix}{prompt}{spec.middle}{suffix}{spec.suffix}" else: - raise ValueError(f"Unsupported generate type: {generate_type}") + raise ValueError(f"Unsupported generate mode: {generate_mode}") @staticmethod def _path_to_name(filepath: str) -> str: diff --git a/xinference/model/llm/vllm/core.py b/xinference/model/llm/vllm/core.py index 2ac004cd7a..c04b155101 100644 --- a/xinference/model/llm/vllm/core.py +++ b/xinference/model/llm/vllm/core.py @@ -23,6 +23,7 @@ Dict, Iterable, List, + Mapping, Optional, TypedDict, Union, @@ -33,6 +34,7 @@ ChatCompletion, ChatCompletionChunk, ChatCompletionMessage, + CodeGenerateMode, Completion, CompletionChoice, CompletionChunk, @@ -42,7 +44,7 @@ ) from .. import LLM, LLMFamilyV1, LLMSpecV1 from ..llm_family import CustomLLMFamilyV1 -from ..utils import ChatModelMixin +from ..utils import ChatModelMixin, CodeModelMixin logger = logging.getLogger(__name__) @@ -505,3 +507,60 @@ async def async_chat( self.model_family, self.model_uid, c, tools ) return self._to_chat_completion(c) + + +class VLLMCodeModel(VLLMModel, CodeModelMixin): + @classmethod + def match( + cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str + ) -> bool: + if XINFERENCE_DISABLE_VLLM: + return False + if llm_spec.model_format not in ["pytorch", "gptq", "awq"]: + return False + if llm_spec.model_format == "pytorch": + if quantization != "none" and not (quantization is None): + return False + if llm_spec.model_format == "awq": + # Currently, only 4-bit weight quantization is supported for AWQ, but got 8 bits. + if "4" not in quantization: + return False + if llm_spec.model_format == "gptq": + if VLLM_INSTALLED and vllm.__version__ >= "0.3.3": + if not any(q in quantization for q in ("3", "4", "8")): + return False + else: + if "4" not in quantization: + return False + if isinstance(llm_family, CustomLLMFamilyV1): + if llm_family.model_family not in VLLM_SUPPORTED_MODELS: + return False + else: + if llm_family.model_name not in VLLM_SUPPORTED_MODELS: + return False + if "code" not in llm_family.model_ability: + return False + return VLLM_INSTALLED + + async def async_code_generate( + self, + generate_mode: CodeGenerateMode, + prompt: str, + suffix: Optional[str], + repo_name: Optional[str], + files: Optional[Mapping[str, str]], + generate_config: Optional[Dict] = None, + ) -> Union[Completion, AsyncGenerator[CompletionChunk, None]]: + code_prompt = self.get_code_prompt( + generate_mode, + prompt, + self.model_spec.code_prompt_style, + suffix, + repo_name, + files, + ) + + if generate_config is not None and generate_config.get("stream", False): + generate_config["stream"] = False + + return await self.async_generate(code_prompt, generate_config) diff --git a/xinference/types.py b/xinference/types.py index e9c71becc2..0e6b745d93 100644 --- a/xinference/types.py +++ b/xinference/types.py @@ -509,14 +509,10 @@ def from_dict(cls, data: Dict): ) -class CreateCodeCompletion(CreateCompletion): - mode: Literal["completion"] = "completion" - +CodeGenerateMode = Literal["completion", "infill"] -class CreateCodeInFill(CreateCompletion): - mode: Literal["infill"] = "infill" - -class CreateCodeRepoLevelCompletion(CreateCompletion): - mode: Literal["repo-level-completion"] = "repo-level-completion" - files: Mapping[str, str] +class CreateCodeCompletion(CreateCompletion): + mode: CodeGenerateMode = "completion" + repo_name: Optional[str] = none_field + files: Optional[Mapping[str, str]] = none_field From ec8c8150c4a34c76efa5f3b43984cbfbab20364f Mon Sep 17 00:00:00 2001 From: Shi Hui Date: Fri, 10 May 2024 09:59:17 +0800 Subject: [PATCH 05/37] correct path to name function to work on windows file system. Add some unit test cases. --- xinference/model/llm/tests/test_utils.py | 45 ++++++++++++++++++++++++ xinference/model/llm/utils.py | 5 ++- 2 files changed, 49 insertions(+), 1 deletion(-) diff --git a/xinference/model/llm/tests/test_utils.py b/xinference/model/llm/tests/test_utils.py index 3b299d2041..671ca10248 100644 --- a/xinference/model/llm/tests/test_utils.py +++ b/xinference/model/llm/tests/test_utils.py @@ -11,6 +11,7 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import pytest from ....types import ChatCompletionMessage from ..llm_family import CodePromptStyleV1, FIMSpecV1, PromptStyleV1 @@ -466,6 +467,26 @@ def test_is_valid_model_name(): assert not is_valid_model_name("") +def test_path_to_name(): + path = "/home/test/works/project/main.py" + assert "main.py" == CodeModelMixin._path_to_name(path) + + path = "/main.py" + assert "main.py" == CodeModelMixin._path_to_name(path) + + path = "main.py" + assert "main.py" == CodeModelMixin._path_to_name(path) + + path = ".main.py" + assert ".main.py" == CodeModelMixin._path_to_name(path) + + path = r"\main.py" + assert "main.py" == CodeModelMixin._path_to_name(path) + + path = r"C:\works\main.py" + assert "main.py" == CodeModelMixin._path_to_name(path) + + def test_code_prompt_style_starcoder(): code_prompt_style = CodePromptStyleV1( style_name="STARCODER", @@ -489,6 +510,17 @@ def test_code_prompt_style_starcoder(): "infill", prompt, code_prompt_style, suffix ) + suffix = None + with pytest.raises(ValueError) as exc_info: + CodeModelMixin.get_code_prompt("infill", prompt, code_prompt_style, suffix) + assert exc_info.value == ValueError("suffix is required in infill mode") + + with pytest.raises(ValueError) as exc_info: + CodeModelMixin.get_code_prompt("test", prompt, code_prompt_style) + assert exc_info.value == ValueError( + "Unsupported generate mode: test, only 'PSM' and 'PMS' are supported now" + ) + def test_code_prompt_style_deepseek_coder(): code_prompt_style = CodePromptStyleV1( @@ -693,3 +725,16 @@ def main(): assert expected == CodeModelMixin.get_code_prompt( "completion", prompt, code_prompt_style, None, None, files ) + + +def test_code_prompt_style_without_fim(): + code_prompt_style = CodePromptStyleV1( + style_name="NO_FIM_CODER", + ) + prompt = "def print_hello_world():\n " + suffix = "\n print('Hello world!')" + with pytest.raises(ValueError) as exc_info: + CodeModelMixin.get_code_prompt("infill", prompt, code_prompt_style, suffix) + assert exc_info.value == ValueError( + "This model is not support infill mode generate" + ) diff --git a/xinference/model/llm/utils.py b/xinference/model/llm/utils.py index fac6cd1d57..cb206fb006 100644 --- a/xinference/model/llm/utils.py +++ b/xinference/model/llm/utils.py @@ -787,10 +787,13 @@ def get_code_prompt( return f"{spec.prefix}{prompt}{spec.middle}{suffix}{spec.suffix}" else: - raise ValueError(f"Unsupported generate mode: {generate_mode}") + raise ValueError( + f"Unsupported generate mode: {generate_mode}, only 'PSM' and 'PMS' are supported now" + ) @staticmethod def _path_to_name(filepath: str) -> str: + filepath = filepath.replace("\\", "/") return os.path.split(filepath)[1] From 38200a385f8b692e0b252e759461ff2af662c6c8 Mon Sep 17 00:00:00 2001 From: Shi Hui Date: Fri, 10 May 2024 15:01:06 +0800 Subject: [PATCH 06/37] make the code completion process works --- xinference/api/restful_api.py | 14 +- xinference/client/restful/restful_client.py | 76 ++++++ xinference/core/model.py | 279 ++++++++++---------- xinference/model/llm/ggml/llamacpp.py | 4 +- xinference/model/llm/llm_family.py | 6 +- xinference/model/llm/pytorch/core.py | 6 +- xinference/model/llm/utils.py | 8 +- xinference/model/llm/vllm/core.py | 6 +- 8 files changed, 234 insertions(+), 165 deletions(-) diff --git a/xinference/api/restful_api.py b/xinference/api/restful_api.py index c01bce70cf..6dc36ba22f 100644 --- a/xinference/api/restful_api.py +++ b/xinference/api/restful_api.py @@ -457,7 +457,7 @@ def serve(self, logging_conf: Optional[dict] = None): ) self._router.add_api_route( - "v1/code/completions", + "/v1/code/completions", self.create_code_completion, methods=["POST"], response_model=Completion, @@ -1422,15 +1422,11 @@ async def stream_results(): async def create_code_completion(self, request: Request) -> Response: json_data = await request.json() - if "mode" not in json_data: - raise HTTPException( - status_code=400, detail="mode is required in code completion request" - ) - if json_data["mode"] not in ("completion", "infill", "repo-level-completion"): + if "mode" in json_data and json_data["mode"] not in ("completion", "infill"): raise HTTPException( status_code=400, - detail="mode must be one of 'completion', 'infill' or 'repo-level-completion", + detail="mode must be one of 'completion' or 'infill'", ) if json_data.get("stream", False): @@ -1438,7 +1434,7 @@ async def create_code_completion(self, request: Request) -> Response: body = CreateCodeCompletion.parse_obj(json_data) exclude = { - "generate_mode", + "mode", "prompt", "suffix", "repo_name", @@ -1477,7 +1473,7 @@ async def create_code_completion(self, request: Request) -> Response: try: data = await model.code_generate( - body.generate_mode, + body.mode, body.prompt, body.suffix, body.repo_name, diff --git a/xinference/client/restful/restful_client.py b/xinference/client/restful/restful_client.py index 87241ae4ff..fe2344b002 100644 --- a/xinference/client/restful/restful_client.py +++ b/xinference/client/restful/restful_client.py @@ -25,6 +25,7 @@ ChatCompletionChunk, ChatCompletionMessage, ChatglmCppGenerateConfig, + CodeGenerateMode, Completion, CompletionChunk, Embedding, @@ -429,6 +430,77 @@ def chat( return response_data +class RESTfulCodeModelHandle(RESTfulGenerateModelHandle): + def code_generate( + self, + mode: "CodeGenerateMode", + prompt: str, + suffix: Optional[str] = None, + repo_name: Optional[str] = None, + files: Optional[typing.Mapping] = None, + generate_config: Optional[ + Union["LlamaCppGenerateConfig", "PytorchGenerateConfig"] + ] = None, + ) -> "Completion": + """ + Given code generation hint to complete the code, the model will return a response via RESTful APIs. + + Parameters + ---------- + mode: Literal["completion", "infill"] + Code Generation mode + Completion includes code fragment completion and repository level code completion + Infill is fill in middle completion, complete the code according provided prefix and suffix content. + prompt: str + The user's input, it presents prefix content in infill mode. + suffix: Optional[str] + The suffix content in infill mode. + repo_name: Optional[str] + The repository name in repository level code completion mode. + files: Optional[Mapping] + The file name/path and its content key values in repository level code completion mode + generate_config: Optional[Union["LlamaCppGenerateConfig", "PytorchGenerateConfig"]] + Additional configuration for the chat generation. + "LlamaCppGenerateConfig" -> configuration for ggml model + "PytorchGenerateConfig" -> configuration for pytorch model + + Returns + ------- + "Completion" + + Raises + ------ + RuntimeError + Report the failure to generate the code from the server. Detailed information provided in error message. + + """ + + url = f"{self._base_url}/v1/code/completions" + + request_body: Dict[str, Any] = { + "model": self._model_uid, + "mode": mode, + "prompt": prompt, + "suffix": suffix, + "repo_name": repo_name, + "files": files, + } + + if generate_config is not None: + for key, value in generate_config.items(): + request_body[key] = value + + response = requests.post(url, json=request_body, headers=self.auth_headers) + + if response.status_code != 200: + raise RuntimeError( + f"Failed to generate code completion, detail: {_get_error_string(response)}" + ) + + response_data = response.json() + return response_data + + class RESTfulChatglmCppChatModelHandle(RESTfulModelHandle): def chat( self, @@ -942,6 +1014,10 @@ def get_model(self, model_uid: str) -> RESTfulModelHandle: return RESTfulChatModelHandle( model_uid, self.base_url, auth_headers=self._headers ) + elif "code" in desc["model_ability"]: + return RESTfulCodeModelHandle( + model_uid, self.base_url, auth_headers=self._headers + ) elif "generate" in desc["model_ability"]: return RESTfulGenerateModelHandle( model_uid, self.base_url, auth_headers=self._headers diff --git a/xinference/core/model.py b/xinference/core/model.py index ed79058644..06b73c7fab 100644 --- a/xinference/core/model.py +++ b/xinference/core/model.py @@ -333,7 +333,7 @@ async def _call_wrapper(self, fn: Callable, *args, **kwargs): if inspect.iscoroutinefunction(fn): ret = await fn(*args, **kwargs) else: - ret = await asyncio.to_thread(fn, *args, **kwargs) + ret = await asyncio.to_thread(fn, *args, **kwargs) # type: ignore[attr-defined] if self._lock is not None and self._current_generator(): raise Exception("Parallel generation is not supported by ggml.") @@ -403,7 +403,7 @@ async def chat(self, prompt: str, *args, **kwargs): @xo.generator async def code_generate( self, - generate_mode: CodeGenerateMode, + mode: CodeGenerateMode, prompt: str, suffix: Optional[str], repo_name: Optional[str], @@ -417,7 +417,7 @@ async def code_generate( if hasattr(self._model, "code_generate"): response = await self._call_wrapper( self._model.code_generate, - generate_mode, + mode, prompt, suffix, repo_name, @@ -429,7 +429,7 @@ async def code_generate( if hasattr(self._model, "async_code_generate"): response = await self._call_wrapper( self._model.async_code_generate, - generate_mode, + mode, prompt, suffix, repo_name, @@ -459,148 +459,145 @@ async def code_generate( prompt_tokens, ) + @log_async(logger=logger) + @request_limit + async def create_embedding(self, input: Union[str, List[str]], *args, **kwargs): + if hasattr(self._model, "create_embedding"): + return await self._call_wrapper( + self._model.create_embedding, input, *args, **kwargs + ) -@log_async(logger=logger) -@request_limit -async def create_embedding(self, input: Union[str, List[str]], *args, **kwargs): - if hasattr(self._model, "create_embedding"): - return await self._call_wrapper( - self._model.create_embedding, input, *args, **kwargs + raise AttributeError( + f"Model {self._model.model_spec} is not for creating embedding." ) - raise AttributeError( - f"Model {self._model.model_spec} is not for creating embedding." - ) - - -@log_async(logger=logger) -@request_limit -async def rerank( - self, - documents: List[str], - query: str, - top_n: Optional[int], - max_chunks_per_doc: Optional[int], - return_documents: Optional[bool], - *args, - **kwargs, -): - if hasattr(self._model, "rerank"): - return await self._call_wrapper( - self._model.rerank, - documents, - query, - top_n, - max_chunks_per_doc, - return_documents, - *args, - **kwargs, - ) - raise AttributeError(f"Model {self._model.model_spec} is not for reranking.") - - -@log_async(logger=logger, args_formatter=lambda _, kwargs: kwargs.pop("audio")) -@request_limit -async def transcriptions( - self, - audio: bytes, - language: Optional[str] = None, - prompt: Optional[str] = None, - response_format: str = "json", - temperature: float = 0, - timestamp_granularities: Optional[List[str]] = None, -): - if hasattr(self._model, "transcriptions"): - return await self._call_wrapper( - self._model.transcriptions, - audio, - language, - prompt, - response_format, - temperature, - timestamp_granularities, - ) - raise AttributeError( - f"Model {self._model.model_spec} is not for creating transcriptions." - ) - - -@log_async(logger=logger, args_formatter=lambda _, kwargs: kwargs.pop("audio")) -@request_limit -async def translations( - self, - audio: bytes, - language: Optional[str] = None, - prompt: Optional[str] = None, - response_format: str = "json", - temperature: float = 0, - timestamp_granularities: Optional[List[str]] = None, -): - if hasattr(self._model, "translations"): - return await self._call_wrapper( - self._model.translations, - audio, - language, - prompt, - response_format, - temperature, - timestamp_granularities, + @log_async(logger=logger) + @request_limit + async def rerank( + self, + documents: List[str], + query: str, + top_n: Optional[int], + max_chunks_per_doc: Optional[int], + return_documents: Optional[bool], + *args, + **kwargs, + ): + if hasattr(self._model, "rerank"): + return await self._call_wrapper( + self._model.rerank, + documents, + query, + top_n, + max_chunks_per_doc, + return_documents, + *args, + **kwargs, + ) + raise AttributeError(f"Model {self._model.model_spec} is not for reranking.") + + @log_async(logger=logger, args_formatter=lambda _, kwargs: kwargs.pop("audio")) + @request_limit + async def transcriptions( + self, + audio: bytes, + language: Optional[str] = None, + prompt: Optional[str] = None, + response_format: str = "json", + temperature: float = 0, + timestamp_granularities: Optional[List[str]] = None, + ): + if hasattr(self._model, "transcriptions"): + return await self._call_wrapper( + self._model.transcriptions, + audio, + language, + prompt, + response_format, + temperature, + timestamp_granularities, + ) + raise AttributeError( + f"Model {self._model.model_spec} is not for creating transcriptions." ) - raise AttributeError( - f"Model {self._model.model_spec} is not for creating translations." - ) - - -@log_async(logger=logger) -@request_limit -async def text_to_image( - self, - prompt: str, - n: int = 1, - size: str = "1024*1024", - response_format: str = "url", - *args, - **kwargs, -): - if hasattr(self._model, "text_to_image"): - return await self._call_wrapper( - self._model.text_to_image, - prompt, - n, - size, - response_format, - *args, - **kwargs, + + @log_async(logger=logger, args_formatter=lambda _, kwargs: kwargs.pop("audio")) + @request_limit + async def translations( + self, + audio: bytes, + language: Optional[str] = None, + prompt: Optional[str] = None, + response_format: str = "json", + temperature: float = 0, + timestamp_granularities: Optional[List[str]] = None, + ): + if hasattr(self._model, "translations"): + return await self._call_wrapper( + self._model.translations, + audio, + language, + prompt, + response_format, + temperature, + timestamp_granularities, + ) + raise AttributeError( + f"Model {self._model.model_spec} is not for creating translations." ) - raise AttributeError(f"Model {self._model.model_spec} is not for creating image.") - - -async def image_to_image( - self, - image: "PIL.Image", - prompt: str, - negative_prompt: str, - n: int = 1, - size: str = "1024*1024", - response_format: str = "url", - *args, - **kwargs, -): - if hasattr(self._model, "image_to_image"): - return await self._call_wrapper( - self._model.image_to_image, - image, - prompt, - negative_prompt, - n, - size, - response_format, - *args, - **kwargs, + + @log_async(logger=logger) + @request_limit + async def text_to_image( + self, + prompt: str, + n: int = 1, + size: str = "1024*1024", + response_format: str = "url", + *args, + **kwargs, + ): + if hasattr(self._model, "text_to_image"): + return await self._call_wrapper( + self._model.text_to_image, + prompt, + n, + size, + response_format, + *args, + **kwargs, + ) + raise AttributeError( + f"Model {self._model.model_spec} is not for creating image." ) - raise AttributeError(f"Model {self._model.model_spec} is not for creating image.") + async def image_to_image( + self, + image: "PIL.Image", + prompt: str, + negative_prompt: str, + n: int = 1, + size: str = "1024*1024", + response_format: str = "url", + *args, + **kwargs, + ): + if hasattr(self._model, "image_to_image"): + return await self._call_wrapper( + self._model.image_to_image, + image, + prompt, + negative_prompt, + n, + size, + response_format, + *args, + **kwargs, + ) + raise AttributeError( + f"Model {self._model.model_spec} is not for creating image." + ) -async def record_metrics(self, name, op, kwargs): - worker_ref = await self._get_worker_ref() - await worker_ref.record_metrics(name, op, kwargs) + async def record_metrics(self, name, op, kwargs): + worker_ref = await self._get_worker_ref() + await worker_ref.record_metrics(name, op, kwargs) diff --git a/xinference/model/llm/ggml/llamacpp.py b/xinference/model/llm/ggml/llamacpp.py index 787bd1c3f8..601e2c48bc 100644 --- a/xinference/model/llm/ggml/llamacpp.py +++ b/xinference/model/llm/ggml/llamacpp.py @@ -332,7 +332,7 @@ def match( return False if "chatglm" in llm_family.model_name: return False - return "code" not in llm_family.model_ability + return "code" in llm_family.model_ability def code_generate( self, @@ -346,7 +346,7 @@ def code_generate( code_prompt = self.get_code_prompt( generate_model, prompt, - self.model_spec.code_prompt_style, + self.model_family.code_prompt_style, suffix, repo_name, files, diff --git a/xinference/model/llm/llm_family.py b/xinference/model/llm/llm_family.py index 9adb11d3e7..c6633f04d8 100644 --- a/xinference/model/llm/llm_family.py +++ b/xinference/model/llm/llm_family.py @@ -126,8 +126,8 @@ class RepoLevelCodeCompletionSpecV1(BaseModel): class CodePromptStyleV1(BaseModel): style_name: str - fim_spec: Optional[FIMSpecV1] - repo_level_spec: Optional[RepoLevelCodeCompletionSpecV1] + fim_spec: Optional["FIMSpecV1"] + repo_level_spec: Optional["RepoLevelCodeCompletionSpecV1"] class LLMFamilyV1(BaseModel): @@ -141,7 +141,7 @@ class LLMFamilyV1(BaseModel): model_family: Optional[str] model_specs: List["LLMSpecV1"] prompt_style: Optional["PromptStyleV1"] - code_prompt_style: Optional[CodePromptStyleV1] + code_prompt_style: Optional["CodePromptStyleV1"] class CustomLLMFamilyV1(LLMFamilyV1): diff --git a/xinference/model/llm/pytorch/core.py b/xinference/model/llm/pytorch/core.py index b51a3f6733..5f252a0939 100644 --- a/xinference/model/llm/pytorch/core.py +++ b/xinference/model/llm/pytorch/core.py @@ -550,7 +550,7 @@ def match( def code_generate( self, - generate_mode: CodeGenerateMode, + mode: CodeGenerateMode, prompt: str, suffix: Optional[str], repo_name: Optional[str], @@ -558,9 +558,9 @@ def code_generate( generate_config: Optional[PytorchGenerateConfig] = None, ) -> Union[Completion, Iterator[CompletionChunk]]: code_prompt = self.get_code_prompt( - generate_mode, + mode, prompt, - self.model_spec.code_prompt_style, + self.model_family.code_prompt_style, suffix, repo_name, files, diff --git a/xinference/model/llm/utils.py b/xinference/model/llm/utils.py index cb206fb006..785f95b40a 100644 --- a/xinference/model/llm/utils.py +++ b/xinference/model/llm/utils.py @@ -719,14 +719,14 @@ def _tool_calls_completion(cls, model_family, model_uid, c, tools): class CodeModelMixin: @staticmethod def get_code_prompt( - generate_mode: CodeGenerateMode, + mode: CodeGenerateMode, prompt: str, code_prompt_style: CodePromptStyleV1, suffix: Optional[str] = None, repo_name: Optional[str] = None, files: Optional[Mapping[str, str]] = None, ) -> str: - if generate_mode == "completion": + if mode == "completion": if suffix is not None: logger.warning( "Suffix is only required on generate type is infill, ignored" @@ -768,7 +768,7 @@ def get_code_prompt( return "\n".join(chunks) - elif generate_mode == "infill": + elif mode == "infill": spec = code_prompt_style.fim_spec if spec is None: raise ValueError("This model is not support infill mode generate") @@ -788,7 +788,7 @@ def get_code_prompt( else: raise ValueError( - f"Unsupported generate mode: {generate_mode}, only 'PSM' and 'PMS' are supported now" + f"Unsupported generate mode: {mode}, only 'PSM' and 'PMS' are supported now" ) @staticmethod diff --git a/xinference/model/llm/vllm/core.py b/xinference/model/llm/vllm/core.py index c04b155101..303ffaa1be 100644 --- a/xinference/model/llm/vllm/core.py +++ b/xinference/model/llm/vllm/core.py @@ -544,7 +544,7 @@ def match( async def async_code_generate( self, - generate_mode: CodeGenerateMode, + mode: CodeGenerateMode, prompt: str, suffix: Optional[str], repo_name: Optional[str], @@ -552,9 +552,9 @@ async def async_code_generate( generate_config: Optional[Dict] = None, ) -> Union[Completion, AsyncGenerator[CompletionChunk, None]]: code_prompt = self.get_code_prompt( - generate_mode, + mode, prompt, - self.model_spec.code_prompt_style, + self.model_family.code_prompt_style, suffix, repo_name, files, From e352804ec96224b64137430fa784249070dc17fe Mon Sep 17 00:00:00 2001 From: Shi Hui Date: Fri, 10 May 2024 17:03:19 +0800 Subject: [PATCH 07/37] add vllm inference engine support for `deepseek-coder-base` code model. --- xinference/model/llm/llm_family.json | 27 ++++++++ .../model/llm/llm_family_modelscope.json | 62 +++++++++++++++++++ xinference/model/llm/vllm/core.py | 7 ++- 3 files changed, 94 insertions(+), 2 deletions(-) diff --git a/xinference/model/llm/llm_family.json b/xinference/model/llm/llm_family.json index bc22aa6ce0..7dcad72f3a 100644 --- a/xinference/model/llm/llm_family.json +++ b/xinference/model/llm/llm_family.json @@ -4532,6 +4532,33 @@ ], "model_id": "TheBloke/deepseek-coder-33B-base-GGUF", "model_file_name_template": "deepseek-coder-33b-base.{quantization}.gguf" + }, + { + "model_format": "awq", + "model_size_in_billions": "1_3", + "quantizations": [ + "Int4" + ], + "model_id": "TheBloke/deepseek-coder-1.3b-base-AWQ", + "model_revision": "ffb66f1a2a194401b4f29025edcd261d7f0a08a7" + }, + { + "model_format": "awq", + "model_size_in_billions": "6_7", + "quantizations": [ + "Int4" + ], + "model_id": "TheBloke/deepseek-coder-6.7B-base-AWQ", + "model_revision": "e3d4bdf39712665f5e9d5c05c9df6f20fe1e2d5a" + }, + { + "model_format": "awq", + "model_size_in_billions": "33", + "quantizations": [ + "Int4" + ], + "model_id": "TheBloke/deepseek-coder-33B-base-AWQ", + "model_revision": "c7edb2d5868d61a5dcf2591933a8992c8cbe3ef4" } ], "code_prompt_style": { diff --git a/xinference/model/llm/llm_family_modelscope.json b/xinference/model/llm/llm_family_modelscope.json index f34d2ad1b7..4c16ef914e 100644 --- a/xinference/model/llm/llm_family_modelscope.json +++ b/xinference/model/llm/llm_family_modelscope.json @@ -2522,6 +2522,68 @@ ] } }, + { + "version": 1, + "context_length": 16384, + "model_name": "deepseek-coder-base", + "model_lang": [ + "en", + "zh" + ], + "model_ability": [ + "generate", + "code" + ], + "model_description": "deepseek-coder-base is pre-trained on project-level code corpus by employing a window size of 16K and a extra fill-in-the-blank task, to support project-level code completion and infilling.", + "model_specs": [ + { + "model_format": "pytorch", + "model_size_in_billions": "1_3", + "quantizations": [ + "4-bit", + "8-bit", + "none" + ], + "model_id": "deepseek-ai/deepseek-coder-1.3b-base", + "model_hub": "modelscope" + }, + { + "model_format": "pytorch", + "model_size_in_billions": "6_7", + "quantizations": [ + "4-bit", + "8-bit", + "none" + ], + "model_id": "deepseek-ai/deepseek-coder-6.7b-base", + "model_hub": "modelscope" + }, + { + "model_format": "pytorch", + "model_size_in_billions": 33, + "quantizations": [ + "4-bit", + "8-bit", + "none" + ], + "model_id": "deepseek-ai/deepseek-coder-33b-base", + "model_hub": "modelscope" + } + ], + "code_prompt_style": { + "style_name": "DEEPSEEK_CODER", + "fim_spec": { + "style": "PMS", + "prefix": "<|fim▁begin|>", + "middle": "<|fim▁hole|>", + "suffix": "<|fim▁end|>" + }, + "repo_level_spec": { + "file_type": "filename", + "file_separator": "#" + } + } + }, { "version": 1, "context_length": 4096, diff --git a/xinference/model/llm/vllm/core.py b/xinference/model/llm/vllm/core.py index 303ffaa1be..9f3d0d4a10 100644 --- a/xinference/model/llm/vllm/core.py +++ b/xinference/model/llm/vllm/core.py @@ -119,6 +119,9 @@ class VLLMGenerateConfig(TypedDict, total=False): "deepseek-chat", "deepseek-coder-instruct", ] +VLLM_SUPPORTED_CODE_MODELS = [ + "deepseek-coder-base", +] if VLLM_INSTALLED and vllm.__version__ >= "0.3.0": VLLM_SUPPORTED_CHAT_MODELS.append("qwen1.5-chat") VLLM_SUPPORTED_CHAT_MODELS.append("codeqwen1.5-chat") @@ -533,10 +536,10 @@ def match( if "4" not in quantization: return False if isinstance(llm_family, CustomLLMFamilyV1): - if llm_family.model_family not in VLLM_SUPPORTED_MODELS: + if llm_family.model_family not in VLLM_SUPPORTED_CODE_MODELS: return False else: - if llm_family.model_name not in VLLM_SUPPORTED_MODELS: + if llm_family.model_name not in VLLM_SUPPORTED_CODE_MODELS: return False if "code" not in llm_family.model_ability: return False From 2093b658890e151d67a6fa1392793edc6ea42d5e Mon Sep 17 00:00:00 2001 From: Shi Hui Date: Fri, 10 May 2024 17:30:26 +0800 Subject: [PATCH 08/37] Show code icon for code models in model cards. --- .../web/ui/src/scenes/launch_model/modelCard.js | 11 +++++++++++ 1 file changed, 11 insertions(+) diff --git a/xinference/web/ui/src/scenes/launch_model/modelCard.js b/xinference/web/ui/src/scenes/launch_model/modelCard.js index f191b03771..66c4ceaac6 100644 --- a/xinference/web/ui/src/scenes/launch_model/modelCard.js +++ b/xinference/web/ui/src/scenes/launch_model/modelCard.js @@ -4,6 +4,7 @@ import { ExpandLess, ExpandMore, HelpCenterOutlined, + LogoDevOutlined, RocketLaunchOutlined, UndoOutlined, } from '@mui/icons-material' @@ -473,6 +474,16 @@ const ModelCard = ({ chat model ) + } else if ( + modelData.model_ability && + modelData.model_ability.includes("code") + ) { + return ( +
+ + code model +
+ ) } else if ( modelData.model_ability && modelData.model_ability.includes('generate') From f353d35b24a30d269548d3a1a5b36c230fe46f7a Mon Sep 17 00:00:00 2001 From: Shi Hui Date: Fri, 10 May 2024 18:21:55 +0800 Subject: [PATCH 09/37] add client test for code completions. --- xinference/core/tests/test_restful_api.py | 79 +++++++++++++++++++++++ 1 file changed, 79 insertions(+) diff --git a/xinference/core/tests/test_restful_api.py b/xinference/core/tests/test_restful_api.py index 608b1cb9d5..98dcb0f886 100644 --- a/xinference/core/tests/test_restful_api.py +++ b/xinference/core/tests/test_restful_api.py @@ -1266,3 +1266,82 @@ def test_cluster_info(setup): assert result[1]["node_type"] == "Worker" assert result[1]["gpu_count"] == 0 assert result[1]["gpu_vram_total"] == 0 + + +def test_restful_api_for_code_completions(setup): + model_name = "deepseek-coder-base" + + endpoint, _ = setup + url = f"{endpoint}/v1/models" + + # list + response = requests.get(url) + response_data = response.json() + assert len(response_data["data"]) == 0 + + # launch + payload = { + "model_uid": "deepseek-coder-base", + "model_name": model_name, + "model_type": "LLM", + "model_engine": "llama.cpp", + "model_size_in_billions": "1_3", + "quantization": "q4_k_m", + } + + response = requests.post(url, json=payload) + response_data = response.json() + model_uid_res = response_data["model_uid"] + assert model_uid_res == "deepseek-coder-base" + + response = requests.get(url) + response_data = response.json() + assert len(response_data["data"]) == 1 + + # test embedding + url = f"{endpoint}/v1/code/completions" + payload = { + "model": "deepseek-coder-base", + "prompt": "#write a quick sort algorithm", + "max_tokens": 4096, + } + response = requests.post(url, json=payload) + coding_res = response.json() + + assert len(coding_res["choices"]) == 1 + assert "text" in coding_res["choices"][0] + assert coding_res["choices"][0]["finish_reason"] == "stop" + + # test multiple + payload = { + "model": "deepseek-coder-base", + "mode": "infill", + "prompt": """def quick_sort(arr): + if len(arr) <= 1: + return arr + pivot = arr[0] + left = [] + right = [] +""", + "suffix": """ if arr[i] < pivot: + left.append(arr[i]) + else: + right.append(arr[i]) + return quick_sort(left) + [pivot] + quick_sort(right)""", + "max_tokens": 4096, + } + response = requests.post(url, json=payload) + coding_res = response.json() + + assert len(coding_res["choices"]) == 1 + assert "text" in coding_res["choices"][0] + assert coding_res["choices"][0]["finish_reason"] == "stop" + + # delete model + url = f"{endpoint}/v1/models/deepseek-coder-base" + response = requests.delete(url) + assert response.status_code == 200 + + response = requests.get(f"{endpoint}/v1/models") + response_data = response.json() + assert len(response_data["data"]) == 0 From 0e449e0484687ddf708cfa6ee2baa49dfd7ccac5 Mon Sep 17 00:00:00 2001 From: Shi Hui Date: Fri, 10 May 2024 20:47:11 +0800 Subject: [PATCH 10/37] check whether the code_prompt_style is None. --- xinference/model/llm/utils.py | 7 ++++++- 1 file changed, 6 insertions(+), 1 deletion(-) diff --git a/xinference/model/llm/utils.py b/xinference/model/llm/utils.py index 785f95b40a..9671982223 100644 --- a/xinference/model/llm/utils.py +++ b/xinference/model/llm/utils.py @@ -721,11 +721,16 @@ class CodeModelMixin: def get_code_prompt( mode: CodeGenerateMode, prompt: str, - code_prompt_style: CodePromptStyleV1, + code_prompt_style: Optional[CodePromptStyleV1], suffix: Optional[str] = None, repo_name: Optional[str] = None, files: Optional[Mapping[str, str]] = None, ) -> str: + if code_prompt_style is None: + raise ValueError( + "code prompt style is not provided, the model spec is wrong." + ) + if mode == "completion": if suffix is not None: logger.warning( From 4bbbb152e42560fa90721aca7d0d0a52fd8f2d88 Mon Sep 17 00:00:00 2001 From: Shi Hui Date: Fri, 10 May 2024 20:52:31 +0800 Subject: [PATCH 11/37] format the code by prettier. --- xinference/web/ui/src/scenes/launch_model/modelCard.js | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xinference/web/ui/src/scenes/launch_model/modelCard.js b/xinference/web/ui/src/scenes/launch_model/modelCard.js index 66c4ceaac6..2a754d3fa2 100644 --- a/xinference/web/ui/src/scenes/launch_model/modelCard.js +++ b/xinference/web/ui/src/scenes/launch_model/modelCard.js @@ -476,11 +476,11 @@ const ModelCard = ({ ) } else if ( modelData.model_ability && - modelData.model_ability.includes("code") + modelData.model_ability.includes('code') ) { return (
- + code model
) From 62c8dd9fc3d0b4d1394353448a84388e0edcc21d Mon Sep 17 00:00:00 2001 From: Shi Hui Date: Sat, 11 May 2024 16:17:59 +0800 Subject: [PATCH 12/37] add the function to generate prompt for code completion. --- xinference/api/restful_api.py | 50 +++++++++++ xinference/client/restful/restful_client.py | 59 ++++++++++++ xinference/core/model.py | 27 ++++++ xinference/core/tests/test_restful_api.py | 99 ++++++++++++++++++++- xinference/model/llm/ggml/llamacpp.py | 1 - xinference/model/llm/pytorch/core.py | 1 - xinference/model/llm/tests/test_utils.py | 16 ++-- xinference/model/llm/utils.py | 20 ++++- xinference/model/llm/vllm/core.py | 1 - 9 files changed, 259 insertions(+), 15 deletions(-) diff --git a/xinference/api/restful_api.py b/xinference/api/restful_api.py index 6dc36ba22f..86bfff25b8 100644 --- a/xinference/api/restful_api.py +++ b/xinference/api/restful_api.py @@ -468,6 +468,17 @@ def serve(self, logging_conf: Optional[dict] = None): ), ) + self._router.add_api_route( + "/v1/code/prompt", + self.get_code_prompt, + methods=["POST"], + dependencies=( + [Security(self._auth_service, scopes=["models:read"])] + if self.is_authenticated() + else None + ), + ) + # for custom models self._router.add_api_route( "/v1/model_registrations/{model_type}", @@ -1487,6 +1498,45 @@ async def create_code_completion(self, request: Request) -> Response: self.handle_request_limit_error(e) raise HTTPException(status_code=500, detail=str(e)) + async def get_code_prompt(self, request: Request) -> Response: + json_data = await request.json() + + if "mode" in json_data and json_data["mode"] not in ("completion", "infill"): + raise HTTPException( + status_code=400, + detail="mode must be one of 'completion' or 'infill'", + ) + + body = CreateCodeCompletion.parse_obj(json_data) + + model_uid = body.model + + try: + model = await (await self._get_supervisor_ref()).get_model(model_uid) + except ValueError as ve: + logger.error(str(ve), exc_info=True) + await self._report_error_event(model_uid, str(ve)) + raise HTTPException(status_code=400, detail=str(ve)) + except Exception as e: + logger.error(e, exc_info=True) + await self._report_error_event(model_uid, str(e)) + raise HTTPException(status_code=500, detail=str(e)) + + try: + code_prompt = await model.get_code_prompt( + body.mode, + body.prompt, + body.suffix, + body.repo_name, + body.files, + ) + return Response(content=code_prompt, media_type="application/json") + except Exception as e: + logger.error(e, exc_info=True) + await self._report_error_event(model_uid, str(e)) + self.handle_request_limit_error(e) + raise HTTPException(status_code=500, detail=str(e)) + async def query_engines_by_model_name(self, model_name: str) -> JSONResponse: try: content = await ( diff --git a/xinference/client/restful/restful_client.py b/xinference/client/restful/restful_client.py index fe2344b002..4919ea34b7 100644 --- a/xinference/client/restful/restful_client.py +++ b/xinference/client/restful/restful_client.py @@ -500,6 +500,65 @@ def code_generate( response_data = response.json() return response_data + def get_code_prompt( + self, + mode: "CodeGenerateMode", + prompt: str, + suffix: Optional[str] = None, + repo_name: Optional[str] = None, + files: Optional[typing.Mapping] = None, + ) -> str: + """ + Given code generating prompt which can be used to complete the code, the model will return a response via + RESTful APIs. + + Parameters + ---------- + mode: Literal["completion", "infill"] + Code Generation mode + Completion includes code fragment completion and repository level code completion + Infill is fill in middle completion, complete the code according provided prefix and suffix content. + prompt: str + The user's input, it presents prefix content in infill mode. + suffix: Optional[str] + The suffix content in infill mode. + repo_name: Optional[str] + The repository name in repository level code completion mode. + files: Optional[Mapping] + The file name/path and its content key values in repository level code completion mode + + Returns + ------- + str: generated prompt for code generating + + Raises + ------ + RuntimeError + Report the failure to generate the code prompt from the server. + Detailed information provided in error message. + + """ + + url = f"{self._base_url}/v1/code/prompt" + + request_body: Dict[str, Any] = { + "model": self._model_uid, + "mode": mode, + "prompt": prompt, + "suffix": suffix, + "repo_name": repo_name, + "files": files, + } + + response = requests.post(url, json=request_body, headers=self.auth_headers) + + if response.status_code != 200: + raise RuntimeError( + f"Failed to generate code prompt generating, detail: {_get_error_string(response)}" + ) + + return response.text + class RESTfulChatglmCppChatModelHandle(RESTfulModelHandle): def chat( diff --git a/xinference/core/model.py b/xinference/core/model.py index 06b73c7fab..c68565d0da 100644 --- a/xinference/core/model.py +++ b/xinference/core/model.py @@ -459,6 +459,33 @@ async def code_generate( prompt_tokens, ) + @log_async(logger=logger) + @request_limit + @xo.generator + async def get_code_prompt( + self, + mode: CodeGenerateMode, + prompt: str, + suffix: Optional[str], + repo_name: Optional[str], + files: Optional[Mapping[str, str]], + ): + from ..model.llm.utils import CodeModelMixin + + if isinstance(self._model, CodeModelMixin): + return await self._call_wrapper( + self._model.get_code_prompt, + mode, + prompt, + suffix, + repo_name, + files, + ) + else: + raise ValueError( + f"Model {self._model.model_family.model_name} does not support code generating" + ) + @log_async(logger=logger) @request_limit async def create_embedding(self, input: Union[str, List[str]], *args, **kwargs): diff --git a/xinference/core/tests/test_restful_api.py b/xinference/core/tests/test_restful_api.py index 98dcb0f886..56ce5d7984 100644 --- a/xinference/core/tests/test_restful_api.py +++ b/xinference/core/tests/test_restful_api.py @@ -488,7 +488,10 @@ def test_restful_api_for_tool_calls(setup, model_format, quantization): payload = { "model": model_uid_res, "messages": [ - {"role": "system", "content": "你是一个有用的助手。不要对要函数调用的值做出假设。"}, + { + "role": "system", + "content": "你是一个有用的助手。不要对要函数调用的值做出假设。", + }, {"role": "user", "content": "上海现在的天气怎么样?"}, ], "temperature": 0.7, @@ -1268,6 +1271,97 @@ def test_cluster_info(setup): assert result[1]["gpu_vram_total"] == 0 +def test_restfule_api_for_code_prompt(setup): + model_name = "deepseek-coder-base" + + endpoint, _ = setup + url = f"{endpoint}/v1/models" + + # list + response = requests.get(url) + response_data = response.json() + assert len(response_data["data"]) == 0 + + # launch + payload = { + "model_uid": "deepseek-coder-base", + "model_name": model_name, + "model_type": "LLM", + "model_engine": "llama.cpp", + "model_size_in_billions": "1_3", + "quantization": "q4_k_m", + } + + response = requests.post(url, json=payload) + response_data = response.json() + model_uid_res = response_data["model_uid"] + assert model_uid_res == "deepseek-coder-base" + + response = requests.get(url) + response_data = response.json() + assert len(response_data["data"]) == 1 + + # test embedding + url = f"{endpoint}/v1/code/prompt" + payload = { + "model": "deepseek-coder-base", + "prompt": "#write a quick sort algorithm", + } + response = requests.post(url, json=payload) + coding_res = response.json() + + assert "prompt" in coding_res + + assert "#write a quick sort algorithm" == coding_res["prompt"] + + # test multiple + payload = { + "model": "deepseek-coder-base", + "mode": "infill", + "prompt": """def quick_sort(arr): + if len(arr) <= 1: + return arr + pivot = arr[0] + left = [] + right = [] +""", + "suffix": """ + if arr[i] < pivot: + left.append(arr[i]) + else: + right.append(arr[i]) + return quick_sort(left) + [pivot] + quick_sort(right)""", + } + response = requests.post(url, json=payload) + coding_res = response.json() + + assert "prompt" in coding_res + assert ( + coding_res["prompt"] + == """<|fim▁begin|>def quick_sort(arr): + if len(arr) <= 1: + return arr + pivot = arr[0] + left = [] + right = [] +<|fim▁hole|> + if arr[i] < pivot: + left.append(arr[i]) + else: + right.append(arr[i]) + return quick_sort(left) + [pivot] + quick_sort(right)<|fim▁end|>""" + ) + + # delete model + url = f"{endpoint}/v1/models/deepseek-coder-base" + response = requests.delete(url) + assert response.status_code == 200 + + response = requests.get(f"{endpoint}/v1/models") + response_data = response.json() + assert len(response_data["data"]) == 0 + + def test_restful_api_for_code_completions(setup): model_name = "deepseek-coder-base" @@ -1323,7 +1417,8 @@ def test_restful_api_for_code_completions(setup): left = [] right = [] """, - "suffix": """ if arr[i] < pivot: + "suffix": """ + if arr[i] < pivot: left.append(arr[i]) else: right.append(arr[i]) diff --git a/xinference/model/llm/ggml/llamacpp.py b/xinference/model/llm/ggml/llamacpp.py index 601e2c48bc..218f082f25 100644 --- a/xinference/model/llm/ggml/llamacpp.py +++ b/xinference/model/llm/ggml/llamacpp.py @@ -346,7 +346,6 @@ def code_generate( code_prompt = self.get_code_prompt( generate_model, prompt, - self.model_family.code_prompt_style, suffix, repo_name, files, diff --git a/xinference/model/llm/pytorch/core.py b/xinference/model/llm/pytorch/core.py index 5f252a0939..726f651773 100644 --- a/xinference/model/llm/pytorch/core.py +++ b/xinference/model/llm/pytorch/core.py @@ -560,7 +560,6 @@ def code_generate( code_prompt = self.get_code_prompt( mode, prompt, - self.model_family.code_prompt_style, suffix, repo_name, files, diff --git a/xinference/model/llm/tests/test_utils.py b/xinference/model/llm/tests/test_utils.py index 671ca10248..cc912a67fa 100644 --- a/xinference/model/llm/tests/test_utils.py +++ b/xinference/model/llm/tests/test_utils.py @@ -499,24 +499,24 @@ def test_code_prompt_style_starcoder(): ) prompt = "def print_hello_world():" expected = prompt - assert expected == CodeModelMixin.get_code_prompt( + assert expected == CodeModelMixin._get_code_prompt( "completion", prompt, code_prompt_style ) prompt = "def print_hello_world():\n " suffix = "\n print('Hello world!')" expected = "def print_hello_world():\n \n print('Hello world!')" - assert expected == CodeModelMixin.get_code_prompt( + assert expected == CodeModelMixin._get_code_prompt( "infill", prompt, code_prompt_style, suffix ) suffix = None with pytest.raises(ValueError) as exc_info: - CodeModelMixin.get_code_prompt("infill", prompt, code_prompt_style, suffix) + CodeModelMixin._get_code_prompt("infill", prompt, code_prompt_style, suffix) assert exc_info.value == ValueError("suffix is required in infill mode") with pytest.raises(ValueError) as exc_info: - CodeModelMixin.get_code_prompt("test", prompt, code_prompt_style) + CodeModelMixin._get_code_prompt("test", prompt, code_prompt_style) assert exc_info.value == ValueError( "Unsupported generate mode: test, only 'PSM' and 'PMS' are supported now" ) @@ -537,7 +537,7 @@ def test_code_prompt_style_deepseek_coder(): prompt = "#write a quick sort algorithm" expected = prompt - assert expected == CodeModelMixin.get_code_prompt( + assert expected == CodeModelMixin._get_code_prompt( "completion", prompt, code_prompt_style ) @@ -568,7 +568,7 @@ def test_code_prompt_style_deepseek_coder(): right.append(arr[i]) return quick_sort(left) + [pivot] + quick_sort(right)<|fim▁end|>""" - assert expected == CodeModelMixin.get_code_prompt( + assert expected == CodeModelMixin._get_code_prompt( "infill", prompt, code_prompt_style, suffix ) @@ -722,7 +722,7 @@ def predict(self, X_test): def main(): # Model training and evaluation """ - assert expected == CodeModelMixin.get_code_prompt( + assert expected == CodeModelMixin._get_code_prompt( "completion", prompt, code_prompt_style, None, None, files ) @@ -734,7 +734,7 @@ def test_code_prompt_style_without_fim(): prompt = "def print_hello_world():\n " suffix = "\n print('Hello world!')" with pytest.raises(ValueError) as exc_info: - CodeModelMixin.get_code_prompt("infill", prompt, code_prompt_style, suffix) + CodeModelMixin._get_code_prompt("infill", prompt, code_prompt_style, suffix) assert exc_info.value == ValueError( "This model is not support infill mode generate" ) diff --git a/xinference/model/llm/utils.py b/xinference/model/llm/utils.py index 9671982223..5bd3b61dc7 100644 --- a/xinference/model/llm/utils.py +++ b/xinference/model/llm/utils.py @@ -28,6 +28,7 @@ Completion, CompletionChunk, ) +from .core import LLM from .llm_family import ( CodePromptStyleV1, GgmlLLMSpecV1, @@ -717,11 +718,26 @@ def _tool_calls_completion(cls, model_family, model_uid, c, tools): class CodeModelMixin: - @staticmethod def get_code_prompt( + self: "LLM", + mode: CodeGenerateMode, + prompt: str, + suffix: Optional[str] = None, + repo_name: Optional[str] = None, + files: Optional[Mapping[str, str]] = None, + ): + code_prompt_style = self.model_family.code_prompt_style + return { + "prompt": CodeModelMixin._get_code_prompt( + mode, prompt, code_prompt_style, suffix, repo_name, files + ) + } + + @staticmethod + def _get_code_prompt( mode: CodeGenerateMode, prompt: str, - code_prompt_style: Optional[CodePromptStyleV1], + code_prompt_style: Optional["CodePromptStyleV1"], suffix: Optional[str] = None, repo_name: Optional[str] = None, files: Optional[Mapping[str, str]] = None, diff --git a/xinference/model/llm/vllm/core.py b/xinference/model/llm/vllm/core.py index 9f3d0d4a10..b9a56f1198 100644 --- a/xinference/model/llm/vllm/core.py +++ b/xinference/model/llm/vllm/core.py @@ -557,7 +557,6 @@ async def async_code_generate( code_prompt = self.get_code_prompt( mode, prompt, - self.model_family.code_prompt_style, suffix, repo_name, files, From d38830322a534a346baa0ff2e40d2ba12128ec66 Mon Sep 17 00:00:00 2001 From: Shi Hui Date: Sat, 11 May 2024 17:01:38 +0800 Subject: [PATCH 13/37] fix the bug that cannot get generated prompt correctly. --- xinference/model/llm/ggml/llamacpp.py | 2 +- xinference/model/llm/pytorch/core.py | 2 +- xinference/model/llm/vllm/core.py | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/xinference/model/llm/ggml/llamacpp.py b/xinference/model/llm/ggml/llamacpp.py index 218f082f25..126aae9fa4 100644 --- a/xinference/model/llm/ggml/llamacpp.py +++ b/xinference/model/llm/ggml/llamacpp.py @@ -349,7 +349,7 @@ def code_generate( suffix, repo_name, files, - ) + )["prompt"] if generate_config is not None and generate_config.get("stream", False): generate_config["stream"] = False diff --git a/xinference/model/llm/pytorch/core.py b/xinference/model/llm/pytorch/core.py index 726f651773..2032c6b7b5 100644 --- a/xinference/model/llm/pytorch/core.py +++ b/xinference/model/llm/pytorch/core.py @@ -563,7 +563,7 @@ def code_generate( suffix, repo_name, files, - ) + )["prompt"] if generate_config is not None and generate_config.get("stream", False): generate_config["stream"] = False diff --git a/xinference/model/llm/vllm/core.py b/xinference/model/llm/vllm/core.py index b9a56f1198..5e33579ec3 100644 --- a/xinference/model/llm/vllm/core.py +++ b/xinference/model/llm/vllm/core.py @@ -560,7 +560,7 @@ async def async_code_generate( suffix, repo_name, files, - ) + )["prompt"] if generate_config is not None and generate_config.get("stream", False): generate_config["stream"] = False From 51390bd5cc636cfd792cde3cefafc492a17707f4 Mon Sep 17 00:00:00 2001 From: Shi Hui Date: Sat, 11 May 2024 18:13:13 +0800 Subject: [PATCH 14/37] fix the bug that cannot get generated prompt correctly. --- xinference/client/restful/restful_client.py | 4 ++-- xinference/model/llm/utils.py | 4 ++-- 2 files changed, 4 insertions(+), 4 deletions(-) diff --git a/xinference/client/restful/restful_client.py b/xinference/client/restful/restful_client.py index 4919ea34b7..aedb76b3e8 100644 --- a/xinference/client/restful/restful_client.py +++ b/xinference/client/restful/restful_client.py @@ -529,7 +529,7 @@ def get_code_prompt( Returns ------- - str: generated prompt for code generating + {"prompt": "generated prompt"} Raises ------ @@ -557,7 +557,7 @@ def get_code_prompt( f"Failed to generate code prompt generating, detail: {_get_error_string(response)}" ) - return response.text + return response.json() class RESTfulChatglmCppChatModelHandle(RESTfulModelHandle): diff --git a/xinference/model/llm/utils.py b/xinference/model/llm/utils.py index 5bd3b61dc7..317ff502be 100644 --- a/xinference/model/llm/utils.py +++ b/xinference/model/llm/utils.py @@ -719,14 +719,14 @@ def _tool_calls_completion(cls, model_family, model_uid, c, tools): class CodeModelMixin: def get_code_prompt( - self: "LLM", + self, mode: CodeGenerateMode, prompt: str, suffix: Optional[str] = None, repo_name: Optional[str] = None, files: Optional[Mapping[str, str]] = None, ): - code_prompt_style = self.model_family.code_prompt_style + code_prompt_style = cast(LLM, self).model_family.code_prompt_style return { "prompt": CodeModelMixin._get_code_prompt( mode, prompt, code_prompt_style, suffix, repo_name, files From ec28dec9f898228f89ae3cfc42fa4768c0c3f54a Mon Sep 17 00:00:00 2001 From: Shi Hui Date: Mon, 13 May 2024 09:53:11 +0800 Subject: [PATCH 15/37] adjust the test result to make unit test pass. --- xinference/model/llm/tests/test_llm_family.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xinference/model/llm/tests/test_llm_family.py b/xinference/model/llm/tests/test_llm_family.py index ddd0274321..20c40ee9ed 100644 --- a/xinference/model/llm/tests/test_llm_family.py +++ b/xinference/model/llm/tests/test_llm_family.py @@ -160,7 +160,7 @@ def test_serialize_llm_family_v1(): prompt_style=prompt_style, ) - expected = """{"version": 1, "context_length": 2048, "model_name": "TestModel", "model_lang": ["en"], "model_ability": ["embed", "generate"], "model_description": null, "model_family": null, "model_specs": [{"model_format": "ggmlv3", "model_hub": "huggingface", "model_size_in_billions": 2, "quantizations": ["q4_0", "q4_1"], "quantization_parts": {"q4_2": ["a", "b"]}, "model_id": "example/TestModel", "model_revision": "123", "model_file_name_template": "TestModel.{quantization}.ggmlv3.bin", "model_file_name_split_template": "TestModel.{quantization}.ggmlv3.bin.{part}", "model_uri": null}, {"model_format": "pytorch", "model_hub": "huggingface", "model_size_in_billions": 3, "quantizations": ["int8", "int4", "none"], "model_id": "example/TestModel", "model_revision": "456", "model_uri": null}], "prompt_style": {"style_name": "ADD_COLON_SINGLE", "system_prompt": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.", "roles": ["user", "assistant"], "intra_message_sep": "\\n### ", "inter_message_sep": "\\n### ", "stop": null, "stop_token_ids": null}}""" + expected = """{"version": 1, "context_length": 2048, "model_name": "TestModel", "model_lang": ["en"], "model_ability": ["embed", "generate"], "model_description": null, "model_family": null, "model_specs": [{"model_format": "ggmlv3", "model_hub": "huggingface", "model_size_in_billions": 2, "quantizations": ["q4_0", "q4_1"], "quantization_parts": {"q4_2": ["a", "b"]}, "model_id": "example/TestModel", "model_revision": "123", "model_file_name_template": "TestModel.{quantization}.ggmlv3.bin", "model_file_name_split_template": "TestModel.{quantization}.ggmlv3.bin.{part}", "model_uri": null}, {"model_format": "pytorch", "model_hub": "huggingface", "model_size_in_billions": 3, "quantizations": ["int8", "int4", "none"], "model_id": "example/TestModel", "model_revision": "456", "model_uri": null}], "prompt_style": {"style_name": "ADD_COLON_SINGLE", "system_prompt": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.", "roles": ["user", "assistant"], "intra_message_sep": "\\n### ", "inter_message_sep": "\\n### ", "stop": null, "stop_token_ids": null}, "code_prompt_style": null}""" assert json.loads(llm_family.json()) == json.loads(expected) llm_family_context_length = LLMFamilyV1( From 10d90b0d3447f40617d4a6f4bfec44471b1349bc Mon Sep 17 00:00:00 2001 From: Shi Hui Date: Tue, 14 May 2024 14:06:11 +0800 Subject: [PATCH 16/37] basically finished code generating web client. Need to add repo_name support. --- xinference/api/restful_api.py | 4 + xinference/client/restful/restful_client.py | 30 ++ xinference/core/chat_interface.py | 400 +++++++++++++++++- .../web/ui/src/scenes/running_models/index.js | 129 +++--- 4 files changed, 513 insertions(+), 50 deletions(-) diff --git a/xinference/api/restful_api.py b/xinference/api/restful_api.py index 510a7be87c..134d3097f9 100644 --- a/xinference/api/restful_api.py +++ b/xinference/api/restful_api.py @@ -138,6 +138,8 @@ class BuildGradioInterfaceRequest(BaseModel): model_ability: List[str] model_description: str model_lang: List[str] + infill_supported: Optional[bool] + repo_level_supported: Optional[bool] class BuildGradioImageInterfaceRequest(BaseModel): @@ -860,6 +862,8 @@ async def build_gradio_interface( model_description=body.model_description, model_lang=body.model_lang, access_token=access_token, + infill_supported=body.infill_supported, + repo_level_supported=body.repo_level_supported, ).build() gr.mount_gradio_app(self._app, interface, f"/{model_uid}") except ValueError as ve: diff --git a/xinference/client/restful/restful_client.py b/xinference/client/restful/restful_client.py index aedb76b3e8..b31def8d35 100644 --- a/xinference/client/restful/restful_client.py +++ b/xinference/client/restful/restful_client.py @@ -1287,3 +1287,33 @@ def query_engine_by_model_name(self, model_name: str): response_data = response.json() return response_data + + def list_builtin_prompts(self): + """ + Get the builtin prompts + :return: List[Dict[str, Any]] + The builtin prompts + """ + url = f"{self.base_url}/v1/models/prompts" + response = requests.get(url, headers=self._headers) + if response.status_code != 200: + raise RuntimeError( + f"Failed to get builtin prompts, details: {_get_error_string(response)}" + ) + response_data = response.json() + return response_data + + def list_builtin_code_prompts(self): + """ + Get the builtin code prompts + :return: List[Dict[str, Any]] + The builtin code prompts + """ + url = f"{self.base_url}/v1/models/code_prompts" + response = requests.get(url, headers=self._headers) + if response.status_code != 200: + raise RuntimeError( + f"Failed to get builtin code prompts, details: {_get_error_string(response)}" + ) + response_data = response.json() + return response_data diff --git a/xinference/core/chat_interface.py b/xinference/core/chat_interface.py index 0bdc8b7fb3..f82eaf56b9 100644 --- a/xinference/core/chat_interface.py +++ b/xinference/core/chat_interface.py @@ -20,12 +20,13 @@ import gradio as gr import PIL.Image -from gradio.components import Markdown, Textbox +from gradio.components import Code, Dropdown, File, Markdown, Textbox from gradio.layouts import Accordion, Column, Row from ..client.restful.restful_client import ( RESTfulChatglmCppChatModelHandle, RESTfulChatModelHandle, + RESTfulCodeModelHandle, RESTfulGenerateModelHandle, ) from ..types import ChatCompletionMessage @@ -33,6 +34,30 @@ logger = logging.getLogger(__name__) +def compare_history(current, hist): + if current["mode"] != hist["mode"]: + return False + + if current["prompt"] != hist["prompt"]: + return False + + if current["suffix"] != hist["suffix"]: + return False + + if current["files"] != current["files"]: + return False + + return True + + +EMPTY = { + "mode": "Code Completion", + "prompt": "", + "suffix": "", + "files": None, +} + + class GradioInterface: def __init__( self, @@ -48,6 +73,8 @@ def __init__( model_description: str, model_lang: List[str], access_token: Optional[str], + infill_supported: Optional[bool], + repo_level_supported: Optional[bool], ): self.endpoint = endpoint self.model_uid = model_uid @@ -63,12 +90,16 @@ def __init__( self._access_token = ( access_token.replace("Bearer ", "") if access_token is not None else None ) + self.infill_supported = infill_supported + self.repo_level_supported = repo_level_supported def build(self) -> "gr.Blocks": if "vision" in self.model_ability: interface = self.build_chat_vl_interface() elif "chat" in self.model_ability: interface = self.build_chat_interface() + elif "code" in self.model_ability: + interface = self.build_code_generate_interface() else: interface = self.build_generate_interface() @@ -305,6 +336,373 @@ def update_button(text): return chat_vl_interface + def build_code_generate_interface( + self, + ): + def undo(g_mode, text, g_suffix, g_files, hist): + current = { + "mode": g_mode, + "prompt": text, + "suffix": g_suffix, + "files": g_files, + } + + if len(hist) == 0: + return { + generate_mode: "Code Completion", + prompt: "", + suffix: "", + files: None, + history: [current], + } + if compare_history(current, hist[-1]): + hist = hist[:-1] + + req = hist[-1] if len(hist) > 0 else EMPTY + + return { + generate_mode: req["mode"], + prompt: req["prompt"], + suffix: req["suffix"], + files: g_files, + history: hist, + } + + def clear(g_mode, text, g_suffix, g_files, hist): + current = { + "mode": g_mode, + "prompt": text, + "suffix": g_suffix, + "files": g_files, + } + if len(hist) == 0 or ( + len(hist) > 0 and not compare_history(current, hist[-1]) + ): + hist.append(current) + hist.append(EMPTY) + return { + generate_mode: "Code Completion", + prompt: "", + suffix: "", + files: None, + history: hist, + } + + def complete(g_mode, text, g_suffix, g_files, hist, max_tokens, temperature): + from ..client import RESTfulClient + + client = RESTfulClient(self.endpoint) + client._set_token(self._access_token) + + model = client.get_model(self.model_uid) + assert isinstance(model, RESTfulCodeModelHandle) + + repo_files = ( + {k: open(k, mode="r", encoding="utf8").read() for k in g_files} + if g_files + else None + ) + current = { + "mode": g_mode, + "prompt": text, + "suffix": g_suffix, + "files": g_files, + } + if len(hist) == 0 or ( + len(hist) > 0 and not compare_history(current, hist[-1]) + ): + hist.append(current) + + response_content = text + + if g_mode == "Code Completion": + if self.repo_level_supported: + resp = model.code_generate( + "completion", + prompt=text, + files=repo_files, + generate_config={ + "max_tokens": max_tokens, + "temperature": temperature, + }, + ) + else: + resp = model.code_generate( + mode="completion", + prompt=text, + generate_config={ + "max_tokens": max_tokens, + "temperature": temperature, + }, + ) + else: + resp = model.code_generate( + mode="infill", + prompt=text, + suffix=g_suffix, + generate_config={ + "max_tokens": max_tokens, + "temperature": temperature, + }, + ) + assert isinstance(resp, dict) + choice = resp["choices"][0] + + response_content += choice["text"] + + current = { + "mode": g_mode, + "prompt": response_content, + "suffix": g_suffix, + "files": g_files, + } + + hist.append(current) + return { + prompt: response_content, + history: hist, + } + + def retry(g_mode, text, g_suffix, g_files, hist, max_tokens, temperature): + from ..client import RESTfulClient + + client = RESTfulClient(self.endpoint) + client._set_token(self._access_token) + + model = client.get_model(self.model_uid) + assert isinstance(model, RESTfulCodeModelHandle) + + current = { + "mode": g_mode, + "prompt": text, + "suffix": g_suffix, + "files": g_files, + } + + if len(hist) == 0 or ( + len(hist) > 0 and not compare_history(current, hist[-1]) + ): + hist.append(current) + + req = hist[-2] if len(hist) > 1 else EMPTY + + response_content = req["prompt"] + + repo_files = ( + {k: open(k, mode="r", encoding="utf8").read() for k in req["files"]} + if req["files"] + else None + ) + + resp = model.code_generate( + mode="completion" if req["mode"] == "Code Completion" else "infill", + prompt=req["prompt"], + suffix=req["suffix"], + files=repo_files, + generate_config={ + "max_tokens": max_tokens, + "temperature": temperature, + }, + ) + assert isinstance(resp, dict) + choice = resp["choices"][0] + response_content += choice["text"] + + req["prompt"] = response_content + + hist.append(req) + return { + generate_mode: req["mode"], + prompt: response_content, + suffix: req["suffix"], + files: req["files"], + history: hist, + } + + def mode_change(generate_mode): + if generate_mode == "Code Completion": + return { + suffix: Code( + container=True, + show_label=True, + label="Suffix", + lines=21, + visible=False, + ), + files: File( + container=False, + show_label=False, + label="Files", + file_count="multiple", + visible=self.repo_level_supported, + ), + } + else: + return { + suffix: Code( + container=True, + show_label=True, + label="Suffix", + lines=21, + interactive=True, + visible=self.infill_supported, + ), + files: File( + container=False, + show_label=False, + label="Files", + file_count="multiple", + visible=False, + ), + } + + with gr.Blocks( + title=f"🚀 Xinference Code Generate Bot : {self.model_name} 🚀", + css=""" + .center{ + display: flex; + justify-content: center; + align-items: center; + padding: 0px; + color: #9ea4b0 !important; + } + """, + analytics_enabled=False, + ) as code_generate_interface: + modes = ( + ["Code Completion", "Code Infill Completion"] + if self.infill_supported + else ["Code Completion"] + ) + + history = gr.State([]) + + Markdown( + f""" +

🚀 Xinference Code Generate Bot : {self.model_name} 🚀

+ """ + ) + Markdown( + f""" +
+ Model ID: {self.model_uid} +
+
+ Model Size: {self.model_size_in_billions} Billion Parameters +
+
+ Model Format: {self.model_format} +
+
+ Model Quantization: {self.quantization} +
+
+ Support Infill Code Completion: {self.infill_supported} +
+
+ Support Repository Level Code Completion: {self.repo_level_supported} +
+ """ + ) + + with Column(variant="panel"): + generate_mode = Dropdown( + container=True, + show_label=True, + label="Code Generate Mode", + choices=modes, + value=modes[0], + ) + + prompt = Code( + container=True, + show_label=True, + label="Prompt", + lines=21, + interactive=True, + ) + + suffix = Code( + container=True, + show_label=True, + label="Suffix", + lines=21, + visible=False, + ) + + files = File( + container=True, + show_label=True, + label="Repository Files", + file_count="multiple", + interactive=True, + visible=self.repo_level_supported, + ) + + with Row(): + btn_generate = gr.Button("Generate", variant="primary") + with Row(): + btn_undo = gr.Button("↩️ Undo") + btn_retry = gr.Button("🔄 Retry") + btn_clear = gr.Button("🗑️ Clear") + with Accordion("Additional Inputs", open=False): + length = gr.Slider( + minimum=1, + maximum=self.context_length, + value=1024, + step=1, + label="Max Tokens", + ) + temperature = gr.Slider( + minimum=0, maximum=2, value=1, step=0.01, label="Temperature" + ) + + generate_mode.change( + fn=mode_change, inputs=[generate_mode], outputs=[suffix, files] + ) + + btn_generate.click( + fn=complete, + inputs=[ + generate_mode, + prompt, + suffix, + files, + history, + length, + temperature, + ], + outputs=[prompt, history], + ) + + btn_undo.click( + fn=undo, + inputs=[generate_mode, prompt, suffix, files, history], + outputs=[generate_mode, prompt, suffix, files, history], + ) + + btn_retry.click( + fn=retry, + inputs=[ + generate_mode, + prompt, + suffix, + files, + history, + length, + temperature, + ], + outputs=[generate_mode, prompt, suffix, files, history], + ) + + btn_clear.click( + fn=clear, + inputs=[generate_mode, prompt, suffix, files, history], + outputs=[generate_mode, prompt, suffix, files, history], + ) + + return code_generate_interface + def build_generate_interface( self, ): diff --git a/xinference/web/ui/src/scenes/running_models/index.js b/xinference/web/ui/src/scenes/running_models/index.js index 1e0da75bb3..4fe5272e4b 100644 --- a/xinference/web/ui/src/scenes/running_models/index.js +++ b/xinference/web/ui/src/scenes/running_models/index.js @@ -14,7 +14,7 @@ import Title from '../../components/Title' const RunningModels = () => { const [tabValue, setTabValue] = React.useState( - sessionStorage.getItem('runningModelType') + sessionStorage.getItem('runningModelType'), ) const [llmData, setLlmData] = useState([]) const [embeddingModelData, setEmbeddingModelData] = useState([]) @@ -34,6 +34,63 @@ const RunningModels = () => { sessionStorage.setItem('runningModelType', newValue) } + function get_models(code_prompts) { + fetcher(`${endPoint}/v1/models`, { + method: 'GET', + }) + .then((response) => { + if (!response.ok) { + response.json().then((errorData) => { + setErrorMsg( + `Login failed: ${response.status} - ${ + errorData.detail || 'Unknown error' + }`, + ) + }) + } else { + response.json().then((response) => { + const newLlmData = [] + const newEmbeddingModelData = [] + const newImageModelData = [] + const newAudioModelData = [] + const newRerankModelData = [] + response.data.forEach((model) => { + let newValue = { + ...model, + id: model.id, + url: model.id, + } + if (newValue.model_type === 'LLM') { + if (model.model_name in code_prompts) { + newValue['infill_supported'] = 'fim_spec' in code_prompts[model.model_name] + newValue['repo_level_supported'] = 'repo_level_spec' in code_prompts[model.model_name] + } + newLlmData.push(newValue) + } else if (newValue.model_type === 'embedding') { + newEmbeddingModelData.push(newValue) + } else if (newValue.model_type === 'audio') { + newAudioModelData.push(newValue) + } else if (newValue.model_type === 'image') { + newImageModelData.push(newValue) + } else if (newValue.model_type === 'rerank') { + newRerankModelData.push(newValue) + } + }) + setLlmData(newLlmData) + setEmbeddingModelData(newEmbeddingModelData) + setAudioModelData(newAudioModelData) + setImageModelData(newImageModelData) + setRerankModelData(newRerankModelData) + setIsUpdatingModel(false) + }) + } + }) + .catch((error) => { + console.error('Error:', error) + setIsUpdatingModel(false) + }) + } + const update = (isCallingApi) => { if (cookie.token === '' || cookie.token === undefined) { return @@ -58,52 +115,24 @@ const RunningModels = () => { ]) } else { setIsUpdatingModel(true) - fetcher(`${endPoint}/v1/models`, { + + fetcher(`${endPoint}/v1/models/code_prompts`, { method: 'GET', + }).then((response) => { + if (!response.ok) { + response.json().then((errorData) => { + setErrorMsg( + `Login failed: ${response.status} - ${ + errorData.detail || 'Unknown error' + }`, + ) + }) + } else { + response.json().then((code_prompts) => { + get_models(code_prompts) + }) + } }) - .then((response) => { - if (!response.ok) { - response.json().then((errorData) => { - setErrorMsg( - `Login failed: ${response.status} - ${ - errorData.detail || 'Unknown error' - }` - ) - }) - } else { - response.json().then((response) => { - const newLlmData = [] - const newEmbeddingModelData = [] - const newImageModelData = [] - const newAudioModelData = [] - const newRerankModelData = [] - response.data.forEach((model) => { - let newValue = { - ...model, - id: model.id, - url: model.id, - } - if (newValue.model_type === 'LLM') { - newLlmData.push(newValue) - } else if (newValue.model_type === 'embedding') { - newEmbeddingModelData.push(newValue) - } else if (newValue.model_type === 'audio') { - newAudioModelData.push(newValue) - } else if (newValue.model_type === 'image') { - newImageModelData.push(newValue) - } else if (newValue.model_type === 'rerank') { - newRerankModelData.push(newValue) - } - }) - setLlmData(newLlmData) - setEmbeddingModelData(newEmbeddingModelData) - setAudioModelData(newAudioModelData) - setImageModelData(newImageModelData) - setRerankModelData(newRerankModelData) - setIsUpdatingModel(false) - }) - } - }) .catch((error) => { console.error('Error:', error) setIsUpdatingModel(false) @@ -218,11 +247,13 @@ const RunningModels = () => { model_ability: row.model_ability, model_description: row.model_description, model_lang: row.model_lang, + infill_supported: row.infill_supported, + repo_level_supported: row.repo_level_supported, }), }) .then((response) => response.json()) .then(() => - window.open(openUrl, '_blank', 'noopener noreferrer') + window.open(openUrl, '_blank', 'noopener noreferrer'), ) .finally(() => setIsCallingApi(false)) } else if (response.ok) { @@ -233,7 +264,7 @@ const RunningModels = () => { } else { // Other HTTP errors console.error( - `Unexpected response status: ${response.status}` + `Unexpected response status: ${response.status}`, ) setIsCallingApi(false) } @@ -501,7 +532,7 @@ const RunningModels = () => { }) .then((response) => response.json()) .then(() => - window.open(openUrl, '_blank', 'noopener noreferrer') + window.open(openUrl, '_blank', 'noopener noreferrer'), ) .finally(() => setIsCallingApi(false)) } else if (response.ok) { @@ -512,7 +543,7 @@ const RunningModels = () => { } else { // Other HTTP errors console.error( - `Unexpected response status: ${response.status}` + `Unexpected response status: ${response.status}`, ) setIsCallingApi(false) } From 62154548e812e1aae9e35c36d250e44711a51e03 Mon Sep 17 00:00:00 2001 From: Shi Hui Date: Tue, 14 May 2024 17:49:50 +0800 Subject: [PATCH 17/37] added the repo_name and file_path for prompt file support, and add the definition for codeqwen1.5 and make it supported by vllm engine. --- xinference/api/restful_api.py | 2 + xinference/client/restful/restful_client.py | 8 + xinference/core/chat_interface.py | 65 +- xinference/core/model.py | 5 + xinference/model/llm/__init__.py | 2 + xinference/model/llm/ggml/llamacpp.py | 2 + xinference/model/llm/lang_utils.py | 1127 +++++++++++++++++ xinference/model/llm/llm_family.json | 68 +- .../model/llm/llm_family_modelscope.json | 60 +- xinference/model/llm/pytorch/core.py | 2 + xinference/model/llm/tests/test_utils.py | 238 +++- xinference/model/llm/utils.py | 39 +- xinference/model/llm/vllm/core.py | 3 + xinference/types.py | 1 + 14 files changed, 1591 insertions(+), 31 deletions(-) create mode 100644 xinference/model/llm/lang_utils.py diff --git a/xinference/api/restful_api.py b/xinference/api/restful_api.py index 134d3097f9..ac64bd3eba 100644 --- a/xinference/api/restful_api.py +++ b/xinference/api/restful_api.py @@ -1456,6 +1456,7 @@ async def create_code_completion(self, request: Request) -> Response: exclude = { "mode", "prompt", + "file_path", "suffix", "repo_name", "files", @@ -1495,6 +1496,7 @@ async def create_code_completion(self, request: Request) -> Response: data = await model.code_generate( body.mode, body.prompt, + body.file_path, body.suffix, body.repo_name, body.files, diff --git a/xinference/client/restful/restful_client.py b/xinference/client/restful/restful_client.py index b31def8d35..40bc768df2 100644 --- a/xinference/client/restful/restful_client.py +++ b/xinference/client/restful/restful_client.py @@ -435,6 +435,7 @@ def code_generate( self, mode: "CodeGenerateMode", prompt: str, + file_path: Optional[str] = None, suffix: Optional[str] = None, repo_name: Optional[str] = None, files: Optional[typing.Mapping] = None, @@ -453,6 +454,8 @@ def code_generate( Infill is fill in middle completion, complete the code according provided prefix and suffix content. prompt: str The user's input, it presents prefix content in infill mode. + file_path: Optional[str] + The file path for prompt content file. suffix: Optional[str] The suffix content in infill mode. repo_name: Optional[str] @@ -481,6 +484,7 @@ def code_generate( "model": self._model_uid, "mode": mode, "prompt": prompt, + "file_path": file_path, "suffix": suffix, "repo_name": repo_name, "files": files, @@ -504,6 +508,7 @@ def get_code_prompt( self, mode: "CodeGenerateMode", prompt: str, + file_path: Optional[str] = None, suffix: Optional[str] = None, repo_name: Optional[str] = None, files: Optional[typing.Mapping] = None, @@ -520,6 +525,8 @@ def get_code_prompt( Infill is fill in middle completion, complete the code according provided prefix and suffix content. prompt: str The user's input, it presents prefix content in infill mode. + file_path: Optional[str] + The file path for prompt suffix: Optional[str] The suffix content in infill mode. repo_name: Optional[str] @@ -545,6 +552,7 @@ def get_code_prompt( "model": self._model_uid, "mode": mode, "prompt": prompt, + "file_path": file_path, "suffix": suffix, "repo_name": repo_name, "files": files, diff --git a/xinference/core/chat_interface.py b/xinference/core/chat_interface.py index f82eaf56b9..ae4ca5abd0 100644 --- a/xinference/core/chat_interface.py +++ b/xinference/core/chat_interface.py @@ -41,6 +41,9 @@ def compare_history(current, hist): if current["prompt"] != hist["prompt"]: return False + if current["file_path"] != hist["file_path"]: + return False + if current["suffix"] != hist["suffix"]: return False @@ -53,6 +56,7 @@ def compare_history(current, hist): EMPTY = { "mode": "Code Completion", "prompt": "", + "file_path": "", "suffix": "", "files": None, } @@ -339,10 +343,11 @@ def update_button(text): def build_code_generate_interface( self, ): - def undo(g_mode, text, g_suffix, g_files, hist): + def undo(g_mode, text, g_file_path, g_suffix, g_files, hist): current = { "mode": g_mode, "prompt": text, + "file_path": g_file_path, "suffix": g_suffix, "files": g_files, } @@ -351,6 +356,7 @@ def undo(g_mode, text, g_suffix, g_files, hist): return { generate_mode: "Code Completion", prompt: "", + file_path: "", suffix: "", files: None, history: [current], @@ -363,15 +369,17 @@ def undo(g_mode, text, g_suffix, g_files, hist): return { generate_mode: req["mode"], prompt: req["prompt"], + file_path: req["file_path"], suffix: req["suffix"], files: g_files, history: hist, } - def clear(g_mode, text, g_suffix, g_files, hist): + def clear(g_mode, text, g_file_path, g_suffix, g_files, hist): current = { "mode": g_mode, "prompt": text, + "file_path": g_file_path, "suffix": g_suffix, "files": g_files, } @@ -383,12 +391,15 @@ def clear(g_mode, text, g_suffix, g_files, hist): return { generate_mode: "Code Completion", prompt: "", + file_path: "", suffix: "", files: None, history: hist, } - def complete(g_mode, text, g_suffix, g_files, hist, max_tokens, temperature): + def complete( + g_mode, text, g_file_path, g_suffix, g_files, hist, max_tokens, temperature + ): from ..client import RESTfulClient client = RESTfulClient(self.endpoint) @@ -402,9 +413,11 @@ def complete(g_mode, text, g_suffix, g_files, hist, max_tokens, temperature): if g_files else None ) + current = { "mode": g_mode, "prompt": text, + "file_path": g_file_path, "suffix": g_suffix, "files": g_files, } @@ -420,6 +433,7 @@ def complete(g_mode, text, g_suffix, g_files, hist, max_tokens, temperature): resp = model.code_generate( "completion", prompt=text, + file_path=g_file_path, files=repo_files, generate_config={ "max_tokens": max_tokens, @@ -430,6 +444,7 @@ def complete(g_mode, text, g_suffix, g_files, hist, max_tokens, temperature): resp = model.code_generate( mode="completion", prompt=text, + file_path=g_file_path, generate_config={ "max_tokens": max_tokens, "temperature": temperature, @@ -453,6 +468,7 @@ def complete(g_mode, text, g_suffix, g_files, hist, max_tokens, temperature): current = { "mode": g_mode, "prompt": response_content, + "file_path": g_file_path, "suffix": g_suffix, "files": g_files, } @@ -463,7 +479,9 @@ def complete(g_mode, text, g_suffix, g_files, hist, max_tokens, temperature): history: hist, } - def retry(g_mode, text, g_suffix, g_files, hist, max_tokens, temperature): + def retry( + g_mode, text, g_suffix, g_file_path, g_files, hist, max_tokens, temperature + ): from ..client import RESTfulClient client = RESTfulClient(self.endpoint) @@ -475,6 +493,7 @@ def retry(g_mode, text, g_suffix, g_files, hist, max_tokens, temperature): current = { "mode": g_mode, "prompt": text, + "file_path": g_file_path, "suffix": g_suffix, "files": g_files, } @@ -497,6 +516,7 @@ def retry(g_mode, text, g_suffix, g_files, hist, max_tokens, temperature): resp = model.code_generate( mode="completion" if req["mode"] == "Code Completion" else "infill", prompt=req["prompt"], + file_path=req["file_path"], suffix=req["suffix"], files=repo_files, generate_config={ @@ -514,6 +534,7 @@ def retry(g_mode, text, g_suffix, g_files, hist, max_tokens, temperature): return { generate_mode: req["mode"], prompt: response_content, + file_path: req["file_path"], suffix: req["suffix"], files: req["files"], history: hist, @@ -522,6 +543,12 @@ def retry(g_mode, text, g_suffix, g_files, hist, max_tokens, temperature): def mode_change(generate_mode): if generate_mode == "Code Completion": return { + file_path: Textbox( + container=True, + show_label=True, + label="Prompt file path", + interactive=self.repo_level_supported, + ), suffix: Code( container=True, show_label=True, @@ -539,6 +566,13 @@ def mode_change(generate_mode): } else: return { + file_path: Textbox( + container=True, + show_label=True, + label="Prompt file path", + interactive=True, + visible=False, + ), suffix: Code( container=True, show_label=True, @@ -630,6 +664,13 @@ def mode_change(generate_mode): visible=False, ) + file_path = Textbox( + container=True, + show_label=True, + label="Prompt file path", + interactive=True, + ) + files = File( container=True, show_label=True, @@ -658,7 +699,9 @@ def mode_change(generate_mode): ) generate_mode.change( - fn=mode_change, inputs=[generate_mode], outputs=[suffix, files] + fn=mode_change, + inputs=[generate_mode], + outputs=[file_path, suffix, files], ) btn_generate.click( @@ -666,6 +709,7 @@ def mode_change(generate_mode): inputs=[ generate_mode, prompt, + file_path, suffix, files, history, @@ -677,8 +721,8 @@ def mode_change(generate_mode): btn_undo.click( fn=undo, - inputs=[generate_mode, prompt, suffix, files, history], - outputs=[generate_mode, prompt, suffix, files, history], + inputs=[generate_mode, prompt, file_path, suffix, files, history], + outputs=[generate_mode, prompt, file_path, suffix, files, history], ) btn_retry.click( @@ -686,19 +730,20 @@ def mode_change(generate_mode): inputs=[ generate_mode, prompt, + file_path, suffix, files, history, length, temperature, ], - outputs=[generate_mode, prompt, suffix, files, history], + outputs=[generate_mode, prompt, file_path, suffix, files, history], ) btn_clear.click( fn=clear, - inputs=[generate_mode, prompt, suffix, files, history], - outputs=[generate_mode, prompt, suffix, files, history], + inputs=[generate_mode, prompt, file_path, suffix, files, history], + outputs=[generate_mode, prompt, file_path, suffix, files, history], ) return code_generate_interface diff --git a/xinference/core/model.py b/xinference/core/model.py index c68565d0da..e3994d092a 100644 --- a/xinference/core/model.py +++ b/xinference/core/model.py @@ -405,6 +405,7 @@ async def code_generate( self, mode: CodeGenerateMode, prompt: str, + file_path: Optional[str], suffix: Optional[str], repo_name: Optional[str], files: Optional[Mapping[str, str]], @@ -419,6 +420,7 @@ async def code_generate( self._model.code_generate, mode, prompt, + file_path, suffix, repo_name, files, @@ -431,6 +433,7 @@ async def code_generate( self._model.async_code_generate, mode, prompt, + file_path, suffix, repo_name, files, @@ -466,6 +469,7 @@ async def get_code_prompt( self, mode: CodeGenerateMode, prompt: str, + file_path: Optional[str], suffix: Optional[str], repo_name: Optional[str], files: Optional[Mapping[str, str]], @@ -477,6 +481,7 @@ async def get_code_prompt( self._model.get_code_prompt, mode, prompt, + file_path, suffix, repo_name, files, diff --git a/xinference/model/llm/__init__.py b/xinference/model/llm/__init__.py index c797191680..f344dee475 100644 --- a/xinference/model/llm/__init__.py +++ b/xinference/model/llm/__init__.py @@ -41,11 +41,13 @@ VLLM_CLASSES, CodePromptStyleV1, CustomLLMFamilyV1, + FIMSpecV1, GgmlLLMSpecV1, LLMFamilyV1, LLMSpecV1, PromptStyleV1, PytorchLLMSpecV1, + RepoLevelCodeCompletionSpecV1, get_cache_status, get_user_defined_llm_families, match_llm, diff --git a/xinference/model/llm/ggml/llamacpp.py b/xinference/model/llm/ggml/llamacpp.py index 126aae9fa4..acfbcdcba1 100644 --- a/xinference/model/llm/ggml/llamacpp.py +++ b/xinference/model/llm/ggml/llamacpp.py @@ -338,6 +338,7 @@ def code_generate( self, generate_model: CodeGenerateMode, prompt: str, + file_path: Optional[str], suffix: Optional[str], repo_name: Optional[str], files: Optional[Mapping[str, str]], @@ -346,6 +347,7 @@ def code_generate( code_prompt = self.get_code_prompt( generate_model, prompt, + file_path, suffix, repo_name, files, diff --git a/xinference/model/llm/lang_utils.py b/xinference/model/llm/lang_utils.py new file mode 100644 index 0000000000..60bbbb3ea7 --- /dev/null +++ b/xinference/model/llm/lang_utils.py @@ -0,0 +1,1127 @@ +# Copyright 2022-2023 XProbe Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os.path + +from . import RepoLevelCodeCompletionSpecV1 + +logger = logging.getLogger(__name__) + +# NOTICE: these code referenced by aixcoder + +# https://github.com/aixcoder-plugin/aiXcoder-7B/blob/main/hf_mini/utils.py + +LANGUAGE_WRAPPER = { + "c": "// ", + "c++": "// ", + "cpp": "// ", + "c#": "// ", + "csharp": "// ", + "c-sharp": "// ", + "css": "/* */", + "cuda": "// ", + "dart": "// ", + "lua": "// ", + "objectivec": "// ", + "objective-c": "// ", + "objective-c++": "// ", + "python": "# ", + "perl": "# ", + "prolog": "% ", + "swift": "// ", + "lisp": "; ", + "java": "// ", + "scala": "// ", + "tex": "% ", + "vue": "", + "markdown": "", + "html": "", + "php": "// ", + "js": "// ", + "javascript": "// ", + "typescript": "// ", + "go": "// ", + "shell": "# ", + "rust": "// ", + "sql": "-- ", + "kotlin": "// ", + "vb": "' ", + "ruby": "# ", + "pascal": "// ", + "r": "# ", + "fortran": "!", + "lean": "-- ", + "matlab": "% ", + "delphi": "{}", + "scheme": "; ", + "basic": "' ", + "assembly": "; ", + "groovy": "// ", + "abap": "* ", + "gdscript": "# ", + "haskell": "-- ", + "julia": "# ", + "elixir": "# ", + "excel": "' ", + "clojure": "; ", + "actionscript": "// ", + "solidity": "// ", + "powershell": "# ", + "erlang": "% ", + "cobol": "// ", + "alloy": "/* */", + "awk": "// ", + "thrift": "/* */", + "sparql": "# ", + "augeas": "// ", + "cmake": "# ", + "f-sharp": "// ", + "stan": "// ", + "isabelle": "(**)", + "dockerfile": "# ", + "rmarkdown": "# ", + "literate-agda": "-- ", + "tcl": "// ", + "glsl": "// ", + "antlr": "// ", + "verilog": "// ", + "racket": "; ", + "standard-ml": "(**)", + "elm": "-- ", + "yaml": "# ", + "smalltalk": "'' ", + "ocaml": "(**)", + "idris": "-- ", + "visual-basic": "' ", + "protocol-buffer": "// ", + "bluespec": "// ", + "applescript": "-- ", + "makefile": "# ", + "tcsh": "# ", + "maple": "# ", + "systemverilog": "// ", + "literate-coffeescript": "# ", + "vhdl": "-- ", + "restructuredtext": ".. ", + "sas": "* ", + "literate-haskell": "> ", + "java-server-pages": "// ", + "coffeescript": "# ", + "emacs-lisp": "; ", + "mathematica": "// ", + "xslt": "", + "zig": "// ", + "common-lisp": "; ", + "stata": "* ", + "agda": "-- ", + "ada": "-- ", + "jsx": "// ", + "tsx": "// ", +} + +EXT2LANG = { + ".abap": "abap", + ".ash": "ags script", + ".ampl": "ampl", + ".g4": "antlr", + ".apib": "api blueprint", + ".apl": "apl", + ".dyalog": "apl", + ".asp": "asp", + ".asax": "asp", + ".ascx": "asp", + ".ashx": "asp", + ".asmx": "asp", + ".aspx": "asp", + ".axd": "asp", + ".dats": "ats", + ".hats": "ats", + ".sats": "ats", + ".as": "actionscript", + ".adb": "ada", + ".ada": "ada", + ".ads": "ada", + ".agda": "agda", + ".apacheconf": "apacheconf", + ".vhost": "apacheconf", + ".applescript": "applescript", + ".scpt": "applescript", + ".arc": "arc", + ".ino": "arduino", + ".asciidoc": "asciidoc", + ".adoc": "asciidoc", + ".aj": "aspectj", + ".asm": "assembly", + ".a51": "assembly", + ".nasm": "assembly", + ".aug": "augeas", + ".ahk": "autohotkey", + ".ahkl": "autohotkey", + ".au3": "autoit", + ".awk": "awk", + ".auk": "awk", + ".gawk": "awk", + ".mawk": "awk", + ".nawk": "awk", + ".bat": "batchfile", + ".cmd": "batchfile", + ".befunge": "befunge", + ".bison": "bison", + ".bb": "bitbake", + ".decls": "blitzbasic", + ".bmx": "blitzmax", + ".bsv": "bluespec", + ".boo": "boo", + ".bf": "brainfuck", + ".brs": "brightscript", + ".bro": "bro", + ".c": "c", + ".cats": "c", + ".h": "c++", + ".idc": "c", + ".w": "c", + ".cs": "c#", + ".cake": "c#", + ".cshtml": "c#", + ".csx": "c#", + ".cpp": "c++", + ".c++": "c++", + ".cc": "c++", + ".cp": "c++", + ".cxx": "c++", + ".h++": "c++", + ".hh": "c++", + ".hpp": "c++", + ".hxx": "c++", + ".inl": "c++", + ".ipp": "c++", + ".tcc": "c++", + ".tpp": "c++", + ".C": "c++", + ".H": "c++", + ".c-objdump": "c-objdump", + ".chs": "c2hs haskell", + ".clp": "clips", + ".cmake": "cmake", + ".cmake.in": "cmake", + ".cob": "cobol", + ".cbl": "cobol", + ".ccp": "cobol", + ".cobol": "cobol", + ".cpy": "cobol", + ".css": "css", + ".csv": "csv", + ".capnp": "cap'n proto", + ".mss": "cartocss", + ".ceylon": "ceylon", + ".chpl": "chapel", + ".ck": "chuck", + ".cirru": "cirru", + ".clw": "clarion", + ".icl": "clean", + ".dcl": "clean", + ".click": "click", + ".clj": "clojure", + ".boot": "clojure", + ".cl2": "clojure", + ".cljc": "clojure", + ".cljs": "clojure", + ".cljs.hl": "clojure", + ".cljscm": "clojure", + ".cljx": "clojure", + ".hic": "clojure", + ".coffee": "coffeescript", + "._coffee": "coffeescript", + ".cjsx": "coffeescript", + ".cson": "coffeescript", + ".iced": "coffeescript", + ".cfm": "coldfusion", + ".cfml": "coldfusion", + ".cfc": "coldfusion cfc", + ".lisp": "common lisp", + ".asd": "common lisp", + ".lsp": "common lisp", + ".ny": "common lisp", + ".podsl": "common lisp", + ".sexp": "common lisp", + ".cps": "component pascal", + ".coq": "coq", + ".cppobjdump": "cpp-objdump", + ".c++-objdump": "cpp-objdump", + ".c++objdump": "cpp-objdump", + ".cpp-objdump": "cpp-objdump", + ".cxx-objdump": "cpp-objdump", + ".creole": "creole", + ".cr": "crystal", + ".csd": "csound", + ".feature": "cucumber", + ".cu": "cuda", + ".cuh": "cuda", + ".cy": "cycript", + ".pyx": "cython", + ".pxd": "cython", + ".pxi": "cython", + ".di": "d", + ".d-objdump": "d-objdump", + ".com": "digital command language", + ".dm": "dm", + ".zone": "dns zone", + ".arpa": "dns zone", + ".darcspatch": "darcs patch", + ".dpatch": "darcs patch", + ".dart": "dart", + ".diff": "diff", + ".patch": "diff", + ".dockerfile": "dockerfile", + "Dockerfile": "dockerfile", + ".djs": "dogescript", + ".dylan": "dylan", + ".dyl": "dylan", + ".intr": "dylan", + ".lid": "dylan", + ".E": "e", + ".ecl": "ecl", + ".eclxml": "ecl", + ".sch": "eagle", + ".brd": "eagle", + ".epj": "ecere projects", + ".e": "eiffel", + ".ex": "elixir", + ".exs": "elixir", + ".elm": "elm", + ".el": "emacs lisp", + ".emacs": "emacs lisp", + ".emacs.desktop": "emacs lisp", + ".em": "emberscript", + ".emberscript": "emberscript", + ".erl": "erlang", + ".escript": "erlang", + ".hrl": "erlang", + ".xrl": "erlang", + ".yrl": "erlang", + ".fs": "f#", + ".fsi": "f#", + ".fsx": "f#", + ".flux": "flux", + ".f90": "fortran", + ".f": "fortran", + ".f03": "fortran", + ".f08": "fortran", + ".f77": "fortran", + ".f95": "fortran", + ".for": "fortran", + ".fpp": "fortran", + ".factor": "factor", + ".fy": "fancy", + ".fancypack": "fancy", + ".fan": "fantom", + ".eam.fs": "formatted", + ".fth": "forth", + ".4th": "forth", + ".forth": "forth", + ".frt": "forth", + ".ftl": "freemarker", + ".g": "g-code", + ".gco": "g-code", + ".gcode": "g-code", + ".gms": "gams", + ".gap": "gap", + ".gi": "gap", + ".s": "gas", + ".gd": "gdscript", + ".glsl": "glsl", + ".fp": "glsl", + ".frag": "glsl", + ".frg": "glsl", + ".fsh": "glsl", + ".fshader": "glsl", + ".geo": "glsl", + ".geom": "glsl", + ".glslv": "glsl", + ".gshader": "glsl", + ".shader": "glsl", + ".vert": "glsl", + ".vrx": "glsl", + ".vsh": "glsl", + ".vshader": "glsl", + ".kid": "genshi", + ".ebuild": "gentoo ebuild", + ".eclass": "gentoo eclass", + ".po": "gettext catalog", + ".pot": "gettext catalog", + ".glf": "glyph", + ".gp": "gnuplot", + ".gnu": "gnuplot", + ".gnuplot": "gnuplot", + ".plot": "gnuplot", + ".plt": "gnuplot", + ".go": "go", + ".golo": "golo", + ".gst": "gosu", + ".gsx": "gosu", + ".vark": "gosu", + ".grace": "grace", + ".gradle": "gradle", + ".gf": "grammatical framework", + ".graphql": "graphql", + ".dot": "graphviz (dot)", + ".gv": "graphviz (dot)", + ".man": "groff", + ".1": "groff", + ".1in": "groff", + ".1m": "groff", + ".1x": "groff", + ".2": "groff", + ".3": "groff", + ".3in": "groff", + ".3m": "groff", + ".3qt": "groff", + ".3x": "groff", + ".4": "groff", + ".5": "groff", + ".6": "groff", + ".7": "groff", + ".8": "groff", + ".9": "groff", + ".me": "groff", + ".rno": "groff", + ".roff": "groff", + ".groovy": "groovy", + ".grt": "groovy", + ".gtpl": "groovy", + ".gvy": "groovy", + ".gsp": "groovy server pages", + ".hcl": "hcl", + ".tf": "hcl", + ".hlsl": "hlsl", + ".fxh": "hlsl", + ".hlsli": "hlsl", + ".html": "html", + ".htm": "html", + ".html.hl": "html", + ".xht": "html", + ".xhtml": "html", + ".mustache": "html+django", + ".jinja": "html+django", + ".eex": "html+eex", + ".erb": "html+erb", + ".erb.deface": "html+erb", + ".phtml": "html+php", + ".http": "http", + ".haml": "haml", + ".haml.deface": "haml", + ".handlebars": "handlebars", + ".hbs": "handlebars", + ".hb": "harbour", + ".hs": "haskell", + ".hsc": "haskell", + ".hx": "haxe", + ".hxsl": "haxe", + ".hy": "hy", + ".dlm": "idl", + ".ipf": "igor pro", + ".ini": "ini", + ".cfg": "ini", + ".prefs": "ini", + ".properties": "ini", + ".irclog": "irc log", + ".weechatlog": "irc log", + ".idr": "idris", + ".lidr": "idris", + ".ni": "inform 7", + ".i7x": "inform 7", + ".iss": "inno setup", + ".io": "io", + ".ik": "ioke", + ".thy": "isabelle", + ".ijs": "j", + ".flex": "jflex", + ".jflex": "jflex", + ".json": "json", + ".geojson": "json", + ".lock": "json", + ".topojson": "json", + ".json5": "json5", + ".jsonld": "jsonld", + ".jq": "jsoniq", + ".jsx": "jsx", + ".jade": "jade", + ".j": "jasmin", + ".java": "java", + ".jsp": "java server pages", + ".js": "javascript", + "._js": "javascript", + ".bones": "javascript", + ".es6": "javascript", + ".jake": "javascript", + ".jsb": "javascript", + ".jscad": "javascript", + ".jsfl": "javascript", + ".jsm": "javascript", + ".jss": "javascript", + ".njs": "javascript", + ".pac": "javascript", + ".sjs": "javascript", + ".ssjs": "javascript", + ".xsjs": "javascript", + ".xsjslib": "javascript", + ".jl": "julia", + ".ipynb": "jupyter notebook", + ".krl": "krl", + ".kicad_pcb": "kicad", + ".kit": "kit", + ".kt": "kotlin", + ".ktm": "kotlin", + ".kts": "kotlin", + ".lfe": "lfe", + ".ll": "llvm", + ".lol": "lolcode", + ".lsl": "lsl", + ".lslp": "lsl", + ".lvproj": "labview", + ".lasso": "lasso", + ".las": "lasso", + ".lasso8": "lasso", + ".lasso9": "lasso", + ".ldml": "lasso", + ".latte": "latte", + ".lean": "lean", + ".hlean": "lean", + ".less": "less", + ".lex": "lex", + ".ly": "lilypond", + ".ily": "lilypond", + ".ld": "linker script", + ".lds": "linker script", + ".liquid": "liquid", + ".lagda": "literate agda", + ".litcoffee": "literate coffeescript", + ".lhs": "literate haskell", + ".ls": "livescript", + "._ls": "livescript", + ".xm": "logos", + ".x": "logos", + ".xi": "logos", + ".lgt": "logtalk", + ".logtalk": "logtalk", + ".lookml": "lookml", + ".lua": "lua", + ".nse": "lua", + ".pd_lua": "lua", + ".rbxs": "lua", + ".wlua": "lua", + ".mumps": "m", + ".m4": "m4", + ".mcr": "maxscript", + ".mtml": "mtml", + ".muf": "muf", + ".mak": "makefile", + ".mk": "makefile", + ".mkfile": "makefile", + "Makefile": "makefile", + ".mako": "mako", + ".mao": "mako", + ".mpl": "maple", + ".md": "markdown", + ".markdown": "markdown", + ".mkd": "markdown", + ".mkdn": "markdown", + ".mkdown": "markdown", + ".ron": "markdown", + ".mask": "mask", + ".mathematica": "mathematica", + ".cdf": "mathematica", + ".ma": "mathematica", + ".mt": "mathematica", + ".nb": "mathematica", + ".nbp": "mathematica", + ".wl": "mathematica", + ".wlt": "mathematica", + ".matlab": "matlab", + ".maxpat": "max", + ".maxhelp": "max", + ".maxproj": "max", + ".mxt": "max", + ".pat": "max", + ".mediawiki": "mediawiki", + ".wiki": "mediawiki", + ".metal": "metal", + ".minid": "minid", + ".druby": "mirah", + ".duby": "mirah", + ".mir": "mirah", + ".mirah": "mirah", + ".mo": "modelica", + ".mms": "module management system", + ".mmk": "module management system", + ".monkey": "monkey", + ".moon": "moonscript", + ".myt": "myghty", + ".nsi": "nsis", + ".nsh": "nsis", + ".axs": "netlinx", + ".axi": "netlinx", + ".axs.erb": "netlinx+erb", + ".axi.erb": "netlinx+erb", + ".nlogo": "netlogo", + ".nginxconf": "nginx", + ".nim": "nimrod", + ".nimrod": "nimrod", + ".ninja": "ninja", + ".nit": "nit", + ".nix": "nix", + ".nu": "nu", + ".numpy": "numpy", + ".numpyw": "numpy", + ".numsc": "numpy", + ".ml": "ocaml", + ".eliom": "ocaml", + ".eliomi": "ocaml", + ".ml4": "ocaml", + ".mli": "ocaml", + ".mll": "ocaml", + ".mly": "ocaml", + ".objdump": "objdump", + ".mm": "objective-c++", + ".sj": "objective-j", + ".oct": "octave", + ".omgrofl": "omgrofl", + ".opa": "opa", + ".opal": "opal", + ".cl": "opencl", + ".opencl": "opencl", + ".p": "openedge abl", + ".scad": "openscad", + ".org": "org", + ".ox": "ox", + ".oxh": "ox", + ".oxo": "ox", + ".oxygene": "oxygene", + ".oz": "oz", + ".pwn": "pawn", + ".php": "php", + ".aw": "php", + ".ctp": "php", + ".php3": "php", + ".php4": "php", + ".php5": "php", + ".phps": "php", + ".phpt": "php", + ".pov": "pov-ray sdl", + ".pan": "pan", + ".psc": "papyrus", + ".parrot": "parrot", + ".pasm": "parrot assembly", + ".pir": "parrot internal representation", + ".pas": "pascal", + ".dfm": "pascal", + ".dpr": "pascal", + ".lpr": "pascal", + ".pl": "perl", + ".al": "perl", + ".perl": "perl", + ".ph": "perl", + ".plx": "perl", + ".pm": "perl", + ".psgi": "perl", + ".t": "perl", + ".6pl": "perl6", + ".6pm": "perl6", + ".nqp": "perl6", + ".p6": "perl6", + ".p6l": "perl6", + ".p6m": "perl6", + ".pl6": "perl6", + ".pm6": "perl6", + ".pkl": "pickle", + ".pig": "piglatin", + ".pike": "pike", + ".pmod": "pike", + ".pod": "pod", + ".pogo": "pogoscript", + ".pony": "pony", + ".ps": "postscript", + ".eps": "postscript", + ".ps1": "powershell", + ".psd1": "powershell", + ".psm1": "powershell", + ".pde": "processing", + ".prolog": "prolog", + ".yap": "prolog", + ".spin": "propeller spin", + ".proto": "protocol buffer", + ".pub": "public key", + ".pd": "pure data", + ".pb": "purebasic", + ".pbi": "purebasic", + ".purs": "purescript", + ".py": "python", + ".bzl": "python", + ".gyp": "python", + ".lmi": "python", + ".pyde": "python", + ".pyp": "python", + ".pyt": "python", + ".pyw": "python", + ".tac": "python", + ".wsgi": "python", + ".xpy": "python", + ".pytb": "python traceback", + ".qml": "qml", + ".qbs": "qml", + ".pri": "qmake", + ".r": "r", + ".rd": "r", + ".rsx": "r", + ".raml": "raml", + ".rdoc": "rdoc", + ".rbbas": "realbasic", + ".rbfrm": "realbasic", + ".rbmnu": "realbasic", + ".rbres": "realbasic", + ".rbtbar": "realbasic", + ".rbuistate": "realbasic", + ".rhtml": "rhtml", + ".rmd": "rmarkdown", + ".rkt": "racket", + ".rktd": "racket", + ".rktl": "racket", + ".scrbl": "racket", + ".rl": "ragel in ruby host", + ".raw": "raw token data", + ".reb": "rebol", + ".r2": "rebol", + ".r3": "rebol", + ".rebol": "rebol", + ".red": "red", + ".reds": "red", + ".cw": "redcode", + ".rpy": "ren'py", + ".rsh": "renderscript", + ".robot": "robotframework", + ".rb": "ruby", + ".builder": "ruby", + ".gemspec": "ruby", + ".god": "ruby", + ".irbrc": "ruby", + ".jbuilder": "ruby", + ".mspec": "ruby", + ".podspec": "ruby", + ".rabl": "ruby", + ".rake": "ruby", + ".rbuild": "ruby", + ".rbw": "ruby", + ".rbx": "ruby", + ".ru": "ruby", + ".ruby": "ruby", + ".thor": "ruby", + ".watchr": "ruby", + ".rs": "rust", + ".rs.in": "rust", + ".sas": "sas", + ".scss": "scss", + ".smt2": "smt", + ".smt": "smt", + ".sparql": "sparql", + ".rq": "sparql", + ".sqf": "sqf", + ".hqf": "sqf", + ".pls": "sql", + ".pck": "sql", + ".pkb": "sql", + ".pks": "sql", + ".plb": "sql", + ".plsql": "sql", + ".sql": "sql", + ".cql": "sql", + ".ddl": "sql", + ".prc": "sql", + ".tab": "sql", + ".udf": "sql", + ".viw": "sql", + ".db2": "sql", + ".ston": "ston", + ".svg": "svg", + ".sage": "sage", + ".sagews": "sage", + ".sls": "saltstack", + ".sass": "sass", + ".scala": "scala", + ".sbt": "scala", + ".scaml": "scaml", + ".scm": "scheme", + ".sld": "scheme", + ".sps": "scheme", + ".ss": "scheme", + ".sci": "scilab", + ".sce": "scilab", + ".self": "self", + ".sh": "shell", + ".bash": "shell", + ".bats": "shell", + ".command": "shell", + ".ksh": "shell", + ".sh.in": "shell", + ".tmux": "shell", + ".tool": "shell", + ".zsh": "shell", + ".sh-session": "shellsession", + ".shen": "shen", + ".sl": "slash", + ".slim": "slim", + ".smali": "smali", + ".st": "smalltalk", + ".tpl": "smarty", + ".sol": "solidity", + ".sp": "sourcepawn", + ".sma": "sourcepawn", + ".nut": "squirrel", + ".stan": "stan", + ".ML": "standard ml", + ".fun": "standard ml", + ".sig": "standard ml", + ".sml": "standard ml", + ".do": "stata", + ".ado": "stata", + ".doh": "stata", + ".ihlp": "stata", + ".matah": "stata", + ".sthlp": "stata", + ".styl": "stylus", + ".scd": "supercollider", + ".swift": "swift", + ".sv": "systemverilog", + ".svh": "systemverilog", + ".vh": "systemverilog", + ".toml": "toml", + ".txl": "txl", + ".tcl": "tcl", + ".adp": "tcl", + ".tm": "tcl", + ".tcsh": "tcsh", + ".csh": "tcsh", + ".tex": "tex", + ".aux": "tex", + ".bbx": "tex", + ".bib": "tex", + ".cbx": "tex", + ".dtx": "tex", + ".ins": "tex", + ".lbx": "tex", + ".ltx": "tex", + ".mkii": "tex", + ".mkiv": "tex", + ".mkvi": "tex", + ".sty": "tex", + ".toc": "tex", + ".tea": "tea", + ".txt": "text", + ".no": "text", + ".textile": "textile", + ".thrift": "thrift", + ".tu": "turing", + ".ttl": "turtle", + ".twig": "twig", + ".ts": "typescript", + ".tsx": "tsx", + ".upc": "unified parallel c", + ".anim": "unity3d asset", + ".asset": "unity3d asset", + ".mat": "unity3d asset", + ".meta": "unity3d asset", + ".prefab": "unity3d asset", + ".unity": "unity3d asset", + ".uno": "uno", + ".uc": "unrealscript", + ".ur": "urweb", + ".urs": "urweb", + ".vcl": "vcl", + ".vhdl": "vhdl", + ".vhd": "vhdl", + ".vhf": "vhdl", + ".vhi": "vhdl", + ".vho": "vhdl", + ".vhs": "vhdl", + ".vht": "vhdl", + ".vhw": "vhdl", + ".vala": "vala", + ".vapi": "vala", + ".veo": "verilog", + ".vim": "viml", + ".vb": "visual basic", + ".bas": "visual basic", + ".frm": "visual basic", + ".frx": "visual basic", + ".vba": "visual basic", + ".vbhtml": "visual basic", + ".vbs": "visual basic", + ".volt": "volt", + ".vue": "vue", + ".owl": "web ontology language", + ".wat": "webassembly", + ".webidl": "webidl", + ".x10": "x10", + ".xc": "xc", + ".xml": "xml", + ".ant": "xml", + ".axml": "xml", + ".ccxml": "xml", + ".clixml": "xml", + ".cproject": "xml", + ".csl": "xml", + ".csproj": "xml", + ".ct": "xml", + ".dita": "xml", + ".ditamap": "xml", + ".ditaval": "xml", + ".dll.config": "xml", + ".dotsettings": "xml", + ".filters": "xml", + ".fsproj": "xml", + ".fxml": "xml", + ".glade": "xml", + ".grxml": "xml", + ".iml": "xml", + ".ivy": "xml", + ".jelly": "xml", + ".jsproj": "xml", + ".kml": "xml", + ".launch": "xml", + ".mdpolicy": "xml", + ".mxml": "xml", + ".nproj": "xml", + ".nuspec": "xml", + ".odd": "xml", + ".osm": "xml", + ".plist": "xml", + ".props": "xml", + ".ps1xml": "xml", + ".psc1": "xml", + ".pt": "xml", + ".rdf": "xml", + ".rss": "xml", + ".scxml": "xml", + ".srdf": "xml", + ".storyboard": "xml", + ".stTheme": "xml", + ".sublime-snippet": "xml", + ".targets": "xml", + ".tmCommand": "xml", + ".tml": "xml", + ".tmLanguage": "xml", + ".tmPreferences": "xml", + ".tmSnippet": "xml", + ".tmTheme": "xml", + ".ui": "xml", + ".urdf": "xml", + ".ux": "xml", + ".vbproj": "xml", + ".vcxproj": "xml", + ".vssettings": "xml", + ".vxml": "xml", + ".wsdl": "xml", + ".wsf": "xml", + ".wxi": "xml", + ".wxl": "xml", + ".wxs": "xml", + ".x3d": "xml", + ".xacro": "xml", + ".xaml": "xml", + ".xib": "xml", + ".xlf": "xml", + ".xliff": "xml", + ".xmi": "xml", + ".xml.dist": "xml", + ".xproj": "xml", + ".xsd": "xml", + ".xul": "xml", + ".zcml": "xml", + ".xsp-config": "xpages", + ".xsp.metadata": "xpages", + ".xpl": "xproc", + ".xproc": "xproc", + ".xquery": "xquery", + ".xq": "xquery", + ".xql": "xquery", + ".xqm": "xquery", + ".xqy": "xquery", + ".xs": "xs", + ".xslt": "xslt", + ".xsl": "xslt", + ".xojo_code": "xojo", + ".xojo_menu": "xojo", + ".xojo_report": "xojo", + ".xojo_script": "xojo", + ".xojo_toolbar": "xojo", + ".xojo_window": "xojo", + ".xtend": "xtend", + ".yml": "yaml", + ".reek": "yaml", + ".rviz": "yaml", + ".sublime-syntax": "yaml", + ".syntax": "yaml", + ".yaml": "yaml", + ".yaml-tmlanguage": "yaml", + ".yang": "yang", + ".y": "yacc", + ".yacc": "yacc", + ".yy": "yacc", + ".zep": "zephir", + ".zig": "zig", + ".zimpl": "zimpl", + ".zmpl": "zimpl", + ".zpl": "zimpl", + ".desktop": "desktop", + ".desktop.in": "desktop", + ".ec": "ec", + ".eh": "ec", + ".fish": "fish", + ".mu": "mupad", + ".nc": "nesc", + ".ooc": "ooc", + ".rst": "restructuredtext", + ".rest": "restructuredtext", + ".rest.txt": "restructuredtext", + ".rst.txt": "restructuredtext", + ".wisp": "wisp", + ".prg": "xbase", + ".prw": "xbase", +} + + +LANGUAGE_TAG = { + "c": "// the code file is written by C", + "c++": "// the code file is written by C++", + "cpp": "// the code file is written by C++", + "c#": "// the code file is written by C#", + "csharp": "// the code file is written by C#", + "c-sharp": "// the code file is written by C#", + "css": "/* the code file is written by CSS */", + "cuda": "// the code file is written by Cuda", + "dart": "// the code file is written by Dart", + "lua": "// the code file is written by Lua", + "objectivec": "// the code file is written by Objective-C", + "objective-c": "// the code file is written by Objective-C", + "objective-c++": "// the code file is written by Objective-C++", + "python": "# the code file is written by Python", + "perl": "# the code file is written by Perl", + "prolog": "% the code file is written by Prolog", + "swift": "// the code file is written by swift", + "lisp": "; the code file is written by Lisp", + "java": "// the code file is written by Java", + "scala": "// the code file is written by Scala", + "tex": "% the code file is written by TeX", + "vue": "", + "markdown": "", + "html": "", + "php": "// the code file is written by PHP", + "js": "// the code file is written by JavaScript", + "javascript": "// the code file is written by JavaScript", + "typescript": "// the code file is written by TypeScript", + "go": "// the code file is written by Go", + "shell": "# the code file is written by Shell", + "rust": "// the code file is written by Rust", + "sql": "-- the code file is written by SQL", + "kotlin": "// the code file is written by Kotlin", + "vb": "' the code file is written by Visual Basic", + "ruby": "# the code file is written by Ruby", + "pascal": "// the code file is written by Pascal", + "r": "# the code file is written by R", + "fortran": "!the code file is written by Fortran", + "lean": "-- the code file is written by Lean", + "matlab": "% the code file is written by Matlab", + "delphi": "{the code file is written by Delphi}", + "scheme": "; the code file is written by Scheme", + "basic": "' the code file is written by Basic", + "assembly": "; the code file is written by Assembly", + "groovy": "// the code file is written by Groovy", + "abap": "* the code file is written by Abap", + "gdscript": "# the code file is written by GDScript", + "haskell": "-- the code file is written by Haskell", + "julia": "# the code file is written by Julia", + "elixir": "# the code file is written by Elixir", + "excel": "' the code file is written by Excel", + "clojure": "; the code file is written by Clojure", + "actionscript": "// the code file is written by ActionScript", + "solidity": "// the code file is written by Solidity", + "powershell": "# the code file is written by PowerShell", + "erlang": "% the code file is written by Erlang", + "cobol": "// the code file is written by Cobol", + "alloy": "/* the code file is written by Alloy */", + "awk": "// the code file is written by AWK", + "thrift": "/* the code file is written by Thrift */", + "sparql": "# the code file is written by SPARQL", + "augeas": "// the code file is written by Augeas", + "cmake": "# the code file is written by CMake", + "f-sharp": "// the code file is written by F#", + "stan": "// the code file is written by Stan", + "isabelle": "(*the code file is written by Isabelle*)", + "dockerfile": "# the code file is written by Dockerfile", + "rmarkdown": "# the code file is written by RMarkdown", + "literate-agda": "-- the code file is written by Literate Agda", + "tcl": "// the code file is written by Augeas", + "glsl": "// the code file is written by GLSL", + "antlr": "// the code file is written by ANTLR", + "verilog": "// the code file is written by Verilog", + "racket": "; the code file is written by Racket", + "standard-ml": "(*the code file is written byStandard ML*)", + "elm": "-- the code file is written by Elm", + "yaml": "# the code file is written by YAML", + "smalltalk": "'' the code file is written by Smalltalk", + "ocaml": "(*the code file is written by OCaml*)", + "idris": "-- the code file is written by Idris", + "visual-basic": "' the code file is written by Visual Basic", + "protocol-buffer": "// the code file is written by Protocol Buffer", + "bluespec": "// the code file is written by Bluespec", + "applescript": "-- the code file is written by AppleScript", + "makefile": "# the code file is written by Makefile", + "tcsh": "# the code file is written by TCSH", + "maple": "# the code file is written by Maple", + "systemverilog": "// the code file is written by SystemVerilog", + "literate-coffeescript": "# the code file is written by Literate CoffeeScript", + "vhdl": "-- the code file is written by VHDL", + "restructuredtext": ".. the code file is written by reStructuredText", + "sas": "* the code file is written by SAS", + "literate-haskell": "> the code file is written by Literate Haskell", + "java-server-pages": "// the code file is written by Java Server Pages", + "coffeescript": "# the code file is written by CoffeeScript", + "emacs-lisp": "; the code file is written by Emacs Lisp", + "mathematica": "// the code file is written by Mathematica", + "xslt": "", + "zig": "// the code file is written by Zig", + "common-lisp": "; the code file is written by Common Lisp", + "stata": "* the code file is written by Stata", + "agda": "-- the code file is written by Agda", + "ada": "-- the code file is written by Ada", + "jsx": "// the code file is written by JSX", + "tsx": "// the code file is written by TypeScript with JSX", +} + + +def get_file_separator( + repo_level_spec: RepoLevelCodeCompletionSpecV1, filename: str +) -> str: + if repo_level_spec.file_separator == "": + lang = EXT2LANG.get(os.path.splitext(filename)[-1].lower(), None) + sep = "// " + if lang is None: + logger.warning( + f"Unsupported file extension for {filename}, use '//' as file separator as fallback" + ) + else: + sep = LANGUAGE_WRAPPER[lang] + + return sep.replace("", filename) + else: + return repo_level_spec.file_separator + filename diff --git a/xinference/model/llm/llm_family.json b/xinference/model/llm/llm_family.json index 7dcad72f3a..f30759f758 100644 --- a/xinference/model/llm/llm_family.json +++ b/xinference/model/llm/llm_family.json @@ -2266,6 +2266,72 @@ ] } }, + { + "version": 1, + "context_length": 65536, + "model_name": "codeqwen1.5", + "model_lang": [ + "en", + "zh" + ], + "model_ability": [ + "code" + ], + "model_description": "CodeQwen1.5 is the Code-Specific version of Qwen1.5. It is a transformer-based decoder-only language model pretrained on a large amount of data of codes.", + "model_specs": [ + { + "model_format": "pytorch", + "model_size_in_billions": 7, + "quantizations": [ + "4-bit", + "8-bit", + "none" + ], + "model_id": "Qwen/CodeQwen1.5-7B" + }, + { + "model_format": "awq", + "model_size_in_billions": 7, + "quantizations": [ + "Int4" + ], + "model_id": "Qwen/CodeQwen1.5-7B-AWQ" + } + ], + "prompt_style": { + "style_name": "QWEN", + "system_prompt": "You are a helpful assistant.", + "roles": [ + "user", + "assistant" + ], + "intra_message_sep": "\n", + "stop_token_ids": [ + 2, + 3, + 4 + ], + "stop": [ + "<|endoftext|>", + "<|im_start|>", + "<|im_end|>" + ] + }, + "code_prompt_style": { + "style_name": "CODEQWEN", + "fim_spec": { + "style": "PSM", + "prefix": "", + "middle": "", + "suffix": "" + }, + "repo_level_spec": { + "repo_name": "", + "file_type": "filepath", + "file_separator": "" + } + } + }, { "version": 1, "context_length": 8192, @@ -4571,7 +4637,7 @@ }, "repo_level_spec": { "file_type": "filename", - "file_separator": "#" + "file_separator": "" } } }, diff --git a/xinference/model/llm/llm_family_modelscope.json b/xinference/model/llm/llm_family_modelscope.json index 4c16ef914e..ca507d9fcb 100644 --- a/xinference/model/llm/llm_family_modelscope.json +++ b/xinference/model/llm/llm_family_modelscope.json @@ -2365,6 +2365,64 @@ ] } }, + { + "version": 1, + "context_length": 65536, + "model_name": "codeqwen1.5", + "model_lang": [ + "en", + "zh" + ], + "model_ability": [ + "code" + ], + "model_description": "CodeQwen1.5 is the Code-Specific version of Qwen1.5. It is a transformer-based decoder-only language model pretrained on a large amount of data of codes.", + "model_specs": [ + { + "model_format": "pytorch", + "model_size_in_billions": 7, + "quantizations": [ + "4-bit", + "8-bit", + "none" + ], + "model_id": "qwen/CodeQwen1.5-7B" + } + ], + "prompt_style": { + "style_name": "QWEN", + "system_prompt": "You are a helpful assistant.", + "roles": [ + "user", + "assistant" + ], + "intra_message_sep": "\n", + "stop_token_ids": [ + 2, + 3, + 4 + ], + "stop": [ + "<|endoftext|>", + "<|im_start|>", + "<|im_end|>" + ] + }, + "code_prompt_style": { + "style_name": "CODEQWEN", + "fim_spec": { + "style": "PSM", + "prefix": "", + "middle": "", + "suffix": "" + }, + "repo_level_spec": { + "repo_name": "", + "file_type": "filepath", + "file_separator": "" + } + } + }, { "version": 1, "context_length": 4096, @@ -2580,7 +2638,7 @@ }, "repo_level_spec": { "file_type": "filename", - "file_separator": "#" + "file_separator": "" } } }, diff --git a/xinference/model/llm/pytorch/core.py b/xinference/model/llm/pytorch/core.py index 2032c6b7b5..73815e43a6 100644 --- a/xinference/model/llm/pytorch/core.py +++ b/xinference/model/llm/pytorch/core.py @@ -552,6 +552,7 @@ def code_generate( self, mode: CodeGenerateMode, prompt: str, + file_path: Optional[str], suffix: Optional[str], repo_name: Optional[str], files: Optional[Mapping[str, str]], @@ -560,6 +561,7 @@ def code_generate( code_prompt = self.get_code_prompt( mode, prompt, + file_path, suffix, repo_name, files, diff --git a/xinference/model/llm/tests/test_utils.py b/xinference/model/llm/tests/test_utils.py index cc912a67fa..2058fea165 100644 --- a/xinference/model/llm/tests/test_utils.py +++ b/xinference/model/llm/tests/test_utils.py @@ -507,12 +507,14 @@ def test_code_prompt_style_starcoder(): suffix = "\n print('Hello world!')" expected = "def print_hello_world():\n \n print('Hello world!')" assert expected == CodeModelMixin._get_code_prompt( - "infill", prompt, code_prompt_style, suffix + "infill", prompt, code_prompt_style, None, suffix ) suffix = None with pytest.raises(ValueError) as exc_info: - CodeModelMixin._get_code_prompt("infill", prompt, code_prompt_style, suffix) + CodeModelMixin._get_code_prompt( + "infill", prompt, code_prompt_style, None, suffix + ) assert exc_info.value == ValueError("suffix is required in infill mode") with pytest.raises(ValueError) as exc_info: @@ -531,7 +533,9 @@ def test_code_prompt_style_deepseek_coder(): middle="<|fim▁hole|>", suffix="<|fim▁end|>", ), - repo_level_spec=RepoLevelSpecV1(file_type="filename", file_separator="#"), + repo_level_spec=RepoLevelSpecV1( + file_type="filename", file_separator="" + ), ) prompt = "#write a quick sort algorithm" @@ -569,7 +573,7 @@ def test_code_prompt_style_deepseek_coder(): return quick_sort(left) + [pivot] + quick_sort(right)<|fim▁end|>""" assert expected == CodeModelMixin._get_code_prompt( - "infill", prompt, code_prompt_style, suffix + "infill", prompt, code_prompt_style, None, suffix ) files = { @@ -638,17 +642,225 @@ def predict(self, X_test): outputs = self(X_test) _, predicted = outputs.max(1) return predicted.numpy()""", - "main.py": """from utils import load_data, evaluate_predictions + } + + prompt = """from utils import load_data, evaluate_predictions +from model import IrisClassifier as Classifier + +def main(): + # Model training and evaluation +""" + file_path = "/home/test/works/proj01/main.py" + + expected = """# utils.py +import torch +from sklearn import datasets +from sklearn.model_selection import train_test_split +from sklearn.preprocessing import StandardScaler +from sklearn.metrics import accuracy_score + +def load_data(): + iris = datasets.load_iris() + X = iris.data + y = iris.target + + # Standardize the data + scaler = StandardScaler() + X = scaler.fit_transform(X) + + X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42) + + # Convert numpy data to PyTorch tensors + X_train = torch.tensor(X_train, dtype=torch.float32) + X_test = torch.tensor(X_test, dtype=torch.float32) + y_train = torch.tensor(y_train, dtype=torch.int64) + y_test = torch.tensor(y_test, dtype=torch.int64) + + return X_train, X_test, y_train, y_test + +def evaluate_predictions(y_test, y_pred): + return accuracy_score(y_test, y_pred) +# model.py +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader, TensorDataset + +class IrisClassifier(nn.Module): + def __init__(self): + super(IrisClassifier, self).__init__() + self.fc = nn.Sequential( + nn.Linear(4, 16), + nn.ReLU(), + nn.Linear(16, 3) + ) + + def forward(self, x): + return self.fc(x) + + def train_model(self, X_train, y_train, epochs, lr, batch_size): + criterion = nn.CrossEntropyLoss() + optimizer = optim.Adam(self.parameters(), lr=lr) + + # Create DataLoader for batches + dataset = TensorDataset(X_train, y_train) + dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) + + for epoch in range(epochs): + for batch_X, batch_y in dataloader: + optimizer.zero_grad() + outputs = self(batch_X) + loss = criterion(outputs, batch_y) + loss.backward() + optimizer.step() + + def predict(self, X_test): + with torch.no_grad(): + outputs = self(X_test) + _, predicted = outputs.max(1) + return predicted.numpy() +# main.py +from utils import load_data, evaluate_predictions from model import IrisClassifier as Classifier def main(): # Model training and evaluation -""", +""" + assert expected == CodeModelMixin._get_code_prompt( + "completion", prompt, code_prompt_style, file_path, None, None, files + ) + + +def test_code_prompt_style_codeqwen(): + code_prompt_style = CodePromptStyleV1( + style_name="CODEQWEN", + fim_spec=FIMSpecV1( + style="PSM", + prefix="", + middle="", + suffix="", + ), + repo_level_spec=RepoLevelSpecV1( + repo_name="", file_type="filepath", file_separator="" + ), + ) + + prompt = "#write a quick sort algorithm" + expected = prompt + + assert expected == CodeModelMixin._get_code_prompt( + "completion", prompt, code_prompt_style + ) + + prompt = """def quick_sort(arr): + if len(arr) <= 1: + return arr + pivot = arr[0] + left = [] + right = [] +""" + suffix = """ + if arr[i] < pivot: + left.append(arr[i]) + else: + right.append(arr[i]) + return quick_sort(left) + [pivot] + quick_sort(right)""" + + expected = """def quick_sort(arr): + if len(arr) <= 1: + return arr + pivot = arr[0] + left = [] + right = [] + + if arr[i] < pivot: + left.append(arr[i]) + else: + right.append(arr[i]) + return quick_sort(left) + [pivot] + quick_sort(right)""" + + assert expected == CodeModelMixin._get_code_prompt( + "infill", prompt, code_prompt_style, None, suffix + ) + + files = { + "/home/test/works/proj01/utils.py": """import torch +from sklearn import datasets +from sklearn.model_selection import train_test_split +from sklearn.preprocessing import StandardScaler +from sklearn.metrics import accuracy_score + +def load_data(): + iris = datasets.load_iris() + X = iris.data + y = iris.target + + # Standardize the data + scaler = StandardScaler() + X = scaler.fit_transform(X) + + X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42) + + # Convert numpy data to PyTorch tensors + X_train = torch.tensor(X_train, dtype=torch.float32) + X_test = torch.tensor(X_test, dtype=torch.float32) + y_train = torch.tensor(y_train, dtype=torch.int64) + y_test = torch.tensor(y_test, dtype=torch.int64) + + return X_train, X_test, y_train, y_test + +def evaluate_predictions(y_test, y_pred): + return accuracy_score(y_test, y_pred)""", + "/home/test/works/proj01/model.py": """import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader, TensorDataset + +class IrisClassifier(nn.Module): + def __init__(self): + super(IrisClassifier, self).__init__() + self.fc = nn.Sequential( + nn.Linear(4, 16), + nn.ReLU(), + nn.Linear(16, 3) + ) + + def forward(self, x): + return self.fc(x) + + def train_model(self, X_train, y_train, epochs, lr, batch_size): + criterion = nn.CrossEntropyLoss() + optimizer = optim.Adam(self.parameters(), lr=lr) + + # Create DataLoader for batches + dataset = TensorDataset(X_train, y_train) + dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) + + for epoch in range(epochs): + for batch_X, batch_y in dataloader: + optimizer.zero_grad() + outputs = self(batch_X) + loss = criterion(outputs, batch_y) + loss.backward() + optimizer.step() + + def predict(self, X_test): + with torch.no_grad(): + outputs = self(X_test) + _, predicted = outputs.max(1) + return predicted.numpy()""", } - prompt = "" + prompt = """from utils import load_data, evaluate_predictions +from model import IrisClassifier as Classifier + +def main(): + # Model training and evaluation +""" + file_path = "/home/test/works/proj01/main.py" - expected = """#utils.py + expected = """project01 +/home/test/works/proj01/utils.py import torch from sklearn import datasets from sklearn.model_selection import train_test_split @@ -676,7 +888,7 @@ def load_data(): def evaluate_predictions(y_test, y_pred): return accuracy_score(y_test, y_pred) -#model.py +/home/test/works/proj01/model.py import torch import torch.nn as nn import torch.optim as optim @@ -715,7 +927,7 @@ def predict(self, X_test): outputs = self(X_test) _, predicted = outputs.max(1) return predicted.numpy() -#main.py +/home/test/works/proj01/main.py from utils import load_data, evaluate_predictions from model import IrisClassifier as Classifier @@ -723,7 +935,7 @@ def main(): # Model training and evaluation """ assert expected == CodeModelMixin._get_code_prompt( - "completion", prompt, code_prompt_style, None, None, files + "completion", prompt, code_prompt_style, file_path, None, "project01", files ) @@ -734,7 +946,9 @@ def test_code_prompt_style_without_fim(): prompt = "def print_hello_world():\n " suffix = "\n print('Hello world!')" with pytest.raises(ValueError) as exc_info: - CodeModelMixin._get_code_prompt("infill", prompt, code_prompt_style, suffix) + CodeModelMixin._get_code_prompt( + "infill", prompt, code_prompt_style, None, suffix + ) assert exc_info.value == ValueError( "This model is not support infill mode generate" ) diff --git a/xinference/model/llm/utils.py b/xinference/model/llm/utils.py index 317ff502be..58e82e669a 100644 --- a/xinference/model/llm/utils.py +++ b/xinference/model/llm/utils.py @@ -29,6 +29,7 @@ CompletionChunk, ) from .core import LLM +from .lang_utils import get_file_separator from .llm_family import ( CodePromptStyleV1, GgmlLLMSpecV1, @@ -722,6 +723,7 @@ def get_code_prompt( self, mode: CodeGenerateMode, prompt: str, + file_path: Optional[str] = None, suffix: Optional[str] = None, repo_name: Optional[str] = None, files: Optional[Mapping[str, str]] = None, @@ -729,7 +731,7 @@ def get_code_prompt( code_prompt_style = cast(LLM, self).model_family.code_prompt_style return { "prompt": CodeModelMixin._get_code_prompt( - mode, prompt, code_prompt_style, suffix, repo_name, files + mode, prompt, code_prompt_style, file_path, suffix, repo_name, files ) } @@ -738,6 +740,7 @@ def _get_code_prompt( mode: CodeGenerateMode, prompt: str, code_prompt_style: Optional["CodePromptStyleV1"], + file_path: Optional[str] = None, suffix: Optional[str] = None, repo_name: Optional[str] = None, files: Optional[Mapping[str, str]] = None, @@ -753,10 +756,25 @@ def _get_code_prompt( "Suffix is only required on generate type is infill, ignored" ) + spec = code_prompt_style.repo_level_spec + if files is None or len(files) == 0: - return prompt + if file_path is None or len(file_path.strip()) == 0: + return prompt + else: + if spec is None: + logger.warning( + "repository level file separator not defined, but file_path provided, ignored" + ) + return prompt + else: + repo_file = ( + file_path + if spec.file_type == "filepath" + else CodeModelMixin._path_to_name(file_path) + ) + return get_file_separator(spec, repo_file) - spec = code_prompt_style.repo_level_spec if spec is None: logger.warning( "The model does not support repository level code completion, 'repo_name' and 'files' are ignored" @@ -782,10 +800,17 @@ def _get_code_prompt( if spec.file_type == "filepath" else CodeModelMixin._path_to_name(filepath) ) - chunks.append(f"{spec.file_separator}{repo_file}\n{content}") + chunks.append(get_file_separator(spec, repo_file)) + chunks.append(content) - if len(prompt.strip()) > 0: - chunks.append(prompt) + if file_path is not None and len(file_path.strip()) > 0: + repo_file = ( + file_path + if spec.file_type == "filepath" + else CodeModelMixin._path_to_name(file_path) + ) + chunks.append(get_file_separator(spec, repo_file)) + chunks.append(prompt) return "\n".join(chunks) @@ -809,7 +834,7 @@ def _get_code_prompt( else: raise ValueError( - f"Unsupported generate mode: {mode}, only 'PSM' and 'PMS' are supported now" + f"Unsupported generate mode: {mode}, only 'completion' and 'infill' are supported now" ) @staticmethod diff --git a/xinference/model/llm/vllm/core.py b/xinference/model/llm/vllm/core.py index 289e71dd4c..291b6d6dff 100644 --- a/xinference/model/llm/vllm/core.py +++ b/xinference/model/llm/vllm/core.py @@ -122,6 +122,7 @@ class VLLMGenerateConfig(TypedDict, total=False): ] VLLM_SUPPORTED_CODE_MODELS = [ "deepseek-coder-base", + "codeqwen1.5", ] if VLLM_INSTALLED and vllm.__version__ >= "0.3.0": VLLM_SUPPORTED_CHAT_MODELS.append("qwen1.5-chat") @@ -551,6 +552,7 @@ async def async_code_generate( self, mode: CodeGenerateMode, prompt: str, + file_path: Optional[str], suffix: Optional[str], repo_name: Optional[str], files: Optional[Mapping[str, str]], @@ -559,6 +561,7 @@ async def async_code_generate( code_prompt = self.get_code_prompt( mode, prompt, + file_path, suffix, repo_name, files, diff --git a/xinference/types.py b/xinference/types.py index 0e6b745d93..0ff177a248 100644 --- a/xinference/types.py +++ b/xinference/types.py @@ -513,6 +513,7 @@ def from_dict(cls, data: Dict): class CreateCodeCompletion(CreateCompletion): + file_path: Optional[str] = none_field mode: CodeGenerateMode = "completion" repo_name: Optional[str] = none_field files: Optional[Mapping[str, str]] = none_field From 514e862b0323576988e46a64e2a0d40fc7fa3b8a Mon Sep 17 00:00:00 2001 From: Shi Hui Date: Tue, 14 May 2024 18:00:47 +0800 Subject: [PATCH 18/37] ignore the codespell since there are a lot of file extensions that are misspelled. --- setup.cfg | 2 +- xinference/model/llm/lang_utils.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/setup.cfg b/setup.cfg index fa05d9a0ba..a63e415a13 100644 --- a/setup.cfg +++ b/setup.cfg @@ -248,7 +248,7 @@ exclude = [codespell] ignore-words-list = hist,rcall,fpr,ser,nd,inout,ot,Ba,ba,asend,hart,coo,splitted,datas,fro -skip = .idea,.git,./build,./docs/build,node_modules,static,generated,*.po,*.ts,*.json,*.c,*.cpp,*.cfg,thirdparty +skip = .idea,.git,./build,./docs/build,node_modules,static,generated,*.po,*.ts,*.json,*.c,*.cpp,*.cfg,thirdparty,xinference/model/llm/lang_utils.py [isort] profile = black diff --git a/xinference/model/llm/lang_utils.py b/xinference/model/llm/lang_utils.py index 60bbbb3ea7..2007601bd9 100644 --- a/xinference/model/llm/lang_utils.py +++ b/xinference/model/llm/lang_utils.py @@ -153,6 +153,7 @@ ".ada": "ada", ".ads": "ada", ".agda": "agda", + ".als": "alloy", ".apacheconf": "apacheconf", ".vhost": "apacheconf", ".applescript": "applescript", @@ -710,6 +711,7 @@ ".rpy": "ren'py", ".rsh": "renderscript", ".robot": "robotframework", + ".rg": "rouge", ".rb": "ruby", ".builder": "ruby", ".gemspec": "ruby", @@ -796,6 +798,7 @@ ".ado": "stata", ".doh": "stata", ".ihlp": "stata", + ".mata": "stata", ".matah": "stata", ".sthlp": "stata", ".styl": "stylus", @@ -986,6 +989,7 @@ ".desktop.in": "desktop", ".ec": "ec", ".eh": "ec", + ".edn": "edn", ".fish": "fish", ".mu": "mupad", ".nc": "nesc", From 2a091aa9c098e5b01ae886068a42612342f764ec Mon Sep 17 00:00:00 2001 From: Shi Hui Date: Tue, 14 May 2024 18:06:36 +0800 Subject: [PATCH 19/37] format ui code by prettier. --- .../web/ui/src/scenes/running_models/index.js | 47 ++++++++++--------- 1 file changed, 25 insertions(+), 22 deletions(-) diff --git a/xinference/web/ui/src/scenes/running_models/index.js b/xinference/web/ui/src/scenes/running_models/index.js index 4fe5272e4b..7da17fbcf0 100644 --- a/xinference/web/ui/src/scenes/running_models/index.js +++ b/xinference/web/ui/src/scenes/running_models/index.js @@ -14,7 +14,7 @@ import Title from '../../components/Title' const RunningModels = () => { const [tabValue, setTabValue] = React.useState( - sessionStorage.getItem('runningModelType'), + sessionStorage.getItem('runningModelType') ) const [llmData, setLlmData] = useState([]) const [embeddingModelData, setEmbeddingModelData] = useState([]) @@ -44,7 +44,7 @@ const RunningModels = () => { setErrorMsg( `Login failed: ${response.status} - ${ errorData.detail || 'Unknown error' - }`, + }` ) }) } else { @@ -62,8 +62,10 @@ const RunningModels = () => { } if (newValue.model_type === 'LLM') { if (model.model_name in code_prompts) { - newValue['infill_supported'] = 'fim_spec' in code_prompts[model.model_name] - newValue['repo_level_supported'] = 'repo_level_spec' in code_prompts[model.model_name] + newValue['infill_supported'] = + 'fim_spec' in code_prompts[model.model_name] + newValue['repo_level_supported'] = + 'repo_level_spec' in code_prompts[model.model_name] } newLlmData.push(newValue) } else if (newValue.model_type === 'embedding') { @@ -118,21 +120,22 @@ const RunningModels = () => { fetcher(`${endPoint}/v1/models/code_prompts`, { method: 'GET', - }).then((response) => { - if (!response.ok) { - response.json().then((errorData) => { - setErrorMsg( - `Login failed: ${response.status} - ${ - errorData.detail || 'Unknown error' - }`, - ) - }) - } else { - response.json().then((code_prompts) => { - get_models(code_prompts) - }) - } }) + .then((response) => { + if (!response.ok) { + response.json().then((errorData) => { + setErrorMsg( + `Login failed: ${response.status} - ${ + errorData.detail || 'Unknown error' + }` + ) + }) + } else { + response.json().then((code_prompts) => { + get_models(code_prompts) + }) + } + }) .catch((error) => { console.error('Error:', error) setIsUpdatingModel(false) @@ -253,7 +256,7 @@ const RunningModels = () => { }) .then((response) => response.json()) .then(() => - window.open(openUrl, '_blank', 'noopener noreferrer'), + window.open(openUrl, '_blank', 'noopener noreferrer') ) .finally(() => setIsCallingApi(false)) } else if (response.ok) { @@ -264,7 +267,7 @@ const RunningModels = () => { } else { // Other HTTP errors console.error( - `Unexpected response status: ${response.status}`, + `Unexpected response status: ${response.status}` ) setIsCallingApi(false) } @@ -532,7 +535,7 @@ const RunningModels = () => { }) .then((response) => response.json()) .then(() => - window.open(openUrl, '_blank', 'noopener noreferrer'), + window.open(openUrl, '_blank', 'noopener noreferrer') ) .finally(() => setIsCallingApi(false)) } else if (response.ok) { @@ -543,7 +546,7 @@ const RunningModels = () => { } else { // Other HTTP errors console.error( - `Unexpected response status: ${response.status}`, + `Unexpected response status: ${response.status}` ) setIsCallingApi(false) } From c0925e5202ede6c53d5a9cd05ceeaffd67087149 Mon Sep 17 00:00:00 2001 From: Shi Hui Date: Tue, 14 May 2024 23:12:13 +0800 Subject: [PATCH 20/37] fix the get_code_prompt missing parameter. --- xinference/api/restful_api.py | 1 + xinference/core/tests/test_restful_api.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/xinference/api/restful_api.py b/xinference/api/restful_api.py index ac64bd3eba..35af760237 100644 --- a/xinference/api/restful_api.py +++ b/xinference/api/restful_api.py @@ -1537,6 +1537,7 @@ async def get_code_prompt(self, request: Request) -> Response: code_prompt = await model.get_code_prompt( body.mode, body.prompt, + body.file_path, body.suffix, body.repo_name, body.files, diff --git a/xinference/core/tests/test_restful_api.py b/xinference/core/tests/test_restful_api.py index 56ce5d7984..9f7dadc17e 100644 --- a/xinference/core/tests/test_restful_api.py +++ b/xinference/core/tests/test_restful_api.py @@ -1271,7 +1271,7 @@ def test_cluster_info(setup): assert result[1]["gpu_vram_total"] == 0 -def test_restfule_api_for_code_prompt(setup): +def test_restful_api_for_code_prompt(setup): model_name = "deepseek-coder-base" endpoint, _ = setup From f7d02fbd3cf443cf0f1d8cdb9241cc56b163dcd0 Mon Sep 17 00:00:00 2001 From: Shi Hui Date: Mon, 20 May 2024 11:39:10 +0800 Subject: [PATCH 21/37] adapt to langchain 0.2.x, which has breaking changes langchain-community need to be installed separately. --- setup.cfg | 1 + 1 file changed, 1 insertion(+) diff --git a/setup.cfg b/setup.cfg index fa05d9a0ba..437b9e59b4 100644 --- a/setup.cfg +++ b/setup.cfg @@ -77,6 +77,7 @@ dev = openai>1 opencv-contrib-python langchain + langchain-community orjson sphinx-tabs sphinx-design From 12dc09a2fbe47b3ebe9e86a5873bdc0d23ccb6b3 Mon Sep 17 00:00:00 2001 From: Shi Hui Date: Fri, 24 May 2024 10:18:32 +0800 Subject: [PATCH 22/37] add base suffix for codeqwen1.5 to diff with the official generate model --- xinference/model/llm/llm_family.json | 2 +- xinference/model/llm/llm_family_modelscope.json | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/xinference/model/llm/llm_family.json b/xinference/model/llm/llm_family.json index dd42bfdd54..6f74e84a79 100644 --- a/xinference/model/llm/llm_family.json +++ b/xinference/model/llm/llm_family.json @@ -2294,7 +2294,7 @@ { "version": 1, "context_length": 65536, - "model_name": "codeqwen1.5", + "model_name": "codeqwen1.5-base", "model_lang": [ "en", "zh" diff --git a/xinference/model/llm/llm_family_modelscope.json b/xinference/model/llm/llm_family_modelscope.json index ac768d4002..5ffdfd0f56 100644 --- a/xinference/model/llm/llm_family_modelscope.json +++ b/xinference/model/llm/llm_family_modelscope.json @@ -2530,7 +2530,7 @@ { "version": 1, "context_length": 65536, - "model_name": "codeqwen1.5", + "model_name": "codeqwen1.5-base", "model_lang": [ "en", "zh" From 4dbd3bbcb68e874e66e95cafbaa14fd58cf4e8e4 Mon Sep 17 00:00:00 2001 From: Shi Hui Date: Fri, 24 May 2024 10:20:26 +0800 Subject: [PATCH 23/37] add vllm support for codeqwen1.5-base --- xinference/model/llm/vllm/core.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xinference/model/llm/vllm/core.py b/xinference/model/llm/vllm/core.py index 60f9b86325..ed9b97735e 100644 --- a/xinference/model/llm/vllm/core.py +++ b/xinference/model/llm/vllm/core.py @@ -129,7 +129,7 @@ class VLLMGenerateConfig(TypedDict, total=False): ] VLLM_SUPPORTED_CODE_MODELS = [ "deepseek-coder-base", - "codeqwen1.5", + "codeqwen1.5-base", ] if VLLM_INSTALLED and vllm.__version__ >= "0.3.0": VLLM_SUPPORTED_CHAT_MODELS.append("qwen1.5-chat") From 02f7536d924acb4e724601ed975b84dd0d60d86d Mon Sep 17 00:00:00 2001 From: Shi Hui Date: Tue, 28 May 2024 14:04:12 +0800 Subject: [PATCH 24/37] merge llm_family definition. --- xinference/model/llm/llm_family.json | 213 ++---------------- .../model/llm/llm_family_modelscope.json | 159 ++++--------- xinference/model/llm/vllm/core.py | 4 +- 3 files changed, 62 insertions(+), 314 deletions(-) diff --git a/xinference/model/llm/llm_family.json b/xinference/model/llm/llm_family.json index 6f74e84a79..9fe0a11e22 100644 --- a/xinference/model/llm/llm_family.json +++ b/xinference/model/llm/llm_family.json @@ -2198,31 +2198,6 @@ ] } }, - { - "version": 1, - "context_length": 65536, - "model_name": "codeqwen1.5", - "model_lang": [ - "en", - "zh" - ], - "model_ability": [ - "generate" - ], - "model_description": "CodeQwen1.5 is the Code-Specific version of Qwen1.5. It is a transformer-based decoder-only language model pretrained on a large amount of data of codes.", - "model_specs": [ - { - "model_format": "pytorch", - "model_size_in_billions": 7, - "quantizations": [ - "4-bit", - "8-bit", - "none" - ], - "model_id": "Qwen/CodeQwen1.5-7B" - } - ] - }, { "version": 1, "context_length": 65536, @@ -2294,7 +2269,7 @@ { "version": 1, "context_length": 65536, - "model_name": "codeqwen1.5-base", + "model_name": "codeqwen1.5", "model_lang": [ "en", "zh" @@ -2323,25 +2298,6 @@ "model_id": "Qwen/CodeQwen1.5-7B-AWQ" } ], - "prompt_style": { - "style_name": "QWEN", - "system_prompt": "You are a helpful assistant.", - "roles": [ - "user", - "assistant" - ], - "intra_message_sep": "\n", - "stop_token_ids": [ - 2, - 3, - 4 - ], - "stop": [ - "<|endoftext|>", - "<|im_start|>", - "<|im_end|>" - ] - }, "code_prompt_style": { "style_name": "CODEQWEN", "fim_spec": { @@ -4612,7 +4568,8 @@ "zh" ], "model_ability": [ - "generate" + "generate", + "code" ], "model_description": "Deepseek Coder is composed of a series of code language models, each trained from scratch on 2T tokens, with a composition of 87% code and 13% natural language in both English and Chinese. ", "model_specs": [ @@ -4793,7 +4750,20 @@ "model_id": "TheBloke/deepseek-coder-33B-base-AWQ", "model_revision": "c7edb2d5868d61a5dcf2591933a8992c8cbe3ef4" } - ] + ], + "code_prompt_style": { + "style_name": "DEEPSEEK_CODER", + "fim_spec": { + "style": "PMS", + "prefix": "<|fim▁begin|>", + "middle": "<|fim▁hole|>", + "suffix": "<|fim▁end|>" + }, + "repo_level_spec": { + "file_type": "filename", + "file_separator": "" + } + } }, { "version": 1, @@ -4999,155 +4969,6 @@ ] } }, - { - "version": 1, - "context_length": 16384, - "model_name": "deepseek-coder-base", - "model_lang": [ - "en", - "zh" - ], - "model_ability": [ - "generate", - "code" - ], - "model_description": "deepseek-coder-base is pre-trained on project-level code corpus by employing a window size of 16K and a extra fill-in-the-blank task, to support project-level code completion and infilling.", - "model_specs": [ - { - "model_format": "pytorch", - "model_size_in_billions": "1_3", - "quantizations": [ - "4-bit", - "8-bit", - "none" - ], - "model_id": "deepseek-ai/deepseek-coder-1.3b-base", - "model_revision": "c919139c3a9b4070729c8b2cca4847ab29ca8d94" - }, - { - "model_format": "pytorch", - "model_size_in_billions": "6_7", - "quantizations": [ - "4-bit", - "8-bit", - "none" - ], - "model_id": "deepseek-ai/deepseek-coder-6.7b-base", - "model_revision": "ce2207a8bfef3ee92bd7dd4cc31c52cfa0046912" - }, - { - "model_format": "pytorch", - "model_size_in_billions": 33, - "quantizations": [ - "4-bit", - "8-bit", - "none" - ], - "model_id": "deepseek-ai/deepseek-coder-33b-base", - "model_revision": "45c85cadf3720ef3e85a492e24fd4b8c5d21d8ac" - }, - { - "model_format": "ggufv2", - "model_size_in_billions": "1_3", - "quantizations": [ - "Q2_K", - "Q3_K_L", - "Q3_K_M", - "Q3_K_S", - "Q4_0", - "Q4_K_M", - "Q4_K_S", - "Q5_0", - "Q5_K_M", - "Q5_K_S", - "Q6_K", - "Q8_0" - ], - "model_id": "TheBloke/deepseek-coder-1.3b-base-GGUF", - "model_file_name_template": "deepseek-coder-1.3b-base.{quantization}.gguf" - }, - { - "model_format": "ggufv2", - "model_size_in_billions": "6_7", - "quantizations": [ - "Q2_K", - "Q3_K_L", - "Q3_K_M", - "Q3_K_S", - "Q4_0", - "Q4_K_M", - "Q4_K_S", - "Q5_0", - "Q5_K_M", - "Q5_K_S", - "Q6_K", - "Q8_0" - ], - "model_id": "TheBloke/deepseek-coder-6.7B-base-GGUF", - "model_file_name_template": "deepseek-coder-6.7b-base.{quantization}.gguf" - }, - { - "model_format": "ggufv2", - "model_size_in_billions": 33, - "quantizations": [ - "Q2_K", - "Q3_K_L", - "Q3_K_M", - "Q3_K_S", - "Q4_0", - "Q4_K_M", - "Q4_K_S", - "Q5_0", - "Q5_K_M", - "Q5_K_S", - "Q6_K", - "Q8_0" - ], - "model_id": "TheBloke/deepseek-coder-33B-base-GGUF", - "model_file_name_template": "deepseek-coder-33b-base.{quantization}.gguf" - }, - { - "model_format": "awq", - "model_size_in_billions": "1_3", - "quantizations": [ - "Int4" - ], - "model_id": "TheBloke/deepseek-coder-1.3b-base-AWQ", - "model_revision": "ffb66f1a2a194401b4f29025edcd261d7f0a08a7" - }, - { - "model_format": "awq", - "model_size_in_billions": "6_7", - "quantizations": [ - "Int4" - ], - "model_id": "TheBloke/deepseek-coder-6.7B-base-AWQ", - "model_revision": "e3d4bdf39712665f5e9d5c05c9df6f20fe1e2d5a" - }, - { - "model_format": "awq", - "model_size_in_billions": "33", - "quantizations": [ - "Int4" - ], - "model_id": "TheBloke/deepseek-coder-33B-base-AWQ", - "model_revision": "c7edb2d5868d61a5dcf2591933a8992c8cbe3ef4" - } - ], - "code_prompt_style": { - "style_name": "DEEPSEEK_CODER", - "fim_spec": { - "style": "PMS", - "prefix": "<|fim▁begin|>", - "middle": "<|fim▁hole|>", - "suffix": "<|fim▁end|>" - }, - "repo_level_spec": { - "file_type": "filename", - "file_separator": "" - } - } - }, { "version": 1, "context_length": 4096, diff --git a/xinference/model/llm/llm_family_modelscope.json b/xinference/model/llm/llm_family_modelscope.json index 5ffdfd0f56..ad3f923052 100644 --- a/xinference/model/llm/llm_family_modelscope.json +++ b/xinference/model/llm/llm_family_modelscope.json @@ -2430,32 +2430,6 @@ ] } }, - { - "version": 1, - "context_length": 65536, - "model_name": "codeqwen1.5", - "model_lang": [ - "en", - "zh" - ], - "model_ability": [ - "generate" - ], - "model_description": "CodeQwen1.5 is the Code-Specific version of Qwen1.5. It is a transformer-based decoder-only language model pretrained on a large amount of data of codes.", - "model_specs": [ - { - "model_format": "pytorch", - "model_size_in_billions": 7, - "quantizations": [ - "4-bit", - "8-bit", - "none" - ], - "model_id": "qwen/CodeQwen1.5-7B", - "model_hub": "modelscope" - } - ] - }, { "version": 1, "context_length": 65536, @@ -2530,12 +2504,13 @@ { "version": 1, "context_length": 65536, - "model_name": "codeqwen1.5-base", + "model_name": "codeqwen1.5", "model_lang": [ "en", "zh" ], "model_ability": [ + "generate", "code" ], "model_description": "CodeQwen1.5 is the Code-Specific version of Qwen1.5. It is a transformer-based decoder-only language model pretrained on a large amount of data of codes.", @@ -2728,7 +2703,8 @@ "zh" ], "model_ability": [ - "generate" + "generate", + "code" ], "model_description": "Deepseek Coder is composed of a series of code language models, each trained from scratch on 2T tokens, with a composition of 87% code and 13% natural language in both English and Chinese.", "model_specs": [ @@ -2765,7 +2741,20 @@ "model_id": "deepseek-ai/deepseek-coder-33b-base", "model_hub": "modelscope" } - ] + ], + "code_prompt_style": { + "style_name": "DEEPSEEK_CODER", + "fim_spec": { + "style": "PMS", + "prefix": "<|fim▁begin|>", + "middle": "<|fim▁hole|>", + "suffix": "<|fim▁end|>" + }, + "repo_level_spec": { + "file_type": "filename", + "file_separator": "" + } + } }, { "version": 1, @@ -2827,68 +2816,6 @@ ] } }, - { - "version": 1, - "context_length": 16384, - "model_name": "deepseek-coder-base", - "model_lang": [ - "en", - "zh" - ], - "model_ability": [ - "generate", - "code" - ], - "model_description": "deepseek-coder-base is pre-trained on project-level code corpus by employing a window size of 16K and a extra fill-in-the-blank task, to support project-level code completion and infilling.", - "model_specs": [ - { - "model_format": "pytorch", - "model_size_in_billions": "1_3", - "quantizations": [ - "4-bit", - "8-bit", - "none" - ], - "model_id": "deepseek-ai/deepseek-coder-1.3b-base", - "model_hub": "modelscope" - }, - { - "model_format": "pytorch", - "model_size_in_billions": "6_7", - "quantizations": [ - "4-bit", - "8-bit", - "none" - ], - "model_id": "deepseek-ai/deepseek-coder-6.7b-base", - "model_hub": "modelscope" - }, - { - "model_format": "pytorch", - "model_size_in_billions": 33, - "quantizations": [ - "4-bit", - "8-bit", - "none" - ], - "model_id": "deepseek-ai/deepseek-coder-33b-base", - "model_hub": "modelscope" - } - ], - "code_prompt_style": { - "style_name": "DEEPSEEK_CODER", - "fim_spec": { - "style": "PMS", - "prefix": "<|fim▁begin|>", - "middle": "<|fim▁hole|>", - "suffix": "<|fim▁end|>" - }, - "repo_level_spec": { - "file_type": "filename", - "file_separator": "" - } - } - }, { "version": 1, "context_length": 4096, @@ -3224,44 +3151,44 @@ } }, { - "version":1, - "context_length":2048, - "model_name":"OmniLMM", - "model_lang":[ + "version": 1, + "context_length": 2048, + "model_name": "OmniLMM", + "model_lang": [ "en", "zh" ], - "model_ability":[ + "model_ability": [ "chat", "vision" ], - "model_description":"mniLMM is a family of open-source large multimodal models (LMMs) adept at vision & language modeling.", - "model_specs":[ + "model_description": "mniLMM is a family of open-source large multimodal models (LMMs) adept at vision & language modeling.", + "model_specs": [ { - "model_format":"pytorch", - "model_size_in_billions":3, - "quantizations":[ + "model_format": "pytorch", + "model_size_in_billions": 3, + "quantizations": [ "none" ], - "model_id":"OpenBMB/MiniCPM-V", - "model_hub":"modelscope", - "model_revision":"master" + "model_id": "OpenBMB/MiniCPM-V", + "model_hub": "modelscope", + "model_revision": "master" }, { - "model_format":"pytorch", - "model_size_in_billions":12, - "quantizations":[ + "model_format": "pytorch", + "model_size_in_billions": 12, + "quantizations": [ "none" ], - "model_id":"OpenBMB/OmniLMM-12B", - "model_hub":"modelscope", - "model_revision":"master" + "model_id": "OpenBMB/OmniLMM-12B", + "model_hub": "modelscope", + "model_revision": "master" } ], - "prompt_style":{ - "style_name":"OmniLMM", - "system_prompt":"The role of first msg should be user", - "roles":[ + "prompt_style": { + "style_name": "OmniLMM", + "system_prompt": "The role of first msg should be user", + "roles": [ "user", "assistant" ] @@ -3726,7 +3653,7 @@ ], "intra_message_sep": "\n", "inter_message_sep": "<|end|>\n", - "stop_token_ids":[ + "stop_token_ids": [ 32000, 32007 ], @@ -3770,7 +3697,7 @@ ], "intra_message_sep": "\n", "inter_message_sep": "<|end|>\n", - "stop_token_ids":[ + "stop_token_ids": [ 32000, 32007 ], diff --git a/xinference/model/llm/vllm/core.py b/xinference/model/llm/vllm/core.py index ed9b97735e..de9b55c32d 100644 --- a/xinference/model/llm/vllm/core.py +++ b/xinference/model/llm/vllm/core.py @@ -128,8 +128,8 @@ class VLLMGenerateConfig(TypedDict, total=False): "deepseek-coder-instruct", ] VLLM_SUPPORTED_CODE_MODELS = [ - "deepseek-coder-base", - "codeqwen1.5-base", + "deepseek-coder", + "codeqwen1.5", ] if VLLM_INSTALLED and vllm.__version__ >= "0.3.0": VLLM_SUPPORTED_CHAT_MODELS.append("qwen1.5-chat") From ecbf89358e849144b9196824a968677bf7daa62b Mon Sep 17 00:00:00 2001 From: Shi Hui Date: Tue, 28 May 2024 15:43:42 +0800 Subject: [PATCH 25/37] modified the model names to use latest model name in definition. --- xinference/core/tests/test_restful_api.py | 20 ++++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/xinference/core/tests/test_restful_api.py b/xinference/core/tests/test_restful_api.py index 9f7dadc17e..da8635a14c 100644 --- a/xinference/core/tests/test_restful_api.py +++ b/xinference/core/tests/test_restful_api.py @@ -1272,7 +1272,7 @@ def test_cluster_info(setup): def test_restful_api_for_code_prompt(setup): - model_name = "deepseek-coder-base" + model_name = "deepseek-coder" endpoint, _ = setup url = f"{endpoint}/v1/models" @@ -1284,7 +1284,7 @@ def test_restful_api_for_code_prompt(setup): # launch payload = { - "model_uid": "deepseek-coder-base", + "model_uid": "deepseek-coder", "model_name": model_name, "model_type": "LLM", "model_engine": "llama.cpp", @@ -1295,7 +1295,7 @@ def test_restful_api_for_code_prompt(setup): response = requests.post(url, json=payload) response_data = response.json() model_uid_res = response_data["model_uid"] - assert model_uid_res == "deepseek-coder-base" + assert model_uid_res == "deepseek-coder" response = requests.get(url) response_data = response.json() @@ -1316,7 +1316,7 @@ def test_restful_api_for_code_prompt(setup): # test multiple payload = { - "model": "deepseek-coder-base", + "model": "deepseek-coder", "mode": "infill", "prompt": """def quick_sort(arr): if len(arr) <= 1: @@ -1353,7 +1353,7 @@ def test_restful_api_for_code_prompt(setup): ) # delete model - url = f"{endpoint}/v1/models/deepseek-coder-base" + url = f"{endpoint}/v1/models/deepseek-coder" response = requests.delete(url) assert response.status_code == 200 @@ -1363,7 +1363,7 @@ def test_restful_api_for_code_prompt(setup): def test_restful_api_for_code_completions(setup): - model_name = "deepseek-coder-base" + model_name = "deepseek-coder" endpoint, _ = setup url = f"{endpoint}/v1/models" @@ -1375,7 +1375,7 @@ def test_restful_api_for_code_completions(setup): # launch payload = { - "model_uid": "deepseek-coder-base", + "model_uid": "deepseek-coder", "model_name": model_name, "model_type": "LLM", "model_engine": "llama.cpp", @@ -1386,7 +1386,7 @@ def test_restful_api_for_code_completions(setup): response = requests.post(url, json=payload) response_data = response.json() model_uid_res = response_data["model_uid"] - assert model_uid_res == "deepseek-coder-base" + assert model_uid_res == "deepseek-coder" response = requests.get(url) response_data = response.json() @@ -1408,7 +1408,7 @@ def test_restful_api_for_code_completions(setup): # test multiple payload = { - "model": "deepseek-coder-base", + "model": "deepseek-coder", "mode": "infill", "prompt": """def quick_sort(arr): if len(arr) <= 1: @@ -1433,7 +1433,7 @@ def test_restful_api_for_code_completions(setup): assert coding_res["choices"][0]["finish_reason"] == "stop" # delete model - url = f"{endpoint}/v1/models/deepseek-coder-base" + url = f"{endpoint}/v1/models/deepseek-coder" response = requests.delete(url) assert response.status_code == 200 From 71843c1a358e5ee479de4170bafd709cb6e74aef Mon Sep 17 00:00:00 2001 From: Shi Hui Date: Tue, 28 May 2024 16:33:27 +0800 Subject: [PATCH 26/37] modified the model names to use latest model name in definition. --- xinference/core/tests/test_restful_api.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xinference/core/tests/test_restful_api.py b/xinference/core/tests/test_restful_api.py index da8635a14c..56a105d389 100644 --- a/xinference/core/tests/test_restful_api.py +++ b/xinference/core/tests/test_restful_api.py @@ -1304,7 +1304,7 @@ def test_restful_api_for_code_prompt(setup): # test embedding url = f"{endpoint}/v1/code/prompt" payload = { - "model": "deepseek-coder-base", + "model": "deepseek-coder", "prompt": "#write a quick sort algorithm", } response = requests.post(url, json=payload) @@ -1395,7 +1395,7 @@ def test_restful_api_for_code_completions(setup): # test embedding url = f"{endpoint}/v1/code/completions" payload = { - "model": "deepseek-coder-base", + "model": "deepseek-coder", "prompt": "#write a quick sort algorithm", "max_tokens": 4096, } From d8e025957b90eac48348fb072c72c62dbebdf7d9 Mon Sep 17 00:00:00 2001 From: Shi Hui Date: Thu, 13 Jun 2024 15:15:25 +0800 Subject: [PATCH 27/37] add model_hub for model definition in llm_family_modelscope.json. --- xinference/model/llm/llm_family_modelscope.json | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/xinference/model/llm/llm_family_modelscope.json b/xinference/model/llm/llm_family_modelscope.json index 18455f366b..323100390b 100644 --- a/xinference/model/llm/llm_family_modelscope.json +++ b/xinference/model/llm/llm_family_modelscope.json @@ -2780,7 +2780,8 @@ "8-bit", "none" ], - "model_id": "qwen/CodeQwen1.5-7B" + "model_id": "qwen/CodeQwen1.5-7B", + "model_hub": "modelscope" } ], "prompt_style": { From 06ead5b911b3c56a5ae60358eabe5e7552af80e1 Mon Sep 17 00:00:00 2001 From: Shi Hui Date: Wed, 10 Jul 2024 17:13:15 +0800 Subject: [PATCH 28/37] add the missing import module --- xinference/model/llm/__init__.py | 1 + 1 file changed, 1 insertion(+) diff --git a/xinference/model/llm/__init__.py b/xinference/model/llm/__init__.py index 1bfb0df0b8..f2075993c7 100644 --- a/xinference/model/llm/__init__.py +++ b/xinference/model/llm/__init__.py @@ -50,6 +50,7 @@ MLXLLMSpecV1, PromptStyleV1, PytorchLLMSpecV1, + RepoLevelCodeCompletionSpecV1, get_cache_status, get_user_defined_llm_families, match_llm, From 1901603deeb70dfea6cca2f55cc2cb11a7550699 Mon Sep 17 00:00:00 2001 From: Shi Hui Date: Wed, 10 Jul 2024 17:17:35 +0800 Subject: [PATCH 29/37] format the frontend code. --- xinference/web/ui/src/scenes/running_models/index.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xinference/web/ui/src/scenes/running_models/index.js b/xinference/web/ui/src/scenes/running_models/index.js index 74ef74cc55..ab106faa85 100644 --- a/xinference/web/ui/src/scenes/running_models/index.js +++ b/xinference/web/ui/src/scenes/running_models/index.js @@ -91,7 +91,7 @@ const RunningModels = () => { console.error('Error:', error) setIsUpdatingModel(false) if (error.response.status !== 403 && error.response.status !== 401) { - setErrorMsg(error.message) + setErrorMsg(error.message) } }) } From b71e3b6b5cbd3bb6748781dba0397180ef3bbde8 Mon Sep 17 00:00:00 2001 From: Shi Hui Date: Wed, 10 Jul 2024 17:53:13 +0800 Subject: [PATCH 30/37] fix the wrong usage of fetchWrapper --- .../web/ui/src/scenes/running_models/index.js | 96 +++++++------------ 1 file changed, 36 insertions(+), 60 deletions(-) diff --git a/xinference/web/ui/src/scenes/running_models/index.js b/xinference/web/ui/src/scenes/running_models/index.js index ab106faa85..7cb8318d44 100644 --- a/xinference/web/ui/src/scenes/running_models/index.js +++ b/xinference/web/ui/src/scenes/running_models/index.js @@ -39,53 +39,41 @@ const RunningModels = () => { fetchWrapper .get('/v1/models') .then((response) => { - if (!response.ok) { - response.json().then((errorData) => { - setErrorMsg( - `Login failed: ${response.status} - ${ - errorData.detail || 'Unknown error' - }` - ) - }) - } else { - response.json().then((response) => { - const newLlmData = [] - const newEmbeddingModelData = [] - const newImageModelData = [] - const newAudioModelData = [] - const newRerankModelData = [] - response.data.forEach((model) => { - let newValue = { - ...model, - id: model.id, - url: model.id, - } - if (newValue.model_type === 'LLM') { - if (model.model_name in code_prompts) { - newValue['infill_supported'] = - 'fim_spec' in code_prompts[model.model_name] - newValue['repo_level_supported'] = - 'repo_level_spec' in code_prompts[model.model_name] - } - newLlmData.push(newValue) - } else if (newValue.model_type === 'embedding') { - newEmbeddingModelData.push(newValue) - } else if (newValue.model_type === 'audio') { - newAudioModelData.push(newValue) - } else if (newValue.model_type === 'image') { - newImageModelData.push(newValue) - } else if (newValue.model_type === 'rerank') { - newRerankModelData.push(newValue) - } - }) - setLlmData(newLlmData) - setEmbeddingModelData(newEmbeddingModelData) - setAudioModelData(newAudioModelData) - setImageModelData(newImageModelData) - setRerankModelData(newRerankModelData) - setIsUpdatingModel(false) - }) - } + const newLlmData = [] + const newEmbeddingModelData = [] + const newImageModelData = [] + const newAudioModelData = [] + const newRerankModelData = [] + response.data.forEach((model) => { + let newValue = { + ...model, + id: model.id, + url: model.id, + } + if (newValue.model_type === 'LLM') { + if (model.model_name in code_prompts) { + newValue['infill_supported'] = + 'fim_spec' in code_prompts[model.model_name] + newValue['repo_level_supported'] = + 'repo_level_spec' in code_prompts[model.model_name] + } + newLlmData.push(newValue) + } else if (newValue.model_type === 'embedding') { + newEmbeddingModelData.push(newValue) + } else if (newValue.model_type === 'audio') { + newAudioModelData.push(newValue) + } else if (newValue.model_type === 'image') { + newImageModelData.push(newValue) + } else if (newValue.model_type === 'rerank') { + newRerankModelData.push(newValue) + } + }) + setLlmData(newLlmData) + setEmbeddingModelData(newEmbeddingModelData) + setAudioModelData(newAudioModelData) + setImageModelData(newImageModelData) + setRerankModelData(newRerankModelData) + setIsUpdatingModel(false) }) .catch((error) => { console.error('Error:', error) @@ -125,19 +113,7 @@ const RunningModels = () => { fetchWrapper .get('/v1/models/code_prompts') .then((response) => { - if (!response.ok) { - response.json().then((errorData) => { - setErrorMsg( - `Login failed: ${response.status} - ${ - errorData.detail || 'Unknown error' - }` - ) - }) - } else { - response.json().then((code_prompts) => { - get_models(code_prompts) - }) - } + get_models(response.data) }) .catch((error) => { console.error('Error:', error) From 3869b3064e305b38eb34079f974e99f1b102e094 Mon Sep 17 00:00:00 2001 From: Shi Hui Date: Wed, 10 Jul 2024 18:04:54 +0800 Subject: [PATCH 31/37] fix wrong code_prompts get logic --- xinference/web/ui/src/scenes/running_models/index.js | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/xinference/web/ui/src/scenes/running_models/index.js b/xinference/web/ui/src/scenes/running_models/index.js index 7cb8318d44..161431195b 100644 --- a/xinference/web/ui/src/scenes/running_models/index.js +++ b/xinference/web/ui/src/scenes/running_models/index.js @@ -112,8 +112,8 @@ const RunningModels = () => { fetchWrapper .get('/v1/models/code_prompts') - .then((response) => { - get_models(response.data) + .then((code_prompts) => { + get_models(code_prompts) }) .catch((error) => { console.error('Error:', error) From efa6c4721983d81ee1a1f29f14154840c29d23bb Mon Sep 17 00:00:00 2001 From: Shi Hui Date: Thu, 18 Jul 2024 11:48:05 +0800 Subject: [PATCH 32/37] format the frontend code --- xinference/web/ui/src/scenes/running_models/index.js | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xinference/web/ui/src/scenes/running_models/index.js b/xinference/web/ui/src/scenes/running_models/index.js index cb47cc8704..53fef83345 100644 --- a/xinference/web/ui/src/scenes/running_models/index.js +++ b/xinference/web/ui/src/scenes/running_models/index.js @@ -69,7 +69,7 @@ const RunningModels = () => { } else if (newValue.model_type === 'rerank') { newRerankModelData.push(newValue) } else if (newValue.model_type === 'flexible') { - newFlexibleModelData.push(newValue) + newFlexibleModelData.push(newValue) } }) setLlmData(newLlmData) From 7ae7c77eba3d281938ce1857023b7d21993a6bd6 Mon Sep 17 00:00:00 2001 From: Shi Hui Date: Thu, 18 Jul 2024 14:04:48 +0800 Subject: [PATCH 33/37] to use the right call wrapper method. --- xinference/api/restful_api.py | 2 +- xinference/core/model.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/xinference/api/restful_api.py b/xinference/api/restful_api.py index 7e30c6c263..b05db3e939 100644 --- a/xinference/api/restful_api.py +++ b/xinference/api/restful_api.py @@ -1821,7 +1821,7 @@ async def get_code_prompt(self, request: Request) -> Response: detail="mode must be one of 'completion' or 'infill'", ) - body = CreateCodeCompletion.parse_obj(json_data) + body = CreateCodeCompletion.model_validate_json(json_data) model_uid = body.model diff --git a/xinference/core/model.py b/xinference/core/model.py index 984001a28a..2f7dccbd86 100644 --- a/xinference/core/model.py +++ b/xinference/core/model.py @@ -590,7 +590,7 @@ async def code_generate( ) return response if hasattr(self._model, "async_code_generate"): - response = await self._call_wrapper( + response = await self._call_wrapper_json( self._model.async_code_generate, mode, prompt, @@ -638,7 +638,7 @@ async def get_code_prompt( from ..model.llm.utils import CodeModelMixin if isinstance(self._model, CodeModelMixin): - return await self._call_wrapper( + return await self._call_wrapper_json( self._model.get_code_prompt, mode, prompt, From 533fcfc0efbf6d13e82bf698fe8344796858bfad Mon Sep 17 00:00:00 2001 From: Shi Hui Date: Thu, 18 Jul 2024 14:42:32 +0800 Subject: [PATCH 34/37] to use the right call wrapper method, again --- xinference/core/model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xinference/core/model.py b/xinference/core/model.py index 2f7dccbd86..0152d7f516 100644 --- a/xinference/core/model.py +++ b/xinference/core/model.py @@ -577,7 +577,7 @@ async def code_generate( response = None try: if hasattr(self._model, "code_generate"): - response = await self._call_wrapper( + response = await self._call_wrapper_json( self._model.code_generate, mode, prompt, From aa724ac22bfd6aabc80bd7e43e2d93eaca45b538 Mon Sep 17 00:00:00 2001 From: Shi Hui Date: Thu, 18 Jul 2024 16:02:57 +0800 Subject: [PATCH 35/37] reversed code to use parse_obj --- xinference/api/restful_api.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/xinference/api/restful_api.py b/xinference/api/restful_api.py index b05db3e939..7e30c6c263 100644 --- a/xinference/api/restful_api.py +++ b/xinference/api/restful_api.py @@ -1821,7 +1821,7 @@ async def get_code_prompt(self, request: Request) -> Response: detail="mode must be one of 'completion' or 'infill'", ) - body = CreateCodeCompletion.model_validate_json(json_data) + body = CreateCodeCompletion.parse_obj(json_data) model_uid = body.model From 5a3131269eb4fb852f69434b5357708c58472bff Mon Sep 17 00:00:00 2001 From: Shi Hui Date: Wed, 31 Jul 2024 14:15:33 +0800 Subject: [PATCH 36/37] remove vllm disable setting check. --- xinference/model/llm/vllm/core.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/xinference/model/llm/vllm/core.py b/xinference/model/llm/vllm/core.py index e79cf32e1b..7bb6a41cda 100644 --- a/xinference/model/llm/vllm/core.py +++ b/xinference/model/llm/vllm/core.py @@ -618,8 +618,6 @@ class VLLMCodeModel(VLLMModel, CodeModelMixin): def match( cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str ) -> bool: - if XINFERENCE_DISABLE_VLLM: - return False if llm_spec.model_format not in ["pytorch", "gptq", "awq"]: return False if llm_spec.model_format == "pytorch": From c6948cac198744706c197e272ea0f5b15372ab80 Mon Sep 17 00:00:00 2001 From: Shi Hui Date: Sat, 24 Aug 2024 16:05:25 +0800 Subject: [PATCH 37/37] change the format of starcoder from gglmv3 to ggufv2 --- xinference/model/llm/llm_family.json | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/xinference/model/llm/llm_family.json b/xinference/model/llm/llm_family.json index 0d4ca73869..3c6fdd37f1 100644 --- a/xinference/model/llm/llm_family.json +++ b/xinference/model/llm/llm_family.json @@ -2616,17 +2616,13 @@ "model_description": "Starcoder is an open-source Transformer based LLM that is trained on permissively licensed data from GitHub.", "model_specs": [ { - "model_format": "ggmlv3", + "model_format": "ggufv2", "model_size_in_billions": 16, "quantizations": [ - "q4_0", - "q4_1", - "q5_0", - "q5_1", - "q8_0" + "q5_k_m" ], - "model_id": "TheBloke/starcoder-GGML", - "model_file_name_template": "starcoder.ggmlv3.{quantization}.bin" + "model_id": "osukhoroslov-hw/starcoder-Q5_K_M-GGUF", + "model_file_name_template": "starcoder-{quantization}.gguf" } ], "code_prompt_style": {