-
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathclient.go
141 lines (122 loc) · 3.39 KB
/
client.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
// client.go
//
// Copyright (c) 2023 Junpei Kawamoto
//
// This software is released under the MIT License.
//
// http://opensource.org/licenses/mit-license.php
package main
import (
"context"
"encoding/hex"
"errors"
"fmt"
"io"
"mime"
"net/http"
"os"
"path/filepath"
"strings"
"github.com/cheggaaa/pb/v3"
"github.com/jkawamoto/go-civitai/client"
"github.com/jkawamoto/go-civitai/client/operations"
"github.com/jkawamoto/go-civitai/models"
"github.com/zeebo/blake3"
"golang.org/x/net/context/ctxhttp"
)
var (
ErrFileNotFound = errors.New("model files are not found in this version")
ErrFileHashNotMatch = errors.New("file hash doesn't match")
ErrGetFailure = errors.New("failed to get a file")
ErrNoFilename = errors.New("failed to get a filename")
)
type Client struct {
clientService operations.ClientService
httpClient *http.Client
PreferredFormat string
}
func NewClient(preferredFormat string) Client {
return Client{
clientService: client.Default.Operations,
PreferredFormat: preferredFormat,
}
}
func (cli Client) GetModelVersion(ctx context.Context, hash string) (*models.ModelVersion, error) {
res, err := cli.clientService.GetModelVersionByHash(
operations.NewGetModelVersionByHashParamsWithContext(ctx).WithHTTPClient(cli.httpClient).WithHash(hash))
if err != nil {
return nil, err
}
return res.GetPayload(), nil
}
func (cli Client) GetModel(ctx context.Context, id int64) (*models.Model, error) {
res, err := cli.clientService.GetModel(
operations.NewGetModelParamsWithContext(ctx).WithHTTPClient(cli.httpClient).WithModelID(id))
if err != nil {
return nil, err
}
return res.GetPayload(), nil
}
// Download gets a model file associated with the given version and stores it into the given directory.
func (cli Client) Download(ctx context.Context, ver *models.ModelVersion, dir string) (err error) {
var file *models.File
for _, f := range ver.Files {
if strings.ToLower(f.Format) == cli.PreferredFormat {
file = f
}
if f.Primary && file == nil {
file = f
}
}
if file == nil {
return ErrFileNotFound
}
res, err := ctxhttp.Get(ctx, cli.httpClient, file.DownloadURL)
if err != nil {
return err
}
defer func() {
if _, e := io.Copy(io.Discard, res.Body); e != nil {
err = errors.Join(err, e)
}
err = errors.Join(err, res.Body.Close())
}()
if res.StatusCode != http.StatusOK {
return fmt.Errorf("%w: %v", ErrGetFailure, res.Status)
}
_, params, err := mime.ParseMediaType(res.Header.Get("Content-Disposition"))
if err != nil {
return errors.Join(ErrNoFilename, err)
}
name := params["filename"]
bar := pb.New(int(file.SizeKB * 1024))
bar.Set(pb.SIBytesPrefix, true)
bar.Set("prefix", filepath.Base(name)+" ")
bar.Start()
defer bar.Finish()
hash := blake3.New()
dest := filepath.Join(dir, name)
err = writeFile(dest, io.TeeReader(bar.NewProxyReader(res.Body), hash))
if err != nil {
return err
}
if hex.EncodeToString(hash.Sum(nil)) != strings.ToLower(file.Hashes.BLAKE3) {
// if hash doesn't match, remove the downloaded file.
return errors.Join(ErrFileHashNotMatch, os.Remove(dest))
}
return nil
}
func writeFile(name string, r io.Reader) (err error) {
if _, err = os.Stat(name); err == nil {
return fmt.Errorf("%v already exists: %w", name, os.ErrExist)
}
f, err := os.Create(name)
if err != nil {
return err
}
defer func() {
err = errors.Join(err, f.Close())
}()
_, err = io.Copy(f, r)
return err
}