diff --git a/setup.cfg b/setup.cfg
index 7b83211c84..91e2256adb 100644
--- a/setup.cfg
+++ b/setup.cfg
@@ -292,7 +292,7 @@ exclude =
[codespell]
ignore-words-list = hist,rcall,fpr,ser,nd,inout,ot,Ba,ba,asend,hart,coo,splitted,datas,fro
-skip = .idea,.git,./build,./docs/build,node_modules,static,generated,*.po,*.ts,*.json,*.c,*.cpp,*.cfg,thirdparty
+skip = .idea,.git,./build,./docs/build,node_modules,static,generated,*.po,*.ts,*.json,*.c,*.cpp,*.cfg,thirdparty,xinference/model/llm/lang_utils.py
[isort]
profile = black
diff --git a/setup.py b/setup.py
index fe7c02301a..56f19a1759 100644
--- a/setup.py
+++ b/setup.py
@@ -73,6 +73,7 @@ class CustomDevelop(ExtraCommandMixin, develop):
class CustomSDist(ExtraCommandMixin, sdist):
pass
+
class BuildWeb(Command):
"""build_web command"""
diff --git a/xinference/api/restful_api.py b/xinference/api/restful_api.py
index 9c553c4c5d..acb206d650 100644
--- a/xinference/api/restful_api.py
+++ b/xinference/api/restful_api.py
@@ -62,6 +62,7 @@
ChatCompletionMessage,
Completion,
CreateChatCompletion,
+ CreateCodeCompletion,
CreateCompletion,
ImageList,
PeftModelConfig,
@@ -158,6 +159,8 @@ class BuildGradioInterfaceRequest(BaseModel):
model_ability: List[str]
model_description: str
model_lang: List[str]
+ infill_supported: Optional[bool]
+ repo_level_supported: Optional[bool]
class BuildGradioImageInterfaceRequest(BaseModel):
@@ -258,6 +261,9 @@ async def internal_exception_handler(request: Request, exc: Exception):
self._router.add_api_route(
"/v1/models/prompts", self._get_builtin_prompts, methods=["GET"]
)
+ self._router.add_api_route(
+ "/v1/models/code_prompts", self._get_builtin_code_prompts, methods=["GET"]
+ )
self._router.add_api_route(
"/v1/models/families", self._get_builtin_families, methods=["GET"]
)
@@ -554,6 +560,29 @@ async def internal_exception_handler(request: Request, exc: Exception):
),
)
+ self._router.add_api_route(
+ "/v1/code/completions",
+ self.create_code_completion,
+ methods=["POST"],
+ response_model=Completion,
+ dependencies=(
+ [Security(self._auth_service, scopes=["models:read"])]
+ if self.is_authenticated()
+ else None
+ ),
+ )
+
+ self._router.add_api_route(
+ "/v1/code/prompt",
+ self.get_code_prompt,
+ methods=["POST"],
+ dependencies=(
+ [Security(self._auth_service, scopes=["models:read"])]
+ if self.is_authenticated()
+ else None
+ ),
+ )
+
# for custom models
self._router.add_api_route(
"/v1/model_registrations/{model_type}",
@@ -743,6 +772,18 @@ async def _get_builtin_prompts(self) -> JSONResponse:
logger.error(e, exc_info=True)
raise HTTPException(status_code=500, detail=str(e))
+ async def _get_builtin_code_prompts(self) -> JSONResponse:
+ """
+ For internal usage
+ :return:
+ """
+ try:
+ data = await (await self._get_supervisor_ref()).get_builtin_code_prompts()
+ return JSONResponse(content=data)
+ except Exception as e:
+ logger.error(e, exc_info=True)
+ raise HTTPException(status_code=500, detail=str(e))
+
async def _get_builtin_families(self) -> JSONResponse:
"""
For internal usage
@@ -1003,6 +1044,8 @@ async def build_gradio_interface(
model_description=body.model_description,
model_lang=body.model_lang,
access_token=access_token,
+ infill_supported=body.infill_supported,
+ repo_level_supported=body.repo_level_supported,
).build()
gr.mount_gradio_app(self._app, interface, f"/{model_uid}")
except ValueError as ve:
@@ -1763,6 +1806,115 @@ async def stream_results():
self.handle_request_limit_error(e)
raise HTTPException(status_code=500, detail=str(e))
+ async def create_code_completion(self, request: Request) -> Response:
+ json_data = await request.json()
+
+ if "mode" in json_data and json_data["mode"] not in ("completion", "infill"):
+ raise HTTPException(
+ status_code=400,
+ detail="mode must be one of 'completion' or 'infill'",
+ )
+
+ if json_data.get("stream", False):
+ json_data["stream"] = False
+
+ body = CreateCodeCompletion.parse_obj(json_data)
+ exclude = {
+ "mode",
+ "prompt",
+ "file_path",
+ "suffix",
+ "repo_name",
+ "files",
+ "model",
+ "n",
+ "messages",
+ "logit_bias",
+ "logit_bias_type",
+ "user",
+ }
+
+ kwargs = body.dict(exclude_unset=True, exclude=exclude)
+
+ # TODO: Decide if this default value override is necessary #1061
+ if body.max_tokens is None:
+ kwargs["max_tokens"] = max_tokens_field.default
+
+ if body.logit_bias is not None:
+ raise HTTPException(status_code=501, detail="Not implemented")
+
+ model_uid = body.model
+
+ try:
+ model = await (await self._get_supervisor_ref()).get_model(model_uid)
+ except ValueError as ve:
+ logger.error(str(ve), exc_info=True)
+ await self._report_error_event(model_uid, str(ve))
+ raise HTTPException(status_code=400, detail=str(ve))
+ except Exception as e:
+ logger.error(e, exc_info=True)
+ await self._report_error_event(model_uid, str(e))
+ raise HTTPException(status_code=500, detail=str(e))
+
+ assert not body.stream
+
+ try:
+ data = await model.code_generate(
+ body.mode,
+ body.prompt,
+ body.file_path,
+ body.suffix,
+ body.repo_name,
+ body.files,
+ kwargs,
+ )
+ return Response(content=data, media_type="application/json")
+ except Exception as e:
+ logger.error(e, exc_info=True)
+ await self._report_error_event(model_uid, str(e))
+ self.handle_request_limit_error(e)
+ raise HTTPException(status_code=500, detail=str(e))
+
+ async def get_code_prompt(self, request: Request) -> Response:
+ json_data = await request.json()
+
+ if "mode" in json_data and json_data["mode"] not in ("completion", "infill"):
+ raise HTTPException(
+ status_code=400,
+ detail="mode must be one of 'completion' or 'infill'",
+ )
+
+ body = CreateCodeCompletion.parse_obj(json_data)
+
+ model_uid = body.model
+
+ try:
+ model = await (await self._get_supervisor_ref()).get_model(model_uid)
+ except ValueError as ve:
+ logger.error(str(ve), exc_info=True)
+ await self._report_error_event(model_uid, str(ve))
+ raise HTTPException(status_code=400, detail=str(ve))
+ except Exception as e:
+ logger.error(e, exc_info=True)
+ await self._report_error_event(model_uid, str(e))
+ raise HTTPException(status_code=500, detail=str(e))
+
+ try:
+ code_prompt = await model.get_code_prompt(
+ body.mode,
+ body.prompt,
+ body.file_path,
+ body.suffix,
+ body.repo_name,
+ body.files,
+ )
+ return Response(content=code_prompt, media_type="application/json")
+ except Exception as e:
+ logger.error(e, exc_info=True)
+ await self._report_error_event(model_uid, str(e))
+ self.handle_request_limit_error(e)
+ raise HTTPException(status_code=500, detail=str(e))
+
async def query_engines_by_model_name(self, model_name: str) -> JSONResponse:
try:
content = await (
diff --git a/xinference/client/restful/restful_client.py b/xinference/client/restful/restful_client.py
index 679f65d296..2d6f2a24d7 100644
--- a/xinference/client/restful/restful_client.py
+++ b/xinference/client/restful/restful_client.py
@@ -25,6 +25,7 @@
ChatCompletion,
ChatCompletionChunk,
ChatCompletionMessage,
+ CodeGenerateMode,
Completion,
CompletionChunk,
Embedding,
@@ -552,6 +553,144 @@ def chat(
return response_data
+class RESTfulCodeModelHandle(RESTfulGenerateModelHandle):
+ def code_generate(
+ self,
+ mode: "CodeGenerateMode",
+ prompt: str,
+ file_path: Optional[str] = None,
+ suffix: Optional[str] = None,
+ repo_name: Optional[str] = None,
+ files: Optional[typing.Mapping] = None,
+ generate_config: Optional[
+ Union["LlamaCppGenerateConfig", "PytorchGenerateConfig"]
+ ] = None,
+ ) -> "Completion":
+ """
+ Given code generation hint to complete the code, the model will return a response via RESTful APIs.
+
+ Parameters
+ ----------
+ mode: Literal["completion", "infill"]
+ Code Generation mode
+ Completion includes code fragment completion and repository level code completion
+ Infill is fill in middle completion, complete the code according provided prefix and suffix content.
+ prompt: str
+ The user's input, it presents prefix content in infill mode.
+ file_path: Optional[str]
+ The file path for prompt content file.
+ suffix: Optional[str]
+ The suffix content in infill mode.
+ repo_name: Optional[str]
+ The repository name in repository level code completion mode.
+ files: Optional[Mapping]
+ The file name/path and its content key values in repository level code completion mode
+ generate_config: Optional[Union["LlamaCppGenerateConfig", "PytorchGenerateConfig"]]
+ Additional configuration for the chat generation.
+ "LlamaCppGenerateConfig" -> configuration for ggml model
+ "PytorchGenerateConfig" -> configuration for pytorch model
+
+ Returns
+ -------
+ "Completion"
+
+ Raises
+ ------
+ RuntimeError
+ Report the failure to generate the code from the server. Detailed information provided in error message.
+
+ """
+
+ url = f"{self._base_url}/v1/code/completions"
+
+ request_body: Dict[str, Any] = {
+ "model": self._model_uid,
+ "mode": mode,
+ "prompt": prompt,
+ "file_path": file_path,
+ "suffix": suffix,
+ "repo_name": repo_name,
+ "files": files,
+ }
+
+ if generate_config is not None:
+ for key, value in generate_config.items():
+ request_body[key] = value
+
+ response = requests.post(url, json=request_body, headers=self.auth_headers)
+
+ if response.status_code != 200:
+ raise RuntimeError(
+ f"Failed to generate code completion, detail: {_get_error_string(response)}"
+ )
+
+ response_data = response.json()
+ return response_data
+
+ def get_code_prompt(
+ self,
+ mode: "CodeGenerateMode",
+ prompt: str,
+ file_path: Optional[str] = None,
+ suffix: Optional[str] = None,
+ repo_name: Optional[str] = None,
+ files: Optional[typing.Mapping] = None,
+ ) -> str:
+ """
+ Given code generating prompt which can be used to complete the code, the model will return a response via
+ RESTful APIs.
+
+ Parameters
+ ----------
+ mode: Literal["completion", "infill"]
+ Code Generation mode
+ Completion includes code fragment completion and repository level code completion
+ Infill is fill in middle completion, complete the code according provided prefix and suffix content.
+ prompt: str
+ The user's input, it presents prefix content in infill mode.
+ file_path: Optional[str]
+ The file path for prompt
+ suffix: Optional[str]
+ The suffix content in infill mode.
+ repo_name: Optional[str]
+ The repository name in repository level code completion mode.
+ files: Optional[Mapping]
+ The file name/path and its content key values in repository level code completion mode
+
+ Returns
+ -------
+ {"prompt": "generated prompt"}
+
+ Raises
+ ------
+ RuntimeError
+ Report the failure to generate the code prompt from the server.
+ Detailed information provided in error message.
+
+ """
+
+ url = f"{self._base_url}/v1/code/prompt"
+
+ request_body: Dict[str, Any] = {
+ "model": self._model_uid,
+ "mode": mode,
+ "prompt": prompt,
+ "file_path": file_path,
+ "suffix": suffix,
+ "repo_name": repo_name,
+ "files": files,
+ }
+
+ response = requests.post(url, json=request_body, headers=self.auth_headers)
+
+ if response.status_code != 200:
+ raise RuntimeError(
+ f"Failed to generate code prompt generating, detail: {_get_error_string(response)}"
+ )
+
+ return response.json()
+
+
class RESTfulAudioModelHandle(RESTfulModelHandle):
def transcriptions(
self,
@@ -1032,6 +1171,10 @@ def get_model(self, model_uid: str) -> RESTfulModelHandle:
return RESTfulChatModelHandle(
model_uid, self.base_url, auth_headers=self._headers
)
+ elif "code" in desc["model_ability"]:
+ return RESTfulCodeModelHandle(
+ model_uid, self.base_url, auth_headers=self._headers
+ )
elif "generate" in desc["model_ability"]:
return RESTfulGenerateModelHandle(
model_uid, self.base_url, auth_headers=self._headers
@@ -1384,6 +1527,36 @@ def abort_request(self, model_uid: str, request_id: str):
response_data = response.json()
return response_data
+ def list_builtin_prompts(self):
+ """
+ Get the builtin prompts
+ :return: List[Dict[str, Any]]
+ The builtin prompts
+ """
+ url = f"{self.base_url}/v1/models/prompts"
+ response = requests.get(url, headers=self._headers)
+ if response.status_code != 200:
+ raise RuntimeError(
+ f"Failed to get builtin prompts, details: {_get_error_string(response)}"
+ )
+ response_data = response.json()
+ return response_data
+
+ def list_builtin_code_prompts(self):
+ """
+ Get the builtin code prompts
+ :return: List[Dict[str, Any]]
+ The builtin code prompts
+ """
+ url = f"{self.base_url}/v1/models/code_prompts"
+ response = requests.get(url, headers=self._headers)
+ if response.status_code != 200:
+ raise RuntimeError(
+ f"Failed to get builtin code prompts, details: {_get_error_string(response)}"
+ )
+ response_data = response.json()
+ return response_data
+
def get_workers_info(self):
url = f"{self.base_url}/v1/workers"
response = requests.get(url, headers=self._headers)
diff --git a/xinference/core/chat_interface.py b/xinference/core/chat_interface.py
index 8738141f90..3d98e344e0 100644
--- a/xinference/core/chat_interface.py
+++ b/xinference/core/chat_interface.py
@@ -20,11 +20,12 @@
import gradio as gr
import PIL.Image
-from gradio.components import Markdown, Textbox
+from gradio.components import Code, Dropdown, File, Markdown, Textbox
from gradio.layouts import Accordion, Column, Row
from ..client.restful.restful_client import (
RESTfulChatModelHandle,
+ RESTfulCodeModelHandle,
RESTfulGenerateModelHandle,
)
from ..types import ChatCompletionMessage
@@ -32,6 +33,34 @@
logger = logging.getLogger(__name__)
+def compare_history(current, hist):
+ if current["mode"] != hist["mode"]:
+ return False
+
+ if current["prompt"] != hist["prompt"]:
+ return False
+
+ if current["file_path"] != hist["file_path"]:
+ return False
+
+ if current["suffix"] != hist["suffix"]:
+ return False
+
+ if current["files"] != current["files"]:
+ return False
+
+ return True
+
+
+EMPTY = {
+ "mode": "Code Completion",
+ "prompt": "",
+ "file_path": "",
+ "suffix": "",
+ "files": None,
+}
+
+
class GradioInterface:
def __init__(
self,
@@ -47,6 +76,8 @@ def __init__(
model_description: str,
model_lang: List[str],
access_token: Optional[str],
+ infill_supported: Optional[bool],
+ repo_level_supported: Optional[bool],
):
self.endpoint = endpoint
self.model_uid = model_uid
@@ -62,12 +93,16 @@ def __init__(
self._access_token = (
access_token.replace("Bearer ", "") if access_token is not None else None
)
+ self.infill_supported = infill_supported
+ self.repo_level_supported = repo_level_supported
def build(self) -> "gr.Blocks":
if "vision" in self.model_ability:
interface = self.build_chat_vl_interface()
elif "chat" in self.model_ability:
interface = self.build_chat_interface()
+ elif "code" in self.model_ability:
+ interface = self.build_code_generate_interface()
else:
interface = self.build_generate_interface()
@@ -401,6 +436,414 @@ def update_button(text):
return chat_vl_interface
+ def build_code_generate_interface(
+ self,
+ ):
+ def undo(g_mode, text, g_file_path, g_suffix, g_files, hist):
+ current = {
+ "mode": g_mode,
+ "prompt": text,
+ "file_path": g_file_path,
+ "suffix": g_suffix,
+ "files": g_files,
+ }
+
+ if len(hist) == 0:
+ return {
+ generate_mode: "Code Completion",
+ prompt: "",
+ file_path: "",
+ suffix: "",
+ files: None,
+ history: [current],
+ }
+ if compare_history(current, hist[-1]):
+ hist = hist[:-1]
+
+ req = hist[-1] if len(hist) > 0 else EMPTY
+
+ return {
+ generate_mode: req["mode"],
+ prompt: req["prompt"],
+ file_path: req["file_path"],
+ suffix: req["suffix"],
+ files: g_files,
+ history: hist,
+ }
+
+ def clear(g_mode, text, g_file_path, g_suffix, g_files, hist):
+ current = {
+ "mode": g_mode,
+ "prompt": text,
+ "file_path": g_file_path,
+ "suffix": g_suffix,
+ "files": g_files,
+ }
+ if len(hist) == 0 or (
+ len(hist) > 0 and not compare_history(current, hist[-1])
+ ):
+ hist.append(current)
+ hist.append(EMPTY)
+ return {
+ generate_mode: "Code Completion",
+ prompt: "",
+ file_path: "",
+ suffix: "",
+ files: None,
+ history: hist,
+ }
+
+ def complete(
+ g_mode, text, g_file_path, g_suffix, g_files, hist, max_tokens, temperature
+ ):
+ from ..client import RESTfulClient
+
+ client = RESTfulClient(self.endpoint)
+ client._set_token(self._access_token)
+
+ model = client.get_model(self.model_uid)
+ assert isinstance(model, RESTfulCodeModelHandle)
+
+ repo_files = (
+ {k: open(k, mode="r", encoding="utf8").read() for k in g_files}
+ if g_files
+ else None
+ )
+
+ current = {
+ "mode": g_mode,
+ "prompt": text,
+ "file_path": g_file_path,
+ "suffix": g_suffix,
+ "files": g_files,
+ }
+ if len(hist) == 0 or (
+ len(hist) > 0 and not compare_history(current, hist[-1])
+ ):
+ hist.append(current)
+
+ response_content = text
+
+ if g_mode == "Code Completion":
+ if self.repo_level_supported:
+ resp = model.code_generate(
+ "completion",
+ prompt=text,
+ file_path=g_file_path,
+ files=repo_files,
+ generate_config={
+ "max_tokens": max_tokens,
+ "temperature": temperature,
+ },
+ )
+ else:
+ resp = model.code_generate(
+ mode="completion",
+ prompt=text,
+ file_path=g_file_path,
+ generate_config={
+ "max_tokens": max_tokens,
+ "temperature": temperature,
+ },
+ )
+ else:
+ resp = model.code_generate(
+ mode="infill",
+ prompt=text,
+ suffix=g_suffix,
+ generate_config={
+ "max_tokens": max_tokens,
+ "temperature": temperature,
+ },
+ )
+ assert isinstance(resp, dict)
+ choice = resp["choices"][0]
+
+ response_content += choice["text"]
+
+ current = {
+ "mode": g_mode,
+ "prompt": response_content,
+ "file_path": g_file_path,
+ "suffix": g_suffix,
+ "files": g_files,
+ }
+
+ hist.append(current)
+ return {
+ prompt: response_content,
+ history: hist,
+ }
+
+ def retry(
+ g_mode, text, g_suffix, g_file_path, g_files, hist, max_tokens, temperature
+ ):
+ from ..client import RESTfulClient
+
+ client = RESTfulClient(self.endpoint)
+ client._set_token(self._access_token)
+
+ model = client.get_model(self.model_uid)
+ assert isinstance(model, RESTfulCodeModelHandle)
+
+ current = {
+ "mode": g_mode,
+ "prompt": text,
+ "file_path": g_file_path,
+ "suffix": g_suffix,
+ "files": g_files,
+ }
+
+ if len(hist) == 0 or (
+ len(hist) > 0 and not compare_history(current, hist[-1])
+ ):
+ hist.append(current)
+
+ req = hist[-2] if len(hist) > 1 else EMPTY
+
+ response_content = req["prompt"]
+
+ repo_files = (
+ {k: open(k, mode="r", encoding="utf8").read() for k in req["files"]}
+ if req["files"]
+ else None
+ )
+
+ resp = model.code_generate(
+ mode="completion" if req["mode"] == "Code Completion" else "infill",
+ prompt=req["prompt"],
+ file_path=req["file_path"],
+ suffix=req["suffix"],
+ files=repo_files,
+ generate_config={
+ "max_tokens": max_tokens,
+ "temperature": temperature,
+ },
+ )
+ assert isinstance(resp, dict)
+ choice = resp["choices"][0]
+ response_content += choice["text"]
+
+ req["prompt"] = response_content
+
+ hist.append(req)
+ return {
+ generate_mode: req["mode"],
+ prompt: response_content,
+ file_path: req["file_path"],
+ suffix: req["suffix"],
+ files: req["files"],
+ history: hist,
+ }
+
+ def mode_change(generate_mode):
+ if generate_mode == "Code Completion":
+ return {
+ file_path: Textbox(
+ container=True,
+ show_label=True,
+ label="Prompt file path",
+ interactive=self.repo_level_supported,
+ ),
+ suffix: Code(
+ container=True,
+ show_label=True,
+ label="Suffix",
+ lines=21,
+ visible=False,
+ ),
+ files: File(
+ container=False,
+ show_label=False,
+ label="Files",
+ file_count="multiple",
+ visible=self.repo_level_supported,
+ ),
+ }
+ else:
+ return {
+ file_path: Textbox(
+ container=True,
+ show_label=True,
+ label="Prompt file path",
+ interactive=True,
+ visible=False,
+ ),
+ suffix: Code(
+ container=True,
+ show_label=True,
+ label="Suffix",
+ lines=21,
+ interactive=True,
+ visible=self.infill_supported,
+ ),
+ files: File(
+ container=False,
+ show_label=False,
+ label="Files",
+ file_count="multiple",
+ visible=False,
+ ),
+ }
+
+ with gr.Blocks(
+ title=f"🚀 Xinference Code Generate Bot : {self.model_name} 🚀",
+ css="""
+ .center{
+ display: flex;
+ justify-content: center;
+ align-items: center;
+ padding: 0px;
+ color: #9ea4b0 !important;
+ }
+ """,
+ analytics_enabled=False,
+ ) as code_generate_interface:
+ modes = (
+ ["Code Completion", "Code Infill Completion"]
+ if self.infill_supported
+ else ["Code Completion"]
+ )
+
+ history = gr.State([])
+
+ Markdown(
+ f"""
+
🚀 Xinference Code Generate Bot : {self.model_name} 🚀
+ """
+ )
+ Markdown(
+ f"""
+
+ Model ID: {self.model_uid}
+
+
+ Model Size: {self.model_size_in_billions} Billion Parameters
+
+
+ Model Format: {self.model_format}
+
+
+ Model Quantization: {self.quantization}
+
+
+ Support Infill Code Completion: {self.infill_supported}
+
+
+ Support Repository Level Code Completion: {self.repo_level_supported}
+
+ """
+ )
+
+ with Column(variant="panel"):
+ generate_mode = Dropdown(
+ container=True,
+ show_label=True,
+ label="Code Generate Mode",
+ choices=modes,
+ value=modes[0],
+ )
+
+ prompt = Code(
+ container=True,
+ show_label=True,
+ label="Prompt",
+ lines=21,
+ interactive=True,
+ )
+
+ suffix = Code(
+ container=True,
+ show_label=True,
+ label="Suffix",
+ lines=21,
+ visible=False,
+ )
+
+ file_path = Textbox(
+ container=True,
+ show_label=True,
+ label="Prompt file path",
+ interactive=True,
+ )
+
+ files = File(
+ container=True,
+ show_label=True,
+ label="Repository Files",
+ file_count="multiple",
+ interactive=True,
+ visible=self.repo_level_supported,
+ )
+
+ with Row():
+ btn_generate = gr.Button("Generate", variant="primary")
+ with Row():
+ btn_undo = gr.Button("↩️ Undo")
+ btn_retry = gr.Button("🔄 Retry")
+ btn_clear = gr.Button("🗑️ Clear")
+ with Accordion("Additional Inputs", open=False):
+ length = gr.Slider(
+ minimum=1,
+ maximum=self.context_length,
+ value=1024,
+ step=1,
+ label="Max Tokens",
+ )
+ temperature = gr.Slider(
+ minimum=0, maximum=2, value=1, step=0.01, label="Temperature"
+ )
+
+ generate_mode.change(
+ fn=mode_change,
+ inputs=[generate_mode],
+ outputs=[file_path, suffix, files],
+ )
+
+ btn_generate.click(
+ fn=complete,
+ inputs=[
+ generate_mode,
+ prompt,
+ file_path,
+ suffix,
+ files,
+ history,
+ length,
+ temperature,
+ ],
+ outputs=[prompt, history],
+ )
+
+ btn_undo.click(
+ fn=undo,
+ inputs=[generate_mode, prompt, file_path, suffix, files, history],
+ outputs=[generate_mode, prompt, file_path, suffix, files, history],
+ )
+
+ btn_retry.click(
+ fn=retry,
+ inputs=[
+ generate_mode,
+ prompt,
+ file_path,
+ suffix,
+ files,
+ history,
+ length,
+ temperature,
+ ],
+ outputs=[generate_mode, prompt, file_path, suffix, files, history],
+ )
+
+ btn_clear.click(
+ fn=clear,
+ inputs=[generate_mode, prompt, file_path, suffix, files, history],
+ outputs=[generate_mode, prompt, file_path, suffix, files, history],
+ )
+
+ return code_generate_interface
+
def build_generate_interface(
self,
):
diff --git a/xinference/core/model.py b/xinference/core/model.py
index 602f712514..ad32528673 100644
--- a/xinference/core/model.py
+++ b/xinference/core/model.py
@@ -33,6 +33,7 @@
Generator,
Iterator,
List,
+ Mapping,
Optional,
Union,
)
@@ -41,6 +42,7 @@
import xoscar as xo
from ..constants import XINFERENCE_TRANSFORMERS_ENABLE_BATCHING
+from ..types import CodeGenerateMode
if TYPE_CHECKING:
from .worker import WorkerActor
@@ -565,6 +567,99 @@ async def abort_request(self, request_id: str) -> str:
return await self._scheduler_ref.abort_request(request_id)
return AbortRequestMessage.NO_OP.name
+ @log_async(logger=logger)
+ @request_limit
+ @xo.generator
+ async def code_generate(
+ self,
+ mode: CodeGenerateMode,
+ prompt: str,
+ file_path: Optional[str],
+ suffix: Optional[str],
+ repo_name: Optional[str],
+ files: Optional[Mapping[str, str]],
+ *args,
+ **kwargs,
+ ):
+ start_time = time.time()
+ response = None
+ try:
+ if hasattr(self._model, "code_generate"):
+ response = await self._call_wrapper_json(
+ self._model.code_generate,
+ mode,
+ prompt,
+ file_path,
+ suffix,
+ repo_name,
+ files,
+ *args,
+ **kwargs,
+ )
+ return response
+ if hasattr(self._model, "async_code_generate"):
+ response = await self._call_wrapper_json(
+ self._model.async_code_generate,
+ mode,
+ prompt,
+ file_path,
+ suffix,
+ repo_name,
+ files,
+ *args,
+ **kwargs,
+ )
+ return response
+ raise AttributeError(
+ f"Model {self._model.model_spec} is not for code generate."
+ )
+ finally:
+ # For the non stream result.
+ record = None
+ if isinstance(response, (Generator, AsyncGenerator)):
+ record = response
+ elif isinstance(response, bytes):
+ record = json.loads(response)
+ if record and isinstance(record, dict):
+ usage = record["usage"]
+ # Some backends may not have a valid usage, we just skip them.
+ completion_tokens = usage["completion_tokens"]
+ prompt_tokens = usage["prompt_tokens"]
+ await self._record_completion_metrics(
+ time.time() - start_time,
+ completion_tokens,
+ prompt_tokens,
+ )
+
+ @log_async(logger=logger)
+ @request_limit
+ @xo.generator
+ async def get_code_prompt(
+ self,
+ mode: CodeGenerateMode,
+ prompt: str,
+ file_path: Optional[str],
+ suffix: Optional[str],
+ repo_name: Optional[str],
+ files: Optional[Mapping[str, str]],
+ ):
+ from ..model.llm.utils import CodeModelMixin
+
+ if isinstance(self._model, CodeModelMixin):
+ return await self._call_wrapper_json(
+ self._model.get_code_prompt,
+ mode,
+ prompt,
+ file_path,
+ suffix,
+ repo_name,
+ files,
+ )
+ else:
+ raise ValueError(
+ f"Model {self._model.model_family.model_name} does not support code generating"
+ )
+
@log_async(logger=logger)
@request_limit
async def create_embedding(self, input: Union[str, List[str]], *args, **kwargs):
diff --git a/xinference/core/supervisor.py b/xinference/core/supervisor.py
index 2b6f7b9fc5..d2c71265cc 100644
--- a/xinference/core/supervisor.py
+++ b/xinference/core/supervisor.py
@@ -313,10 +313,20 @@ async def get_builtin_prompts() -> Dict[str, Any]:
data[k] = v.dict()
return data
+ @staticmethod
+ async def get_builtin_code_prompts() -> Dict[str, Any]:
+ from ..model.llm.llm_family import BUILTIN_LLM_CODE_PROMPT_STYLE
+
+ data = {}
+ for k, v in BUILTIN_LLM_CODE_PROMPT_STYLE.items():
+ data[k] = v.dict()
+ return data
+
@staticmethod
async def get_builtin_families() -> Dict[str, List[str]]:
from ..model.llm.llm_family import (
BUILTIN_LLM_MODEL_CHAT_FAMILIES,
+ BUILTIN_LLM_MODEL_CODE_FAMILIES,
BUILTIN_LLM_MODEL_GENERATE_FAMILIES,
BUILTIN_LLM_MODEL_TOOL_CALL_FAMILIES,
)
@@ -325,6 +335,7 @@ async def get_builtin_families() -> Dict[str, List[str]]:
"chat": list(BUILTIN_LLM_MODEL_CHAT_FAMILIES),
"generate": list(BUILTIN_LLM_MODEL_GENERATE_FAMILIES),
"tools": list(BUILTIN_LLM_MODEL_TOOL_CALL_FAMILIES),
+ "code": list(BUILTIN_LLM_MODEL_CODE_FAMILIES),
}
async def get_devices_count(self) -> int:
diff --git a/xinference/core/tests/test_restful_api.py b/xinference/core/tests/test_restful_api.py
index cd47b98cc5..7bb44ef164 100644
--- a/xinference/core/tests/test_restful_api.py
+++ b/xinference/core/tests/test_restful_api.py
@@ -1239,3 +1239,192 @@ def test_launch_model_by_version(setup):
# delete again
url = f"{endpoint}/v1/models/test_qwen15"
requests.delete(url)
+
+
+@pytest.mark.skipif(bool(os.environ.get("GITHUB_ACTIONS")), reason="Skip windows")
+def test_cluster_info(setup):
+ endpoint, _ = setup
+ url = f"{endpoint}/v1/cluster/info"
+
+ response = requests.get(url)
+ assert response.status_code == 200
+ result = response.json()
+ assert isinstance(result, list)
+ assert len(result) == 2
+ assert result[0]["node_type"] == "Supervisor"
+ assert result[0]["gpu_count"] == 0
+ assert result[0]["gpu_vram_total"] == 0
+ assert result[1]["node_type"] == "Worker"
+ assert result[1]["gpu_count"] == 0
+ assert result[1]["gpu_vram_total"] == 0
+
+
+def test_restful_api_for_code_prompt(setup):
+ model_name = "deepseek-coder"
+
+ endpoint, _ = setup
+ url = f"{endpoint}/v1/models"
+
+ # list
+ response = requests.get(url)
+ response_data = response.json()
+ assert len(response_data["data"]) == 0
+
+ # launch
+ payload = {
+ "model_uid": "deepseek-coder",
+ "model_name": model_name,
+ "model_type": "LLM",
+ "model_engine": "llama.cpp",
+ "model_size_in_billions": "1_3",
+ "quantization": "q4_k_m",
+ }
+
+ response = requests.post(url, json=payload)
+ response_data = response.json()
+ model_uid_res = response_data["model_uid"]
+ assert model_uid_res == "deepseek-coder"
+
+ response = requests.get(url)
+ response_data = response.json()
+ assert len(response_data["data"]) == 1
+
+ # test embedding
+ url = f"{endpoint}/v1/code/prompt"
+ payload = {
+ "model": "deepseek-coder",
+ "prompt": "#write a quick sort algorithm",
+ }
+ response = requests.post(url, json=payload)
+ coding_res = response.json()
+
+ assert "prompt" in coding_res
+
+ assert "#write a quick sort algorithm" == coding_res["prompt"]
+
+ # test multiple
+ payload = {
+ "model": "deepseek-coder",
+ "mode": "infill",
+ "prompt": """def quick_sort(arr):
+ if len(arr) <= 1:
+ return arr
+ pivot = arr[0]
+ left = []
+ right = []
+""",
+ "suffix": """
+ if arr[i] < pivot:
+ left.append(arr[i])
+ else:
+ right.append(arr[i])
+ return quick_sort(left) + [pivot] + quick_sort(right)""",
+ }
+ response = requests.post(url, json=payload)
+ coding_res = response.json()
+
+ assert "prompt" in coding_res
+ assert (
+ coding_res["prompt"]
+ == """<|fim▁begin|>def quick_sort(arr):
+ if len(arr) <= 1:
+ return arr
+ pivot = arr[0]
+ left = []
+ right = []
+<|fim▁hole|>
+ if arr[i] < pivot:
+ left.append(arr[i])
+ else:
+ right.append(arr[i])
+ return quick_sort(left) + [pivot] + quick_sort(right)<|fim▁end|>"""
+ )
+
+ # delete model
+ url = f"{endpoint}/v1/models/deepseek-coder"
+ response = requests.delete(url)
+ assert response.status_code == 200
+
+ response = requests.get(f"{endpoint}/v1/models")
+ response_data = response.json()
+ assert len(response_data["data"]) == 0
+
+
+def test_restful_api_for_code_completions(setup):
+ model_name = "deepseek-coder"
+
+ endpoint, _ = setup
+ url = f"{endpoint}/v1/models"
+
+ # list
+ response = requests.get(url)
+ response_data = response.json()
+ assert len(response_data["data"]) == 0
+
+ # launch
+ payload = {
+ "model_uid": "deepseek-coder",
+ "model_name": model_name,
+ "model_type": "LLM",
+ "model_engine": "llama.cpp",
+ "model_size_in_billions": "1_3",
+ "quantization": "q4_k_m",
+ }
+
+ response = requests.post(url, json=payload)
+ response_data = response.json()
+ model_uid_res = response_data["model_uid"]
+ assert model_uid_res == "deepseek-coder"
+
+ response = requests.get(url)
+ response_data = response.json()
+ assert len(response_data["data"]) == 1
+
+ # test embedding
+ url = f"{endpoint}/v1/code/completions"
+ payload = {
+ "model": "deepseek-coder",
+ "prompt": "#write a quick sort algorithm",
+ "max_tokens": 4096,
+ }
+ response = requests.post(url, json=payload)
+ coding_res = response.json()
+
+ assert len(coding_res["choices"]) == 1
+ assert "text" in coding_res["choices"][0]
+ assert coding_res["choices"][0]["finish_reason"] == "stop"
+
+ # test multiple
+ payload = {
+ "model": "deepseek-coder",
+ "mode": "infill",
+ "prompt": """def quick_sort(arr):
+ if len(arr) <= 1:
+ return arr
+ pivot = arr[0]
+ left = []
+ right = []
+""",
+ "suffix": """
+ if arr[i] < pivot:
+ left.append(arr[i])
+ else:
+ right.append(arr[i])
+ return quick_sort(left) + [pivot] + quick_sort(right)""",
+ "max_tokens": 4096,
+ }
+ response = requests.post(url, json=payload)
+ coding_res = response.json()
+
+ assert len(coding_res["choices"]) == 1
+ assert "text" in coding_res["choices"][0]
+ assert coding_res["choices"][0]["finish_reason"] == "stop"
+
+ # delete model
+ url = f"{endpoint}/v1/models/deepseek-coder"
+ response = requests.delete(url)
+ assert response.status_code == 200
+
+ response = requests.get(f"{endpoint}/v1/models")
+ response_data = response.json()
+ assert len(response_data["data"]) == 0
diff --git a/xinference/model/llm/__init__.py b/xinference/model/llm/__init__.py
index 9909addebb..843302bd15 100644
--- a/xinference/model/llm/__init__.py
+++ b/xinference/model/llm/__init__.py
@@ -26,8 +26,10 @@
)
from .llm_family import (
BUILTIN_CSGHUB_LLM_FAMILIES,
+ BUILTIN_LLM_CODE_PROMPT_STYLE,
BUILTIN_LLM_FAMILIES,
BUILTIN_LLM_MODEL_CHAT_FAMILIES,
+ BUILTIN_LLM_MODEL_CODE_FAMILIES,
BUILTIN_LLM_MODEL_GENERATE_FAMILIES,
BUILTIN_LLM_MODEL_TOOL_CALL_FAMILIES,
BUILTIN_LLM_PROMPT_STYLE,
@@ -40,13 +42,16 @@
SUPPORTED_ENGINES,
TRANSFORMERS_CLASSES,
VLLM_CLASSES,
+ CodePromptStyleV1,
CustomLLMFamilyV1,
+ FIMSpecV1,
LlamaCppLLMSpecV1,
LLMFamilyV1,
LLMSpecV1,
MLXLLMSpecV1,
PromptStyleV1,
PytorchLLMSpecV1,
+ RepoLevelCodeCompletionSpecV1,
get_cache_status,
get_user_defined_llm_families,
match_llm,
@@ -113,14 +118,14 @@ def generate_engine_config_by_model_family(model_family):
def _install():
- from .llama_cpp.core import LlamaCppChatModel, LlamaCppModel
+ from .llama_cpp.core import LlamaCppChatModel, LlamaCppCodeModel, LlamaCppModel
from .lmdeploy.core import LMDeployChatModel, LMDeployModel
from .mlx.core import MLXChatModel, MLXModel
from .sglang.core import SGLANGChatModel, SGLANGModel
from .transformers.chatglm import ChatglmPytorchChatModel
from .transformers.cogvlm2 import CogVLM2Model
from .transformers.cogvlm2_video import CogVLM2VideoModel
- from .transformers.core import PytorchChatModel, PytorchModel
+ from .transformers.core import PytorchChatModel, PytorchCodeModel, PytorchModel
from .transformers.deepseek_vl import DeepSeekVLChatModel
from .transformers.glm4v import Glm4VModel
from .transformers.intern_vl import InternVLChatModel
@@ -130,7 +135,7 @@ def _install():
from .transformers.minicpmv26 import MiniCPMV26Model
from .transformers.qwen_vl import QwenVLChatModel
from .transformers.yi_vl import YiVLChatModel
- from .vllm.core import VLLMChatModel, VLLMModel, VLLMVisionModel
+ from .vllm.core import VLLMChatModel, VLLMCodeModel, VLLMModel, VLLMVisionModel
try:
from .transformers.omnilmm import OmniLMMModel
@@ -144,11 +149,12 @@ def _install():
LLAMA_CLASSES.extend(
[
LlamaCppChatModel,
+ LlamaCppCodeModel,
LlamaCppModel,
]
)
SGLANG_CLASSES.extend([SGLANGModel, SGLANGChatModel])
- VLLM_CLASSES.extend([VLLMModel, VLLMChatModel, VLLMVisionModel])
+ VLLM_CLASSES.extend([VLLMModel, VLLMChatModel, VLLMCodeModel, VLLMVisionModel])
MLX_CLASSES.extend([MLXModel, MLXChatModel])
LMDEPLOY_CLASSES.extend([LMDeployModel, LMDeployChatModel])
TRANSFORMERS_CLASSES.extend(
@@ -157,6 +163,7 @@ def _install():
LlamaPytorchModel,
LlamaPytorchChatModel,
PytorchChatModel,
+ PytorchCodeModel,
Internlm2PytorchChatModel,
QwenVLChatModel,
YiVLChatModel,
@@ -203,6 +210,16 @@ def _install():
if "tools" in model_spec.model_ability:
BUILTIN_LLM_MODEL_TOOL_CALL_FAMILIES.add(model_spec.model_name)
+ if "code" in model_spec.model_ability and isinstance(
+ model_spec.code_prompt_style, CodePromptStyleV1
+ ):
+ BUILTIN_LLM_CODE_PROMPT_STYLE[
+ model_spec.model_name
+ ] = model_spec.code_prompt_style
+
+ if "code" in model_spec.model_ability:
+ BUILTIN_LLM_MODEL_CODE_FAMILIES.add(model_spec.model_name)
+
modelscope_json_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "llm_family_modelscope.json"
)
@@ -225,6 +242,8 @@ def _install():
BUILTIN_LLM_MODEL_GENERATE_FAMILIES.add(model_spec.model_name)
if "tools" in model_spec.model_ability:
BUILTIN_LLM_MODEL_TOOL_CALL_FAMILIES.add(model_spec.model_name)
+ if "code" in model_spec.model_ability:
+ BUILTIN_LLM_MODEL_CODE_FAMILIES.add(model_spec.model_name)
csghub_json_path = os.path.join(
os.path.dirname(os.path.abspath(__file__)), "llm_family_csghub.json"
diff --git a/xinference/model/llm/lang_utils.py b/xinference/model/llm/lang_utils.py
new file mode 100644
index 0000000000..2007601bd9
--- /dev/null
+++ b/xinference/model/llm/lang_utils.py
@@ -0,0 +1,1131 @@
+# Copyright 2022-2023 XProbe Inc.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import logging
+import os.path
+
+from . import RepoLevelCodeCompletionSpecV1
+
+logger = logging.getLogger(__name__)
+
+# NOTICE: these code referenced by aixcoder
+
+# https://github.com/aixcoder-plugin/aiXcoder-7B/blob/main/hf_mini/utils.py
+
+LANGUAGE_WRAPPER = {
+ "c": "// ",
+ "c++": "// ",
+ "cpp": "// ",
+ "c#": "// ",
+ "csharp": "// ",
+ "c-sharp": "// ",
+ "css": "/* */",
+ "cuda": "// ",
+ "dart": "// ",
+ "lua": "// ",
+ "objectivec": "// ",
+ "objective-c": "// ",
+ "objective-c++": "// ",
+ "python": "# ",
+ "perl": "# ",
+ "prolog": "% ",
+ "swift": "// ",
+ "lisp": "; ",
+ "java": "// ",
+ "scala": "// ",
+ "tex": "% ",
+ "vue": "",
+ "markdown": "",
+ "html": "",
+ "php": "// ",
+ "js": "// ",
+ "javascript": "// ",
+ "typescript": "// ",
+ "go": "// ",
+ "shell": "# ",
+ "rust": "// ",
+ "sql": "-- ",
+ "kotlin": "// ",
+ "vb": "' ",
+ "ruby": "# ",
+ "pascal": "// ",
+ "r": "# ",
+ "fortran": "!",
+ "lean": "-- ",
+ "matlab": "% ",
+ "delphi": "{}",
+ "scheme": "; ",
+ "basic": "' ",
+ "assembly": "; ",
+ "groovy": "// ",
+ "abap": "* ",
+ "gdscript": "# ",
+ "haskell": "-- ",
+ "julia": "# ",
+ "elixir": "# ",
+ "excel": "' ",
+ "clojure": "; ",
+ "actionscript": "// ",
+ "solidity": "// ",
+ "powershell": "# ",
+ "erlang": "% ",
+ "cobol": "// ",
+ "alloy": "/* */",
+ "awk": "// ",
+ "thrift": "/* */",
+ "sparql": "# ",
+ "augeas": "// ",
+ "cmake": "# ",
+ "f-sharp": "// ",
+ "stan": "// ",
+ "isabelle": "(**)",
+ "dockerfile": "# ",
+ "rmarkdown": "# ",
+ "literate-agda": "-- ",
+ "tcl": "// ",
+ "glsl": "// ",
+ "antlr": "// ",
+ "verilog": "// ",
+ "racket": "; ",
+ "standard-ml": "(**)",
+ "elm": "-- ",
+ "yaml": "# ",
+ "smalltalk": "'' ",
+ "ocaml": "(**)",
+ "idris": "-- ",
+ "visual-basic": "' ",
+ "protocol-buffer": "// ",
+ "bluespec": "// ",
+ "applescript": "-- ",
+ "makefile": "# ",
+ "tcsh": "# ",
+ "maple": "# ",
+ "systemverilog": "// ",
+ "literate-coffeescript": "# ",
+ "vhdl": "-- ",
+ "restructuredtext": ".. ",
+ "sas": "* ",
+ "literate-haskell": "> ",
+ "java-server-pages": "// ",
+ "coffeescript": "# ",
+ "emacs-lisp": "; ",
+ "mathematica": "// ",
+ "xslt": "",
+ "zig": "// ",
+ "common-lisp": "; ",
+ "stata": "* ",
+ "agda": "-- ",
+ "ada": "-- ",
+ "jsx": "// ",
+ "tsx": "// ",
+}
+
+EXT2LANG = {
+ ".abap": "abap",
+ ".ash": "ags script",
+ ".ampl": "ampl",
+ ".g4": "antlr",
+ ".apib": "api blueprint",
+ ".apl": "apl",
+ ".dyalog": "apl",
+ ".asp": "asp",
+ ".asax": "asp",
+ ".ascx": "asp",
+ ".ashx": "asp",
+ ".asmx": "asp",
+ ".aspx": "asp",
+ ".axd": "asp",
+ ".dats": "ats",
+ ".hats": "ats",
+ ".sats": "ats",
+ ".as": "actionscript",
+ ".adb": "ada",
+ ".ada": "ada",
+ ".ads": "ada",
+ ".agda": "agda",
+ ".als": "alloy",
+ ".apacheconf": "apacheconf",
+ ".vhost": "apacheconf",
+ ".applescript": "applescript",
+ ".scpt": "applescript",
+ ".arc": "arc",
+ ".ino": "arduino",
+ ".asciidoc": "asciidoc",
+ ".adoc": "asciidoc",
+ ".aj": "aspectj",
+ ".asm": "assembly",
+ ".a51": "assembly",
+ ".nasm": "assembly",
+ ".aug": "augeas",
+ ".ahk": "autohotkey",
+ ".ahkl": "autohotkey",
+ ".au3": "autoit",
+ ".awk": "awk",
+ ".auk": "awk",
+ ".gawk": "awk",
+ ".mawk": "awk",
+ ".nawk": "awk",
+ ".bat": "batchfile",
+ ".cmd": "batchfile",
+ ".befunge": "befunge",
+ ".bison": "bison",
+ ".bb": "bitbake",
+ ".decls": "blitzbasic",
+ ".bmx": "blitzmax",
+ ".bsv": "bluespec",
+ ".boo": "boo",
+ ".bf": "brainfuck",
+ ".brs": "brightscript",
+ ".bro": "bro",
+ ".c": "c",
+ ".cats": "c",
+ ".h": "c++",
+ ".idc": "c",
+ ".w": "c",
+ ".cs": "c#",
+ ".cake": "c#",
+ ".cshtml": "c#",
+ ".csx": "c#",
+ ".cpp": "c++",
+ ".c++": "c++",
+ ".cc": "c++",
+ ".cp": "c++",
+ ".cxx": "c++",
+ ".h++": "c++",
+ ".hh": "c++",
+ ".hpp": "c++",
+ ".hxx": "c++",
+ ".inl": "c++",
+ ".ipp": "c++",
+ ".tcc": "c++",
+ ".tpp": "c++",
+ ".C": "c++",
+ ".H": "c++",
+ ".c-objdump": "c-objdump",
+ ".chs": "c2hs haskell",
+ ".clp": "clips",
+ ".cmake": "cmake",
+ ".cmake.in": "cmake",
+ ".cob": "cobol",
+ ".cbl": "cobol",
+ ".ccp": "cobol",
+ ".cobol": "cobol",
+ ".cpy": "cobol",
+ ".css": "css",
+ ".csv": "csv",
+ ".capnp": "cap'n proto",
+ ".mss": "cartocss",
+ ".ceylon": "ceylon",
+ ".chpl": "chapel",
+ ".ck": "chuck",
+ ".cirru": "cirru",
+ ".clw": "clarion",
+ ".icl": "clean",
+ ".dcl": "clean",
+ ".click": "click",
+ ".clj": "clojure",
+ ".boot": "clojure",
+ ".cl2": "clojure",
+ ".cljc": "clojure",
+ ".cljs": "clojure",
+ ".cljs.hl": "clojure",
+ ".cljscm": "clojure",
+ ".cljx": "clojure",
+ ".hic": "clojure",
+ ".coffee": "coffeescript",
+ "._coffee": "coffeescript",
+ ".cjsx": "coffeescript",
+ ".cson": "coffeescript",
+ ".iced": "coffeescript",
+ ".cfm": "coldfusion",
+ ".cfml": "coldfusion",
+ ".cfc": "coldfusion cfc",
+ ".lisp": "common lisp",
+ ".asd": "common lisp",
+ ".lsp": "common lisp",
+ ".ny": "common lisp",
+ ".podsl": "common lisp",
+ ".sexp": "common lisp",
+ ".cps": "component pascal",
+ ".coq": "coq",
+ ".cppobjdump": "cpp-objdump",
+ ".c++-objdump": "cpp-objdump",
+ ".c++objdump": "cpp-objdump",
+ ".cpp-objdump": "cpp-objdump",
+ ".cxx-objdump": "cpp-objdump",
+ ".creole": "creole",
+ ".cr": "crystal",
+ ".csd": "csound",
+ ".feature": "cucumber",
+ ".cu": "cuda",
+ ".cuh": "cuda",
+ ".cy": "cycript",
+ ".pyx": "cython",
+ ".pxd": "cython",
+ ".pxi": "cython",
+ ".di": "d",
+ ".d-objdump": "d-objdump",
+ ".com": "digital command language",
+ ".dm": "dm",
+ ".zone": "dns zone",
+ ".arpa": "dns zone",
+ ".darcspatch": "darcs patch",
+ ".dpatch": "darcs patch",
+ ".dart": "dart",
+ ".diff": "diff",
+ ".patch": "diff",
+ ".dockerfile": "dockerfile",
+ "Dockerfile": "dockerfile",
+ ".djs": "dogescript",
+ ".dylan": "dylan",
+ ".dyl": "dylan",
+ ".intr": "dylan",
+ ".lid": "dylan",
+ ".E": "e",
+ ".ecl": "ecl",
+ ".eclxml": "ecl",
+ ".sch": "eagle",
+ ".brd": "eagle",
+ ".epj": "ecere projects",
+ ".e": "eiffel",
+ ".ex": "elixir",
+ ".exs": "elixir",
+ ".elm": "elm",
+ ".el": "emacs lisp",
+ ".emacs": "emacs lisp",
+ ".emacs.desktop": "emacs lisp",
+ ".em": "emberscript",
+ ".emberscript": "emberscript",
+ ".erl": "erlang",
+ ".escript": "erlang",
+ ".hrl": "erlang",
+ ".xrl": "erlang",
+ ".yrl": "erlang",
+ ".fs": "f#",
+ ".fsi": "f#",
+ ".fsx": "f#",
+ ".flux": "flux",
+ ".f90": "fortran",
+ ".f": "fortran",
+ ".f03": "fortran",
+ ".f08": "fortran",
+ ".f77": "fortran",
+ ".f95": "fortran",
+ ".for": "fortran",
+ ".fpp": "fortran",
+ ".factor": "factor",
+ ".fy": "fancy",
+ ".fancypack": "fancy",
+ ".fan": "fantom",
+ ".eam.fs": "formatted",
+ ".fth": "forth",
+ ".4th": "forth",
+ ".forth": "forth",
+ ".frt": "forth",
+ ".ftl": "freemarker",
+ ".g": "g-code",
+ ".gco": "g-code",
+ ".gcode": "g-code",
+ ".gms": "gams",
+ ".gap": "gap",
+ ".gi": "gap",
+ ".s": "gas",
+ ".gd": "gdscript",
+ ".glsl": "glsl",
+ ".fp": "glsl",
+ ".frag": "glsl",
+ ".frg": "glsl",
+ ".fsh": "glsl",
+ ".fshader": "glsl",
+ ".geo": "glsl",
+ ".geom": "glsl",
+ ".glslv": "glsl",
+ ".gshader": "glsl",
+ ".shader": "glsl",
+ ".vert": "glsl",
+ ".vrx": "glsl",
+ ".vsh": "glsl",
+ ".vshader": "glsl",
+ ".kid": "genshi",
+ ".ebuild": "gentoo ebuild",
+ ".eclass": "gentoo eclass",
+ ".po": "gettext catalog",
+ ".pot": "gettext catalog",
+ ".glf": "glyph",
+ ".gp": "gnuplot",
+ ".gnu": "gnuplot",
+ ".gnuplot": "gnuplot",
+ ".plot": "gnuplot",
+ ".plt": "gnuplot",
+ ".go": "go",
+ ".golo": "golo",
+ ".gst": "gosu",
+ ".gsx": "gosu",
+ ".vark": "gosu",
+ ".grace": "grace",
+ ".gradle": "gradle",
+ ".gf": "grammatical framework",
+ ".graphql": "graphql",
+ ".dot": "graphviz (dot)",
+ ".gv": "graphviz (dot)",
+ ".man": "groff",
+ ".1": "groff",
+ ".1in": "groff",
+ ".1m": "groff",
+ ".1x": "groff",
+ ".2": "groff",
+ ".3": "groff",
+ ".3in": "groff",
+ ".3m": "groff",
+ ".3qt": "groff",
+ ".3x": "groff",
+ ".4": "groff",
+ ".5": "groff",
+ ".6": "groff",
+ ".7": "groff",
+ ".8": "groff",
+ ".9": "groff",
+ ".me": "groff",
+ ".rno": "groff",
+ ".roff": "groff",
+ ".groovy": "groovy",
+ ".grt": "groovy",
+ ".gtpl": "groovy",
+ ".gvy": "groovy",
+ ".gsp": "groovy server pages",
+ ".hcl": "hcl",
+ ".tf": "hcl",
+ ".hlsl": "hlsl",
+ ".fxh": "hlsl",
+ ".hlsli": "hlsl",
+ ".html": "html",
+ ".htm": "html",
+ ".html.hl": "html",
+ ".xht": "html",
+ ".xhtml": "html",
+ ".mustache": "html+django",
+ ".jinja": "html+django",
+ ".eex": "html+eex",
+ ".erb": "html+erb",
+ ".erb.deface": "html+erb",
+ ".phtml": "html+php",
+ ".http": "http",
+ ".haml": "haml",
+ ".haml.deface": "haml",
+ ".handlebars": "handlebars",
+ ".hbs": "handlebars",
+ ".hb": "harbour",
+ ".hs": "haskell",
+ ".hsc": "haskell",
+ ".hx": "haxe",
+ ".hxsl": "haxe",
+ ".hy": "hy",
+ ".dlm": "idl",
+ ".ipf": "igor pro",
+ ".ini": "ini",
+ ".cfg": "ini",
+ ".prefs": "ini",
+ ".properties": "ini",
+ ".irclog": "irc log",
+ ".weechatlog": "irc log",
+ ".idr": "idris",
+ ".lidr": "idris",
+ ".ni": "inform 7",
+ ".i7x": "inform 7",
+ ".iss": "inno setup",
+ ".io": "io",
+ ".ik": "ioke",
+ ".thy": "isabelle",
+ ".ijs": "j",
+ ".flex": "jflex",
+ ".jflex": "jflex",
+ ".json": "json",
+ ".geojson": "json",
+ ".lock": "json",
+ ".topojson": "json",
+ ".json5": "json5",
+ ".jsonld": "jsonld",
+ ".jq": "jsoniq",
+ ".jsx": "jsx",
+ ".jade": "jade",
+ ".j": "jasmin",
+ ".java": "java",
+ ".jsp": "java server pages",
+ ".js": "javascript",
+ "._js": "javascript",
+ ".bones": "javascript",
+ ".es6": "javascript",
+ ".jake": "javascript",
+ ".jsb": "javascript",
+ ".jscad": "javascript",
+ ".jsfl": "javascript",
+ ".jsm": "javascript",
+ ".jss": "javascript",
+ ".njs": "javascript",
+ ".pac": "javascript",
+ ".sjs": "javascript",
+ ".ssjs": "javascript",
+ ".xsjs": "javascript",
+ ".xsjslib": "javascript",
+ ".jl": "julia",
+ ".ipynb": "jupyter notebook",
+ ".krl": "krl",
+ ".kicad_pcb": "kicad",
+ ".kit": "kit",
+ ".kt": "kotlin",
+ ".ktm": "kotlin",
+ ".kts": "kotlin",
+ ".lfe": "lfe",
+ ".ll": "llvm",
+ ".lol": "lolcode",
+ ".lsl": "lsl",
+ ".lslp": "lsl",
+ ".lvproj": "labview",
+ ".lasso": "lasso",
+ ".las": "lasso",
+ ".lasso8": "lasso",
+ ".lasso9": "lasso",
+ ".ldml": "lasso",
+ ".latte": "latte",
+ ".lean": "lean",
+ ".hlean": "lean",
+ ".less": "less",
+ ".lex": "lex",
+ ".ly": "lilypond",
+ ".ily": "lilypond",
+ ".ld": "linker script",
+ ".lds": "linker script",
+ ".liquid": "liquid",
+ ".lagda": "literate agda",
+ ".litcoffee": "literate coffeescript",
+ ".lhs": "literate haskell",
+ ".ls": "livescript",
+ "._ls": "livescript",
+ ".xm": "logos",
+ ".x": "logos",
+ ".xi": "logos",
+ ".lgt": "logtalk",
+ ".logtalk": "logtalk",
+ ".lookml": "lookml",
+ ".lua": "lua",
+ ".nse": "lua",
+ ".pd_lua": "lua",
+ ".rbxs": "lua",
+ ".wlua": "lua",
+ ".mumps": "m",
+ ".m4": "m4",
+ ".mcr": "maxscript",
+ ".mtml": "mtml",
+ ".muf": "muf",
+ ".mak": "makefile",
+ ".mk": "makefile",
+ ".mkfile": "makefile",
+ "Makefile": "makefile",
+ ".mako": "mako",
+ ".mao": "mako",
+ ".mpl": "maple",
+ ".md": "markdown",
+ ".markdown": "markdown",
+ ".mkd": "markdown",
+ ".mkdn": "markdown",
+ ".mkdown": "markdown",
+ ".ron": "markdown",
+ ".mask": "mask",
+ ".mathematica": "mathematica",
+ ".cdf": "mathematica",
+ ".ma": "mathematica",
+ ".mt": "mathematica",
+ ".nb": "mathematica",
+ ".nbp": "mathematica",
+ ".wl": "mathematica",
+ ".wlt": "mathematica",
+ ".matlab": "matlab",
+ ".maxpat": "max",
+ ".maxhelp": "max",
+ ".maxproj": "max",
+ ".mxt": "max",
+ ".pat": "max",
+ ".mediawiki": "mediawiki",
+ ".wiki": "mediawiki",
+ ".metal": "metal",
+ ".minid": "minid",
+ ".druby": "mirah",
+ ".duby": "mirah",
+ ".mir": "mirah",
+ ".mirah": "mirah",
+ ".mo": "modelica",
+ ".mms": "module management system",
+ ".mmk": "module management system",
+ ".monkey": "monkey",
+ ".moon": "moonscript",
+ ".myt": "myghty",
+ ".nsi": "nsis",
+ ".nsh": "nsis",
+ ".axs": "netlinx",
+ ".axi": "netlinx",
+ ".axs.erb": "netlinx+erb",
+ ".axi.erb": "netlinx+erb",
+ ".nlogo": "netlogo",
+ ".nginxconf": "nginx",
+ ".nim": "nimrod",
+ ".nimrod": "nimrod",
+ ".ninja": "ninja",
+ ".nit": "nit",
+ ".nix": "nix",
+ ".nu": "nu",
+ ".numpy": "numpy",
+ ".numpyw": "numpy",
+ ".numsc": "numpy",
+ ".ml": "ocaml",
+ ".eliom": "ocaml",
+ ".eliomi": "ocaml",
+ ".ml4": "ocaml",
+ ".mli": "ocaml",
+ ".mll": "ocaml",
+ ".mly": "ocaml",
+ ".objdump": "objdump",
+ ".mm": "objective-c++",
+ ".sj": "objective-j",
+ ".oct": "octave",
+ ".omgrofl": "omgrofl",
+ ".opa": "opa",
+ ".opal": "opal",
+ ".cl": "opencl",
+ ".opencl": "opencl",
+ ".p": "openedge abl",
+ ".scad": "openscad",
+ ".org": "org",
+ ".ox": "ox",
+ ".oxh": "ox",
+ ".oxo": "ox",
+ ".oxygene": "oxygene",
+ ".oz": "oz",
+ ".pwn": "pawn",
+ ".php": "php",
+ ".aw": "php",
+ ".ctp": "php",
+ ".php3": "php",
+ ".php4": "php",
+ ".php5": "php",
+ ".phps": "php",
+ ".phpt": "php",
+ ".pov": "pov-ray sdl",
+ ".pan": "pan",
+ ".psc": "papyrus",
+ ".parrot": "parrot",
+ ".pasm": "parrot assembly",
+ ".pir": "parrot internal representation",
+ ".pas": "pascal",
+ ".dfm": "pascal",
+ ".dpr": "pascal",
+ ".lpr": "pascal",
+ ".pl": "perl",
+ ".al": "perl",
+ ".perl": "perl",
+ ".ph": "perl",
+ ".plx": "perl",
+ ".pm": "perl",
+ ".psgi": "perl",
+ ".t": "perl",
+ ".6pl": "perl6",
+ ".6pm": "perl6",
+ ".nqp": "perl6",
+ ".p6": "perl6",
+ ".p6l": "perl6",
+ ".p6m": "perl6",
+ ".pl6": "perl6",
+ ".pm6": "perl6",
+ ".pkl": "pickle",
+ ".pig": "piglatin",
+ ".pike": "pike",
+ ".pmod": "pike",
+ ".pod": "pod",
+ ".pogo": "pogoscript",
+ ".pony": "pony",
+ ".ps": "postscript",
+ ".eps": "postscript",
+ ".ps1": "powershell",
+ ".psd1": "powershell",
+ ".psm1": "powershell",
+ ".pde": "processing",
+ ".prolog": "prolog",
+ ".yap": "prolog",
+ ".spin": "propeller spin",
+ ".proto": "protocol buffer",
+ ".pub": "public key",
+ ".pd": "pure data",
+ ".pb": "purebasic",
+ ".pbi": "purebasic",
+ ".purs": "purescript",
+ ".py": "python",
+ ".bzl": "python",
+ ".gyp": "python",
+ ".lmi": "python",
+ ".pyde": "python",
+ ".pyp": "python",
+ ".pyt": "python",
+ ".pyw": "python",
+ ".tac": "python",
+ ".wsgi": "python",
+ ".xpy": "python",
+ ".pytb": "python traceback",
+ ".qml": "qml",
+ ".qbs": "qml",
+ ".pri": "qmake",
+ ".r": "r",
+ ".rd": "r",
+ ".rsx": "r",
+ ".raml": "raml",
+ ".rdoc": "rdoc",
+ ".rbbas": "realbasic",
+ ".rbfrm": "realbasic",
+ ".rbmnu": "realbasic",
+ ".rbres": "realbasic",
+ ".rbtbar": "realbasic",
+ ".rbuistate": "realbasic",
+ ".rhtml": "rhtml",
+ ".rmd": "rmarkdown",
+ ".rkt": "racket",
+ ".rktd": "racket",
+ ".rktl": "racket",
+ ".scrbl": "racket",
+ ".rl": "ragel in ruby host",
+ ".raw": "raw token data",
+ ".reb": "rebol",
+ ".r2": "rebol",
+ ".r3": "rebol",
+ ".rebol": "rebol",
+ ".red": "red",
+ ".reds": "red",
+ ".cw": "redcode",
+ ".rpy": "ren'py",
+ ".rsh": "renderscript",
+ ".robot": "robotframework",
+ ".rg": "rouge",
+ ".rb": "ruby",
+ ".builder": "ruby",
+ ".gemspec": "ruby",
+ ".god": "ruby",
+ ".irbrc": "ruby",
+ ".jbuilder": "ruby",
+ ".mspec": "ruby",
+ ".podspec": "ruby",
+ ".rabl": "ruby",
+ ".rake": "ruby",
+ ".rbuild": "ruby",
+ ".rbw": "ruby",
+ ".rbx": "ruby",
+ ".ru": "ruby",
+ ".ruby": "ruby",
+ ".thor": "ruby",
+ ".watchr": "ruby",
+ ".rs": "rust",
+ ".rs.in": "rust",
+ ".sas": "sas",
+ ".scss": "scss",
+ ".smt2": "smt",
+ ".smt": "smt",
+ ".sparql": "sparql",
+ ".rq": "sparql",
+ ".sqf": "sqf",
+ ".hqf": "sqf",
+ ".pls": "sql",
+ ".pck": "sql",
+ ".pkb": "sql",
+ ".pks": "sql",
+ ".plb": "sql",
+ ".plsql": "sql",
+ ".sql": "sql",
+ ".cql": "sql",
+ ".ddl": "sql",
+ ".prc": "sql",
+ ".tab": "sql",
+ ".udf": "sql",
+ ".viw": "sql",
+ ".db2": "sql",
+ ".ston": "ston",
+ ".svg": "svg",
+ ".sage": "sage",
+ ".sagews": "sage",
+ ".sls": "saltstack",
+ ".sass": "sass",
+ ".scala": "scala",
+ ".sbt": "scala",
+ ".scaml": "scaml",
+ ".scm": "scheme",
+ ".sld": "scheme",
+ ".sps": "scheme",
+ ".ss": "scheme",
+ ".sci": "scilab",
+ ".sce": "scilab",
+ ".self": "self",
+ ".sh": "shell",
+ ".bash": "shell",
+ ".bats": "shell",
+ ".command": "shell",
+ ".ksh": "shell",
+ ".sh.in": "shell",
+ ".tmux": "shell",
+ ".tool": "shell",
+ ".zsh": "shell",
+ ".sh-session": "shellsession",
+ ".shen": "shen",
+ ".sl": "slash",
+ ".slim": "slim",
+ ".smali": "smali",
+ ".st": "smalltalk",
+ ".tpl": "smarty",
+ ".sol": "solidity",
+ ".sp": "sourcepawn",
+ ".sma": "sourcepawn",
+ ".nut": "squirrel",
+ ".stan": "stan",
+ ".ML": "standard ml",
+ ".fun": "standard ml",
+ ".sig": "standard ml",
+ ".sml": "standard ml",
+ ".do": "stata",
+ ".ado": "stata",
+ ".doh": "stata",
+ ".ihlp": "stata",
+ ".mata": "stata",
+ ".matah": "stata",
+ ".sthlp": "stata",
+ ".styl": "stylus",
+ ".scd": "supercollider",
+ ".swift": "swift",
+ ".sv": "systemverilog",
+ ".svh": "systemverilog",
+ ".vh": "systemverilog",
+ ".toml": "toml",
+ ".txl": "txl",
+ ".tcl": "tcl",
+ ".adp": "tcl",
+ ".tm": "tcl",
+ ".tcsh": "tcsh",
+ ".csh": "tcsh",
+ ".tex": "tex",
+ ".aux": "tex",
+ ".bbx": "tex",
+ ".bib": "tex",
+ ".cbx": "tex",
+ ".dtx": "tex",
+ ".ins": "tex",
+ ".lbx": "tex",
+ ".ltx": "tex",
+ ".mkii": "tex",
+ ".mkiv": "tex",
+ ".mkvi": "tex",
+ ".sty": "tex",
+ ".toc": "tex",
+ ".tea": "tea",
+ ".txt": "text",
+ ".no": "text",
+ ".textile": "textile",
+ ".thrift": "thrift",
+ ".tu": "turing",
+ ".ttl": "turtle",
+ ".twig": "twig",
+ ".ts": "typescript",
+ ".tsx": "tsx",
+ ".upc": "unified parallel c",
+ ".anim": "unity3d asset",
+ ".asset": "unity3d asset",
+ ".mat": "unity3d asset",
+ ".meta": "unity3d asset",
+ ".prefab": "unity3d asset",
+ ".unity": "unity3d asset",
+ ".uno": "uno",
+ ".uc": "unrealscript",
+ ".ur": "urweb",
+ ".urs": "urweb",
+ ".vcl": "vcl",
+ ".vhdl": "vhdl",
+ ".vhd": "vhdl",
+ ".vhf": "vhdl",
+ ".vhi": "vhdl",
+ ".vho": "vhdl",
+ ".vhs": "vhdl",
+ ".vht": "vhdl",
+ ".vhw": "vhdl",
+ ".vala": "vala",
+ ".vapi": "vala",
+ ".veo": "verilog",
+ ".vim": "viml",
+ ".vb": "visual basic",
+ ".bas": "visual basic",
+ ".frm": "visual basic",
+ ".frx": "visual basic",
+ ".vba": "visual basic",
+ ".vbhtml": "visual basic",
+ ".vbs": "visual basic",
+ ".volt": "volt",
+ ".vue": "vue",
+ ".owl": "web ontology language",
+ ".wat": "webassembly",
+ ".webidl": "webidl",
+ ".x10": "x10",
+ ".xc": "xc",
+ ".xml": "xml",
+ ".ant": "xml",
+ ".axml": "xml",
+ ".ccxml": "xml",
+ ".clixml": "xml",
+ ".cproject": "xml",
+ ".csl": "xml",
+ ".csproj": "xml",
+ ".ct": "xml",
+ ".dita": "xml",
+ ".ditamap": "xml",
+ ".ditaval": "xml",
+ ".dll.config": "xml",
+ ".dotsettings": "xml",
+ ".filters": "xml",
+ ".fsproj": "xml",
+ ".fxml": "xml",
+ ".glade": "xml",
+ ".grxml": "xml",
+ ".iml": "xml",
+ ".ivy": "xml",
+ ".jelly": "xml",
+ ".jsproj": "xml",
+ ".kml": "xml",
+ ".launch": "xml",
+ ".mdpolicy": "xml",
+ ".mxml": "xml",
+ ".nproj": "xml",
+ ".nuspec": "xml",
+ ".odd": "xml",
+ ".osm": "xml",
+ ".plist": "xml",
+ ".props": "xml",
+ ".ps1xml": "xml",
+ ".psc1": "xml",
+ ".pt": "xml",
+ ".rdf": "xml",
+ ".rss": "xml",
+ ".scxml": "xml",
+ ".srdf": "xml",
+ ".storyboard": "xml",
+ ".stTheme": "xml",
+ ".sublime-snippet": "xml",
+ ".targets": "xml",
+ ".tmCommand": "xml",
+ ".tml": "xml",
+ ".tmLanguage": "xml",
+ ".tmPreferences": "xml",
+ ".tmSnippet": "xml",
+ ".tmTheme": "xml",
+ ".ui": "xml",
+ ".urdf": "xml",
+ ".ux": "xml",
+ ".vbproj": "xml",
+ ".vcxproj": "xml",
+ ".vssettings": "xml",
+ ".vxml": "xml",
+ ".wsdl": "xml",
+ ".wsf": "xml",
+ ".wxi": "xml",
+ ".wxl": "xml",
+ ".wxs": "xml",
+ ".x3d": "xml",
+ ".xacro": "xml",
+ ".xaml": "xml",
+ ".xib": "xml",
+ ".xlf": "xml",
+ ".xliff": "xml",
+ ".xmi": "xml",
+ ".xml.dist": "xml",
+ ".xproj": "xml",
+ ".xsd": "xml",
+ ".xul": "xml",
+ ".zcml": "xml",
+ ".xsp-config": "xpages",
+ ".xsp.metadata": "xpages",
+ ".xpl": "xproc",
+ ".xproc": "xproc",
+ ".xquery": "xquery",
+ ".xq": "xquery",
+ ".xql": "xquery",
+ ".xqm": "xquery",
+ ".xqy": "xquery",
+ ".xs": "xs",
+ ".xslt": "xslt",
+ ".xsl": "xslt",
+ ".xojo_code": "xojo",
+ ".xojo_menu": "xojo",
+ ".xojo_report": "xojo",
+ ".xojo_script": "xojo",
+ ".xojo_toolbar": "xojo",
+ ".xojo_window": "xojo",
+ ".xtend": "xtend",
+ ".yml": "yaml",
+ ".reek": "yaml",
+ ".rviz": "yaml",
+ ".sublime-syntax": "yaml",
+ ".syntax": "yaml",
+ ".yaml": "yaml",
+ ".yaml-tmlanguage": "yaml",
+ ".yang": "yang",
+ ".y": "yacc",
+ ".yacc": "yacc",
+ ".yy": "yacc",
+ ".zep": "zephir",
+ ".zig": "zig",
+ ".zimpl": "zimpl",
+ ".zmpl": "zimpl",
+ ".zpl": "zimpl",
+ ".desktop": "desktop",
+ ".desktop.in": "desktop",
+ ".ec": "ec",
+ ".eh": "ec",
+ ".edn": "edn",
+ ".fish": "fish",
+ ".mu": "mupad",
+ ".nc": "nesc",
+ ".ooc": "ooc",
+ ".rst": "restructuredtext",
+ ".rest": "restructuredtext",
+ ".rest.txt": "restructuredtext",
+ ".rst.txt": "restructuredtext",
+ ".wisp": "wisp",
+ ".prg": "xbase",
+ ".prw": "xbase",
+}
+
+
+LANGUAGE_TAG = {
+ "c": "// the code file is written by C",
+ "c++": "// the code file is written by C++",
+ "cpp": "// the code file is written by C++",
+ "c#": "// the code file is written by C#",
+ "csharp": "// the code file is written by C#",
+ "c-sharp": "// the code file is written by C#",
+ "css": "/* the code file is written by CSS */",
+ "cuda": "// the code file is written by Cuda",
+ "dart": "// the code file is written by Dart",
+ "lua": "// the code file is written by Lua",
+ "objectivec": "// the code file is written by Objective-C",
+ "objective-c": "// the code file is written by Objective-C",
+ "objective-c++": "// the code file is written by Objective-C++",
+ "python": "# the code file is written by Python",
+ "perl": "# the code file is written by Perl",
+ "prolog": "% the code file is written by Prolog",
+ "swift": "// the code file is written by swift",
+ "lisp": "; the code file is written by Lisp",
+ "java": "// the code file is written by Java",
+ "scala": "// the code file is written by Scala",
+ "tex": "% the code file is written by TeX",
+ "vue": "",
+ "markdown": "",
+ "html": "",
+ "php": "// the code file is written by PHP",
+ "js": "// the code file is written by JavaScript",
+ "javascript": "// the code file is written by JavaScript",
+ "typescript": "// the code file is written by TypeScript",
+ "go": "// the code file is written by Go",
+ "shell": "# the code file is written by Shell",
+ "rust": "// the code file is written by Rust",
+ "sql": "-- the code file is written by SQL",
+ "kotlin": "// the code file is written by Kotlin",
+ "vb": "' the code file is written by Visual Basic",
+ "ruby": "# the code file is written by Ruby",
+ "pascal": "// the code file is written by Pascal",
+ "r": "# the code file is written by R",
+ "fortran": "!the code file is written by Fortran",
+ "lean": "-- the code file is written by Lean",
+ "matlab": "% the code file is written by Matlab",
+ "delphi": "{the code file is written by Delphi}",
+ "scheme": "; the code file is written by Scheme",
+ "basic": "' the code file is written by Basic",
+ "assembly": "; the code file is written by Assembly",
+ "groovy": "// the code file is written by Groovy",
+ "abap": "* the code file is written by Abap",
+ "gdscript": "# the code file is written by GDScript",
+ "haskell": "-- the code file is written by Haskell",
+ "julia": "# the code file is written by Julia",
+ "elixir": "# the code file is written by Elixir",
+ "excel": "' the code file is written by Excel",
+ "clojure": "; the code file is written by Clojure",
+ "actionscript": "// the code file is written by ActionScript",
+ "solidity": "// the code file is written by Solidity",
+ "powershell": "# the code file is written by PowerShell",
+ "erlang": "% the code file is written by Erlang",
+ "cobol": "// the code file is written by Cobol",
+ "alloy": "/* the code file is written by Alloy */",
+ "awk": "// the code file is written by AWK",
+ "thrift": "/* the code file is written by Thrift */",
+ "sparql": "# the code file is written by SPARQL",
+ "augeas": "// the code file is written by Augeas",
+ "cmake": "# the code file is written by CMake",
+ "f-sharp": "// the code file is written by F#",
+ "stan": "// the code file is written by Stan",
+ "isabelle": "(*the code file is written by Isabelle*)",
+ "dockerfile": "# the code file is written by Dockerfile",
+ "rmarkdown": "# the code file is written by RMarkdown",
+ "literate-agda": "-- the code file is written by Literate Agda",
+ "tcl": "// the code file is written by Augeas",
+ "glsl": "// the code file is written by GLSL",
+ "antlr": "// the code file is written by ANTLR",
+ "verilog": "// the code file is written by Verilog",
+ "racket": "; the code file is written by Racket",
+ "standard-ml": "(*the code file is written byStandard ML*)",
+ "elm": "-- the code file is written by Elm",
+ "yaml": "# the code file is written by YAML",
+ "smalltalk": "'' the code file is written by Smalltalk",
+ "ocaml": "(*the code file is written by OCaml*)",
+ "idris": "-- the code file is written by Idris",
+ "visual-basic": "' the code file is written by Visual Basic",
+ "protocol-buffer": "// the code file is written by Protocol Buffer",
+ "bluespec": "// the code file is written by Bluespec",
+ "applescript": "-- the code file is written by AppleScript",
+ "makefile": "# the code file is written by Makefile",
+ "tcsh": "# the code file is written by TCSH",
+ "maple": "# the code file is written by Maple",
+ "systemverilog": "// the code file is written by SystemVerilog",
+ "literate-coffeescript": "# the code file is written by Literate CoffeeScript",
+ "vhdl": "-- the code file is written by VHDL",
+ "restructuredtext": ".. the code file is written by reStructuredText",
+ "sas": "* the code file is written by SAS",
+ "literate-haskell": "> the code file is written by Literate Haskell",
+ "java-server-pages": "// the code file is written by Java Server Pages",
+ "coffeescript": "# the code file is written by CoffeeScript",
+ "emacs-lisp": "; the code file is written by Emacs Lisp",
+ "mathematica": "// the code file is written by Mathematica",
+ "xslt": "",
+ "zig": "// the code file is written by Zig",
+ "common-lisp": "; the code file is written by Common Lisp",
+ "stata": "* the code file is written by Stata",
+ "agda": "-- the code file is written by Agda",
+ "ada": "-- the code file is written by Ada",
+ "jsx": "// the code file is written by JSX",
+ "tsx": "// the code file is written by TypeScript with JSX",
+}
+
+
+def get_file_separator(
+ repo_level_spec: RepoLevelCodeCompletionSpecV1, filename: str
+) -> str:
+ if repo_level_spec.file_separator == "":
+ lang = EXT2LANG.get(os.path.splitext(filename)[-1].lower(), None)
+ sep = "// "
+ if lang is None:
+ logger.warning(
+ f"Unsupported file extension for {filename}, use '//' as file separator as fallback"
+ )
+ else:
+ sep = LANGUAGE_WRAPPER[lang]
+
+ return sep.replace("", filename)
+ else:
+ return repo_level_spec.file_separator + filename
diff --git a/xinference/model/llm/llama_cpp/core.py b/xinference/model/llm/llama_cpp/core.py
index b820fce466..ddca1479d0 100644
--- a/xinference/model/llm/llama_cpp/core.py
+++ b/xinference/model/llm/llama_cpp/core.py
@@ -14,12 +14,13 @@
import logging
import os
import time
-from typing import Iterable, Iterator, List, Optional, Union
+from typing import Iterable, Iterator, List, Mapping, Optional, Union
from ....types import (
ChatCompletion,
ChatCompletionChunk,
ChatCompletionMessage,
+ CodeGenerateMode,
Completion,
CompletionChunk,
CompletionUsage,
@@ -29,7 +30,7 @@
)
from ..core import LLM
from ..llm_family import LLMFamilyV1, LLMSpecV1
-from ..utils import QWEN_TOOL_CALL_FAMILY, ChatModelMixin
+from ..utils import QWEN_TOOL_CALL_FAMILY, ChatModelMixin, CodeModelMixin
logger = logging.getLogger(__name__)
@@ -309,3 +310,57 @@ def chat(
self.model_family, self.model_uid, c, tools
)
return self._to_chat_completion(c)
+
+
+class LlamaCppCodeModel(LlamaCppModel, CodeModelMixin):
+ def __init__(
+ self,
+ model_uid: str,
+ model_family: "LLMFamilyV1",
+ model_spec: "LLMSpecV1",
+ quantization: str,
+ model_path: str,
+ llamacpp_model_config: Optional[LlamaCppModelConfig] = None,
+ ):
+ super().__init__(
+ model_uid,
+ model_family,
+ model_spec,
+ quantization,
+ model_path,
+ llamacpp_model_config,
+ )
+
+ @classmethod
+ def match(
+ cls, llm_family: LLMFamilyV1, llm_spec: LLMSpecV1, quantization: str
+ ) -> bool:
+ if llm_spec.model_format not in ["ggmlv3", "ggufv2"]:
+ return False
+ if "chatglm" in llm_family.model_name:
+ return False
+ return "code" in llm_family.model_ability
+
+ def code_generate(
+ self,
+ generate_model: CodeGenerateMode,
+ prompt: str,
+ file_path: Optional[str],
+ suffix: Optional[str],
+ repo_name: Optional[str],
+ files: Optional[Mapping[str, str]],
+ generate_config: Optional[LlamaCppGenerateConfig] = None,
+ ):
+ code_prompt = self.get_code_prompt(
+ generate_model,
+ prompt,
+ file_path,
+ suffix,
+ repo_name,
+ files,
+ )["prompt"]
+
+ if generate_config is not None and generate_config.get("stream", False):
+ generate_config["stream"] = False
+
+ return self.generate(code_prompt, generate_config)
diff --git a/xinference/model/llm/llm_family.json b/xinference/model/llm/llm_family.json
index 26f1d599a8..d8e9843049 100644
--- a/xinference/model/llm/llm_family.json
+++ b/xinference/model/llm/llm_family.json
@@ -2098,31 +2098,6 @@
]
}
},
- {
- "version": 1,
- "context_length": 65536,
- "model_name": "codeqwen1.5",
- "model_lang": [
- "en",
- "zh"
- ],
- "model_ability": [
- "generate"
- ],
- "model_description": "CodeQwen1.5 is the Code-Specific version of Qwen1.5. It is a transformer-based decoder-only language model pretrained on a large amount of data of codes.",
- "model_specs": [
- {
- "model_format": "pytorch",
- "model_size_in_billions": 7,
- "quantizations": [
- "4-bit",
- "8-bit",
- "none"
- ],
- "model_id": "Qwen/CodeQwen1.5-7B"
- }
- ]
- },
{
"version": 1,
"context_length": 65536,
@@ -2580,6 +2555,86 @@
]
}
},
+ {
+ "version": 1,
+ "context_length": 65536,
+ "model_name": "codeqwen1.5",
+ "model_lang": [
+ "en",
+ "zh"
+ ],
+ "model_ability": [
+ "code"
+ ],
+ "model_description": "CodeQwen1.5 is the Code-Specific version of Qwen1.5. It is a transformer-based decoder-only language model pretrained on a large amount of data of codes.",
+ "model_specs": [
+ {
+ "model_format": "pytorch",
+ "model_size_in_billions": 7,
+ "quantizations": [
+ "4-bit",
+ "8-bit",
+ "none"
+ ],
+ "model_id": "Qwen/CodeQwen1.5-7B"
+ },
+ {
+ "model_format": "awq",
+ "model_size_in_billions": 7,
+ "quantizations": [
+ "Int4"
+ ],
+ "model_id": "Qwen/CodeQwen1.5-7B-AWQ"
+ }
+ ],
+ "code_prompt_style": {
+ "style_name": "CODEQWEN",
+ "fim_spec": {
+ "style": "PSM",
+ "prefix": "",
+ "middle": "",
+ "suffix": ""
+ },
+ "repo_level_spec": {
+ "repo_name": "",
+ "file_type": "filepath",
+ "file_separator": ""
+ }
+ }
+ },
+ {
+ "version": 1,
+ "context_length": 8192,
+ "model_name": "starcoder",
+ "model_lang": [
+ "en"
+ ],
+ "model_ability": [
+ "generate",
+ "code"
+ ],
+ "model_description": "Starcoder is an open-source Transformer based LLM that is trained on permissively licensed data from GitHub.",
+ "model_specs": [
+ {
+ "model_format": "ggufv2",
+ "model_size_in_billions": 16,
+ "quantizations": [
+ "q5_k_m"
+ ],
+ "model_id": "osukhoroslov-hw/starcoder-Q5_K_M-GGUF",
+ "model_file_name_template": "starcoder-{quantization}.gguf"
+ }
+ ],
+ "code_prompt_style": {
+ "style_name": "STARCODER",
+ "fim_spec": {
+ "style": "PSM",
+ "prefix": "",
+ "middle": "",
+ "suffix": ""
+ }
+ }
+ },
{
"version": 1,
"context_length": 1024,
@@ -5149,7 +5204,8 @@
"zh"
],
"model_ability": [
- "generate"
+ "generate",
+ "code"
],
"model_description": "Deepseek Coder is composed of a series of code language models, each trained from scratch on 2T tokens, with a composition of 87% code and 13% natural language in both English and Chinese. ",
"model_specs": [
@@ -5330,7 +5386,20 @@
"model_id": "TheBloke/deepseek-coder-33B-base-AWQ",
"model_revision": "c7edb2d5868d61a5dcf2591933a8992c8cbe3ef4"
}
- ]
+ ],
+ "code_prompt_style": {
+ "style_name": "DEEPSEEK_CODER",
+ "fim_spec": {
+ "style": "PMS",
+ "prefix": "<|fim▁begin|>",
+ "middle": "<|fim▁hole|>",
+ "suffix": "<|fim▁end|>"
+ },
+ "repo_level_spec": {
+ "file_type": "filename",
+ "file_separator": ""
+ }
+ }
},
{
"version": 1,
diff --git a/xinference/model/llm/llm_family.py b/xinference/model/llm/llm_family.py
index c2ea4d7b98..feaef4f83f 100644
--- a/xinference/model/llm/llm_family.py
+++ b/xinference/model/llm/llm_family.py
@@ -53,9 +53,11 @@
DEFAULT_CONTEXT_LENGTH = 2048
BUILTIN_LLM_PROMPT_STYLE: Dict[str, "PromptStyleV1"] = {}
+BUILTIN_LLM_CODE_PROMPT_STYLE: Dict[str, "CodePromptStyleV1"] = {}
BUILTIN_LLM_MODEL_CHAT_FAMILIES: Set[str] = set()
BUILTIN_LLM_MODEL_GENERATE_FAMILIES: Set[str] = set()
BUILTIN_LLM_MODEL_TOOL_CALL_FAMILIES: Set[str] = set()
+BUILTIN_LLM_MODEL_CODE_FAMILIES: Set[str] = set()
class LlamaCppLLMSpecV1(BaseModel):
@@ -137,17 +139,37 @@ class PromptStyleV1(BaseModel):
stop_token_ids: Optional[List[int]]
+class FIMSpecV1(BaseModel):
+ style: Literal["PSM", "PMS"]
+ prefix: str
+ middle: str
+ suffix: str
+
+
+class RepoLevelCodeCompletionSpecV1(BaseModel):
+ repo_name: Optional[str]
+ file_type: Literal["filepath", "filename"]
+ file_separator: str
+
+
+class CodePromptStyleV1(BaseModel):
+ style_name: str
+ fim_spec: Optional["FIMSpecV1"]
+ repo_level_spec: Optional["RepoLevelCodeCompletionSpecV1"]
+
+
class LLMFamilyV1(BaseModel):
version: Literal[1]
context_length: Optional[int] = DEFAULT_CONTEXT_LENGTH
model_name: str
model_lang: List[str]
- model_ability: List[Literal["embed", "generate", "chat", "tools", "vision"]]
+ model_ability: List[Literal["embed", "generate", "chat", "tools", "vision", "code"]]
model_description: Optional[str]
# reason for not required str here: legacy registration
model_family: Optional[str]
model_specs: List["LLMSpecV1"]
prompt_style: Optional["PromptStyleV1"]
+ code_prompt_style: Optional["CodePromptStyleV1"]
class CustomLLMFamilyV1(LLMFamilyV1):
diff --git a/xinference/model/llm/llm_family_modelscope.json b/xinference/model/llm/llm_family_modelscope.json
index 44ac3e7794..3fc23585ac 100644
--- a/xinference/model/llm/llm_family_modelscope.json
+++ b/xinference/model/llm/llm_family_modelscope.json
@@ -2907,32 +2907,6 @@
]
}
},
- {
- "version": 1,
- "context_length": 65536,
- "model_name": "codeqwen1.5",
- "model_lang": [
- "en",
- "zh"
- ],
- "model_ability": [
- "generate"
- ],
- "model_description": "CodeQwen1.5 is the Code-Specific version of Qwen1.5. It is a transformer-based decoder-only language model pretrained on a large amount of data of codes.",
- "model_specs": [
- {
- "model_format": "pytorch",
- "model_size_in_billions": 7,
- "quantizations": [
- "4-bit",
- "8-bit",
- "none"
- ],
- "model_id": "qwen/CodeQwen1.5-7B",
- "model_hub": "modelscope"
- }
- ]
- },
{
"version": 1,
"context_length": 65536,
@@ -3004,6 +2978,66 @@
]
}
},
+ {
+ "version": 1,
+ "context_length": 65536,
+ "model_name": "codeqwen1.5",
+ "model_lang": [
+ "en",
+ "zh"
+ ],
+ "model_ability": [
+ "generate",
+ "code"
+ ],
+ "model_description": "CodeQwen1.5 is the Code-Specific version of Qwen1.5. It is a transformer-based decoder-only language model pretrained on a large amount of data of codes.",
+ "model_specs": [
+ {
+ "model_format": "pytorch",
+ "model_size_in_billions": 7,
+ "quantizations": [
+ "4-bit",
+ "8-bit",
+ "none"
+ ],
+ "model_id": "qwen/CodeQwen1.5-7B",
+ "model_hub": "modelscope"
+ }
+ ],
+ "prompt_style": {
+ "style_name": "QWEN",
+ "system_prompt": "You are a helpful assistant.",
+ "roles": [
+ "user",
+ "assistant"
+ ],
+ "intra_message_sep": "\n",
+ "stop_token_ids": [
+ 2,
+ 3,
+ 4
+ ],
+ "stop": [
+ "<|endoftext|>",
+ "<|im_start|>",
+ "<|im_end|>"
+ ]
+ },
+ "code_prompt_style": {
+ "style_name": "CODEQWEN",
+ "fim_spec": {
+ "style": "PSM",
+ "prefix": "",
+ "middle": "",
+ "suffix": ""
+ },
+ "repo_level_spec": {
+ "repo_name": "",
+ "file_type": "filepath",
+ "file_separator": ""
+ }
+ }
+ },
{
"version": 1,
"context_length": 32768,
@@ -3528,7 +3562,8 @@
"zh"
],
"model_ability": [
- "generate"
+ "generate",
+ "code"
],
"model_description": "Deepseek Coder is composed of a series of code language models, each trained from scratch on 2T tokens, with a composition of 87% code and 13% natural language in both English and Chinese.",
"model_specs": [
@@ -3565,7 +3600,20 @@
"model_id": "deepseek-ai/deepseek-coder-33b-base",
"model_hub": "modelscope"
}
- ]
+ ],
+ "code_prompt_style": {
+ "style_name": "DEEPSEEK_CODER",
+ "fim_spec": {
+ "style": "PMS",
+ "prefix": "<|fim▁begin|>",
+ "middle": "<|fim▁hole|>",
+ "suffix": "<|fim▁end|>"
+ },
+ "repo_level_spec": {
+ "file_type": "filename",
+ "file_separator": ""
+ }
+ }
},
{
"version": 1,
@@ -4066,7 +4114,7 @@
"chat",
"vision"
],
- "model_description":"OmniLMM is a family of open-source large multimodal models (LMMs) adept at vision & language modeling.",
+ "model_description":"mniLMM is a family of open-source large multimodal models (LMMs) adept at vision & language modeling.",
"model_specs":[
{
"model_format":"pytorch",
diff --git a/xinference/model/llm/tests/test_llm_family.py b/xinference/model/llm/tests/test_llm_family.py
index 252491282c..6a75b1bec4 100644
--- a/xinference/model/llm/tests/test_llm_family.py
+++ b/xinference/model/llm/tests/test_llm_family.py
@@ -159,7 +159,7 @@ def test_serialize_llm_family_v1():
prompt_style=prompt_style,
)
- expected = """{"version": 1, "context_length": 2048, "model_name": "TestModel", "model_lang": ["en"], "model_ability": ["embed", "generate"], "model_description": null, "model_family": null, "model_specs": [{"model_format": "ggufv2", "model_hub": "huggingface", "model_size_in_billions": 2, "quantizations": ["q4_0", "q4_1"], "quantization_parts": {"q4_2": ["a", "b"]}, "model_id": "example/TestModel", "model_revision": "123", "model_file_name_template": "TestModel.{quantization}.bin", "model_file_name_split_template": "TestModel.{quantization}.bin.{part}", "model_uri": null}, {"model_format": "pytorch", "model_hub": "huggingface", "model_size_in_billions": 3, "quantizations": ["int8", "int4", "none"], "model_id": "example/TestModel", "model_revision": "456", "model_uri": null}], "prompt_style": {"style_name": "ADD_COLON_SINGLE", "system_prompt": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.", "roles": ["user", "assistant"], "intra_message_sep": "\\n### ", "inter_message_sep": "\\n### ", "stop": null, "stop_token_ids": null}}"""
+ expected = """{"version": 1, "context_length": 2048, "model_name": "TestModel", "model_lang": ["en"], "model_ability": ["embed", "generate"], "model_description": null, "model_family": null, "model_specs": [{"model_format": "ggufv2", "model_hub": "huggingface", "model_size_in_billions": 2, "quantizations": ["q4_0", "q4_1"], "quantization_parts": {"q4_2": ["a", "b"]}, "model_id": "example/TestModel", "model_revision": "123", "model_file_name_template": "TestModel.{quantization}.bin", "model_file_name_split_template": "TestModel.{quantization}.bin.{part}", "model_uri": null}, {"model_format": "pytorch", "model_hub": "huggingface", "model_size_in_billions": 3, "quantizations": ["int8", "int4", "none"], "model_id": "example/TestModel", "model_revision": "456", "model_uri": null}], "prompt_style": {"style_name": "ADD_COLON_SINGLE", "system_prompt": "A chat between a curious human and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the human's questions.", "roles": ["user", "assistant"], "intra_message_sep": "\\n### ", "inter_message_sep": "\\n### ", "stop": null, "stop_token_ids": null}, "code_prompt_style": null}"""
assert json.loads(llm_family.json()) == json.loads(expected)
llm_family_context_length = LLMFamilyV1(
diff --git a/xinference/model/llm/tests/test_utils.py b/xinference/model/llm/tests/test_utils.py
index 42125a0048..786af4e413 100644
--- a/xinference/model/llm/tests/test_utils.py
+++ b/xinference/model/llm/tests/test_utils.py
@@ -11,10 +11,12 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+import pytest
from ....types import ChatCompletionMessage
-from ..llm_family import PromptStyleV1
-from ..utils import ChatModelMixin
+from ..llm_family import CodePromptStyleV1, FIMSpecV1, PromptStyleV1
+from ..llm_family import RepoLevelCodeCompletionSpecV1 as RepoLevelSpecV1
+from ..utils import ChatModelMixin, CodeModelMixin
def test_prompt_style_add_colon_single():
@@ -330,3 +332,490 @@ def test_is_valid_model_name():
assert not is_valid_model_name("foo/bar")
assert not is_valid_model_name(" ")
assert not is_valid_model_name("")
+
+
+def test_path_to_name():
+ path = "/home/test/works/project/main.py"
+ assert "main.py" == CodeModelMixin._path_to_name(path)
+
+ path = "/main.py"
+ assert "main.py" == CodeModelMixin._path_to_name(path)
+
+ path = "main.py"
+ assert "main.py" == CodeModelMixin._path_to_name(path)
+
+ path = ".main.py"
+ assert ".main.py" == CodeModelMixin._path_to_name(path)
+
+ path = r"\main.py"
+ assert "main.py" == CodeModelMixin._path_to_name(path)
+
+ path = r"C:\works\main.py"
+ assert "main.py" == CodeModelMixin._path_to_name(path)
+
+
+def test_code_prompt_style_starcoder():
+ code_prompt_style = CodePromptStyleV1(
+ style_name="STARCODER",
+ fim_spec=FIMSpecV1(
+ style="PSM",
+ prefix="",
+ middle="",
+ suffix="",
+ ),
+ )
+ prompt = "def print_hello_world():"
+ expected = prompt
+ assert expected == CodeModelMixin._get_code_prompt(
+ "completion", prompt, code_prompt_style
+ )
+
+ prompt = "def print_hello_world():\n "
+ suffix = "\n print('Hello world!')"
+ expected = "def print_hello_world():\n \n print('Hello world!')"
+ assert expected == CodeModelMixin._get_code_prompt(
+ "infill", prompt, code_prompt_style, None, suffix
+ )
+
+ suffix = None
+ with pytest.raises(ValueError) as exc_info:
+ CodeModelMixin._get_code_prompt(
+ "infill", prompt, code_prompt_style, None, suffix
+ )
+ assert exc_info.value == ValueError("suffix is required in infill mode")
+
+ with pytest.raises(ValueError) as exc_info:
+ CodeModelMixin._get_code_prompt("test", prompt, code_prompt_style)
+ assert exc_info.value == ValueError(
+ "Unsupported generate mode: test, only 'PSM' and 'PMS' are supported now"
+ )
+
+
+def test_code_prompt_style_deepseek_coder():
+ code_prompt_style = CodePromptStyleV1(
+ style_name="DEEPSEEK_CODER",
+ fim_spec=FIMSpecV1(
+ style="PMS",
+ prefix="<|fim▁begin|>",
+ middle="<|fim▁hole|>",
+ suffix="<|fim▁end|>",
+ ),
+ repo_level_spec=RepoLevelSpecV1(
+ file_type="filename", file_separator=""
+ ),
+ )
+
+ prompt = "#write a quick sort algorithm"
+ expected = prompt
+
+ assert expected == CodeModelMixin._get_code_prompt(
+ "completion", prompt, code_prompt_style
+ )
+
+ prompt = """def quick_sort(arr):
+ if len(arr) <= 1:
+ return arr
+ pivot = arr[0]
+ left = []
+ right = []
+"""
+ suffix = """
+ if arr[i] < pivot:
+ left.append(arr[i])
+ else:
+ right.append(arr[i])
+ return quick_sort(left) + [pivot] + quick_sort(right)"""
+
+ expected = """<|fim▁begin|>def quick_sort(arr):
+ if len(arr) <= 1:
+ return arr
+ pivot = arr[0]
+ left = []
+ right = []
+<|fim▁hole|>
+ if arr[i] < pivot:
+ left.append(arr[i])
+ else:
+ right.append(arr[i])
+ return quick_sort(left) + [pivot] + quick_sort(right)<|fim▁end|>"""
+
+ assert expected == CodeModelMixin._get_code_prompt(
+ "infill", prompt, code_prompt_style, None, suffix
+ )
+
+ files = {
+ "utils.py": """import torch
+from sklearn import datasets
+from sklearn.model_selection import train_test_split
+from sklearn.preprocessing import StandardScaler
+from sklearn.metrics import accuracy_score
+
+def load_data():
+ iris = datasets.load_iris()
+ X = iris.data
+ y = iris.target
+
+ # Standardize the data
+ scaler = StandardScaler()
+ X = scaler.fit_transform(X)
+
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
+
+ # Convert numpy data to PyTorch tensors
+ X_train = torch.tensor(X_train, dtype=torch.float32)
+ X_test = torch.tensor(X_test, dtype=torch.float32)
+ y_train = torch.tensor(y_train, dtype=torch.int64)
+ y_test = torch.tensor(y_test, dtype=torch.int64)
+
+ return X_train, X_test, y_train, y_test
+
+def evaluate_predictions(y_test, y_pred):
+ return accuracy_score(y_test, y_pred)""",
+ "model.py": """import torch
+import torch.nn as nn
+import torch.optim as optim
+from torch.utils.data import DataLoader, TensorDataset
+
+class IrisClassifier(nn.Module):
+ def __init__(self):
+ super(IrisClassifier, self).__init__()
+ self.fc = nn.Sequential(
+ nn.Linear(4, 16),
+ nn.ReLU(),
+ nn.Linear(16, 3)
+ )
+
+ def forward(self, x):
+ return self.fc(x)
+
+ def train_model(self, X_train, y_train, epochs, lr, batch_size):
+ criterion = nn.CrossEntropyLoss()
+ optimizer = optim.Adam(self.parameters(), lr=lr)
+
+ # Create DataLoader for batches
+ dataset = TensorDataset(X_train, y_train)
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
+
+ for epoch in range(epochs):
+ for batch_X, batch_y in dataloader:
+ optimizer.zero_grad()
+ outputs = self(batch_X)
+ loss = criterion(outputs, batch_y)
+ loss.backward()
+ optimizer.step()
+
+ def predict(self, X_test):
+ with torch.no_grad():
+ outputs = self(X_test)
+ _, predicted = outputs.max(1)
+ return predicted.numpy()""",
+ }
+
+ prompt = """from utils import load_data, evaluate_predictions
+from model import IrisClassifier as Classifier
+
+def main():
+ # Model training and evaluation
+"""
+ file_path = "/home/test/works/proj01/main.py"
+
+ expected = """# utils.py
+import torch
+from sklearn import datasets
+from sklearn.model_selection import train_test_split
+from sklearn.preprocessing import StandardScaler
+from sklearn.metrics import accuracy_score
+
+def load_data():
+ iris = datasets.load_iris()
+ X = iris.data
+ y = iris.target
+
+ # Standardize the data
+ scaler = StandardScaler()
+ X = scaler.fit_transform(X)
+
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
+
+ # Convert numpy data to PyTorch tensors
+ X_train = torch.tensor(X_train, dtype=torch.float32)
+ X_test = torch.tensor(X_test, dtype=torch.float32)
+ y_train = torch.tensor(y_train, dtype=torch.int64)
+ y_test = torch.tensor(y_test, dtype=torch.int64)
+
+ return X_train, X_test, y_train, y_test
+
+def evaluate_predictions(y_test, y_pred):
+ return accuracy_score(y_test, y_pred)
+# model.py
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from torch.utils.data import DataLoader, TensorDataset
+
+class IrisClassifier(nn.Module):
+ def __init__(self):
+ super(IrisClassifier, self).__init__()
+ self.fc = nn.Sequential(
+ nn.Linear(4, 16),
+ nn.ReLU(),
+ nn.Linear(16, 3)
+ )
+
+ def forward(self, x):
+ return self.fc(x)
+
+ def train_model(self, X_train, y_train, epochs, lr, batch_size):
+ criterion = nn.CrossEntropyLoss()
+ optimizer = optim.Adam(self.parameters(), lr=lr)
+
+ # Create DataLoader for batches
+ dataset = TensorDataset(X_train, y_train)
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
+
+ for epoch in range(epochs):
+ for batch_X, batch_y in dataloader:
+ optimizer.zero_grad()
+ outputs = self(batch_X)
+ loss = criterion(outputs, batch_y)
+ loss.backward()
+ optimizer.step()
+
+ def predict(self, X_test):
+ with torch.no_grad():
+ outputs = self(X_test)
+ _, predicted = outputs.max(1)
+ return predicted.numpy()
+# main.py
+from utils import load_data, evaluate_predictions
+from model import IrisClassifier as Classifier
+
+def main():
+ # Model training and evaluation
+"""
+ assert expected == CodeModelMixin._get_code_prompt(
+ "completion", prompt, code_prompt_style, file_path, None, None, files
+ )
+
+
+def test_code_prompt_style_codeqwen():
+ code_prompt_style = CodePromptStyleV1(
+ style_name="CODEQWEN",
+ fim_spec=FIMSpecV1(
+ style="PSM",
+ prefix="",
+ middle="",
+ suffix="",
+ ),
+ repo_level_spec=RepoLevelSpecV1(
+ repo_name="", file_type="filepath", file_separator=""
+ ),
+ )
+
+ prompt = "#write a quick sort algorithm"
+ expected = prompt
+
+ assert expected == CodeModelMixin._get_code_prompt(
+ "completion", prompt, code_prompt_style
+ )
+
+ prompt = """def quick_sort(arr):
+ if len(arr) <= 1:
+ return arr
+ pivot = arr[0]
+ left = []
+ right = []
+"""
+ suffix = """
+ if arr[i] < pivot:
+ left.append(arr[i])
+ else:
+ right.append(arr[i])
+ return quick_sort(left) + [pivot] + quick_sort(right)"""
+
+ expected = """def quick_sort(arr):
+ if len(arr) <= 1:
+ return arr
+ pivot = arr[0]
+ left = []
+ right = []
+
+ if arr[i] < pivot:
+ left.append(arr[i])
+ else:
+ right.append(arr[i])
+ return quick_sort(left) + [pivot] + quick_sort(right)"""
+
+ assert expected == CodeModelMixin._get_code_prompt(
+ "infill", prompt, code_prompt_style, None, suffix
+ )
+
+ files = {
+ "/home/test/works/proj01/utils.py": """import torch
+from sklearn import datasets
+from sklearn.model_selection import train_test_split
+from sklearn.preprocessing import StandardScaler
+from sklearn.metrics import accuracy_score
+
+def load_data():
+ iris = datasets.load_iris()
+ X = iris.data
+ y = iris.target
+
+ # Standardize the data
+ scaler = StandardScaler()
+ X = scaler.fit_transform(X)
+
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
+
+ # Convert numpy data to PyTorch tensors
+ X_train = torch.tensor(X_train, dtype=torch.float32)
+ X_test = torch.tensor(X_test, dtype=torch.float32)
+ y_train = torch.tensor(y_train, dtype=torch.int64)
+ y_test = torch.tensor(y_test, dtype=torch.int64)
+
+ return X_train, X_test, y_train, y_test
+
+def evaluate_predictions(y_test, y_pred):
+ return accuracy_score(y_test, y_pred)""",
+ "/home/test/works/proj01/model.py": """import torch
+import torch.nn as nn
+import torch.optim as optim
+from torch.utils.data import DataLoader, TensorDataset
+
+class IrisClassifier(nn.Module):
+ def __init__(self):
+ super(IrisClassifier, self).__init__()
+ self.fc = nn.Sequential(
+ nn.Linear(4, 16),
+ nn.ReLU(),
+ nn.Linear(16, 3)
+ )
+
+ def forward(self, x):
+ return self.fc(x)
+
+ def train_model(self, X_train, y_train, epochs, lr, batch_size):
+ criterion = nn.CrossEntropyLoss()
+ optimizer = optim.Adam(self.parameters(), lr=lr)
+
+ # Create DataLoader for batches
+ dataset = TensorDataset(X_train, y_train)
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
+
+ for epoch in range(epochs):
+ for batch_X, batch_y in dataloader:
+ optimizer.zero_grad()
+ outputs = self(batch_X)
+ loss = criterion(outputs, batch_y)
+ loss.backward()
+ optimizer.step()
+
+ def predict(self, X_test):
+ with torch.no_grad():
+ outputs = self(X_test)
+ _, predicted = outputs.max(1)
+ return predicted.numpy()""",
+ }
+
+ prompt = """from utils import load_data, evaluate_predictions
+from model import IrisClassifier as Classifier
+
+def main():
+ # Model training and evaluation
+"""
+ file_path = "/home/test/works/proj01/main.py"
+
+ expected = """project01
+/home/test/works/proj01/utils.py
+import torch
+from sklearn import datasets
+from sklearn.model_selection import train_test_split
+from sklearn.preprocessing import StandardScaler
+from sklearn.metrics import accuracy_score
+
+def load_data():
+ iris = datasets.load_iris()
+ X = iris.data
+ y = iris.target
+
+ # Standardize the data
+ scaler = StandardScaler()
+ X = scaler.fit_transform(X)
+
+ X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.3, random_state=42)
+
+ # Convert numpy data to PyTorch tensors
+ X_train = torch.tensor(X_train, dtype=torch.float32)
+ X_test = torch.tensor(X_test, dtype=torch.float32)
+ y_train = torch.tensor(y_train, dtype=torch.int64)
+ y_test = torch.tensor(y_test, dtype=torch.int64)
+
+ return X_train, X_test, y_train, y_test
+
+def evaluate_predictions(y_test, y_pred):
+ return accuracy_score(y_test, y_pred)
+/home/test/works/proj01/model.py
+import torch
+import torch.nn as nn
+import torch.optim as optim
+from torch.utils.data import DataLoader, TensorDataset
+
+class IrisClassifier(nn.Module):
+ def __init__(self):
+ super(IrisClassifier, self).__init__()
+ self.fc = nn.Sequential(
+ nn.Linear(4, 16),
+ nn.ReLU(),
+ nn.Linear(16, 3)
+ )
+
+ def forward(self, x):
+ return self.fc(x)
+
+ def train_model(self, X_train, y_train, epochs, lr, batch_size):
+ criterion = nn.CrossEntropyLoss()
+ optimizer = optim.Adam(self.parameters(), lr=lr)
+
+ # Create DataLoader for batches
+ dataset = TensorDataset(X_train, y_train)
+ dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
+
+ for epoch in range(epochs):
+ for batch_X, batch_y in dataloader:
+ optimizer.zero_grad()
+ outputs = self(batch_X)
+ loss = criterion(outputs, batch_y)
+ loss.backward()
+ optimizer.step()
+
+ def predict(self, X_test):
+ with torch.no_grad():
+ outputs = self(X_test)
+ _, predicted = outputs.max(1)
+ return predicted.numpy()
+/home/test/works/proj01/main.py
+from utils import load_data, evaluate_predictions
+from model import IrisClassifier as Classifier
+
+def main():
+ # Model training and evaluation
+"""
+ assert expected == CodeModelMixin._get_code_prompt(
+ "completion", prompt, code_prompt_style, file_path, None, "project01", files
+ )
+
+
+def test_code_prompt_style_without_fim():
+ code_prompt_style = CodePromptStyleV1(
+ style_name="NO_FIM_CODER",
+ )
+ prompt = "def print_hello_world():\n "
+ suffix = "\n print('Hello world!')"
+ with pytest.raises(ValueError) as exc_info:
+ CodeModelMixin._get_code_prompt(
+ "infill", prompt, code_prompt_style, None, suffix
+ )
+ assert exc_info.value == ValueError(
+ "This model is not support infill mode generate"
+ )
diff --git a/xinference/model/llm/transformers/core.py b/xinference/model/llm/transformers/core.py
index 7b427c4918..e71bc78ea1 100644
--- a/xinference/model/llm/transformers/core.py
+++ b/xinference/model/llm/transformers/core.py
@@ -16,7 +16,7 @@
import logging
import os
from functools import lru_cache
-from typing import Iterable, Iterator, List, Optional, Tuple, Union
+from typing import Iterable, Iterator, List, Mapping, Optional, Tuple, Union
import torch
@@ -30,6 +30,7 @@
ChatCompletion,
ChatCompletionChunk,
ChatCompletionMessage,
+ CodeGenerateMode,
Completion,
CompletionChoice,
CompletionChunk,
@@ -41,7 +42,7 @@
from ...utils import select_device
from ..core import LLM
from ..llm_family import LLMFamilyV1, LLMSpecV1
-from ..utils import QWEN_TOOL_CALL_FAMILY, ChatModelMixin
+from ..utils import QWEN_TOOL_CALL_FAMILY, ChatModelMixin, CodeModelMixin
from .utils import get_context_length, get_max_src_len, pad_prefill_tokens
logger = logging.getLogger(__name__)
@@ -809,3 +810,62 @@ def handle_batch_inference_results(self, req_list: List[InferenceRequest]):
req.completion = results
else:
req.completion[0] = self._to_chat_completion(req.completion[0])
+
+
+class PytorchCodeModel(PytorchModel, CodeModelMixin):
+ def __init__(
+ self,
+ model_uid: str,
+ model_family: "LLMFamilyV1",
+ model_spec: "LLMSpecV1",
+ quantization: str,
+ model_path: str,
+ pytorch_model_config: Optional[PytorchModelConfig] = None,
+ peft_model: Optional[List[LoRA]] = None,
+ ):
+ super().__init__(
+ model_uid,
+ model_family,
+ model_spec,
+ quantization,
+ model_path,
+ pytorch_model_config,
+ peft_model,
+ )
+
+ @classmethod
+ def match(
+ cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
+ ) -> bool:
+ if llm_spec.model_format not in ["pytorch", "gptq", "awq"]:
+ return False
+ model_family = llm_family.model_family or llm_family.model_name
+ if model_family in NON_DEFAULT_MODEL_LIST:
+ return False
+ if "code" not in llm_family.model_ability:
+ return False
+ return True
+
+ def code_generate(
+ self,
+ mode: CodeGenerateMode,
+ prompt: str,
+ file_path: Optional[str],
+ suffix: Optional[str],
+ repo_name: Optional[str],
+ files: Optional[Mapping[str, str]],
+ generate_config: Optional[PytorchGenerateConfig] = None,
+ ) -> Union[Completion, Iterator[CompletionChunk]]:
+ code_prompt = self.get_code_prompt(
+ mode,
+ prompt,
+ file_path,
+ suffix,
+ repo_name,
+ files,
+ )["prompt"]
+
+ if generate_config is not None and generate_config.get("stream", False):
+ generate_config["stream"] = False
+
+ return self.generate(code_prompt, generate_config)
diff --git a/xinference/model/llm/utils.py b/xinference/model/llm/utils.py
index 7f203d0c21..03259feb84 100644
--- a/xinference/model/llm/utils.py
+++ b/xinference/model/llm/utils.py
@@ -19,7 +19,7 @@
import time
import uuid
from io import BytesIO
-from typing import AsyncGenerator, Dict, Iterator, List, Optional, Tuple, cast
+from typing import AsyncGenerator, Dict, Iterator, List, Mapping, Optional, Tuple, cast
import requests
from PIL import Image
@@ -29,10 +29,14 @@
ChatCompletion,
ChatCompletionChunk,
ChatCompletionMessage,
+ CodeGenerateMode,
Completion,
CompletionChunk,
)
+from .core import LLM
+from .lang_utils import get_file_separator
from .llm_family import (
+ CodePromptStyleV1,
LlamaCppLLMSpecV1,
LLMFamilyV1,
LLMSpecV1,
@@ -849,6 +853,131 @@ def get_full_prompt(cls, model_family, prompt, system_prompt, chat_history, tool
return full_prompt
+class CodeModelMixin:
+ def get_code_prompt(
+ self,
+ mode: CodeGenerateMode,
+ prompt: str,
+ file_path: Optional[str] = None,
+ suffix: Optional[str] = None,
+ repo_name: Optional[str] = None,
+ files: Optional[Mapping[str, str]] = None,
+ ):
+ code_prompt_style = cast(LLM, self).model_family.code_prompt_style
+ return {
+ "prompt": CodeModelMixin._get_code_prompt(
+ mode, prompt, code_prompt_style, file_path, suffix, repo_name, files
+ )
+ }
+
+ @staticmethod
+ def _get_code_prompt(
+ mode: CodeGenerateMode,
+ prompt: str,
+ code_prompt_style: Optional["CodePromptStyleV1"],
+ file_path: Optional[str] = None,
+ suffix: Optional[str] = None,
+ repo_name: Optional[str] = None,
+ files: Optional[Mapping[str, str]] = None,
+ ) -> str:
+ if code_prompt_style is None:
+ raise ValueError(
+ "code prompt style is not provided, the model spec is wrong."
+ )
+
+ if mode == "completion":
+ if suffix is not None:
+ logger.warning(
+ "Suffix is only required on generate type is infill, ignored"
+ )
+
+ spec = code_prompt_style.repo_level_spec
+
+ if files is None or len(files) == 0:
+ if file_path is None or len(file_path.strip()) == 0:
+ return prompt
+ else:
+ if spec is None:
+ logger.warning(
+ "repository level file separator not defined, but file_path provided, ignored"
+ )
+ return prompt
+ else:
+ repo_file = (
+ file_path
+ if spec.file_type == "filepath"
+ else CodeModelMixin._path_to_name(file_path)
+ )
+ return get_file_separator(spec, repo_file)
+
+ if spec is None:
+ logger.warning(
+ "The model does not support repository level code completion, 'repo_name' and 'files' are ignored"
+ )
+ return prompt
+
+ chunks = []
+ if spec.repo_name is None:
+ if repo_name is not None:
+ logger.warning(
+ "'repo_name' is provided but it will not be used in this model, ignored"
+ )
+ else:
+ if repo_name is None:
+ raise ValueError(
+ "The 'repo_name' is required for repository level code completion for this model"
+ )
+ chunks.append(f"{spec.repo_name}{repo_name}")
+
+ for filepath, content in files.items():
+ repo_file = (
+ filepath
+ if spec.file_type == "filepath"
+ else CodeModelMixin._path_to_name(filepath)
+ )
+ chunks.append(get_file_separator(spec, repo_file))
+ chunks.append(content)
+
+ if file_path is not None and len(file_path.strip()) > 0:
+ repo_file = (
+ file_path
+ if spec.file_type == "filepath"
+ else CodeModelMixin._path_to_name(file_path)
+ )
+ chunks.append(get_file_separator(spec, repo_file))
+ chunks.append(prompt)
+
+ return "\n".join(chunks)
+
+ elif mode == "infill":
+ spec = code_prompt_style.fim_spec
+ if spec is None:
+ raise ValueError("This model is not support infill mode generate")
+
+ if suffix is None:
+ raise ValueError("suffix is required in infill mode")
+
+ if files is not None and len(files) > 0:
+ logger.warning(
+ "files is only required in repository level code completion, ignored"
+ )
+
+ if spec.style == "PSM":
+ return f"{spec.prefix}{prompt}{spec.suffix}{suffix}{spec.middle}"
+ else:
+ return f"{spec.prefix}{prompt}{spec.middle}{suffix}{spec.suffix}"
+
+ else:
+ raise ValueError(
+ f"Unsupported generate mode: {mode}, only 'completion' and 'infill' are supported now"
+ )
+
+ @staticmethod
+ def _path_to_name(filepath: str) -> str:
+ filepath = filepath.replace("\\", "/")
+ return os.path.split(filepath)[1]
+
+
def get_file_location(
llm_family: LLMFamilyV1, spec: LLMSpecV1, quantization: str
) -> Tuple[str, bool]:
diff --git a/xinference/model/llm/vllm/core.py b/xinference/model/llm/vllm/core.py
index 4b009aa646..fe2ac3e4e5 100644
--- a/xinference/model/llm/vllm/core.py
+++ b/xinference/model/llm/vllm/core.py
@@ -26,6 +26,7 @@
Dict,
Iterable,
List,
+ Mapping,
Optional,
TypedDict,
Union,
@@ -35,6 +36,7 @@
ChatCompletion,
ChatCompletionChunk,
ChatCompletionMessage,
+ CodeGenerateMode,
Completion,
CompletionChoice,
CompletionChunk,
@@ -45,7 +47,7 @@
)
from .. import LLM, LLMFamilyV1, LLMSpecV1
from ..llm_family import CustomLLMFamilyV1
-from ..utils import QWEN_TOOL_CALL_FAMILY, ChatModelMixin
+from ..utils import QWEN_TOOL_CALL_FAMILY, ChatModelMixin, CodeModelMixin
logger = logging.getLogger(__name__)
@@ -130,6 +132,10 @@ class VLLMGenerateConfig(TypedDict, total=False):
"deepseek-chat",
"deepseek-coder-instruct",
]
+VLLM_SUPPORTED_CODE_MODELS = [
+ "deepseek-coder",
+ "codeqwen1.5",
+]
if VLLM_INSTALLED and vllm.__version__ >= "0.3.0":
VLLM_SUPPORTED_CHAT_MODELS.append("qwen1.5-chat")
VLLM_SUPPORTED_MODELS.append("codeqwen1.5")
@@ -743,3 +749,59 @@ async def async_chat(
c = await self.async_generate(inputs, generate_config)
assert not isinstance(c, AsyncGenerator)
return self._to_chat_completion(c)
+
+
+class VLLMCodeModel(VLLMModel, CodeModelMixin):
+ @classmethod
+ def match(
+ cls, llm_family: "LLMFamilyV1", llm_spec: "LLMSpecV1", quantization: str
+ ) -> bool:
+ if llm_spec.model_format not in ["pytorch", "gptq", "awq"]:
+ return False
+ if llm_spec.model_format == "pytorch":
+ if quantization != "none" and quantization is not None:
+ return False
+ if llm_spec.model_format == "awq":
+ # Currently, only 4-bit weight quantization is supported for AWQ, but got 8 bits.
+ if "4" not in quantization:
+ return False
+ if llm_spec.model_format == "gptq":
+ if VLLM_INSTALLED and vllm.__version__ >= "0.3.3":
+ if not any(q in quantization for q in ("3", "4", "8")):
+ return False
+ else:
+ if "4" not in quantization:
+ return False
+ if isinstance(llm_family, CustomLLMFamilyV1):
+ if llm_family.model_family not in VLLM_SUPPORTED_CODE_MODELS:
+ return False
+ else:
+ if llm_family.model_name not in VLLM_SUPPORTED_CODE_MODELS:
+ return False
+ if "code" not in llm_family.model_ability:
+ return False
+ return VLLM_INSTALLED
+
+ async def async_code_generate(
+ self,
+ mode: CodeGenerateMode,
+ prompt: str,
+ file_path: Optional[str],
+ suffix: Optional[str],
+ repo_name: Optional[str],
+ files: Optional[Mapping[str, str]],
+ generate_config: Optional[Dict] = None,
+ ) -> Union[Completion, AsyncGenerator[CompletionChunk, None]]:
+ code_prompt = self.get_code_prompt(
+ mode,
+ prompt,
+ file_path,
+ suffix,
+ repo_name,
+ files,
+ )["prompt"]
+
+ if generate_config is not None and generate_config.get("stream", False):
+ generate_config["stream"] = False
+
+ return await self.async_generate(code_prompt, generate_config)
diff --git a/xinference/types.py b/xinference/types.py
index 3f636d94c3..b941b5a80c 100644
--- a/xinference/types.py
+++ b/xinference/types.py
@@ -14,7 +14,7 @@
from typing import Any, Callable, Dict, ForwardRef, Iterable, List, Optional, Union
-from typing_extensions import Literal, NotRequired, TypedDict
+from typing_extensions import Literal, Mapping, NotRequired, TypedDict
from ._compat import (
BaseModel,
@@ -533,3 +533,13 @@ def from_dict(cls, data: Dict):
image_lora_load_kwargs=data.get("image_lora_load_kwargs"),
image_lora_fuse_kwargs=data.get("image_lora_fuse_kwargs"),
)
+
+
+CodeGenerateMode = Literal["completion", "infill"]
+
+
+class CreateCodeCompletion(CreateCompletion):
+ file_path: Optional[str] = none_field
+ mode: CodeGenerateMode = "completion"
+ repo_name: Optional[str] = none_field
+ files: Optional[Mapping[str, str]] = none_field
diff --git a/xinference/web/ui/src/scenes/launch_model/modelCard.js b/xinference/web/ui/src/scenes/launch_model/modelCard.js
index a4391f830a..790227f081 100644
--- a/xinference/web/ui/src/scenes/launch_model/modelCard.js
+++ b/xinference/web/ui/src/scenes/launch_model/modelCard.js
@@ -10,6 +10,7 @@ import {
ExpandMore,
Grade,
HelpCenterOutlined,
+ LogoDevOutlined,
RocketLaunchOutlined,
StarBorder,
UndoOutlined,
@@ -892,6 +893,16 @@ const ModelCard = ({
chat model
)
+ } else if (
+ modelData.model_ability &&
+ modelData.model_ability.includes('code')
+ ) {
+ return (
+
+
+ code model
+
+ )
} else if (
modelData.model_ability &&
modelData.model_ability.includes('generate')
diff --git a/xinference/web/ui/src/scenes/running_models/index.js b/xinference/web/ui/src/scenes/running_models/index.js
index 4024c8f6d2..e3afd1b9e8 100644
--- a/xinference/web/ui/src/scenes/running_models/index.js
+++ b/xinference/web/ui/src/scenes/running_models/index.js
@@ -38,6 +38,63 @@ const RunningModels = () => {
sessionStorage.setItem('runningModelType', newValue)
}
+ function get_models(code_prompts) {
+ fetchWrapper
+ .get('/v1/models')
+ .then((response) => {
+ const newLlmData = []
+ const newEmbeddingModelData = []
+ const newImageModelData = []
+ const newAudioModelData = []
+ const newVideoModelData = []
+ const newRerankModelData = []
+ const newFlexibleModelData = []
+ response.data.forEach((model) => {
+ let newValue = {
+ ...model,
+ id: model.id,
+ url: model.id,
+ }
+ if (newValue.model_type === 'LLM') {
+ if (model.model_name in code_prompts) {
+ newValue['infill_supported'] =
+ 'fim_spec' in code_prompts[model.model_name]
+ newValue['repo_level_supported'] =
+ 'repo_level_spec' in code_prompts[model.model_name]
+ }
+ newLlmData.push(newValue)
+ } else if (newValue.model_type === 'embedding') {
+ newEmbeddingModelData.push(newValue)
+ } else if (newValue.model_type === 'audio') {
+ newAudioModelData.push(newValue)
+ } else if (newValue.model_type === 'video') {
+ newVideoModelData.push(newValue)
+ } else if (newValue.model_type === 'image') {
+ newImageModelData.push(newValue)
+ } else if (newValue.model_type === 'rerank') {
+ newRerankModelData.push(newValue)
+ } else if (newValue.model_type === 'flexible') {
+ newFlexibleModelData.push(newValue)
+ }
+ })
+ setLlmData(newLlmData)
+ setEmbeddingModelData(newEmbeddingModelData)
+ setAudioModelData(newAudioModelData)
+ setVideoModelData(newVideoModelData)
+ setImageModelData(newImageModelData)
+ setRerankModelData(newRerankModelData)
+ setFlexibleModelData(newFlexibleModelData)
+ setIsUpdatingModel(false)
+ })
+ .catch((error) => {
+ console.error('Error:', error)
+ setIsUpdatingModel(false)
+ if (error.response.status !== 403 && error.response.status !== 401) {
+ setErrorMsg(error.message)
+ }
+ })
+ }
+
const update = (isCallingApi) => {
if (
sessionStorage.getItem('auth') === 'true' &&
@@ -71,45 +128,9 @@ const RunningModels = () => {
setIsUpdatingModel(true)
fetchWrapper
- .get('/v1/models')
- .then((response) => {
- const newLlmData = []
- const newEmbeddingModelData = []
- const newImageModelData = []
- const newAudioModelData = []
- const newVideoModelData = []
- const newRerankModelData = []
- const newFlexibleModelData = []
- response.data.forEach((model) => {
- let newValue = {
- ...model,
- id: model.id,
- url: model.id,
- }
- if (newValue.model_type === 'LLM') {
- newLlmData.push(newValue)
- } else if (newValue.model_type === 'embedding') {
- newEmbeddingModelData.push(newValue)
- } else if (newValue.model_type === 'audio') {
- newAudioModelData.push(newValue)
- } else if (newValue.model_type === 'video') {
- newVideoModelData.push(newValue)
- } else if (newValue.model_type === 'image') {
- newImageModelData.push(newValue)
- } else if (newValue.model_type === 'rerank') {
- newRerankModelData.push(newValue)
- } else if (newValue.model_type === 'flexible') {
- newFlexibleModelData.push(newValue)
- }
- })
- setLlmData(newLlmData)
- setEmbeddingModelData(newEmbeddingModelData)
- setAudioModelData(newAudioModelData)
- setVideoModelData(newVideoModelData)
- setImageModelData(newImageModelData)
- setRerankModelData(newRerankModelData)
- setFlexibleModelData(newFlexibleModelData)
- setIsUpdatingModel(false)
+ .get('/v1/models/code_prompts')
+ .then((code_prompts) => {
+ get_models(code_prompts)
})
.catch((error) => {
console.error('Error:', error)
@@ -228,6 +249,8 @@ const RunningModels = () => {
model_ability: row.model_ability,
model_description: row.model_description,
model_lang: row.model_lang,
+ infill_supported: row.infill_supported,
+ repo_level_supported: row.repo_level_supported,
}),
})
.then((response) => response.json())