Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Mistral support #1

Merged
merged 10 commits into from
Mar 16, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion LICENSE
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
MIT License

Copyright (c) 2024 Jelle Smet
Copyright (c) 2025 Jelle Smet

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
Expand Down
36 changes: 22 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,13 @@ backends:
model: "gpt-4o"
max_tokens: 4096
system: You are a helpful assistant.

mistral:
default:
token: ${{MISTRAL_OPENAI_TOKEN}}
max_tokens: 4096
model: mistral-small-latest
system: You are a helpful assistant.
```

### Backends
Expand Down Expand Up @@ -112,20 +119,6 @@ $ echo $?
0
```

### no-prose

Enabling `--noprose` is useful to make sure the LLM does not add any unwanted
explanation or formatting to the response:

```
cat test.md | clai --prompt "Convert this markdown content to asciidoc."
= Title

== Chapter 1

This is chapter 1
```

## Backends

### OpenAI
Expand Down Expand Up @@ -155,3 +148,18 @@ https://learn.microsoft.com/en-us/azure/ai-services/openai/
- `max_tokens`: The maximum number of tokens the prompt is allowed to generate
- `system`: The system prompt
- `temperature`: The prompt temperature (default: 0)

### Mistral

https://mistral.ai/

The following parameters are supported:

- `token`: The Mistral API token
- `max_tokens`: The maximum number of tokens the prompt is allowed to
generate. (CAVEAT: The token calculation is not accurate since there no
equivalent for tiktoken which allows to tokenize locally without external
dependencies)
- `model`: The model to use
- `system`: The system prompt
- `temperature`: The prompt temperature (default 0)
51 changes: 13 additions & 38 deletions clai/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@

