Skip to content

Commit

Permalink
mlx-server support
Browse files Browse the repository at this point in the history
  • Loading branch information
adrianliechti committed Jun 30, 2024
1 parent 1260ecf commit 6e63e84
Show file tree
Hide file tree
Showing 5 changed files with 178 additions and 89 deletions.
192 changes: 104 additions & 88 deletions config/config_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ import (
"github.com/adrianliechti/llama/pkg/provider/llama"
"github.com/adrianliechti/llama/pkg/provider/mimic"
"github.com/adrianliechti/llama/pkg/provider/mistral"
"github.com/adrianliechti/llama/pkg/provider/mlx"
"github.com/adrianliechti/llama/pkg/provider/ollama"
"github.com/adrianliechti/llama/pkg/provider/openai"
"github.com/adrianliechti/llama/pkg/provider/whisper"
Expand Down Expand Up @@ -142,6 +143,21 @@ func createProvider(cfg providerConfig, model string) (any, error) {
case "automatic1111":
return automatic1111Provider(cfg, model)

case "azure-translator":
return azuretranslatorProvider(cfg, model)

case "coqui":
return coquiProvider(cfg)

case "custom":
return customProvider(cfg, model)

case "deepl":
return deeplProvider(cfg, model)

case "groq":
return groqProvider(cfg, model)

case "huggingface":
return huggingfaceProvider(cfg, model)

Expand All @@ -151,36 +167,24 @@ func createProvider(cfg providerConfig, model string) (any, error) {
case "llama":
return llamaProvider(cfg, model)

case "ollama":
return ollamaProvider(cfg, model)

case "openai":
return openaiProvider(cfg, model)
case "mimic":
return mimicProvider(cfg)

case "mistral":
return mistralProvider(cfg, model)

case "groq":
return groqProvider(cfg, model)
case "mlx":
return mlxProvider(cfg, model)

case "coqui":
return coquiProvider(cfg)
case "ollama":
return ollamaProvider(cfg, model)

case "mimic":
return mimicProvider(cfg)
case "openai":
return openaiProvider(cfg, model)

case "whisper":
return whisperProvider(cfg)

case "azure-translator":
return azuretranslatorProvider(cfg, model)

case "deepl":
return deeplProvider(cfg, model)

case "custom":
return customProvider(cfg, model)

default:
return nil, errors.New("invalid provider type: " + cfg.Type)
}
Expand Down Expand Up @@ -214,6 +218,60 @@ func automatic1111Provider(cfg providerConfig, model string) (*automatic1111.Cli
return automatic1111.New(options...)
}

func azuretranslatorProvider(cfg providerConfig, model string) (*azuretranslator.Client, error) {
var options []azuretranslator.Option

if cfg.Token != "" {
options = append(options, azuretranslator.WithToken(cfg.Token))
}

if model != "" {
options = append(options, azuretranslator.WithLanguage(model))
}

return azuretranslator.New(cfg.URL, options...)
}

func coquiProvider(cfg providerConfig) (*coqui.Client, error) {
var options []coqui.Option

return coqui.New(cfg.URL, options...)
}

func customProvider(cfg providerConfig, model string) (*custom.Client, error) {
var options []custom.Option

return custom.New(cfg.URL, options...)
}

func deeplProvider(cfg providerConfig, model string) (*deepl.Client, error) {
var options []deepl.Option

if cfg.Token != "" {
options = append(options, deepl.WithToken(cfg.Token))
}

if model != "" {
options = append(options, deepl.WithLanguage(model))
}

return deepl.New(cfg.URL, options...)
}

func groqProvider(cfg providerConfig, model string) (*groq.Client, error) {
var options []groq.Option

if cfg.Token != "" {
options = append(options, groq.WithToken(cfg.Token))
}

if model != "" {
options = append(options, groq.WithModel(model))
}

return groq.New(options...)
}

func huggingfaceProvider(cfg providerConfig, model string) (*huggingface.Client, error) {
var options []huggingface.Option

Expand Down Expand Up @@ -244,32 +302,10 @@ func llamaProvider(cfg providerConfig, model string) (*llama.Client, error) {
return llama.New(cfg.URL, options...)
}

func ollamaProvider(cfg providerConfig, model string) (*ollama.Client, error) {
var options []ollama.Option

if model != "" {
options = append(options, ollama.WithModel(model))
}

return ollama.New(cfg.URL, options...)
}

func openaiProvider(cfg providerConfig, model string) (*openai.Client, error) {
var options []openai.Option

if cfg.URL != "" {
options = append(options, openai.WithURL(cfg.URL))
}

if cfg.Token != "" {
options = append(options, openai.WithToken(cfg.Token))
}

if model != "" {
options = append(options, openai.WithModel(model))
}
func mimicProvider(cfg providerConfig) (*mimic.Client, error) {
var options []mimic.Option

return openai.New(options...)
return mimic.New(cfg.URL, options...)
}

func mistralProvider(cfg providerConfig, model string) (*mistral.Client, error) {
Expand All @@ -286,70 +322,50 @@ func mistralProvider(cfg providerConfig, model string) (*mistral.Client, error)
return mistral.New(options...)
}

func groqProvider(cfg providerConfig, model string) (*groq.Client, error) {
var options []groq.Option

if cfg.Token != "" {
options = append(options, groq.WithToken(cfg.Token))
}

if model != "" {
options = append(options, groq.WithModel(model))
}

return groq.New(options...)
}

func coquiProvider(cfg providerConfig) (*coqui.Client, error) {
var options []coqui.Option

return coqui.New(cfg.URL, options...)
}

func mimicProvider(cfg providerConfig) (*mimic.Client, error) {
var options []mimic.Option
func mlxProvider(cfg providerConfig, model string) (*mlx.Client, error) {
var options []mlx.Option

return mimic.New(cfg.URL, options...)
}
// if model != "" {
// options = append(options, mlx.WithModel(model))
// }

func whisperProvider(cfg providerConfig) (*whisper.Client, error) {
var options []whisper.Option
options = append(options, mlx.WithModel("default_model"))

return whisper.New(cfg.URL, options...)
return mlx.New(cfg.URL, options...)
}

func azuretranslatorProvider(cfg providerConfig, model string) (*azuretranslator.Client, error) {
var options []azuretranslator.Option

if cfg.Token != "" {
options = append(options, azuretranslator.WithToken(cfg.Token))
}
func ollamaProvider(cfg providerConfig, model string) (*ollama.Client, error) {
var options []ollama.Option

if model != "" {
options = append(options, azuretranslator.WithLanguage(model))
options = append(options, ollama.WithModel(model))
}

return azuretranslator.New(cfg.URL, options...)
return ollama.New(cfg.URL, options...)
}

func deeplProvider(cfg providerConfig, model string) (*deepl.Client, error) {
var options []deepl.Option
func openaiProvider(cfg providerConfig, model string) (*openai.Client, error) {
var options []openai.Option

if cfg.URL != "" {
options = append(options, openai.WithURL(cfg.URL))
}

if cfg.Token != "" {
options = append(options, deepl.WithToken(cfg.Token))
options = append(options, openai.WithToken(cfg.Token))
}

if model != "" {
options = append(options, deepl.WithLanguage(model))
options = append(options, openai.WithModel(model))
}

return deepl.New(cfg.URL, options...)
return openai.New(options...)
}

func customProvider(cfg providerConfig, model string) (*custom.Client, error) {
var options []custom.Option
func whisperProvider(cfg providerConfig) (*whisper.Client, error) {
var options []whisper.Option

return custom.New(cfg.URL, options...)
return whisper.New(cfg.URL, options...)
}

func createCompleterAdapter(name string, completer provider.Completer) (adapter.Provider, error) {
Expand Down
17 changes: 17 additions & 0 deletions pkg/provider/mlx/client.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
package mlx

type Client struct {
*Completer
}

func New(url string, options ...Option) (*Client, error) {
c, err := NewCompleter(url, options...)

if err != nil {
return nil, err
}

return &Client{
Completer: c,
}, nil
}
31 changes: 31 additions & 0 deletions pkg/provider/mlx/completer.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package mlx

import (
"errors"
"strings"

"github.com/adrianliechti/llama/pkg/provider/openai"
)

type Completer = openai.Completer

func NewCompleter(url string, options ...Option) (*Completer, error) {
if url == "" {
return nil, errors.New("url is required")
}

url = strings.TrimRight(url, "/")
url = strings.TrimSuffix(url, "/v1")

c := &Config{
options: []openai.Option{
openai.WithURL(url + "/v1"),
},
}

for _, option := range options {
option(c)
}

return openai.NewCompleter(c.options...)
}
25 changes: 25 additions & 0 deletions pkg/provider/mlx/config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
package mlx

import (
"net/http"

"github.com/adrianliechti/llama/pkg/provider/openai"
)

type Config struct {
options []openai.Option
}

type Option func(*Config)

func WithClient(client *http.Client) Option {
return func(c *Config) {
c.options = append(c.options, openai.WithClient(client))
}
}

func WithModel(model string) Option {
return func(c *Config) {
c.options = append(c.options, openai.WithModel(model))
}
}
2 changes: 1 addition & 1 deletion pkg/provider/openai/completer.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ func (c *Completer) Complete(ctx context.Context, messages []provider.Message, o
stream, err := c.client.CreateChatCompletionStream(ctx, *req)

if err != nil {
convertError(err)
return nil, err
}

result := provider.Completion{
Expand Down

0 comments on commit 6e63e84

Please sign in to comment.