From ddb0efc3833ef33388d4ae1c965f33193d0df2f8 Mon Sep 17 00:00:00 2001 From: Adrian Liechti Date: Wed, 8 May 2024 10:43:48 +0200 Subject: [PATCH] ollama: ensure models on start --- pkg/provider/ollama/completer.go | 2 + pkg/provider/ollama/embedder.go | 2 + pkg/provider/ollama/model.go | 91 ++++++++++++++++++++++++++++++++ 3 files changed, 95 insertions(+) create mode 100644 pkg/provider/ollama/model.go diff --git a/pkg/provider/ollama/completer.go b/pkg/provider/ollama/completer.go index c27fe70..bd27c71 100644 --- a/pkg/provider/ollama/completer.go +++ b/pkg/provider/ollama/completer.go @@ -36,6 +36,8 @@ func NewCompleter(url string, options ...Option) (*Completer, error) { option(c) } + go c.ensureModel() + return &Completer{ Config: c, }, nil diff --git a/pkg/provider/ollama/embedder.go b/pkg/provider/ollama/embedder.go index f2458a9..1ccf841 100644 --- a/pkg/provider/ollama/embedder.go +++ b/pkg/provider/ollama/embedder.go @@ -31,6 +31,8 @@ func NewEmbedder(url string, options ...Option) (*Embedder, error) { option(c) } + go c.ensureModel() + return &Embedder{ Config: c, }, nil diff --git a/pkg/provider/ollama/model.go b/pkg/provider/ollama/model.go new file mode 100644 index 0000000..46fa344 --- /dev/null +++ b/pkg/provider/ollama/model.go @@ -0,0 +1,91 @@ +package ollama + +import ( + "bufio" + "encoding/json" + "errors" + "io" + "log/slog" + "net/http" + "net/url" +) + +func (c *Config) ensureModel() error { + body := ModelRequest{ + Name: c.model, + } + + u, _ := url.JoinPath(c.url, "/api/show") + resp, err := c.client.Post(u, "application/json", jsonReader(body)) + + if err != nil { + return err + } + + if resp.StatusCode == http.StatusOK { + return nil + } + + return c.pullModel() +} + +func (c *Config) pullModel() error { + body := PullRequest{ + Name: c.model, + Stream: true, + } + + u, _ := url.JoinPath(c.url, "/api/pull") + resp, err := c.client.Post(u, "application/json", jsonReader(body)) + + if err != nil { + return err + } + + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return errors.New("unable to pull model") + } + + reader := bufio.NewReader(resp.Body) + + for i := 0; ; i++ { + data, err := reader.ReadBytes('\n') + + if errors.Is(err, io.EOF) { + break + } + + if err != nil { + return err + } + + if len(data) == 0 { + continue + } + + var pull PullResponse + + if err := json.Unmarshal([]byte(data), &pull); err != nil { + return err + } + + slog.Info("download model", "model", c.model, "status", pull.Status) + } + + return nil +} + +type ModelRequest struct { + Name string `json:"name"` +} + +type PullRequest struct { + Name string `json:"name"` + Stream bool `json:"stream"` +} + +type PullResponse struct { + Status string `json:"status"` +}