diff --git a/setup.cfg b/setup.cfg index 7b83211c84..91e2256adb 100644 --- a/setup.cfg +++ b/setup.cfg @@ -292,7 +292,7 @@ exclude = [codespell] ignore-words-list = hist,rcall,fpr,ser,nd,inout,ot,Ba,ba,asend,hart,coo,splitted,datas,fro -skip = .idea,.git,./build,./docs/build,node_modules,static,generated,*.po,*.ts,*.json,*.c,*.cpp,*.cfg,thirdparty +skip = .idea,.git,./build,./docs/build,node_modules,static,generated,*.po,*.ts,*.json,*.c,*.cpp,*.cfg,thirdparty,xinference/model/llm/lang_utils.py [isort] profile = black diff --git a/setup.py b/setup.py index fe7c02301a..56f19a1759 100644 --- a/setup.py +++ b/setup.py @@ -73,6 +73,7 @@ class CustomDevelop(ExtraCommandMixin, develop): class CustomSDist(ExtraCommandMixin, sdist): pass + class BuildWeb(Command): """build_web command""" diff --git a/xinference/api/restful_api.py b/xinference/api/restful_api.py index 9c553c4c5d..acb206d650 100644 --- a/xinference/api/restful_api.py +++ b/xinference/api/restful_api.py @@ -62,6 +62,7 @@ ChatCompletionMessage, Completion, CreateChatCompletion, + CreateCodeCompletion, CreateCompletion, ImageList, PeftModelConfig, @@ -158,6 +159,8 @@ class BuildGradioInterfaceRequest(BaseModel): model_ability: List[str] model_description: str model_lang: List[str] + infill_supported: Optional[bool] + repo_level_supported: Optional[bool] class BuildGradioImageInterfaceRequest(BaseModel): @@ -258,6 +261,9 @@ async def internal_exception_handler(request: Request, exc: Exception): self._router.add_api_route( "/v1/models/prompts", self._get_builtin_prompts, methods=["GET"] ) + self._router.add_api_route( + "/v1/models/code_prompts", self._get_builtin_code_prompts, methods=["GET"] + ) self._router.add_api_route( "/v1/models/families", self._get_builtin_families, methods=["GET"] ) @@ -554,6 +560,29 @@ async def internal_exception_handler(request: Request, exc: Exception): ), ) + self._router.add_api_route( + "/v1/code/completions", + self.create_code_completion, + methods=["POST"], + response_model=Completion, + dependencies=( + [Security(self._auth_service, scopes=["models:read"])] + if self.is_authenticated() + else None + ), + ) + + self._router.add_api_route( + "/v1/code/prompt", + self.get_code_prompt, + methods=["POST"], + dependencies=( + [Security(self._auth_service, scopes=["models:read"])] + if self.is_authenticated() + else None + ), + ) + # for custom models self._router.add_api_route( "/v1/model_registrations/{model_type}", @@ -743,6 +772,18 @@ async def _get_builtin_prompts(self) -> JSONResponse: logger.error(e, exc_info=True) raise HTTPException(status_code=500, detail=str(e)) + async def _get_builtin_code_prompts(self) -> JSONResponse: + """ + For internal usage + :return: + """ + try: + data = await (await self._get_supervisor_ref()).get_builtin_code_prompts() + return JSONResponse(content=data) + except Exception as e: + logger.error(e, exc_info=True) + raise HTTPException(status_code=500, detail=str(e)) + async def _get_builtin_families(self) -> JSONResponse: """ For internal usage @@ -1003,6 +1044,8 @@ async def build_gradio_interface( model_description=body.model_description, model_lang=body.model_lang, access_token=access_token, + infill_supported=body.infill_supported, + repo_level_supported=body.repo_level_supported, ).build() gr.mount_gradio_app(self._app, interface, f"/{model_uid}") except ValueError as ve: @@ -1763,6 +1806,115 @@ async def stream_results(): self.handle_request_limit_error(e) raise HTTPException(status_code=500, detail=str(e)) + async def create_code_completion(self, request: Request) -> Response: + json_data = await request.json() + + if "mode" in json_data and json_data["mode"] not in ("completion", "infill"): + raise HTTPException( + status_code=400, + detail="mode must be one of 'completion' or 'infill'", + ) + + if json_data.get("stream", False): + json_data["stream"] = False + + body = CreateCodeCompletion.parse_obj(json_data) + exclude = { + "mode", + "prompt", + "file_path", + "suffix", + "repo_name", + "files", + "model", + "n", + "messages", + "logit_bias", + "logit_bias_type", + "user", + } + + kwargs = body.dict(exclude_unset=True, exclude=exclude) + + # TODO: Decide if this default value override is necessary #1061 + if body.max_tokens is None: + kwargs["max_tokens"] = max_tokens_field.default + + if body.logit_bias is not None: + raise HTTPException(status_code=501, detail="Not implemented") + + model_uid = body.model + + try: + model = await (await self._get_supervisor_ref()).get_model(model_uid) + except ValueError as ve: + logger.error(str(ve), exc_info=True) + await self._report_error_event(model_uid, str(ve)) + raise HTTPException(status_code=400, detail=str(ve)) + except Exception as e: + logger.error(e, exc_info=True) + await self._report_error_event(model_uid, str(e)) + raise HTTPException(status_code=500, detail=str(e)) + + assert not body.stream + + try: + data = await model.code_generate( + body.mode, + body.prompt, + body.file_path, + body.suffix, + body.repo_name, + body.files, + kwargs, + ) + return Response(content=data, media_type="application/json") + except Exception as e: + logger.error(e, exc_info=True) + await self._report_error_event(model_uid, str(e)) + self.handle_request_limit_error(e) + raise HTTPException(status_code=500, detail=str(e)) + + async def get_code_prompt(self, request: Request) -> Response: + json_data = await request.json() + + if "mode" in json_data and json_data["mode"] not in ("completion", "infill"): + raise HTTPException( + status_code=400, + detail="mode must be one of 'completion' or 'infill'", + ) + + body = CreateCodeCompletion.parse_obj(json_data) + + model_uid = body.model + + try: + model = await (await self._get_supervisor_ref()).get_model(model_uid) + except ValueError as ve: + logger.error(str(ve), exc_info=True) + await self._report_error_event(model_uid, str(ve)) + raise HTTPException(status_code=400, detail=str(ve)) + except Exception as e: + logger.error(e, exc_info=True) + await self._report_error_event(model_uid, str(e)) + raise HTTPException(status_code=500, detail=str(e)) + + try: + code_prompt = await model.get_code_prompt( + body.mode, + body.prompt, + body.file_path, + body.suffix, + body.repo_name, + body.files, + ) + return Response(content=code_prompt, media_type="application/json") + except Exception as e: + logger.error(e, exc_info=True) + await self._report_error_event(model_uid, str(e)) + self.handle_request_limit_error(e) + raise HTTPException(status_code=500, detail=str(e)) + async def query_engines_by_model_name(self, model_name: str) -> JSONResponse: try: content = await ( diff --git a/xinference/client/restful/restful_client.py b/xinference/client/restful/restful_client.py index 679f65d296..2d6f2a24d7 100644 --- a/xinference/client/restful/restful_client.py +++ b/xinference/client/restful/restful_client.py @@ -25,6 +25,7 @@ ChatCompletion, ChatCompletionChunk, ChatCompletionMessage, + CodeGenerateMode, Completion, CompletionChunk, Embedding, @@ -552,6 +553,144 @@ def chat( return response_data +class RESTfulCodeModelHandle(RESTfulGenerateModelHandle): + def code_generate( + self, + mode: "CodeGenerateMode", + prompt: str, + file_path: Optional[str] = None, + suffix: Optional[str] = None, + repo_name: Optional[str] = None, + files: Optional[typing.Mapping] = None, + generate_config: Optional[ + Union["LlamaCppGenerateConfig", "PytorchGenerateConfig"] + ] = None, + ) -> "Completion": + """ + Given code generation hint to complete the code, the model will return a response via RESTful APIs. + + Parameters + ---------- + mode: Literal["completion", "infill"] + Code Generation mode + Completion includes code fragment completion and repository level code completion + Infill is fill in middle completion, complete the code according provided prefix and suffix content. + prompt: str + The user's input, it presents prefix content in infill mode. + file_path: Optional[str] + The file path for prompt content file. + suffix: Optional[str] + The suffix content in infill mode. + repo_name: Optional[str] + The repository name in repository level code completion mode. + files: Optional[Mapping] + The file name/path and its content key values in repository level code completion mode + generate_config: Optional[Union["LlamaCppGenerateConfig", "PytorchGenerateConfig"]] + Additional configuration for the chat generation. + "LlamaCppGenerateConfig" -> configuration for ggml model + "PytorchGenerateConfig" -> configuration for pytorch model + + Returns + ------- + "Completion" + + Raises + ------ + RuntimeError + Report the failure to generate the code from the server. Detailed information provided in error message. + + """ + + url = f"{self._base_url}/v1/code/completions" + + request_body: Dict[str, Any] = { + "model": self._model_uid, + "mode": mode, + "prompt": prompt, + "file_path": file_path, + "suffix": suffix, + "repo_name": repo_name, + "files": files, + } + + if generate_config is not None: + for key, value in generate_config.items(): + request_body[key] = value + + response = requests.post(url, json=request_body, headers=self.auth_headers) + + if response.status_code != 200: + raise RuntimeError( + f"Failed to generate code completion, detail: {_get_error_string(response)}" + ) + + response_data = response.json() + return response_data + + def get_code_prompt( + self, + mode: "CodeGenerateMode", + prompt: str, + file_path: Optional[str] = None, + suffix: Optional[str] = None, + repo_name: Optional[str] = None, + files: Optional[typing.Mapping] = None, + ) -> str: + """ + Given code generating prompt which can be used to complete the code, the model will return a response via + RESTful APIs. + + Parameters + ---------- + mode: Literal["completion", "infill"] + Code Generation mode + Completion includes code fragment completion and repository level code completion + Infill is fill in middle completion, complete the code according provided prefix and suffix content. + prompt: str + The user's input, it presents prefix content in infill mode. + file_path: Optional[str] + The file path for prompt + suffix: Optional[str] + The suffix content in infill mode. + repo_name: Optional[str] + The repository name in repository level code completion mode. + files: Optional[Mapping] + The file name/path and its content key values in repository level code completion mode + + Returns + ------- + {"prompt": "generated prompt"} + + Raises + ------ + RuntimeError + Report the failure to generate the code prompt from the server. + Detailed information provided in error message. + + """ + + url = f"{self._base_url}/v1/code/prompt" + + request_body: Dict[str, Any] = { + "model": self._model_uid, + "mode": mode, + "prompt": prompt, + "file_path": file_path, + "suffix": suffix, + "repo_name": repo_name, + "files": files, + } + + response = requests.post(url, json=request_body, headers=self.auth_headers) + + if response.status_code != 200: + raise RuntimeError( + f"Failed to generate code prompt generating, detail: {_get_error_string(response)}" + ) + + return response.json() + + class RESTfulAudioModelHandle(RESTfulModelHandle): def transcriptions( self, @@ -1032,6 +1171,10 @@ def get_model(self, model_uid: str) -> RESTfulModelHandle: return RESTfulChatModelHandle( model_uid, self.base_url, auth_headers=self._headers ) + elif "code" in desc["model_ability"]: + return RESTfulCodeModelHandle( + model_uid, self.base_url, auth_headers=self._headers + ) elif "generate" in desc["model_ability"]: return RESTfulGenerateModelHandle( model_uid, self.base_url, auth_headers=self._headers @@ -1384,6 +1527,36 @@ def abort_request(self, model_uid: str, request_id: str): response_data = response.json() return response_data + def list_builtin_prompts(self): + """ + Get the builtin prompts + :return: List[Dict[str, Any]] + The builtin prompts + """ + url = f"{self.base_url}/v1/models/prompts" + response = requests.get(url, headers=self._headers) + if response.status_code != 200: + raise RuntimeError( + f"Failed to get builtin prompts, details: {_get_error_string(response)}" + ) + response_data = response.json() + return response_data + + def list_builtin_code_prompts(self): + """ + Get the builtin code prompts + :return: List[Dict[str, Any]] + The builtin code prompts + """ + url = f"{self.base_url}/v1/models/code_prompts" + response = requests.get(url, headers=self._headers) + if response.status_code != 200: + raise RuntimeError( + f"Failed to get builtin code prompts, details: {_get_error_string(response)}" + ) + response_data = response.json() + return response_data + def get_workers_info(self): url = f"{self.base_url}/v1/workers" response = requests.get(url, headers=self._headers) diff --git a/xinference/core/chat_interface.py b/xinference/core/chat_interface.py index 8738141f90..3d98e344e0 100644 --- a/xinference/core/chat_interface.py +++ b/xinference/core/chat_interface.py @@ -20,11 +20,12 @@ import gradio as gr import PIL.Image -from gradio.components import Markdown, Textbox +from gradio.components import Code, Dropdown, File, Markdown, Textbox from gradio.layouts import Accordion, Column, Row from ..client.restful.restful_client import ( RESTfulChatModelHandle, + RESTfulCodeModelHandle, RESTfulGenerateModelHandle, ) from ..types import ChatCompletionMessage @@ -32,6 +33,34 @@ logger = logging.getLogger(__name__) +def compare_history(current, hist): + if current["mode"] != hist["mode"]: + return False + + if current["prompt"] != hist["prompt"]: + return False + + if current["file_path"] != hist["file_path"]: + return False + + if current["suffix"] != hist["suffix"]: + return False + + if current["files"] != current["files"]: + return False + + return True + + +EMPTY = { + "mode": "Code Completion", + "prompt": "", + "file_path": "", + "suffix": "", + "files": None, +} + + class GradioInterface: def __init__( self, @@ -47,6 +76,8 @@ def __init__( model_description: str, model_lang: List[str], access_token: Optional[str], + infill_supported: Optional[bool], + repo_level_supported: Optional[bool], ): self.endpoint = endpoint self.model_uid = model_uid @@ -62,12 +93,16 @@ def __init__( self._access_token = ( access_token.replace("Bearer ", "") if access_token is not None else None ) + self.infill_supported = infill_supported + self.repo_level_supported = repo_level_supported def build(self) -> "gr.Blocks": if "vision" in self.model_ability: interface = self.build_chat_vl_interface() elif "chat" in self.model_ability: interface = self.build_chat_interface() + elif "code" in self.model_ability: + interface = self.build_code_generate_interface() else: interface = self.build_generate_interface() @@ -401,6 +436,414 @@ def update_button(text): return chat_vl_interface + def build_code_generate_interface( + self, + ): + def undo(g_mode, text, g_file_path, g_suffix, g_files, hist): + current = { + "mode": g_mode, + "prompt": text, + "file_path": g_file_path, + "suffix": g_suffix, + "files": g_files, + } + + if len(hist) == 0: + return { + generate_mode: "Code Completion", + prompt: "", + file_path: "", + suffix: "", + files: None, + history: [current], + } + if compare_history(current, hist[-1]): + hist = hist[:-1] + + req = hist[-1] if len(hist) > 0 else EMPTY + + return { + generate_mode: req["mode"], + prompt: req["prompt"], + file_path: req["file_path"], + suffix: req["suffix"], + files: g_files, + history: hist, + } + + def clear(g_mode, text, g_file_path, g_suffix, g_files, hist): + current = { + "mode": g_mode, + "prompt": text, + "file_path": g_file_path, + "suffix": g_suffix, + "files": g_files, + } + if len(hist) == 0 or ( + len(hist) > 0 and not compare_history(current, hist[-1]) + ): + hist.append(current) + hist.append(EMPTY) + return { + generate_mode: "Code Completion", + prompt: "", + file_path: "", + suffix: "", + files: None, + history: hist, + } + + def complete( + g_mode, text, g_file_path, g_suffix, g_files, hist, max_tokens, temperature + ): + from ..client import RESTfulClient + + client = RESTfulClient(self.endpoint) + client._set_token(self._access_token) + + model = client.get_model(self.model_uid) + assert isinstance(model, RESTfulCodeModelHandle) + + repo_files = ( + {k: open(k, mode="r", encoding="utf8").read() for k in g_files} + if g_files + else None + ) + + current = { + "mode": g_mode, + "prompt": text, + "file_path": g_file_path, + "suffix": g_suffix, + "files": g_files, + } + if len(hist) == 0 or ( + len(hist) > 0 and not compare_history(current, hist[-1]) + ): + hist.append(current) + + response_content = text + + if g_mode == "Code Completion": + if self.repo_level_supported: + resp = model.code_generate( + "completion", + prompt=text, + file_path=g_file_path, + files=repo_files, + generate_config={ + "max_tokens": max_tokens, + "temperature": temperature, + }, + ) + else: + resp = model.code_generate( + mode="completion", + prompt=text, + file_path=g_file_path, + generate_config={ + "max_tokens": max_tokens, + "temperature": temperature, + }, + ) + else: + resp = model.code_generate( + mode="infill", + prompt=text, + suffix=g_suffix, + generate_config={ + "max_tokens": max_tokens, + "temperature": temperature, + }, + ) + assert isinstance(resp, dict) + choice = resp["choices"][0] + + response_content += choice["text"] + + current = { + "mode": g_mode, + "prompt": response_content, + "file_path": g_file_path, + "suffix": g_suffix, + "files": g_files, + } + + hist.append(current) + return { + prompt: response_content, + history: hist, + } + + def retry( + g_mode, text, g_suffix, g_file_path, g_files, hist, max_tokens, temperature + ): + from ..client import RESTfulClient + + client = RESTfulClient(self.endpoint) + client._set_token(self._access_token) + + model = client.get_model(self.model_uid) + assert isinstance(model, RESTfulCodeModelHandle) + + current = { + "mode": g_mode, + "prompt": text, + "file_path": g_file_path, + "suffix": g_suffix, + "files": g_files, + } + + if len(hist) == 0 or ( + len(hist) > 0 and not compare_history(current, hist[-1]) + ): + hist.append(current) + + req = hist[-2] if len(hist) > 1 else EMPTY + + response_content = req["prompt"] + + repo_files = ( + {k: open(k, mode="r", encoding="utf8").read() for k in req["files"]} + if req["files"] + else None + ) + + resp = model.code_generate( + mode="completion" if req["mode"] == "Code Completion" else "infill", + prompt=req["prompt"], + file_path=req["file_path"], + suffix=req["suffix"], + files=repo_files, + generate_config={ + "max_tokens": max_tokens, + "temperature": temperature, + }, + ) + assert isinstance(resp, dict) + choice = resp["choices"][0] + response_content += choice["text"] + + req["prompt"] = response_content + + hist.append(req) + return { + generate_mode: req["mode"], + prompt: response_content, + file_path: req["file_path"], + suffix: req["suffix"], + files: req["files"], + history: hist, + } + + def mode_change(generate_mode): + if generate_mode == "Code Completion": + return { + file_path: Textbox( + container=True, + show_label=True, + label="Prompt file path", + interactive=self.repo_level_supported, + ), + suffix: Code( + container=True, + show_label=True, + label="Suffix", + lines=21, + visible=False, + ), + files: File( + container=False, + show_label=False, + label="Files", + file_count="multiple", + visible=self.repo_level_supported, + ), + } + else: + return { + file_path: Textbox( + container=True, + show_label=True, + label="Prompt file path", + interactive=True, + visible=False, + ), + suffix: Code( + container=True, + show_label=True, + label="Suffix", + lines=21, + interactive=True, + visible=self.infill_supported, + ), + files: File( + container=False, + show_label=False, + label="Files", + file_count="multiple", + visible=False, + ), + } + + with gr.Blocks( + title=f"🚀 Xinference Code Generate Bot : {self.model_name} 🚀", + css=""" + .center{ + display: flex; + justify-content: center; + align-items: center; + padding: 0px; + color: #9ea4b0 !important; + } + """, + analytics_enabled=False, + ) as code_generate_interface: + modes = ( + ["Code Completion", "Code Infill Completion"] + if self.infill_supported + else ["Code Completion"] + ) + + history = gr.State([]) + + Markdown( + f""" +