import importlib
import sys
from clai.prompts import BOOL_PROMPT, CLOSING_PROMPT, NOPROSE_PROMPT
from clai.tools import (
get_backend_instance_config,
parse_arguments,
Expand All @@ -14,67 +13,42 @@


def process_prompt(
prompt: str, bool_prompt: bool, no_prose_prompt, backend, backend_config
prompt: str, bool_prompt: bool, debug: bool, backend, backend_config
) -> None:
"""
Main function to execute the prompt.

Args:
prompt: The user provided prompt for the LLM to process.
bool_prompt: Append `BOOL_PROMPT` content to `prompt`.
no_prose_prompt: Append `NOPROSE_PROMPT` content to `prompt`.
backend: The LLM backend function name to import and execute.
backend_config: The `backend` function variables.
"""
prompts = []
response_format = None
prompts = [prompt]

if bool_prompt:
response_format = {
"type": "json_schema",
"json_schema": {
"name": "true_false",
"schema": {
"type": "object",
"properties": {
"answer": {
"type": "boolean",
},
"reason": {"type": "string"},
},
"required": ["answer", "reason"],
"additionalProperties": False,
},
"strict": True,
},
}
prompts.append(cleanup(BOOL_PROMPT))

elif no_prose_prompt:
prompts.append(cleanup(NOPROSE_PROMPT))

prompts.append(cleanup(CLOSING_PROMPT))
prompts.append(prompt)

backend_func = getattr(importlib.import_module("clai.backends"), backend)
backend_func = importlib.import_module(f"clai.backend.{backend}").prompt

try:
response = backend_func(
**backend_config
| {
"prompts": prompts,
"stdin": read_stdin,
"response_format": response_format,
"debug": debug,
"bool_prompt": bool_prompt,
}
)
except Exception as err:
print(f"Failed to process prompt. Reason: f{err}", file=sys.stderr)
sys.exit(1)

print(response, file=sys.stdout)

if bool_prompt:
sys.exit(get_exit_code(response))
exit_code = get_exit_code(response)
else:
exit_code = 0

print(response, file=sys.stdout)
sys.exit(exit_code)


def main():
Expand All @@ -88,10 +62,11 @@ def main():
process_prompt(
prompt=args.prompt,
bool_prompt=args.bool,
no_prose_prompt=args.no_prose,
debug=args.debug,
backend=args.backend,
backend_config=backend_config,
)

except Exception as err:
print(f"Failed to execute command. Reason: {err}")
sys.exit(1)
Expand Down
1 change: 1 addition & 0 deletions clai/backend/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
SUPPORTED_BACKENDS = ["azure_openai", "openai", "mistral"]
68 changes: 68 additions & 0 deletions clai/backend/azure_openai.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
#
# azure_openai.py
#

import tiktoken
from openai import AzureOpenAI
from clai.backend.openai import (
ValidateTokenLength,
build_messages,
RESPONSE_FORMAT,
)
from clai.prompts import AZURE_BOOL_PROMPT

from pydantic import BaseModel


class BoolResponseModel(BaseModel):
reason: str
answer: bool


def prompt(
endpoint,
api_version,
token,
max_tokens,
model,
system,
prompts,
stdin,
temperature=0,
base_model=None,
bool_prompt=False,
debug=False,
):
client = AzureOpenAI(
api_key=token,
azure_endpoint=endpoint,
api_version=api_version,
)

if base_model is None:
token_model = model
else:
token_model = base_model

if bool_prompt:
system = AZURE_BOOL_PROMPT

messages = build_messages(max_tokens, token_model, system, prompts, stdin)

if debug:
print(messages)

if bool_prompt:
response = client.chat.completions.create(
messages=messages,
model=model,
temperature=temperature,
response_format={"type": "json_object"},
)
else:
response = client.chat.completions.create(
messages=messages,
model=model,
temperature=temperature,
)
return response.choices[0].message.content
109 changes: 109 additions & 0 deletions clai/backend/mistral.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
# MIT License
#
# Copyright (c) 2025 Jelle Smet
#
# This software is released under the MIT License.
# See the LICENSE file in the project root for more information.
#
# mistral.py
#
from mistralai import Mistral
import re
from clai.prompts import BOOL_PROMPT

from mistralai.models import responseformat
from pydantic import BaseModel


class BoolResponseModel(BaseModel):
reason: str
answer: bool


class ValidateTokenLength:
def __init__(self, model, max_tokens):
self.tokenizer = re.compile(
r"\b\w+\b|[^\w\s]"
).findall # very rudimentary tokenizer
self.total_tokens = 0
self.max_tokens = max_tokens

def add(self, data):
self.total_tokens += len(self.tokenizer(data))
if self.total_tokens > self.max_tokens:
raise Exception(f"Total input exceeds {self.total_tokens} tokens.")


def build_messages(max_tokens, model, system, prompts, stdin):
messages = []

vtl = ValidateTokenLength(model=model, max_tokens=max_tokens)

vtl.add(system)
messages.append({"role": "system", "content": system.lstrip().rstrip()})

for prompt in prompts:
vtl.add(prompt)
messages.append({"role": "user", "content": prompt.lstrip().rstrip()})

stdin_content = []
for line in stdin():
vtl.add(line)
stdin_content.append(line.lstrip().rstrip())

if len(stdin_content) > 0:
messages.append(
{"role": "user", "content": "".join(stdin_content)},
)

return messages


def prompt(
token,
max_tokens,
model,
system,
prompts,
stdin,
temperature=0,
bool_prompt=False,
debug=False,
):
client = Mistral(api_key=token)

if bool_prompt:
messages = build_messages(
max_tokens=max_tokens,
model=model,
system=BOOL_PROMPT,
prompts=prompts,
stdin=stdin,
)
else:
messages = build_messages(
max_tokens=max_tokens,
model=model,
system=system,
prompts=prompts,
stdin=stdin,
)

if debug:
print(messages)

if bool_prompt:
chat_response = client.chat.parse(
model=model,
messages=messages,
temperature=temperature,
response_format=BoolResponseModel,
)
else:
chat_response = client.chat.complete(
model=model,
messages=messages,
temperature=temperature,
)

return chat_response.choices[0].message.content
Loading