diff --git a/config/config_provider.go b/config/config_provider.go index 3d34bf9..c344eb8 100644 --- a/config/config_provider.go +++ b/config/config_provider.go @@ -17,6 +17,7 @@ import ( "github.com/adrianliechti/llama/pkg/provider/llama" "github.com/adrianliechti/llama/pkg/provider/mimic" "github.com/adrianliechti/llama/pkg/provider/mistral" + "github.com/adrianliechti/llama/pkg/provider/mlx" "github.com/adrianliechti/llama/pkg/provider/ollama" "github.com/adrianliechti/llama/pkg/provider/openai" "github.com/adrianliechti/llama/pkg/provider/whisper" @@ -142,6 +143,21 @@ func createProvider(cfg providerConfig, model string) (any, error) { case "automatic1111": return automatic1111Provider(cfg, model) + case "azure-translator": + return azuretranslatorProvider(cfg, model) + + case "coqui": + return coquiProvider(cfg) + + case "custom": + return customProvider(cfg, model) + + case "deepl": + return deeplProvider(cfg, model) + + case "groq": + return groqProvider(cfg, model) + case "huggingface": return huggingfaceProvider(cfg, model) @@ -151,36 +167,24 @@ func createProvider(cfg providerConfig, model string) (any, error) { case "llama": return llamaProvider(cfg, model) - case "ollama": - return ollamaProvider(cfg, model) - - case "openai": - return openaiProvider(cfg, model) + case "mimic": + return mimicProvider(cfg) case "mistral": return mistralProvider(cfg, model) - case "groq": - return groqProvider(cfg, model) + case "mlx": + return mlxProvider(cfg, model) - case "coqui": - return coquiProvider(cfg) + case "ollama": + return ollamaProvider(cfg, model) - case "mimic": - return mimicProvider(cfg) + case "openai": + return openaiProvider(cfg, model) case "whisper": return whisperProvider(cfg) - case "azure-translator": - return azuretranslatorProvider(cfg, model) - - case "deepl": - return deeplProvider(cfg, model) - - case "custom": - return customProvider(cfg, model) - default: return nil, errors.New("invalid provider type: " + cfg.Type) } @@ -214,6 +218,60 @@ func automatic1111Provider(cfg providerConfig, model string) (*automatic1111.Cli return automatic1111.New(options...) } +func azuretranslatorProvider(cfg providerConfig, model string) (*azuretranslator.Client, error) { + var options []azuretranslator.Option + + if cfg.Token != "" { + options = append(options, azuretranslator.WithToken(cfg.Token)) + } + + if model != "" { + options = append(options, azuretranslator.WithLanguage(model)) + } + + return azuretranslator.New(cfg.URL, options...) +} + +func coquiProvider(cfg providerConfig) (*coqui.Client, error) { + var options []coqui.Option + + return coqui.New(cfg.URL, options...) +} + +func customProvider(cfg providerConfig, model string) (*custom.Client, error) { + var options []custom.Option + + return custom.New(cfg.URL, options...) +} + +func deeplProvider(cfg providerConfig, model string) (*deepl.Client, error) { + var options []deepl.Option + + if cfg.Token != "" { + options = append(options, deepl.WithToken(cfg.Token)) + } + + if model != "" { + options = append(options, deepl.WithLanguage(model)) + } + + return deepl.New(cfg.URL, options...) +} + +func groqProvider(cfg providerConfig, model string) (*groq.Client, error) { + var options []groq.Option + + if cfg.Token != "" { + options = append(options, groq.WithToken(cfg.Token)) + } + + if model != "" { + options = append(options, groq.WithModel(model)) + } + + return groq.New(options...) +} + func huggingfaceProvider(cfg providerConfig, model string) (*huggingface.Client, error) { var options []huggingface.Option @@ -244,32 +302,10 @@ func llamaProvider(cfg providerConfig, model string) (*llama.Client, error) { return llama.New(cfg.URL, options...) } -func ollamaProvider(cfg providerConfig, model string) (*ollama.Client, error) { - var options []ollama.Option - - if model != "" { - options = append(options, ollama.WithModel(model)) - } - - return ollama.New(cfg.URL, options...) -} - -func openaiProvider(cfg providerConfig, model string) (*openai.Client, error) { - var options []openai.Option - - if cfg.URL != "" { - options = append(options, openai.WithURL(cfg.URL)) - } - - if cfg.Token != "" { - options = append(options, openai.WithToken(cfg.Token)) - } - - if model != "" { - options = append(options, openai.WithModel(model)) - } +func mimicProvider(cfg providerConfig) (*mimic.Client, error) { + var options []mimic.Option - return openai.New(options...) + return mimic.New(cfg.URL, options...) } func mistralProvider(cfg providerConfig, model string) (*mistral.Client, error) { @@ -286,70 +322,50 @@ func mistralProvider(cfg providerConfig, model string) (*mistral.Client, error) return mistral.New(options...) } -func groqProvider(cfg providerConfig, model string) (*groq.Client, error) { - var options []groq.Option - - if cfg.Token != "" { - options = append(options, groq.WithToken(cfg.Token)) - } - - if model != "" { - options = append(options, groq.WithModel(model)) - } - - return groq.New(options...) -} - -func coquiProvider(cfg providerConfig) (*coqui.Client, error) { - var options []coqui.Option - - return coqui.New(cfg.URL, options...) -} - -func mimicProvider(cfg providerConfig) (*mimic.Client, error) { - var options []mimic.Option +func mlxProvider(cfg providerConfig, model string) (*mlx.Client, error) { + var options []mlx.Option - return mimic.New(cfg.URL, options...) -} + // if model != "" { + // options = append(options, mlx.WithModel(model)) + // } -func whisperProvider(cfg providerConfig) (*whisper.Client, error) { - var options []whisper.Option + options = append(options, mlx.WithModel("default_model")) - return whisper.New(cfg.URL, options...) + return mlx.New(cfg.URL, options...) } -func azuretranslatorProvider(cfg providerConfig, model string) (*azuretranslator.Client, error) { - var options []azuretranslator.Option - - if cfg.Token != "" { - options = append(options, azuretranslator.WithToken(cfg.Token)) - } +func ollamaProvider(cfg providerConfig, model string) (*ollama.Client, error) { + var options []ollama.Option if model != "" { - options = append(options, azuretranslator.WithLanguage(model)) + options = append(options, ollama.WithModel(model)) } - return azuretranslator.New(cfg.URL, options...) + return ollama.New(cfg.URL, options...) } -func deeplProvider(cfg providerConfig, model string) (*deepl.Client, error) { - var options []deepl.Option +func openaiProvider(cfg providerConfig, model string) (*openai.Client, error) { + var options []openai.Option + + if cfg.URL != "" { + options = append(options, openai.WithURL(cfg.URL)) + } if cfg.Token != "" { - options = append(options, deepl.WithToken(cfg.Token)) + options = append(options, openai.WithToken(cfg.Token)) } if model != "" { - options = append(options, deepl.WithLanguage(model)) + options = append(options, openai.WithModel(model)) } - return deepl.New(cfg.URL, options...) + return openai.New(options...) } -func customProvider(cfg providerConfig, model string) (*custom.Client, error) { - var options []custom.Option +func whisperProvider(cfg providerConfig) (*whisper.Client, error) { + var options []whisper.Option - return custom.New(cfg.URL, options...) + return whisper.New(cfg.URL, options...) } func createCompleterAdapter(name string, completer provider.Completer) (adapter.Provider, error) { diff --git a/pkg/provider/mlx/client.go b/pkg/provider/mlx/client.go new file mode 100644 index 0000000..8703499 --- /dev/null +++ b/pkg/provider/mlx/client.go @@ -0,0 +1,17 @@ +package mlx + +type Client struct { + *Completer +} + +func New(url string, options ...Option) (*Client, error) { + c, err := NewCompleter(url, options...) + + if err != nil { + return nil, err + } + + return &Client{ + Completer: c, + }, nil +} diff --git a/pkg/provider/mlx/completer.go b/pkg/provider/mlx/completer.go new file mode 100644 index 0000000..b084360 --- /dev/null +++ b/pkg/provider/mlx/completer.go @@ -0,0 +1,31 @@ +package mlx + +import ( + "errors" + "strings" + + "github.com/adrianliechti/llama/pkg/provider/openai" +) + +type Completer = openai.Completer + +func NewCompleter(url string, options ...Option) (*Completer, error) { + if url == "" { + return nil, errors.New("url is required") + } + + url = strings.TrimRight(url, "/") + url = strings.TrimSuffix(url, "/v1") + + c := &Config{ + options: []openai.Option{ + openai.WithURL(url + "/v1"), + }, + } + + for _, option := range options { + option(c) + } + + return openai.NewCompleter(c.options...) +} diff --git a/pkg/provider/mlx/config.go b/pkg/provider/mlx/config.go new file mode 100644 index 0000000..c46a36a --- /dev/null +++ b/pkg/provider/mlx/config.go @@ -0,0 +1,25 @@ +package mlx + +import ( + "net/http" + + "github.com/adrianliechti/llama/pkg/provider/openai" +) + +type Config struct { + options []openai.Option +} + +type Option func(*Config) + +func WithClient(client *http.Client) Option { + return func(c *Config) { + c.options = append(c.options, openai.WithClient(client)) + } +} + +func WithModel(model string) Option { + return func(c *Config) { + c.options = append(c.options, openai.WithModel(model)) + } +} diff --git a/pkg/provider/openai/completer.go b/pkg/provider/openai/completer.go index 304afd2..10109e5 100644 --- a/pkg/provider/openai/completer.go +++ b/pkg/provider/openai/completer.go @@ -71,7 +71,7 @@ func (c *Completer) Complete(ctx context.Context, messages []provider.Message, o stream, err := c.client.CreateChatCompletionStream(ctx, *req) if err != nil { - convertError(err) + return nil, err } result := provider.Completion{