🚀 Xinference Code Generate Bot : {self.model_name} 🚀

+ """ + ) + Markdown( + f""" +
+ Model ID: {self.model_uid} +
+
+ Model Size: {self.model_size_in_billions} Billion Parameters +
+
+ Model Format: {self.model_format} +
+
+ Model Quantization: {self.quantization} +
+
+ Support Infill Code Completion: {self.infill_supported} +
+
+ Support Repository Level Code Completion: {self.repo_level_supported} +
+ """ + ) + + with Column(variant="panel"): + generate_mode = Dropdown( + container=True, + show_label=True, + label="Code Generate Mode", + choices=modes, + value=modes[0], + ) + + prompt = Code( + container=True, + show_label=True, + label="Prompt", + lines=21, + interactive=True, + ) + + suffix = Code( + container=True, + show_label=True, + label="Suffix", + lines=21, + visible=False, + ) + + file_path = Textbox( + container=True, + show_label=True, + label="Prompt file path", + interactive=True, + ) + + files = File( + container=True, + show_label=True, + label="Repository Files", + file_count="multiple", + interactive=True, + visible=self.repo_level_supported, + ) + + with Row(): + btn_generate = gr.Button("Generate", variant="primary") + with Row(): + btn_undo = gr.Button("↩️ Undo") + btn_retry = gr.Button("🔄 Retry") + btn_clear = gr.Button("🗑️ Clear") + with Accordion("Additional Inputs", open=False): + length = gr.Slider( + minimum=1, + maximum=self.context_length, + value=1024, + step=1, + label="Max Tokens", + ) + temperature = gr.Slider( + minimum=0, maximum=2, value=1, step=0.01, label="Temperature" + ) + + generate_mode.change( + fn=mode_change, + inputs=[generate_mode], + outputs=[file_path, suffix, files], + ) + + btn_generate.click( + fn=complete, + inputs=[ + generate_mode, + prompt, + file_path, + suffix, + files, + history, + length, + temperature, + ], + outputs=[prompt, history], + ) + + btn_undo.click( + fn=undo, + inputs=[generate_mode, prompt, file_path, suffix, files, history], + outputs=[generate_mode, prompt, file_path, suffix, files, history], + ) + + btn_retry.click( + fn=retry, + inputs=[ + generate_mode, + prompt, + file_path, + suffix, + files, + history, + length, + temperature, + ], + outputs=[generate_mode, prompt, file_path, suffix, files, history], + ) + + btn_clear.click( + fn=clear, + inputs=[generate_mode, prompt, file_path, suffix, files, history], + outputs=[generate_mode, prompt, file_path, suffix, files, history], + ) + + return code_generate_interface + def build_generate_interface( self, ): diff --git a/xinference/core/model.py b/xinference/core/model.py index 602f712514..ad32528673 100644 --- a/xinference/core/model.py +++ b/xinference/core/model.py @@ -33,6 +33,7 @@ Generator, Iterator, List, + Mapping, Optional, Union, ) @@ -41,6 +42,7 @@ import xoscar as xo from ..constants import XINFERENCE_TRANSFORMERS_ENABLE_BATCHING +from ..types import CodeGenerateMode if TYPE_CHECKING: from .worker import WorkerActor @@ -565,6 +567,99 @@ async def abort_request(self, request_id: str) -> str: return await self._scheduler_ref.abort_request(request_id) return AbortRequestMessage.NO_OP.name + @log_async(logger=logger) + @request_limit + @xo.generator + async def code_generate( + self, + mode: CodeGenerateMode, + prompt: str, + file_path: Optional[str], + suffix: Optional[str], + repo_name: Optional[str], + files: Optional[Mapping[str, str]], + *args, + **kwargs, + ): + start_time = time.time() + response = None + try: + if hasattr(self._model, "code_generate"): + response = await self._call_wrapper_json( + self._model.code_generate, + mode, + prompt, + file_path, + suffix, + repo_name, + files, + *args, + **kwargs, + ) + return response + if hasattr(self._model, "async_code_generate"): + response = await self._call_wrapper_json( + self._model.async_code_generate, + mode, + prompt, + file_path, + suffix, + repo_name, + files, + *args, + **kwargs, + ) + return response + raise AttributeError( + f"Model {self._model.model_spec} is not for code generate." + ) + finally: + # For the non stream result. + record = None + if isinstance(response, (Generator, AsyncGenerator)): + record = response + elif isinstance(response, bytes): + record = json.loads(response) + if record and isinstance(record, dict): + usage = record["usage"] + # Some backends may not have a valid usage, we just skip them. + completion_tokens = usage["completion_tokens"] + prompt_tokens = usage["prompt_tokens"] + await self._record_completion_metrics( + time.time() - start_time, + completion_tokens, + prompt_tokens, + ) + + @log_async(logger=logger) + @request_limit + @xo.generator + async def get_code_prompt( + self, + mode: CodeGenerateMode, + prompt: str, + file_path: Optional[str], + suffix: Optional[str], + repo_name: Optional[str], + files: Optional[Mapping[str, str]], + ): + from ..model.llm.utils import CodeModelMixin + + if isinstance(self._model, CodeModelMixin): + return await self._call_wrapper_json( + self._model.get_code_prompt, + mode, + prompt, + file_path, + suffix, + repo_name, + files, + ) + else: + raise ValueError( + f"Model {self._model.model_family.model_name} does not support code generating" + ) + @log_async(logger=logger) @request_limit async def create_embedding(self, input: Union[str, List[str]], *args, **kwargs): diff --git a/xinference/core/supervisor.py b/xinference/core/supervisor.py index 2b6f7b9fc5..d2c71265cc 100644 --- a/xinference/core/supervisor.py +++ b/xinference/core/supervisor.py @@ -313,10 +313,20 @@ async def get_builtin_prompts() -> Dict[str, Any]: data[k] = v.dict() return data + @staticmethod + async def get_builtin_code_prompts() -> Dict[str, Any]: + from ..model.llm.llm_family import BUILTIN_LLM_CODE_PROMPT_STYLE + + data = {} + for k, v in BUILTIN_LLM_CODE_PROMPT_STYLE.items(): + data[k] = v.dict() + return data + @staticmethod async def get_builtin_families() -> Dict[str, List[str]]: from ..model.llm.llm_family import ( BUILTIN_LLM_MODEL_CHAT_FAMILIES, + BUILTIN_LLM_MODEL_CODE_FAMILIES, BUILTIN_LLM_MODEL_GENERATE_FAMILIES, BUILTIN_LLM_MODEL_TOOL_CALL_FAMILIES, ) @@ -325,6 +335,7 @@ async def get_builtin_families() -> Dict[str, List[str]]: "chat": list(BUILTIN_LLM_MODEL_CHAT_FAMILIES), "generate": list(BUILTIN_LLM_MODEL_GENERATE_FAMILIES), "tools": list(BUILTIN_LLM_MODEL_TOOL_CALL_FAMILIES), + "code": list(BUILTIN_LLM_MODEL_CODE_FAMILIES), } async def get_devices_count(self) -> int: diff --git a/xinference/core/tests/test_restful_api.py b/xinference/core/tests/test_restful_api.py index cd47b98cc5..7bb44ef164 100644 --- a/xinference/core/tests/test_restful_api.py +++ b/xinference/core/tests/test_restful_api.py @@ -1239,3 +1239,192 @@ def test_launch_model_by_version(setup): # delete again url = f"{endpoint}/v1/models/test_qwen15" requests.delete(url) + + +@pytest.mark.skipif(bool(os.environ.get("GITHUB_ACTIONS")), reason="Skip windows") +def test_cluster_info(setup): + endpoint, _ = setup + url = f"{endpoint}/v1/cluster/info" + + response = requests.get(url) + assert response.status_code == 200 + result = response.json() + assert isinstance(result, list) + assert len(result) == 2 + assert result[0]["node_type"] == "Supervisor" + assert result[0]["gpu_count"] == 0 + assert result[0]["gpu_vram_total"] == 0 + assert result[1]["node_type"] == "Worker" + assert result[1]["gpu_count"] == 0 + assert result[1]["gpu_vram_total"] == 0 + + +def test_restful_api_for_code_prompt(setup): + model_name = "deepseek-coder" + + endpoint, _ = setup + url = f"{endpoint}/v1/models" + + # list + response = requests.get(url) + response_data = response.json() + assert len(response_data["data"]) == 0 + + # launch + payload = { + "model_uid": "deepseek-coder", + "model_name": model_name, + "model_type": "LLM", + "model_engine": "llama.cpp", + "model_size_in_billions": "1_3", + "quantization": "q4_k_m", + } + + response = requests.post(url, json=payload) + response_data = response.json() + model_uid_res = response_data["model_uid"] + assert model_uid_res == "deepseek-coder" + + response = requests.get(url) + response_data = response.json() + assert len(response_data["data"]) == 1 + + # test embedding + url = f"{endpoint}/v1/code/prompt" + payload = { + "model": "deepseek-coder", + "prompt": "#write a quick sort algorithm", + } + response = requests.post(url, json=payload) + coding_res = response.json() + + assert "prompt" in coding_res + + assert "#write a quick sort algorithm" == coding_res["prompt"] + + # test multiple + payload = { + "model": "deepseek-coder", + "mode": "infill", + "prompt": """def quick_sort(arr): + if len(arr) <= 1: + return arr + pivot = arr[0] + left = [] + right = [] +""", + "suffix": """ + if arr[i] < pivot: + left.append(arr[i]) + else: + right.append(arr[i]) + return quick_sort(left) + [pivot] + quick_sort(right)""", + } + response = requests.post(url, json=payload) + coding_res = response.json() + + assert "prompt" in coding_res + assert ( + coding_res["prompt"] + == """<|fim▁begin|>def quick_sort(arr): + if len(arr) <= 1: + return arr + pivot = arr[0] + left = [] + right = [] +<|fim▁hole|> + if arr[i] < pivot: + left.append(arr[i]) + else: + right.append(arr[i]) + return quick_sort(left) + [pivot] + quick_sort(right)<|fim▁end|>""" + ) + + # delete model + url = f"{endpoint}/v1/models/deepseek-coder" + response = requests.delete(url) + assert response.status_code == 200 + + response = requests.get(f"{endpoint}/v1/models") + response_data = response.json() + assert len(response_data["data"]) == 0 + + +def test_restful_api_for_code_completions(setup): + model_name = "deepseek-coder" + + endpoint, _ = setup + url = f"{endpoint}/v1/models" + + # list + response = requests.get(url) + response_data = response.json() + assert len(response_data["data"]) == 0 + + # launch + payload = { + "model_uid": "deepseek-coder", + "model_name": model_name, + "model_type": "LLM", + "model_engine": "llama.cpp", + "model_size_in_billions": "1_3", + "quantization": "q4_k_m", + } + + response = requests.post(url, json=payload) + response_data = response.json() + model_uid_res = response_data["model_uid"] + assert model_uid_res == "deepseek-coder" + + response = requests.get(url) + response_data = response.json() + assert len(response_data["data"]) == 1 + + # test embedding + url = f"{endpoint}/v1/code/completions" + payload = { + "model": "deepseek-coder", + "prompt": "#write a quick sort algorithm", + "max_tokens": 4096, + } + response = requests.post(url, json=payload) + coding_res = response.json() + + assert len(coding_res["choices"]) == 1 + assert "text" in coding_res["choices"][0] + assert coding_res["choices"][0]["finish_reason"] == "stop" + + # test multiple + payload = { + "model": "deepseek-coder", + "mode": "infill", + "prompt": """def quick_sort(arr): + if len(arr) <= 1: + return arr + pivot = arr[0] + left = [] + right = [] +""", + "suffix": """ + if arr[i] < pivot: + left.append(arr[i]) + else: + right.append(arr[i]) + return quick_sort(left) + [pivot] + quick_sort(right)""", + "max_tokens": 4096, + } + response = requests.post(url, json=payload) + coding_res = response.json() + + assert len(coding_res["choices"]) == 1 + assert "text" in coding_res["choices"][0] + assert coding_res["choices"][0]["finish_reason"] == "stop" + + # delete model + url = f"{endpoint}/v1/models/deepseek-coder" + response = requests.delete(url) + assert response.status_code == 200 + + response = requests.get(f"{endpoint}/v1/models") + response_data = response.json() + assert len(response_data["data"]) == 0 diff --git a/xinference/model/llm/__init__.py b/xinference/model/llm/__init__.py index 9909addebb..843302bd15 100644 --- a/xinference/model/llm/__init__.py +++ b/xinference/model/llm/__init__.py @@ -26,8 +26,10 @@ ) from .llm_family import ( BUILTIN_CSGHUB_LLM_FAMILIES, + BUILTIN_LLM_CODE_PROMPT_STYLE, BUILTIN_LLM_FAMILIES, BUILTIN_LLM_MODEL_CHAT_FAMILIES, + BUILTIN_LLM_MODEL_CODE_FAMILIES, BUILTIN_LLM_MODEL_GENERATE_FAMILIES, BUILTIN_LLM_MODEL_TOOL_CALL_FAMILIES, BUILTIN_LLM_PROMPT_STYLE, @@ -40,13 +42,16 @@ SUPPORTED_ENGINES, TRANSFORMERS_CLASSES, VLLM_CLASSES, + CodePromptStyleV1, CustomLLMFamilyV1, + FIMSpecV1, LlamaCppLLMSpecV1, LLMFamilyV1, LLMSpecV1, MLXLLMSpecV1, PromptStyleV1, PytorchLLMSpecV1, + RepoLevelCodeCompletionSpecV1, get_cache_status, get_user_defined_llm_families, match_llm, @@ -113,14 +118,14 @@ def generate_engine_config_by_model_family(model_family): def _install(): - from .llama_cpp.core import LlamaCppChatModel, LlamaCppModel + from .llama_cpp.core import LlamaCppChatModel, LlamaCppCodeModel, LlamaCppModel from .lmdeploy.core import LMDeployChatModel, LMDeployModel from .mlx.core import MLXChatModel, MLXModel from .sglang.core import SGLANGChatModel, SGLANGModel from .transformers.chatglm import ChatglmPytorchChatModel from .transformers.cogvlm2 import CogVLM2Model from .transformers.cogvlm2_video import CogVLM2VideoModel - from .transformers.core import PytorchChatModel, PytorchModel + from .transformers.core import PytorchChatModel, PytorchCodeModel, PytorchModel from .transformers.deepseek_vl import DeepSeekVLChatModel from .transformers.glm4v import Glm4VModel from .transformers.intern_vl import InternVLChatModel @@ -130,7 +135,7 @@ def _install(): from .transformers.minicpmv26 import MiniCPMV26Model from .transformers.qwen_vl import QwenVLChatModel from .transformers.yi_vl import YiVLChatModel - from .vllm.core import VLLMChatModel, VLLMModel, VLLMVisionModel + from .vllm.core import VLLMChatModel, VLLMCodeModel, VLLMModel, VLLMVisionModel try: from .transformers.omnilmm import OmniLMMModel @@ -144,11 +149,12 @@ def _install(): LLAMA_CLASSES.extend( [ LlamaCppChatModel, + LlamaCppCodeModel, LlamaCppModel, ] ) SGLANG_CLASSES.extend([SGLANGModel, SGLANGChatModel]) - VLLM_CLASSES.extend([VLLMModel, VLLMChatModel, VLLMVisionModel]) + VLLM_CLASSES.extend([VLLMModel, VLLMChatModel, VLLMCodeModel, VLLMVisionModel]) MLX_CLASSES.extend([MLXModel, MLXChatModel]) LMDEPLOY_CLASSES.extend([LMDeployModel, LMDeployChatModel]) TRANSFORMERS_CLASSES.extend( @@ -157,6 +163,7 @@ def _install(): LlamaPytorchModel, LlamaPytorchChatModel, PytorchChatModel, + PytorchCodeModel, Internlm2PytorchChatModel, QwenVLChatModel, YiVLChatModel, @@ -203,6 +210,16 @@ def _install(): if "tools" in model_spec.model_ability: BUILTIN_LLM_MODEL_TOOL_CALL_FAMILIES.add(model_spec.model_name) + if "code" in model_spec.model_ability and isinstance( + model_spec.code_prompt_style, CodePromptStyleV1 + ): + BUILTIN_LLM_CODE_PROMPT_STYLE[ + model_spec.model_name + ] = model_spec.code_prompt_style + + if "code" in model_spec.model_ability: + BUILTIN_LLM_MODEL_CODE_FAMILIES.add(model_spec.model_name) + modelscope_json_path = os.path.join( os.path.dirname(os.path.abspath(__file__)), "llm_family_modelscope.json" ) @@ -225,6 +242,8 @@ def _install(): BUILTIN_LLM_MODEL_GENERATE_FAMILIES.add(model_spec.model_name) if "tools" in model_spec.model_ability: BUILTIN_LLM_MODEL_TOOL_CALL_FAMILIES.add(model_spec.model_name) + if "code" in model_spec.model_ability: + BUILTIN_LLM_MODEL_CODE_FAMILIES.add(model_spec.model_name) csghub_json_path = os.path.join( os.path.dirname(os.path.abspath(__file__)), "llm_family_csghub.json" diff --git a/xinference/model/llm/lang_utils.py b/xinference/model/llm/lang_utils.py new file mode 100644 index 0000000000..2007601bd9 --- /dev/null +++ b/xinference/model/llm/lang_utils.py @@ -0,0 +1,1131 @@ +# Copyright 2022-2023 XProbe Inc. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import logging +import os.path + +from . import RepoLevelCodeCompletionSpecV1 + +logger = logging.getLogger(__name__) + +# NOTICE: these code referenced by aixcoder + +# https://github.com/aixcoder-plugin/aiXcoder-7B/blob/main/hf_mini/utils.py + +LANGUAGE_WRAPPER = { + "c": "// ", + "c++": "// ", + "cpp": "// ", + "c#": "// ", + "csharp": "// ", + "c-sharp": "// ", + "css": "/* */", + "cuda": "// ", + "dart": "// ", + "lua": "// ", + "objectivec": "// ", + "objective-c": "// ", + "objective-c++": "// ", + "python": "# ", + "perl": "# ", + "prolog": "% ", + "swift": "// ", + "lisp": "; ", + "java": "// ", + "scala": "// ", + "tex": "% ", + "vue": "", + "markdown": "", + "html": "", + "php": "// ", + "js": "// ", + "javascript": "// ", + "typescript": "// ", + "go": "// ", + "shell": "# ", + "rust": "// ", + "sql": "-- ", + "kotlin": "// ", + "vb": "' ", + "ruby": "# ", + "pascal": "// ", + "r": "# ", + "fortran": "!", + "lean": "-- ", + "matlab": "% ", + "delphi": "{}", + "scheme": "; ", + "basic": "' ", + "assembly": "; ", + "groovy": "// ", + "abap": "* ", + "gdscript": "# ", + "haskell": "-- ", + "julia": "# ", + "elixir": "# ", + "excel": "' ", + "clojure": "; ", + "actionscript": "// ", + "solidity": "// ", + "powershell": "# ", + "erlang": "% ", + "cobol": "// ", + "alloy": "/* */", + "awk": "// ", + "thrift": "/* */", + "sparql": "# ", + "augeas": "// ", + "cmake": "# ", + "f-sharp": "// ", + "stan": "// ", + "isabelle": "(**)", + "dockerfile": "# ", + "rmarkdown": "# ", + "literate-agda": "-- ", + "tcl": "// ", + "glsl": "// ", + "antlr": "// ", + "verilog": "// ", + "racket": "; ", + "standard-ml": "(**)", + "elm": "-- ", + "yaml": "# ", + "smalltalk": "'' ", + "ocaml": "(**)", + "idris": "-- ", + "visual-basic": "' ", + "protocol-buffer": "// ", + "bluespec": "// ", + "applescript": "-- ", + "makefile": "# ", + "tcsh": "# ", + "maple": "# ", + "systemverilog": "// ", + "literate-coffeescript": "# ", + "vhdl": "-- ", + "restructuredtext": ".. ", + "sas": "* ", + "literate-haskell": "> ", + "java-server-pages": "// ", + "coffeescript": "# ", + "emacs-lisp": "; ", + "mathematica": "// ", + "xslt": "", + "zig": "// ", + "common-lisp": "; ", + "stata": "* ", + "agda": "-- ", + "ada": "-- ", + "jsx": "// ", + "tsx": "// ", +} + +EXT2LANG = { + ".abap": "abap", + ".ash": "ags script", + ".ampl": "ampl", + ".g4": "antlr", + ".apib": "api blueprint", + ".apl": "apl", + ".dyalog": "apl", + ".asp": "asp", + ".asax": "asp", + ".ascx": "asp", + ".ashx": "asp", + ".asmx": "asp", + ".aspx": "asp", + ".axd": "asp", + ".dats": "ats", + ".hats": "ats", + ".sats": "ats", + ".as": "actionscript", + ".adb": "ada", + ".ada": "ada", + ".ads": "ada", + ".agda": "agda", + ".als": "alloy", + ".apacheconf": "apacheconf", + ".vhost": "apacheconf", + ".applescript": "applescript", + ".scpt": "applescript", + ".arc": "arc", + ".ino": "arduino", + ".asciidoc": "asciidoc", + ".adoc": "asciidoc", + ".aj": "aspectj", + ".asm": "assembly", + ".a51": "assembly", + ".nasm": "assembly", + ".aug": "augeas", + ".ahk": "autohotkey", + ".ahkl": "autohotkey", + ".au3": "autoit", + ".awk": "awk", + ".auk": "awk", + ".gawk": "awk", + ".mawk": "awk", + ".nawk": "awk", + ".bat": "batchfile", + ".cmd": "batchfile", + ".befunge": "befunge", + ".bison": "bison", + ".bb": "bitbake", + ".decls": "blitzbasic", + ".bmx": "blitzmax", + ".bsv": "bluespec", + ".boo": "boo", + ".bf": "brainfuck", + ".brs": "brightscript", + ".bro": "bro", + ".c": "c", + ".cats": "c", + ".h": "c++", + ".idc": "c", + ".w": "c", + ".cs": "c#", + ".cake": "c#", + ".cshtml": "c#", + ".csx": "c#", + ".cpp": "c++", + ".c++": "c++", + ".cc": "c++", + ".cp": "c++", + ".cxx": "c++", + ".h++": "c++", + ".hh": "c++", + ".hpp": "c++", + ".hxx": "c++", + ".inl": "c++", + ".ipp": "c++", + ".tcc": "c++", + ".tpp": "c++", + ".C": "c++", + ".H": "c++", + ".c-objdump": "c-objdump", + ".chs": "c2hs haskell", + ".clp": "clips", + ".cmake": "cmake", + ".cmake.in": "cmake", + ".cob": "cobol", + ".cbl": "cobol", + ".ccp": "cobol", + ".cobol": "cobol", + ".cpy": "cobol", + ".css": "css", + ".csv": "csv", + ".capnp": "cap'n proto", + ".mss": "cartocss", + ".ceylon": "ceylon", + ".chpl": "chapel", + ".ck": "chuck", + ".cirru": "cirru", + ".clw": "clarion", + ".icl": "clean", + ".dcl": "clean", + ".click": "click", + ".clj": "clojure", + ".boot": "clojure", + ".cl2": "clojure", + ".cljc": "clojure", + ".cljs": "clojure", + ".cljs.hl": "clojure", + ".cljscm": "clojure", + ".cljx": "clojure", + ".hic": "clojure", + ".coffee": "coffeescript", + "._coffee": "coffeescript", + ".cjsx": "coffeescript", + ".cson": "coffeescript", + ".iced": "coffeescript", + ".cfm": "coldfusion", + ".cfml": "coldfusion", + ".cfc": "coldfusion cfc", + ".lisp": "common lisp", + ".asd": "common lisp", + ".lsp": "common lisp", + ".ny": "common lisp", + ".podsl": "common lisp", + ".sexp": "common lisp", + ".cps": "component pascal", + ".coq": "coq", + ".cppobjdump": "cpp-objdump", + ".c++-objdump": "cpp-objdump", + ".c++objdump": "cpp-objdump", + ".cpp-objdump": "cpp-objdump", + ".cxx-objdump": "cpp-objdump", + ".creole": "creole", + ".cr": "crystal", + ".csd": "csound", + ".feature": "cucumber", + ".cu": "cuda", + ".cuh": "cuda", + ".cy": "cycript", + ".pyx": "cython", + ".pxd": "cython", + ".pxi": "cython", + ".di": "d", + ".d-objdump": "d-objdump", + ".com": "digital command language", + ".dm": "dm", + ".zone": "dns zone", + ".arpa": "dns zone", + ".darcspatch": "darcs patch", + ".dpatch": "darcs patch", + ".dart": "dart", + ".diff": "diff", + ".patch": "diff", + ".dockerfile": "dockerfile", + "Dockerfile": "dockerfile", + ".djs": "dogescript", + ".dylan": "dylan", + ".dyl": "dylan", + ".intr": "dylan", + ".lid": "dylan", + ".E": "e", + ".ecl": "ecl", + ".eclxml": "ecl", + ".sch": "eagle", + ".brd": "eagle", + ".epj": "ecere projects", + ".e": "eiffel", + ".ex": "elixir", + ".exs": "elixir", + ".elm": "elm", + ".el": "emacs lisp", + ".emacs": "emacs lisp", + ".emacs.desktop": "emacs lisp", + ".em": "emberscript", + ".emberscript": "emberscript", + ".erl": "erlang", + ".escript": "erlang", + ".hrl": "erlang", + ".xrl": "erlang", + ".yrl": "erlang", + ".fs": "f#", + ".fsi": "f#", + ".fsx": "f#", + ".flux": "flux", + ".f90": "fortran", + ".f": "fortran", + ".f03": "fortran", + ".f08": "fortran", + ".f77": "fortran", + ".f95": "fortran", + ".for": "fortran", + ".fpp": "fortran", + ".factor": "factor", + ".fy": "fancy", + ".fancypack": "fancy", + ".fan": "fantom", + ".eam.fs": "formatted", + ".fth": "forth", + ".4th": "forth", + ".forth": "forth", + ".frt": "forth", + ".ftl": "freemarker", + ".g": "g-code", + ".gco": "g-code", + ".gcode": "g-code", + ".gms": "gams", + ".gap": "gap", + ".gi": "gap", + ".s": "gas", + ".gd": "gdscript", + ".glsl": "glsl", + ".fp": "glsl", + ".frag": "glsl", + ".frg": "glsl", + ".fsh": "glsl", + ".fshader": "glsl", + ".geo": "glsl", + ".geom": "glsl", + ".glslv": "glsl", + ".gshader": "glsl", + ".shader": "glsl", + ".vert": "glsl", + ".vrx": "glsl", + ".vsh": "glsl", + ".vshader": "glsl", + ".kid": "genshi", + ".ebuild": "gentoo ebuild", + ".eclass": "gentoo eclass", + ".po": "gettext catalog", + ".pot": "gettext catalog", + ".glf": "glyph", + ".gp": "gnuplot", + ".gnu": "gnuplot", + ".gnuplot": "gnuplot", + ".plot": "gnuplot", + ".plt": "gnuplot", + ".go": "go", + ".golo": "golo", + ".gst": "gosu", + ".gsx": "gosu", + ".vark": "gosu", + ".grace": "grace", + ".gradle": "gradle", + ".gf": "grammatical framework", + ".graphql": "graphql", + ".dot": "graphviz (dot)", + ".gv": "graphviz (dot)", + ".man": "groff", + ".1": "groff", + ".1in": "groff", + ".1m": "groff", + ".1x": "groff", + ".2": "groff", + ".3": "groff", + ".3in": "groff", + ".3m": "groff", + ".3qt": "groff", + ".3x": "groff", + ".4": "groff", + ".5": "groff", + ".6": "groff", + ".7": "groff", + ".8": "groff", + ".9": "groff", + ".me": "groff", + ".rno": "groff", + ".roff": "groff", + ".groovy": "groovy", + ".grt": "groovy", + ".gtpl": "groovy", + ".gvy": "groovy", + ".gsp": "groovy server pages", + ".hcl": "hcl", + ".tf": "hcl", + ".hlsl": "hlsl", + ".fxh": "hlsl", + ".hlsli": "hlsl", + ".html": "html", + ".htm": "html", + ".html.hl": "html", + ".xht": "html", + ".xhtml": "html", + ".mustache": "html+django", + ".jinja": "html+django", + ".eex": "html+eex", + ".erb": "html+erb", + ".erb.deface": "html+erb", + ".phtml": "html+php", + ".http": "http", + ".haml": "haml", + ".haml.deface": "haml", + ".handlebars": "handlebars", + ".hbs": "handlebars", + ".hb": "harbour", + ".hs": "haskell", + ".hsc": "haskell", + ".hx": "haxe", + ".hxsl": "haxe", + ".hy": "hy", + ".dlm": "idl", + ".ipf": "igor pro", + ".ini": "ini", + ".cfg": "ini", + ".prefs": "ini", + ".properties": "ini", + ".irclog": "irc log", + ".weechatlog": "irc log", + ".idr": "idris", + ".lidr": "idris", + ".ni": "inform 7", + ".i7x": "inform 7", + ".iss": "inno setup", + ".io": "io", + ".ik": "ioke", + ".thy": "isabelle", + ".ijs": "j", + ".flex": "jflex", + ".jflex": "jflex", + ".json": "json", + ".geojson": "json", + ".lock": "json", + ".topojson": "json", + ".json5": "json5", + ".jsonld": "jsonld", + ".jq": "jsoniq", + ".jsx": "jsx", + ".jade": "jade", + ".j": "jasmin", + ".java": "java", + ".jsp": "java server pages", + ".js": "javascript", + "._js": "javascript", + ".bones": "javascript", + ".es6": "javascript", + ".jake": "javascript", + ".jsb": "javascript", + ".jscad": "javascript", + ".jsfl": "javascript", + ".jsm": "javascript", + ".jss": "javascript", + ".njs": "javascript", + ".pac": "javascript", + ".sjs": "javascript", + ".ssjs": "javascript", + ".xsjs": "javascript", + ".xsjslib": "javascript", + ".jl": "julia", + ".ipynb": "jupyter notebook", + ".krl": "krl", + ".kicad_pcb": "kicad", + ".kit": "kit", + ".kt": "kotlin", + ".ktm": "kotlin", + ".kts": "kotlin", + ".lfe": "lfe", + ".ll": "llvm", + ".lol": "lolcode", + ".lsl": "lsl", + ".lslp": "lsl", + ".lvproj": "labview", + ".lasso": "lasso", + ".las": "lasso", + ".lasso8": "lasso", + ".lasso9": "lasso", + ".ldml": "lasso", + ".latte": "latte", + ".lean": "lean", + ".hlean": "lean", + ".less": "less", + ".lex": "lex", + ".ly": "lilypond", + ".ily": "lilypond", + ".ld": "linker script", + ".lds": "linker script", + ".liquid": "liquid", + ".lagda": "literate agda", + ".litcoffee": "literate coffeescript", + ".lhs": "literate haskell", + ".ls": "livescript", + "._ls": "livescript", + ".xm": "logos", + ".x": "logos", + ".xi": "logos", + ".lgt": "logtalk", + ".logtalk": "logtalk", + ".lookml": "lookml", + ".lua": "lua", + ".nse": "lua", + ".pd_lua": "lua", + ".rbxs": "lua", + ".wlua": "lua", + ".mumps": "m", + ".m4": "m4", + ".mcr": "maxscript", + ".mtml": "mtml", + ".muf": "muf", + ".mak": "makefile", + ".mk": "makefile", + ".mkfile": "makefile", + "Makefile": "makefile", + ".mako": "mako", + ".mao": "mako", + ".mpl": "maple", + ".md": "markdown", + ".markdown": "markdown", + ".mkd": "markdown", + ".mkdn": "markdown", + ".mkdown": "markdown", + ".ron": "markdown", + ".mask": "mask", + ".mathematica": "mathematica", + ".cdf": "mathematica", + ".ma": "mathematica", + ".mt": "mathematica", + ".nb": "mathematica", + ".nbp": "mathematica", + ".wl": "mathematica", + ".wlt": "mathematica", + ".matlab": "matlab", + ".maxpat": "max", + ".maxhelp": "max", + ".maxproj": "max", + ".mxt": "max", + ".pat": "max", + ".mediawiki": "mediawiki", + ".wiki": "mediawiki", + ".metal": "metal", + ".minid": "minid", + ".druby": "mirah", + ".duby": "mirah", + ".mir": "mirah", + ".mirah": "mirah", + ".mo": "modelica", + ".mms": "module management system", + ".mmk": "module management system", + ".monkey": "monkey", + ".moon": "moonscript", + ".myt": "myghty", + ".nsi": "nsis", + ".nsh": "nsis", + ".axs": "netlinx", + ".axi": "netlinx", + ".axs.erb": "netlinx+erb", + ".axi.erb": "netlinx+erb", + ".nlogo": "netlogo", + ".nginxconf": "nginx", + ".nim": "nimrod", + ".nimrod": "nimrod", + ".ninja": "ninja", + ".nit": "nit", + ".nix": "nix", + ".nu": "nu", + ".numpy": "numpy", + ".numpyw": "numpy", + ".numsc": "numpy", + ".ml": "ocaml", + ".eliom": "ocaml", + ".eliomi": "ocaml", + ".ml4": "ocaml", + ".mli": "ocaml", + ".mll": "ocaml", + ".mly": "ocaml", + ".objdump": "objdump", + ".mm": "objective-c++", + ".sj": "objective-j", + ".oct": "octave", + ".omgrofl": "omgrofl", + ".opa": "opa", + ".opal": "opal", + ".cl": "opencl", + ".opencl": "opencl", + ".p": "openedge abl", + ".scad": "openscad", + ".org": "org", + ".ox": "ox", + ".oxh": "ox", + ".oxo": "ox", + ".oxygene": "oxygene", + ".oz": "oz", + ".pwn": "pawn", + ".php": "php", + ".aw": "php", + ".ctp": "php", + ".php3": "php", + ".php4": "php", + ".php5": "php", + ".phps": "php", + ".phpt": "php", + ".pov": "pov-ray sdl", + ".pan": "pan", + ".psc": "papyrus", + ".parrot": "parrot", + ".pasm": "parrot assembly", + ".pir": "parrot internal representation", + ".pas": "pascal", + ".dfm": "pascal", + ".dpr": "pascal", + ".lpr": "pascal", + ".pl": "perl", + ".al": "perl", + ".perl": "perl", + ".ph": "perl", + ".plx": "perl", + ".pm": "perl", + ".psgi": "perl", + ".t": "perl", + ".6pl": "perl6", + ".6pm": "perl6", + ".nqp": "perl6", + ".p6": "perl6", + ".p6l": "perl6", + ".p6m": "perl6", + ".pl6": "perl6", + ".pm6": "perl6", + ".pkl": "pickle", + ".pig": "piglatin", + ".pike": "pike", + ".pmod": "pike", + ".pod": "pod", + ".pogo": "pogoscript", + ".pony": "pony", + ".ps": "postscript", + ".eps": "postscript", + ".ps1": "powershell", + ".psd1": "powershell", + ".psm1": "powershell", + ".pde": "processing", + ".prolog": "prolog", + ".yap": "prolog", + ".spin": "propeller spin", + ".proto": "protocol buffer", + ".pub": "public key", + ".pd": "pure data", + ".pb": "purebasic", + ".pbi": "purebasic", + ".purs": "purescript", + ".py": "python", + ".bzl": "python", + ".gyp": "python", + ".lmi": "python", + ".pyde": "python", + ".pyp": "python", + ".pyt": "python", + ".pyw": "python", + ".tac": "python", + ".wsgi": "python", + ".xpy": "python", + ".pytb": "python traceback", + ".qml": "qml", + ".qbs": "qml", + ".pri": "qmake", + ".r": "r", + ".rd": "r", + ".rsx": "r", + ".raml": "raml", + ".rdoc": "rdoc", + ".rbbas": "realbasic", + ".rbfrm": "realbasic", + ".rbmnu": "realbasic", + ".rbres": "realbasic", + ".rbtbar": "realbasic", + ".rbuistate": "realbasic", + ".rhtml": "rhtml", + ".rmd": "rmarkdown", + ".rkt": "racket", + ".rktd": "racket", + ".rktl": "racket", + ".scrbl": "racket", + ".rl": "ragel in ruby host", + ".raw": "raw token data", + ".reb": "rebol", + ".r2": "rebol", + ".r3": "rebol", + ".rebol": "rebol", + ".red": "red", + ".reds": "red", + ".cw": "redcode", + ".rpy": "ren'py", + ".rsh": "renderscript", + ".robot": "robotframework", + ".rg": "rouge", + ".rb": "ruby", + ".builder": "ruby", + ".gemspec": "ruby", + ".god": "ruby", + ".irbrc": "ruby", + ".jbuilder": "ruby", + ".mspec": "ruby", + ".podspec": "ruby", + ".rabl": "ruby", + ".rake": "ruby", + ".rbuild": "ruby", + ".rbw": "ruby", + ".rbx": "ruby", + ".ru": "ruby", + ".ruby": "ruby", + ".thor": "ruby", + ".watchr": "ruby", + ".rs": "rust", + ".rs.in": "rust", + ".sas": "sas", + ".scss": "scss", + ".smt2": "smt", + ".smt": "smt", + ".sparql": "sparql", + ".rq": "sparql", + ".sqf": "sqf", + ".hqf": "sqf", + ".pls": "sql", + ".pck": "sql", + ".pkb": "sql", + ".pks": "sql", + ".plb": "sql", + ".plsql": "sql", + ".sql": "sql", + ".cql": "sql", + ".ddl": "sql", + ".prc": "sql", + ".tab": "sql", + ".udf": "sql", + ".viw": "sql", + ".db2": "sql", + ".ston": "ston", + ".svg": "svg", + ".sage": "sage", + ".sagews": "sage", + ".sls": "saltstack", + ".sass": "sass", + ".scala": "scala", + ".sbt": "scala", + ".scaml": "scaml", + ".scm": "scheme", + ".sld": "scheme", + ".sps": "scheme", + ".ss": "scheme", + ".sci": "scilab", + ".sce": "scilab", + ".self": "self", + ".sh": "shell", + ".bash": "shell", + ".bats": "shell", + ".command": "shell", + ".ksh": "shell", + ".sh.in": "shell", + ".tmux": "shell", + ".tool": "shell", + ".zsh": "shell", + ".sh-session": "shellsession", + ".shen": "shen", + ".sl": "slash", + ".slim": "slim", + ".smali": "smali", + ".st": "smalltalk", + ".tpl": "smarty", + ".sol": "solidity", + ".sp": "sourcepawn", + ".sma": "sourcepawn", + ".nut": "squirrel", + ".stan": "stan", + ".ML": "standard ml", + ".fun": "standard ml", + ".sig": "standard ml", + ".sml": "standard ml", + ".do": "stata", + ".ado": "stata", + ".doh": "stata", + ".ihlp": "stata", + ".mata": "stata", + ".matah": "stata", + ".sthlp": "stata", + ".styl": "stylus", + ".scd": "supercollider", + ".swift": "swift", + ".sv": "systemverilog", + ".svh": "systemverilog", + ".vh": "systemverilog", + ".toml": "toml", + ".txl": "txl", + ".tcl": "tcl", + ".adp": "tcl", + ".tm": "tcl", + ".tcsh": "tcsh", + ".csh": "tcsh", + ".tex": "tex", + ".aux": "tex", + ".bbx": "tex", + ".bib": "tex", + ".cbx": "tex", + ".dtx": "tex", + ".ins": "tex", + ".lbx": "tex", + ".ltx": "tex", + ".mkii": "tex", + ".mkiv": "tex", + ".mkvi": "tex", + ".sty": "tex", + ".toc": "tex", + ".tea": "tea", + ".txt": "text", + ".no": "text", + ".textile": "textile", + ".thrift": "thrift", + ".tu": "turing", + ".ttl": "turtle", + ".twig": "twig", + ".ts": "typescript", + ".tsx": "tsx", + ".upc": "unified parallel c", + ".anim": "unity3d asset", + ".asset": "unity3d asset", + ".mat": "unity3d asset", + ".meta": "unity3d asset", + ".prefab": "unity3d asset", + ".unity": "unity3d asset", + ".uno": "uno", + ".uc": "unrealscript", + ".ur": "urweb", + ".urs": "urweb", + ".vcl": "vcl", + ".vhdl": "vhdl", + ".vhd": "vhdl", + ".vhf": "vhdl", + ".vhi": "vhdl", + ".vho": "vhdl", + ".vhs": "vhdl", + ".vht": "vhdl", + ".vhw": "vhdl", + ".vala": "vala", + ".vapi": "vala", + ".veo": "verilog", + ".vim": "viml", + ".vb": "visual basic", + ".bas": "visual basic", + ".frm": "visual basic", + ".frx": "visual basic", + ".vba": "visual basic", + ".vbhtml": "visual basic", + ".vbs": "visual basic", + ".volt": "volt", + ".vue": "vue", + ".owl": "web ontology language", + ".wat": "webassembly", + ".webidl": "webidl", + ".x10": "x10", + ".xc": "xc", + ".xml": "xml", + ".ant": "xml", + ".axml": "xml", + ".ccxml": "xml", + ".clixml": "xml", + ".cproject": "xml", + ".csl": "xml", + ".csproj": "xml", + ".ct": "xml", + ".dita": "xml", + ".ditamap": "xml", + ".ditaval": "xml", + ".dll.config": "xml", + ".dotsettings": "xml", + ".filters": "xml", + ".fsproj": "xml", + ".fxml": "xml", + ".glade": "xml", + ".grxml": "xml", + ".iml": "xml", + ".ivy": "xml", + ".jelly": "xml", + ".jsproj": "xml", + ".kml": "xml", + ".launch": "xml", + ".mdpolicy": "xml", + ".mxml": "xml", + ".nproj": "xml", + ".nuspec": "xml", + ".odd": "xml", + ".osm": "xml", + ".plist": "xml", + ".props": "xml", + ".ps1xml": "xml", + ".psc1": "xml", + ".pt": "xml", + ".rdf": "xml", + ".rss": "xml", + ".scxml": "xml", + ".srdf": "xml", + ".storyboard": "xml", + ".stTheme": "xml", + ".sublime-snippet": "xml", + ".targets": "xml", + ".tmCommand": "xml", + ".tml": "xml", + ".tmLanguage": "xml", + ".tmPreferences": "xml", + ".tmSnippet": "xml", + ".tmTheme": "xml", + ".ui": "xml", + ".urdf": "xml", + ".ux": "xml", + ".vbproj": "xml", + ".vcxproj": "xml", + ".vssettings": "xml", + ".vxml": "xml", + ".wsdl": "xml", + ".wsf": "xml", + ".wxi": "xml", + ".wxl": "xml", + ".wxs": "xml", + ".x3d": "xml", + ".xacro": "xml", + ".xaml": "xml", + ".xib": "xml", + ".xlf": "xml", + ".xliff": "xml", + ".xmi": "xml", + ".xml.dist": "xml", + ".xproj": "xml", + ".xsd": "xml", + ".xul": "xml", + ".zcml": "xml", + ".xsp-config": "xpages", + ".xsp.metadata": "xpages", + ".xpl": "xproc", + ".xproc": "xproc", + ".xquery": "xquery", + ".xq": "xquery", + ".xql": "xquery", + ".xqm": "xquery", + ".xqy": "xquery", + ".xs": "xs", + ".xslt": "xslt", + ".xsl": "xslt", + ".xojo_code": "xojo", + ".xojo_menu": "xojo", + ".xojo_report": "xojo", + ".xojo_script": "xojo", + ".xojo_toolbar": "xojo", + ".xojo_window": "xojo", + ".xtend": "xtend", + ".yml": "yaml", + ".reek": "yaml", + ".rviz": "yaml", + ".sublime-syntax": "yaml", + ".syntax": "yaml", + ".yaml": "yaml", + ".yaml-tmlanguage": "yaml", + ".yang": "yang", + ".y": "yacc", + ".yacc": "yacc", + ".yy": "yacc", + ".zep": "zephir", + ".zig": "zig", + ".zimpl": "zimpl", + ".zmpl": "zimpl", + ".zpl": "zimpl", + ".desktop": "desktop", + ".desktop.in": "desktop", + ".ec": "ec", + ".eh": "ec", + ".edn": "edn", + ".fish": "fish", + ".mu": "mupad", + ".nc": "nesc", + ".ooc": "ooc", + ".rst": "restructuredtext", + ".rest": "restructuredtext", + ".rest.txt": "restructuredtext", + ".rst.txt": "restructuredtext", + ".wisp": "wisp", + ".prg": "xbase", + ".prw": "xbase", +} + + +LANGUAGE_TAG = { + "c": "// the code file is written by C", + "c++": "// the code file is written by C++", + "cpp": "// the code file is written by C++", + "c#": "// the code file is written by C#", + "csharp": "// the code file is written by C#", + "c-sharp": "// the code file is written by C#", + "css": "/* the code file is written by CSS */", + "cuda": "// the code file is written by Cuda", + "dart": "// the code file is written by Dart", + "lua": "// the code file is written by Lua", + "objectivec": "// the code file is written by Objective-C", + "objective-c": "// the code file is written by Objective-C", + "objective-c++": "// the code file is written by Objective-C++", + "python": "# the code file is written by Python", + "perl": "# the code file is written by Perl", + "prolog": "% the code file is written by Prolog", + "swift": "// the code file is written by swift", + "lisp": "; the code file is written by Lisp", + "java": "// the code file is written by Java", + "scala": "// the code file is written by Scala", + "tex": "% the code file is written by TeX", + "vue": "", + "markdown": "", + "html": "", + "php": "// the code file is written by PHP", + "js": "// the code file is written by JavaScript", + "javascript": "// the code file is written by JavaScript", + "typescript": "// the code file is written by TypeScript", + "go": "// the code file is written by Go", + "shell": "# the code file is written by Shell", + "rust": "// the code file is written by Rust", + "sql": "-- the code file is written by SQL", + "kotlin": "// the code file is written by Kotlin", + "vb": "' the code file is written by Visual Basic", + "ruby": "# the code file is written by Ruby", + "pascal": "// the code file is written by Pascal", + "r": "# the code file is written by R", + "fortran": "!the code file is written by Fortran", + "lean": "-- the code file is written by Lean", + "matlab": "% the code file is written by Matlab", + "delphi": "{the code file is written by Delphi}", + "scheme": "; the code file is written by Scheme", + "basic": "' the code file is written by Basic", + "assembly": "; the code file is written by Assembly", + "groovy": "// the code file is written by Groovy", + "abap": "* the code file is written by Abap", + "gdscript": "# the code file is written by GDScript", + "haskell": "-- the code file is written by Haskell", + "julia": "# the code file is written by Julia", + "elixir": "# the code file is written by Elixir", + "excel": "' the code file is written by Excel", + "clojure": "; the code file is written by Clojure", + "actionscript": "// the code file is written by ActionScript", + "solidity": "// the code file is written by Solidity", + "powershell": "# the code file is written by PowerShell", + "erlang": "% the code file is written by Erlang", + "cobol": "// the code file is written by Cobol", + "alloy": "/* the code file is written by Alloy */", + "awk": "// the code file is written by AWK", + "thrift": "/* the code file is written by Thrift */", + "sparql": "# the code file is written by SPARQL", + "augeas": "// the code file is written by Augeas", + "cmake": "# the code file is written by CMake", + "f-sharp": "// the code file is written by F#", + "stan": "// the code file is written by Stan", + "isabelle": "(*the code file is written by Isabelle*)", + "dockerfile": "# the code file is written by Dockerfile", + "rmarkdown": "# the code file is written by RMarkdown", + "literate-agda": "-- the code file is written by Literate Agda", + "tcl": "// the code file is written by Augeas", + "glsl": "// the code file is written by GLSL", + "antlr": "// the code file is written by ANTLR", + "verilog": "// the code file is written by Verilog", + "racket": "; the code file is written by Racket", + "standard-ml": "(*the code file is written byStandard ML*)", + "elm": "-- the code file is written by Elm", + "yaml": "# the code file is written by YAML", + "smalltalk": "'' the code file is written by Smalltalk", + "ocaml": "(*the code file is written by OCaml*)", + "idris": "-- the code file is written by Idris", + "visual-basic": "' the code file is written by Visual Basic", + "protocol-buffer": "// the code file is written by Protocol Buffer", + "bluespec": "// the code file is written by Bluespec", + "applescript": "-- the code file is written by AppleScript", + "makefile": "# the code file is written by Makefile", + "tcsh": "# the code file is written by TCSH", + "maple": "# the code file is written by Maple", + "systemverilog": "// the code file is written by SystemVerilog", + "literate-coffeescript": "# the code file is written by Literate CoffeeScript", + "vhdl": "-- the code file is written by VHDL", + "restructuredtext": ".. the code file is written by reStructuredText", + "sas": "* the code file is written by SAS", + "literate-haskell": "> the code file is written by Literate Haskell", + "java-server-pages": "// the code file is written by Java Server Pages", + "coffeescript": "# the code file is written by CoffeeScript", + "emacs-lisp": "; the code file is written by Emacs Lisp", + "mathematica": "// the code file is written by Mathematica", + "xslt": "", + "zig": "// the code file is written by Zig", + "common-lisp": "; the code file is written by Common Lisp", + "stata": "* the code file is written by Stata", + "agda": "-- the code file is written by Agda", + "ada": "-- the code file is written by Ada", + "jsx": "// the code file is written by JSX", + "tsx": "// the code file is written by TypeScript with JSX", +} + + +def get_file_separator( + repo_level_spec: RepoLevelCodeCompletionSpecV1, filename: str +) -> str: + if repo_level_spec.file_separator == "": + lang = EXT2LANG.get(os.path.splitext(filename)[-1].lower(), None) + sep = "// " + if lang is None: + logger.warning( + f"Unsupported file extension for {filename}, use '//' as file separator as fallback" + ) + else: + sep = LANGUAGE_WRAPPER[lang] + + return sep.replace("", filename) + else: + return repo_level_spec.file_separator + filename diff --git a/xinference/model/llm/llama_cpp/core.py b/xinference/model/llm/llama_cpp/core.py index b820fce466..ddca1479d0 100644 --- a/xinference/model/llm/llama_cpp/core.py +++ b/xinference/model/llm/llama_cpp/core.py @@ -14,12 +14,13 @@ import logging import os import time -from typing import Iterable, Iterator, List, Optional, Union +from typing import Iterable, Iterator, List, Mapping, Optional, Union from ....types import ( ChatCompletion, ChatCompletionChunk, ChatCompletionMessage, + CodeGenerateMode, Completion, CompletionChunk, CompletionUsage, @@ -29,7 +30,7 @@ ) from ..core import LLM from ..llm_family import LLMFamilyV1, LLMSpecV1 -from ..utils import QWEN_TOOL_CALL_FAMILY, ChatModelMixin +from ..utils import QWEN_TOOL_CALL_FAMILY, ChatModelMixin, CodeModelMixin logger = logging.getLogger(__name__) @@ -309,3 +310,57 @@ def chat( self.model_family, self.model_uid, c, tools ) return self._to_chat_completion(c) + + +class LlamaCppCodeModel(LlamaCppModel, CodeModelMixin): + def __init__( + self, + model_uid: str, + model_family: "LLMFamilyV1", + model_spec: "LLMSpecV1", + quantization: str, + model_path: str, + llamacpp_model_config: Optional[LlamaCppModelConfig] = None, + ): + super().__init__( + model_uid, + model_family, + model_spec, + quantization, + model_path, + llamacpp_model_config, + ) + + @classmethod + def match( + cls, llm_family: LLMFamilyV1, llm_spec: LLMSpecV1, quantization: str + ) -> bool: + if llm_spec.model_format not in ["ggmlv3", "ggufv2"]: + return False + if "chatglm" in llm_family.model_name: + return False + return "code" in llm_family.model_ability + + def code_generate( + self, + generate_model: CodeGenerateMode, + prompt: str, + file_path: Optional[str], + suffix: Optional[str], + repo_name: Optional[str], + files: Optional[Mapping[str, str]], + generate_config: Optional[LlamaCppGenerateConfig] = None, + ): + code_prompt = self.get_code_prompt( + generate_model, + prompt, + file_path, + suffix, + repo_name, + files, + )["prompt"] + + if generate_config is not None and generate_config.get("stream", False): + generate_config["stream"] = False + + return self.generate(code_prompt, generate_config) diff --git a/xinference/model/llm/llm_family.json b/xinference/model/llm/llm_family.json index 26f1d599a8..d8e9843049 100644 --- a/xinference/model/llm/llm_family.json +++ b/xinference/model/llm/llm_family.json @@ -2098,31 +2098,6 @@ ] } }, - { - "version": 1, - "context_length": 65536, - "model_name": "codeqwen1.5", - "model_lang": [ - "en", - "zh" - ], - "model_ability": [ - "generate" - ], - "model_description": "CodeQwen1.5 is the Code-Specific version of Qwen1.5. It is a transformer-based decoder-only language model pretrained on a large amount of data of codes.", - "model_specs": [ - { - "model_format": "pytorch", - "model_size_in_billions": 7, - "quantizations": [ - "4-bit", - "8-bit", - "none" - ], - "model_id": "Qwen/CodeQwen1.5-7B" - } - ] - }, { "version": 1, "context_length": 65536, @@ -2580,6 +2555,86 @@ ] } }, + { + "version": 1, + "context_length": 65536, + "model_name": "codeqwen1.5", + "model_lang": [ + "en", + "zh" + ], + "model_ability": [ + "code" + ], + "model_description": "CodeQwen1.5 is the Code-Specific version of Qwen1.5. It is a transformer-based decoder-only language model pretrained on a large amount of data of codes.", + "model_specs": [ + { + "model_format": "pytorch", + "model_size_in_billions": 7, + "quantizations": [ + "4-bit", + "8-bit", + "none" + ], + "model_id": "Qwen/CodeQwen1.5-7B" + }, + { + "model_format": "awq", + "model_size_in_billions": 7, + "quantizations": [ + "Int4" + ], + "model_id": "Qwen/CodeQwen1.5-7B-AWQ" + } + ], + "code_prompt_style": { + "style_name": "CODEQWEN", + "fim_spec": { + "style": "PSM", + "prefix": "", + "middle": "", + "suffix": "" + }, + "repo_level_spec": { + "repo_name": "", + "file_type": "filepath", + "file_separator": "" + } + } + }, + { + "version": 1, + "context_length": 8192, + "model_name": "starcoder", + "model_lang": [ + "en" + ], + "model_ability": [ + "generate", + "code" + ], + "model_description": "Starcoder is an open-source Transformer based LLM that is trained on permissively licensed data from GitHub.", + "model_specs": [ + { + "model_format": "ggufv2", + "model_size_in_billions": 16, + "quantizations": [ + "q5_k_m" + ], + "model_id": "osukhoroslov-hw/starcoder-Q5_K_M-GGUF", + "model_file_name_template": "starcoder-{quantization}.gguf" + } + ], + "code_prompt_style": { + "style_name": "STARCODER", + "fim_spec": { + "style": "PSM", + "prefix": "", + "middle": "", + "suffix": "" + } + } + }, { "version": 1, "context_length": 1024, @@ -5149,7 +5204,8 @@ "zh" ], "model_ability": [ - "generate" + "generate", + "code" ], "model_description": "Deepseek Coder is composed of a series of code language models, each trained from scratch on 2T tokens, with a composition of 87% code and 13% natural language in both English and Chinese. ", "model_specs": [ @@ -5330,7 +5386,20 @@ "model_id": "TheBloke/deepseek-coder-33B-base-AWQ", "model_revision": "c7edb2d5868d61a5dcf2591933a8992c8cbe3ef4" } - ] + ], + "code_prompt_style": { + "style_name": "DEEPSEEK_CODER", + "fim_spec": { + "style": "PMS", + "prefix": "<|fim▁begin|>", + "middle": "<|fim▁hole|>", + "suffix": "<|fim▁end|>" + }, + "repo_level_spec": { + "file_type": "filename", + "file_separator": "" + } + } }, { "version": 1, diff --git a/xinference/model/llm/llm_family.py b/xinference/model/llm/llm_family.py index c2ea4d7b98..feaef4f83f 100644 --- a/xinference/model/llm/llm_family.py +++ b/xinference/model/llm/llm_family.py @@ -53,9 +53,11 @@ DEFAULT_CONTEXT_LENGTH = 2048 BUILTIN_LLM_PROMPT_STYLE: Dict[str, "PromptStyleV1"] = {} +BUILTIN_LLM_CODE_PROMPT_STYLE: Dict[str, "CodePromptStyleV1"] = {} BUILTIN_LLM_MODEL_CHAT_FAMILIES: Set[str] = set() BUILTIN_LLM_MODEL_GENERATE_FAMILIES: Set[str] = set() BUILTIN_LLM_MODEL_TOOL_CALL_FAMILIES: Set[str] = set() +BUILTIN_LLM_MODEL_CODE_FAMILIES: Set[str] = set() class LlamaCppLLMSpecV1(BaseModel): @@ -137,17 +139,37 @@ class PromptStyleV1(BaseModel): stop_token_ids: Optional[List[int]] +class FIMSpecV1(BaseModel): + style: Literal["PSM", "PMS"] + prefix: str + middle: str + suffix: str + + +class RepoLevelCodeCompletionSpecV1(BaseModel): + repo_name: Optional[str] + file_type: Literal["filepath", "filename"] + file_separator: str + + +class CodePromptStyleV1(BaseModel): + style_name: str + fim_spec: Optional["FIMSpecV1"] + repo_level_spec: Optional["RepoLevelCodeCompletionSpecV1"] + + class LLMFamilyV1(BaseModel): version: Literal[1] context_length: Optional[int] = DEFAULT_CONTEXT_LENGTH model_name: str model_lang: List[str] - model_ability: List[Literal["embed", "generate", "chat", "tools", "vision"]] + model_ability: List[Literal["embed", "generate", "chat", "tools", "vision", "code"]] model_description: Optional[str] # reason for not required str here: legacy registration model_family: Optional[str] model_specs: List["LLMSpecV1"] prompt_style: Optional["PromptStyleV1"] + code_prompt_style: Optional["CodePromptStyleV1"] class CustomLLMFamilyV1(LLMFamilyV1): diff --git a/xinference/model/llm/llm_family_modelscope.json b/xinference/model/llm/llm_family_modelscope.json index 44ac3e7794..3fc23585ac 100644 --- a/xinference/model/llm/llm_family_modelscope.json +++ b/xinference/model/llm/llm_family_modelscope.json @@ -2907,32 +2907,6 @@ ] } }, - { - "version": 1, - "context_length": 65536, - "model_name": "codeqwen1.5", - "model_lang": [ - "en", - "zh" - ], - "model_ability": [ - "generate" - ], - "model_description": "CodeQwen1.5 is the Code-Specific version of Qwen1.5. It is a transformer-based decoder-only language model pretrained on a large amount of data of codes.", - "model_specs": [ - { - "model_format": "pytorch", - "model_size_in_billions": 7, - "quantizations": [ - "4-bit", - "8-bit", - "none" - ], - "model_id": "qwen/CodeQwen1.5-7B", - "model_hub": "modelscope" - } - ] - }, { "version": 1, "context_length": 65536, @@ -3004,6 +2978,66 @@ ] } }, + { + "version": 1, + "context_length": 65536, + "model_name": "codeqwen1.5", + "model_lang": [ + "en", + "zh" + ], + "model_ability": [ + "generate", + "code" + ], + "model_description": "CodeQwen1.5 is the Code-Specific version of Qwen1.5. It is a transformer-based decoder-only language model pretrained on a large amount of data of codes.", + "model_specs": [ + { + "model_format": "pytorch", + "model_size_in_billions": 7, + "quantizations": [ + "4-bit", + "8-bit", + "none" + ], + "model_id": "qwen/CodeQwen1.5-7B", + "model_hub": "modelscope" + } + ], + "prompt_style": { + "style_name": "QWEN", + "system_prompt": "You are a helpful assistant.", + "roles": [ + "user", + "assistant" + ], + "intra_message_sep": "\n", + "stop_token_ids": [ + 2, + 3, + 4 + ], + "stop": [ + "<|endoftext|>", + "<|im_start|>", + "<|im_end|>" + ] + }, + "code_prompt_style": { + "style_name": "CODEQWEN", + "fim_spec": { + "style": "PSM", + "prefix": "", + "middle": "", + "suffix": "" + }, + "repo_level_spec": { + "repo_name": "", + "file_type": "filepath", + "file_separator": "" + } + } + }, { "version": 1, "context_length": 32768, @@ -3528,7 +3562,8 @@ "zh" ], "model_ability": [ - "generate" + "generate", + "code" ], "model_description": "Deepseek Coder is composed of a series of code language models, each trained from scratch on 2T tokens, with a composition of 87% code and 13% natural language in both English and Chinese.", "model_specs": [ @@ -3565,7 +3600,20 @@ "model_id": "deepseek-ai/deepseek-coder-33b-base", "model_hub": "modelscope" } - ] + ], + "code_prompt_style": { + "style_name": "DEEPSEEK_CODER", + "fim_spec": { + "style": "PMS", + "prefix": "<|fim▁begin|>", + "middle": "<|fim▁hole|>", + "suffix": "<|fim▁end|>" + }, + "repo_level_spec": { + "file_type": "filename", + "file_separator": "" + } + } }, { "version": 1, @@ -4066,7 +4114,7 @@ "chat", "vision" ], - "model_description":"OmniLMM is a family of open-source large multimodal models (LMMs) adept at vision & language modeling.", + "model_description":"mniLMM is a family of open-source large multimodal models (LMMs) adept at vision & language modeling.", "model_specs":[ { "model_format":"pytorch", diff --git a/xinference/model/llm/tests/test_llm_family.py b/xinference/model/llm/tests/test_llm_family.py index 252491282c..6a75b1bec4 100644 --- a/xinference/model/llm/tests/test_llm_family.py +++ b/xinference/model/llm/tests/test_llm_family.py @@ -159,7 +159,7 @@ def test_serialize_llm_family_v1(): prompt_style=prompt_style, ) - expected = """{"version": 1, "context_length": 2048, "model_name": "TestModel", "model_lang": ["en"], "model_ability": ["embed", "generate"], "model_description": null, "model_family": null, "model_specs": [{"model_format": "ggufv2", "model_hub": "huggingface", "model_size_in_billions": 2, "quantizations": ["q4_0", "q4_1"], "quantization_parts": {"q4_2": ["a", "b"]}, "model_id": "example/TestModel", "model_revision": "123", "model_file_name_template": "TestModel.{quantization}.bin", "model_file_name_split_template": "TestModel.{quantization}.bin.{part}", "model_uri": null}, {"model_format": "pytorch", "model_hub": "huggingface", "model_size_in_billions": 3, "quantizations": ["int8", "int4", "none"], "model_id": "example/TestModel", "model_revision": "456", "model_uri": null}], "prompt_style": {"style_name": "ADD_COLON_SINGLE", "system_prompt": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.", "roles": ["user", "assistant"], "intra_message_sep": "\\n### ", "inter_message_sep": "\\n### ", "stop": null, "stop_token_ids": null}}""" + expected = """{"version": 1, "context_length": 2048, "model_name": "TestModel", "model_lang": ["en"], "model_ability": ["embed", "generate"], "model_description": null, "model_family": null, "model_specs": [{"model_format": "ggufv2", "model_hub": "huggingface", "model_size_in_billions": 2, "quantizations": ["q4_0", "q4_1"], "quantization_parts": {"q4_2": ["a", "b"]}, "model_id": "example/TestModel", "model_revision": "123", "model_file_name_template": "TestModel.{quantization}.bin", "model_file_name_split_template": "TestModel.{quantization}.bin.{part}", "model_uri": null}, {"model_format": "pytorch", "model_hub": "huggingface", "model_size_in_billions": 3, "quantizations": ["int8", "int4", "none"], "model_id": "example/TestModel", "model_revision": "456", "model_uri": null}], "prompt_style": {"style_name": "ADD_COLON_SINGLE", "system_prompt": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.", "roles": ["user", "assistant"], "intra_message_sep": "\\n### ", "inter_message_sep": "\\n### ", "stop": null, "stop_token_ids": null}, "code_prompt_style": null}""" assert json.loads(llm_family.json()) == json.loads(expected) llm_family_context_length = LLMFamilyV1( diff --git a/xinference/model/llm/tests/test_utils.py b/xinference/model/llm/tests/test_utils.py index 42125a0048..786af4e413 100644 --- a/xinference/model/llm/tests/test_utils.py +++ b/xinference/model/llm/tests/test_utils.py @@ -11,10 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import pytest from ....types import ChatCompletionMessage -from ..llm_family import PromptStyleV1 -from ..utils import ChatModelMixin +from ..llm_family import CodePromptStyleV1, FIMSpecV1, PromptStyleV1 +from ..llm_family import RepoLevelCodeCompletionSpecV1 as RepoLevelSpecV1 +from ..utils import ChatModelMixin, CodeModelMixin def test_prompt_style_add_colon_single(): @@ -330,3 +332,490 @@ def test_is_valid_model_name(): assert not is_valid_model_name("foo/bar") assert not is_valid_model_name(" ") assert not is_valid_model_name("") + + +def test_path_to_name(): + path = "/home/test/works/project/main.py" + assert "main.py" == CodeModelMixin._path_to_name(path) + + path = "/main.py" + assert "main.py" == CodeModelMixin._path_to_name(path) + + path = "main.py" + assert "main.py" == CodeModelMixin._path_to_name(path) + + path = ".main.py" + assert ".main.py" == CodeModelMixin._path_to_name(path) + + path = r"\main.py" + assert "main.py" == CodeModelMixin._path_to_name(path) + + path = r"C:\works\main.py" + assert "main.py" == CodeModelMixin._path_to_name(path) + + +def test_code_prompt_style_starcoder(): + code_prompt_style = CodePromptStyleV1( + style_name="STARCODER", + fim_spec=FIMSpecV1( + style="PSM", + prefix="", + middle="", + suffix="", + ), + ) + prompt = "def print_hello_world():" + expected = prompt + assert expected == CodeModelMixin._get_code_prompt( + "completion", prompt, code_prompt_style + ) + + prompt = "def print_hello_world():\n " + suffix = "\n print('Hello world!')" + expected = "def print_hello_world():\n \n print('Hello world!')" + assert expected == CodeModelMixin._get_code_prompt( + "infill", prompt, code_prompt_style, None, suffix + ) + + suffix = None + with pytest.raises(ValueError) as exc_info: + CodeModelMixin._get_code_prompt( + "infill", prompt, code_prompt_style, None, suffix + ) + assert exc_info.value == ValueError("suffix is required in infill mode") + + with pytest.raises(ValueError) as exc_info: + CodeModelMixin._get_code_prompt("test", prompt, code_prompt_style) + assert exc_info.value == ValueError( + "Unsupported generate mode: test, only 'PSM' and 'PMS' are supported now" + ) + + +def test_code_prompt_style_deepseek_coder(): + code_prompt_style = CodePromptStyleV1( + style_name="DEEPSEEK_CODER", + fim_spec=FIMSpecV1( + style="PMS", + prefix="<|fim▁begin|>", + middle="<|fim▁hole|>", + suffix="<|fim▁end|>", + ), + repo_level_spec=RepoLevelSpecV1( + file_type="filename", file_separator="" + ), + ) + + prompt = "#write a quick sort algorithm" + expected = prompt + + assert expected == CodeModelMixin._get_code_prompt( + "completion", prompt, code_prompt_style + ) + + prompt = """def quick_sort(arr): + if len(arr) <= 1: + return arr + pivot = arr[0] + left = [] + right = [] +""" + suffix = """ + if arr[i] < pivot: + left.append(arr[i]) + else: + right.append(arr[i]) + return quick_sort(left) + [pivot] + quick_sort(right)""" + + expected = """<|fim▁begin|>def quick_sort(arr): + if len(arr) <= 1: + return arr + pivot = arr[0] + left = [] + right = [] +<|fim▁hole|> + if arr[i] < pivot: + left.append(arr[i]) + else: + right.append(arr[i]) + return quick_sort(left) + [pivot] + quick_sort(right)<|fim▁end|>""" + + assert expected == CodeModelMixin._get_code_prompt( + "infill", prompt, code_prompt_style, None, suffix + ) + + files = { + "utils.py": """import torch +from sklearn import datasets +from sklearn.model_selection import train_test_split +from sklearn.preprocessing import StandardScaler +from sklearn.metrics import accuracy_score + +def load_data(): + iris = datasets.load_iris() + X = iris.data + y = iris.target + + # Standardize the data + scaler = StandardScaler() + X = scaler.fit_transform(X) + + X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42) + + # Convert numpy data to PyTorch tensors + X_train = torch.tensor(X_train, dtype=torch.float32) + X_test = torch.tensor(X_test, dtype=torch.float32) + y_train = torch.tensor(y_train, dtype=torch.int64) + y_test = torch.tensor(y_test, dtype=torch.int64) + + return X_train, X_test, y_train, y_test + +def evaluate_predictions(y_test, y_pred): + return accuracy_score(y_test, y_pred)""", + "model.py": """import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader, TensorDataset + +class IrisClassifier(nn.Module): + def __init__(self): + super(IrisClassifier, self).__init__() + self.fc = nn.Sequential( + nn.Linear(4, 16), + nn.ReLU(), + nn.Linear(16, 3) + ) + + def forward(self, x): + return self.fc(x) + + def train_model(self, X_train, y_train, epochs, lr, batch_size): + criterion = nn.CrossEntropyLoss() + optimizer = optim.Adam(self.parameters(), lr=lr) + + # Create DataLoader for batches + dataset = TensorDataset(X_train, y_train) + dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) + + for epoch in range(epochs): + for batch_X, batch_y in dataloader: + optimizer.zero_grad() + outputs = self(batch_X) + loss = criterion(outputs, batch_y) + loss.backward() + optimizer.step() + + def predict(self, X_test): + with torch.no_grad(): + outputs = self(X_test) + _, predicted = outputs.max(1) + return predicted.numpy()""", + } + + prompt = """from utils import load_data, evaluate_predictions +from model import IrisClassifier as Classifier + +def main(): + # Model training and evaluation +""" + file_path = "/home/test/works/proj01/main.py" + + expected = """# utils.py +import torch +from sklearn import datasets +from sklearn.model_selection import train_test_split +from sklearn.preprocessing import StandardScaler +from sklearn.metrics import accuracy_score + +def load_data(): + iris = datasets.load_iris() + X = iris.data + y = iris.target + + # Standardize the data + scaler = StandardScaler() + X = scaler.fit_transform(X) + + X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42) + + # Convert numpy data to PyTorch tensors + X_train = torch.tensor(X_train, dtype=torch.float32) + X_test = torch.tensor(X_test, dtype=torch.float32) + y_train = torch.tensor(y_train, dtype=torch.int64) + y_test = torch.tensor(y_test, dtype=torch.int64) + + return X_train, X_test, y_train, y_test + +def evaluate_predictions(y_test, y_pred): + return accuracy_score(y_test, y_pred) +# model.py +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader, TensorDataset + +class IrisClassifier(nn.Module): + def __init__(self): + super(IrisClassifier, self).__init__() + self.fc = nn.Sequential( + nn.Linear(4, 16), + nn.ReLU(), + nn.Linear(16, 3) + ) + + def forward(self, x): + return self.fc(x) + + def train_model(self, X_train, y_train, epochs, lr, batch_size): + criterion = nn.CrossEntropyLoss() + optimizer = optim.Adam(self.parameters(), lr=lr) + + # Create DataLoader for batches + dataset = TensorDataset(X_train, y_train) + dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) + + for epoch in range(epochs): + for batch_X, batch_y in dataloader: + optimizer.zero_grad() + outputs = self(batch_X) + loss = criterion(outputs, batch_y) + loss.backward() + optimizer.step() + + def predict(self, X_test): + with torch.no_grad(): + outputs = self(X_test) + _, predicted = outputs.max(1) + return predicted.numpy() +# main.py +from utils import load_data, evaluate_predictions +from model import IrisClassifier as Classifier + +def main(): + # Model training and evaluation +""" + assert expected == CodeModelMixin._get_code_prompt( + "completion", prompt, code_prompt_style, file_path, None, None, files + ) + + +def test_code_prompt_style_codeqwen(): + code_prompt_style = CodePromptStyleV1( + style_name="CODEQWEN", + fim_spec=FIMSpecV1( + style="PSM", + prefix="", + middle="", + suffix="", + ), + repo_level_spec=RepoLevelSpecV1( + repo_name="", file_type="filepath", file_separator="" + ), + ) + + prompt = "#write a quick sort algorithm" + expected = prompt + + assert expected == CodeModelMixin._get_code_prompt( + "completion", prompt, code_prompt_style + ) + + prompt = """def quick_sort(arr): + if len(arr) <= 1: + return arr + pivot = arr[0] + left = [] + right = [] +""" + suffix = """ + if arr[i] < pivot: + left.append(arr[i]) + else: + right.append(arr[i]) + return quick_sort(left) + [pivot] + quick_sort(right)""" + + expected = """def quick_sort(arr): + if len(arr) <= 1: + return arr + pivot = arr[0] + left = [] + right = [] + + if arr[i] < pivot: + left.append(arr[i]) + else: + right.append(arr[i]) + return quick_sort(left) + [pivot] + quick_sort(right)""" + + assert expected == CodeModelMixin._get_code_prompt( + "infill", prompt, code_prompt_style, None, suffix + ) + + files = { + "/home/test/works/proj01/utils.py": """import torch +from sklearn import datasets +from sklearn.model_selection import train_test_split +from sklearn.preprocessing import StandardScaler +from sklearn.metrics import accuracy_score + +def load_data(): + iris = datasets.load_iris() + X = iris.data + y = iris.target + + # Standardize the data + scaler = StandardScaler() + X = scaler.fit_transform(X) + + X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42) + + # Convert numpy data to PyTorch tensors + X_train = torch.tensor(X_train, dtype=torch.float32) + X_test = torch.tensor(X_test, dtype=torch.float32) + y_train = torch.tensor(y_train, dtype=torch.int64) + y_test = torch.tensor(y_test, dtype=torch.int64) + + return X_train, X_test, y_train, y_test + +def evaluate_predictions(y_test, y_pred): + return accuracy_score(y_test, y_pred)""", + "/home/test/works/proj01/model.py": """import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader, TensorDataset + +class IrisClassifier(nn.Module): + def __init__(self): + super(IrisClassifier, self).__init__() + self.fc = nn.Sequential( + nn.Linear(4, 16), + nn.ReLU(), + nn.Linear(16, 3) + ) + + def forward(self, x): + return self.fc(x) + + def train_model(self, X_train, y_train, epochs, lr, batch_size): + criterion = nn.CrossEntropyLoss() + optimizer = optim.Adam(self.parameters(), lr=lr) + + # Create DataLoader for batches + dataset = TensorDataset(X_train, y_train) + dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) + + for epoch in range(epochs): + for batch_X, batch_y in dataloader: + optimizer.zero_grad() + outputs = self(batch_X) + loss = criterion(outputs, batch_y) + loss.backward() + optimizer.step() + + def predict(self, X_test): + with torch.no_grad(): + outputs = self(X_test) + _, predicted = outputs.max(1) + return predicted.numpy()""", + } + + prompt = """from utils import load_data, evaluate_predictions +from model import IrisClassifier as Classifier + +def main(): + # Model training and evaluation +""" + file_path = "/home/test/works/proj01/main.py" + + expected = """project01 +/home/test/works/proj01/utils.py +import torch +from sklearn import datasets +from sklearn.model_selection import train_test_split +from sklearn.preprocessing import StandardScaler +from sklearn.metrics import accuracy_score + +def load_data(): + iris = datasets.load_iris() + X = iris.data + y = iris.target + + # Standardize the data + scaler = StandardScaler() + X = scaler.fit_transform(X) + + X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42) + + # Convert numpy data to PyTorch tensors + X_train = torch.tensor(X_train, dtype=torch.float32) + X_test = torch.tensor(X_test, dtype=torch.float32) + y_train = torch.tensor(y_train, dtype=torch.int64) + y_test = torch.tensor(y_test, dtype=torch.int64) + + return X_train, X_test, y_train, y_test + +def evaluate_predictions(y_test, y_pred): + return accuracy_score(y_test, y_pred) +/home/test/works/proj01/model.py +import torch +import torch.nn as nn +import torch.optim as optim +from torch.utils.data import DataLoader, TensorDataset + +class IrisClassifier(nn.Module): + def __init__(self): + super(IrisClassifier, self).__init__() + self.fc = nn.Sequential( + nn.Linear(4, 16), + nn.ReLU(), + nn.Linear(16, 3) + ) + + def forward(self, x): + return self.fc(x) + + def train_model(self, X_train, y_train, epochs, lr, batch_size): + criterion = nn.CrossEntropyLoss() + optimizer = optim.Adam(self.parameters(), lr=lr) + + # Create DataLoader for batches + dataset = TensorDataset(X_train, y_train) + dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True) + + for epoch in range(epochs): + for batch_X, batch_y in dataloader: + optimizer.zero_grad() + outputs = self(batch_X) + loss = criterion(outputs, batch_y) + loss.backward() + optimizer.step() + + def predict(self, X_test): + with torch.no_grad(): + outputs = self(X_test) + _, predicted = outputs.max(1) + return predicted.numpy() +/home/test/works/proj01/main.py +from utils import load_data, evaluate_predictions +from model import IrisClassifier as Classifier + +def main(): + # Model training and evaluation +""" + assert expected == CodeModelMixin._get_code_prompt( + "completion", prompt, code_prompt_style, file_path, None, "project01", files + ) + + +def test_code_prompt_style_without_fim(): + code_prompt_style = CodePromptStyleV1( + style_name="NO_FIM_CODER", + ) + prompt = "def print_hello_world():\n " + suffix = "\n print('Hello world!')" + with pytest.raises(ValueError) as exc_info: + CodeModelMixin._get_code_prompt( + "infill", prompt, code_prompt_style, None, suffix + ) + assert exc_info.value == ValueError( + "This model is not support infill mode generate" + ) diff --git a/xinference/model/llm/transformers/core.py b/xinference/model/llm/transformers/core.py index 7b427c4918..e71bc78ea1 100644 --- a/xinference/model/llm/transformers/core.py +++ b/xinference/model/llm/transformers/core.py @@ -16,7 +16,7 @@ import logging import os from functools import lru_cache -from typing import Iterable, Iterator, List, Optional, Tuple, Union +from typing import Iterable, Iterator, List, Mapping, Optional, Tuple, Union import torch @@ -30,6 +30,7 @@ ChatCompletion, ChatCompletionChunk, ChatCompletionMessage, + CodeGenerateMode, Completion, CompletionChoice, CompletionChunk, @@ -41,7 +42,7 @@ from ...utils import select_device from ..core import LLM from ..llm_family import LLMFamilyV1, LLMSpecV1 -from ..utils import QWEN_TOOL_CALL_FAMILY, ChatModelMixin +from ..utils import QWEN_TOOL_CALL_FAMILY, ChatModelMixin, CodeModelMixin from .utils import get_context_length, get_max_src_len, pad_prefill_tokens logger = logging.getLogger(__name__) @@ -809,3 +810,62 @@ def handle_batch_inference_results(self, req_list: List[InferenceRequest]): req.completion = results else: req.completion[0] = self._to_chat_completion(req.completion[0]) + + +class PytorchCodeModel(PytorchModel, CodeModelMixin): + def __init__( + self, + model_uid: str, + model_family: "LLMFamilyV1", + model_spec: "LLMSpecV1", + quantization: str, + model_path: str, + pytorch_model_config: Optional[PytorchModelConfig] = None, + peft_model: Optional[List[LoRA]] = None, + ): + super().__init__( + model_uid, + model_family, + model_spec, + quantization, + model_path, + pytorch_model_config, + peft_model, + ) + + @classmethod + def match( + cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str + ) -> bool: + if llm_spec.model_format not in ["pytorch", "gptq", "awq"]: + return False + model_family = llm_family.model_family or llm_family.model_name + if model_family in NON_DEFAULT_MODEL_LIST: + return False + if "code" not in llm_family.model_ability: + return False + return True + + def code_generate( + self, + mode: CodeGenerateMode, + prompt: str, + file_path: Optional[str], + suffix: Optional[str], + repo_name: Optional[str], + files: Optional[Mapping[str, str]], + generate_config: Optional[PytorchGenerateConfig] = None, + ) -> Union[Completion, Iterator[CompletionChunk]]: + code_prompt = self.get_code_prompt( + mode, + prompt, + file_path, + suffix, + repo_name, + files, + )["prompt"] + + if generate_config is not None and generate_config.get("stream", False): + generate_config["stream"] = False + + return self.generate(code_prompt, generate_config) diff --git a/xinference/model/llm/utils.py b/xinference/model/llm/utils.py index 7f203d0c21..03259feb84 100644 --- a/xinference/model/llm/utils.py +++ b/xinference/model/llm/utils.py @@ -19,7 +19,7 @@ import time import uuid from io import BytesIO -from typing import AsyncGenerator, Dict, Iterator, List, Optional, Tuple, cast +from typing import AsyncGenerator, Dict, Iterator, List, Mapping, Optional, Tuple, cast import requests from PIL import Image @@ -29,10 +29,14 @@ ChatCompletion, ChatCompletionChunk, ChatCompletionMessage, + CodeGenerateMode, Completion, CompletionChunk, ) +from .core import LLM +from .lang_utils import get_file_separator from .llm_family import ( + CodePromptStyleV1, LlamaCppLLMSpecV1, LLMFamilyV1, LLMSpecV1, @@ -849,6 +853,131 @@ def get_full_prompt(cls, model_family, prompt, system_prompt, chat_history, tool return full_prompt +class CodeModelMixin: + def get_code_prompt( + self, + mode: CodeGenerateMode, + prompt: str, + file_path: Optional[str] = None, + suffix: Optional[str] = None, + repo_name: Optional[str] = None, + files: Optional[Mapping[str, str]] = None, + ): + code_prompt_style = cast(LLM, self).model_family.code_prompt_style + return { + "prompt": CodeModelMixin._get_code_prompt( + mode, prompt, code_prompt_style, file_path, suffix, repo_name, files + ) + } + + @staticmethod + def _get_code_prompt( + mode: CodeGenerateMode, + prompt: str, + code_prompt_style: Optional["CodePromptStyleV1"], + file_path: Optional[str] = None, + suffix: Optional[str] = None, + repo_name: Optional[str] = None, + files: Optional[Mapping[str, str]] = None, + ) -> str: + if code_prompt_style is None: + raise ValueError( + "code prompt style is not provided, the model spec is wrong." + ) + + if mode == "completion": + if suffix is not None: + logger.warning( + "Suffix is only required on generate type is infill, ignored" + ) + + spec = code_prompt_style.repo_level_spec + + if files is None or len(files) == 0: + if file_path is None or len(file_path.strip()) == 0: + return prompt + else: + if spec is None: + logger.warning( + "repository level file separator not defined, but file_path provided, ignored" + ) + return prompt + else: + repo_file = ( + file_path + if spec.file_type == "filepath" + else CodeModelMixin._path_to_name(file_path) + ) + return get_file_separator(spec, repo_file) + + if spec is None: + logger.warning( + "The model does not support repository level code completion, 'repo_name' and 'files' are ignored" + ) + return prompt + + chunks = [] + if spec.repo_name is None: + if repo_name is not None: + logger.warning( + "'repo_name' is provided but it will not be used in this model, ignored" + ) + else: + if repo_name is None: + raise ValueError( + "The 'repo_name' is required for repository level code completion for this model" + ) + chunks.append(f"{spec.repo_name}{repo_name}") + + for filepath, content in files.items(): + repo_file = ( + filepath + if spec.file_type == "filepath" + else CodeModelMixin._path_to_name(filepath) + ) + chunks.append(get_file_separator(spec, repo_file)) + chunks.append(content) + + if file_path is not None and len(file_path.strip()) > 0: + repo_file = ( + file_path + if spec.file_type == "filepath" + else CodeModelMixin._path_to_name(file_path) + ) + chunks.append(get_file_separator(spec, repo_file)) + chunks.append(prompt) + + return "\n".join(chunks) + + elif mode == "infill": + spec = code_prompt_style.fim_spec + if spec is None: + raise ValueError("This model is not support infill mode generate") + + if suffix is None: + raise ValueError("suffix is required in infill mode") + + if files is not None and len(files) > 0: + logger.warning( + "files is only required in repository level code completion, ignored" + ) + + if spec.style == "PSM": + return f"{spec.prefix}{prompt}{spec.suffix}{suffix}{spec.middle}" + else: + return f"{spec.prefix}{prompt}{spec.middle}{suffix}{spec.suffix}" + + else: + raise ValueError( + f"Unsupported generate mode: {mode}, only 'completion' and 'infill' are supported now" + ) + + @staticmethod + def _path_to_name(filepath: str) -> str: + filepath = filepath.replace("\\", "/") + return os.path.split(filepath)[1] + + def get_file_location( llm_family: LLMFamilyV1, spec: LLMSpecV1, quantization: str ) -> Tuple[str, bool]: diff --git a/xinference/model/llm/vllm/core.py b/xinference/model/llm/vllm/core.py index 4b009aa646..fe2ac3e4e5 100644 --- a/xinference/model/llm/vllm/core.py +++ b/xinference/model/llm/vllm/core.py @@ -26,6 +26,7 @@ Dict, Iterable, List, + Mapping, Optional, TypedDict, Union, @@ -35,6 +36,7 @@ ChatCompletion, ChatCompletionChunk, ChatCompletionMessage, + CodeGenerateMode, Completion, CompletionChoice, CompletionChunk, @@ -45,7 +47,7 @@ ) from .. import LLM, LLMFamilyV1, LLMSpecV1 from ..llm_family import CustomLLMFamilyV1 -from ..utils import QWEN_TOOL_CALL_FAMILY, ChatModelMixin +from ..utils import QWEN_TOOL_CALL_FAMILY, ChatModelMixin, CodeModelMixin logger = logging.getLogger(__name__) @@ -130,6 +132,10 @@ class VLLMGenerateConfig(TypedDict, total=False): "deepseek-chat", "deepseek-coder-instruct", ] +VLLM_SUPPORTED_CODE_MODELS = [ + "deepseek-coder", + "codeqwen1.5", +] if VLLM_INSTALLED and vllm.__version__ >= "0.3.0": VLLM_SUPPORTED_CHAT_MODELS.append("qwen1.5-chat") VLLM_SUPPORTED_MODELS.append("codeqwen1.5") @@ -743,3 +749,59 @@ async def async_chat( c = await self.async_generate(inputs, generate_config) assert not isinstance(c, AsyncGenerator) return self._to_chat_completion(c) + + +class VLLMCodeModel(VLLMModel, CodeModelMixin): + @classmethod + def match( + cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str + ) -> bool: + if llm_spec.model_format not in ["pytorch", "gptq", "awq"]: + return False + if llm_spec.model_format == "pytorch": + if quantization != "none" and quantization is not None: + return False + if llm_spec.model_format == "awq": + # Currently, only 4-bit weight quantization is supported for AWQ, but got 8 bits. + if "4" not in quantization: + return False + if llm_spec.model_format == "gptq": + if VLLM_INSTALLED and vllm.__version__ >= "0.3.3": + if not any(q in quantization for q in ("3", "4", "8")): + return False + else: + if "4" not in quantization: + return False + if isinstance(llm_family, CustomLLMFamilyV1): + if llm_family.model_family not in VLLM_SUPPORTED_CODE_MODELS: + return False + else: + if llm_family.model_name not in VLLM_SUPPORTED_CODE_MODELS: + return False + if "code" not in llm_family.model_ability: + return False + return VLLM_INSTALLED + + async def async_code_generate( + self, + mode: CodeGenerateMode, + prompt: str, + file_path: Optional[str], + suffix: Optional[str], + repo_name: Optional[str], + files: Optional[Mapping[str, str]], + generate_config: Optional[Dict] = None, + ) -> Union[Completion, AsyncGenerator[CompletionChunk, None]]: + code_prompt = self.get_code_prompt( + mode, + prompt, + file_path, + suffix, + repo_name, + files, + )["prompt"] + + if generate_config is not None and generate_config.get("stream", False): + generate_config["stream"] = False + + return await self.async_generate(code_prompt, generate_config) diff --git a/xinference/types.py b/xinference/types.py index 3f636d94c3..b941b5a80c 100644 --- a/xinference/types.py +++ b/xinference/types.py @@ -14,7 +14,7 @@ from typing import Any, Callable, Dict, ForwardRef, Iterable, List, Optional, Union -from typing_extensions import Literal, NotRequired, TypedDict +from typing_extensions import Literal, Mapping, NotRequired, TypedDict from ._compat import ( BaseModel, @@ -533,3 +533,13 @@ def from_dict(cls, data: Dict): image_lora_load_kwargs=data.get("image_lora_load_kwargs"), image_lora_fuse_kwargs=data.get("image_lora_fuse_kwargs"), ) + + +CodeGenerateMode = Literal["completion", "infill"] + + +class CreateCodeCompletion(CreateCompletion): + file_path: Optional[str] = none_field + mode: CodeGenerateMode = "completion" + repo_name: Optional[str] = none_field + files: Optional[Mapping[str, str]] = none_field diff --git a/xinference/web/ui/src/scenes/launch_model/modelCard.js b/xinference/web/ui/src/scenes/launch_model/modelCard.js index a4391f830a..790227f081 100644 --- a/xinference/web/ui/src/scenes/launch_model/modelCard.js +++ b/xinference/web/ui/src/scenes/launch_model/modelCard.js @@ -10,6 +10,7 @@ import { ExpandMore, Grade, HelpCenterOutlined, + LogoDevOutlined, RocketLaunchOutlined, StarBorder, UndoOutlined, @@ -892,6 +893,16 @@ const ModelCard = ({ chat model ) + } else if ( + modelData.model_ability && + modelData.model_ability.includes('code') + ) { + return ( +
+ + code model +
+ ) } else if ( modelData.model_ability && modelData.model_ability.includes('generate') diff --git a/xinference/web/ui/src/scenes/running_models/index.js b/xinference/web/ui/src/scenes/running_models/index.js index 4024c8f6d2..e3afd1b9e8 100644 --- a/xinference/web/ui/src/scenes/running_models/index.js +++ b/xinference/web/ui/src/scenes/running_models/index.js @@ -38,6 +38,63 @@ const RunningModels = () => { sessionStorage.setItem('runningModelType', newValue) } + function get_models(code_prompts) { + fetchWrapper + .get('/v1/models') + .then((response) => { + const newLlmData = [] + const newEmbeddingModelData = [] + const newImageModelData = [] + const newAudioModelData = [] + const newVideoModelData = [] + const newRerankModelData = [] + const newFlexibleModelData = [] + response.data.forEach((model) => { + let newValue = { + ...model, + id: model.id, + url: model.id, + } + if (newValue.model_type === 'LLM') { + if (model.model_name in code_prompts) { + newValue['infill_supported'] = + 'fim_spec' in code_prompts[model.model_name] + newValue['repo_level_supported'] = + 'repo_level_spec' in code_prompts[model.model_name] + } + newLlmData.push(newValue) + } else if (newValue.model_type === 'embedding') { + newEmbeddingModelData.push(newValue) + } else if (newValue.model_type === 'audio') { + newAudioModelData.push(newValue) + } else if (newValue.model_type === 'video') { + newVideoModelData.push(newValue) + } else if (newValue.model_type === 'image') { + newImageModelData.push(newValue) + } else if (newValue.model_type === 'rerank') { + newRerankModelData.push(newValue) + } else if (newValue.model_type === 'flexible') { + newFlexibleModelData.push(newValue) + } + }) + setLlmData(newLlmData) + setEmbeddingModelData(newEmbeddingModelData) + setAudioModelData(newAudioModelData) + setVideoModelData(newVideoModelData) + setImageModelData(newImageModelData) + setRerankModelData(newRerankModelData) + setFlexibleModelData(newFlexibleModelData) + setIsUpdatingModel(false) + }) + .catch((error) => { + console.error('Error:', error) + setIsUpdatingModel(false) + if (error.response.status !== 403 && error.response.status !== 401) { + setErrorMsg(error.message) + } + }) + } + const update = (isCallingApi) => { if ( sessionStorage.getItem('auth') === 'true' && @@ -71,45 +128,9 @@ const RunningModels = () => { setIsUpdatingModel(true) fetchWrapper - .get('/v1/models') - .then((response) => { - const newLlmData = [] - const newEmbeddingModelData = [] - const newImageModelData = [] - const newAudioModelData = [] - const newVideoModelData = [] - const newRerankModelData = [] - const newFlexibleModelData = [] - response.data.forEach((model) => { - let newValue = { - ...model, - id: model.id, - url: model.id, - } - if (newValue.model_type === 'LLM') { - newLlmData.push(newValue) - } else if (newValue.model_type === 'embedding') { - newEmbeddingModelData.push(newValue) - } else if (newValue.model_type === 'audio') { - newAudioModelData.push(newValue) - } else if (newValue.model_type === 'video') { - newVideoModelData.push(newValue) - } else if (newValue.model_type === 'image') { - newImageModelData.push(newValue) - } else if (newValue.model_type === 'rerank') { - newRerankModelData.push(newValue) - } else if (newValue.model_type === 'flexible') { - newFlexibleModelData.push(newValue) - } - }) - setLlmData(newLlmData) - setEmbeddingModelData(newEmbeddingModelData) - setAudioModelData(newAudioModelData) - setVideoModelData(newVideoModelData) - setImageModelData(newImageModelData) - setRerankModelData(newRerankModelData) - setFlexibleModelData(newFlexibleModelData) - setIsUpdatingModel(false) + .get('/v1/models/code_prompts') + .then((code_prompts) => { + get_models(code_prompts) }) .catch((error) => { console.error('Error:', error) @@ -228,6 +249,8 @@ const RunningModels = () => { model_ability: row.model_ability, model_description: row.model_description, model_lang: row.model_lang, + infill_supported: row.infill_supported, + repo_level_supported: row.repo_level_supported, }), }) .then((response) => response.json())