Skip to content

Commit

Permalink
ollama: ensure models on start
Browse files Browse the repository at this point in the history
  • Loading branch information
adrianliechti committed May 8, 2024
1 parent 398cc57 commit ddb0efc
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 0 deletions.
2 changes: 2 additions & 0 deletions pkg/provider/ollama/completer.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@ func NewCompleter(url string, options ...Option) (*Completer, error) {
option(c)
}

go c.ensureModel()

return &Completer{
Config: c,
}, nil
Expand Down
2 changes: 2 additions & 0 deletions pkg/provider/ollama/embedder.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ func NewEmbedder(url string, options ...Option) (*Embedder, error) {
option(c)
}

go c.ensureModel()

return &Embedder{
Config: c,
}, nil
Expand Down
91 changes: 91 additions & 0 deletions pkg/provider/ollama/model.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
package ollama

import (
"bufio"
"encoding/json"
"errors"
"io"
"log/slog"
"net/http"
"net/url"
)

func (c *Config) ensureModel() error {
body := ModelRequest{
Name: c.model,
}

u, _ := url.JoinPath(c.url, "/api/show")
resp, err := c.client.Post(u, "application/json", jsonReader(body))

if err != nil {
return err
}

if resp.StatusCode == http.StatusOK {
return nil
}

return c.pullModel()
}

func (c *Config) pullModel() error {
body := PullRequest{
Name: c.model,
Stream: true,
}

u, _ := url.JoinPath(c.url, "/api/pull")
resp, err := c.client.Post(u, "application/json", jsonReader(body))

if err != nil {
return err
}

defer resp.Body.Close()

if resp.StatusCode != http.StatusOK {
return errors.New("unable to pull model")
}

reader := bufio.NewReader(resp.Body)

for i := 0; ; i++ {
data, err := reader.ReadBytes('\n')

if errors.Is(err, io.EOF) {
break
}

if err != nil {
return err
}

if len(data) == 0 {
continue
}

var pull PullResponse

if err := json.Unmarshal([]byte(data), &pull); err != nil {
return err
}

slog.Info("download model", "model", c.model, "status", pull.Status)
}

return nil
}

type ModelRequest struct {
Name string `json:"name"`
}

type PullRequest struct {
Name string `json:"name"`
Stream bool `json:"stream"`
}

type PullResponse struct {
Status string `json:"status"`
}

0 comments on commit ddb0efc

Please sign in to comment.