-
Notifications
You must be signed in to change notification settings - Fork 18
/
s3update.go
177 lines (148 loc) · 4.87 KB
/
s3update.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
package s3update
import (
"fmt"
"io"
"io/ioutil"
"os"
"runtime"
"strconv"
"strings"
"syscall"
"time"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/s3"
"github.com/mitchellh/ioprogress"
)
type Updater struct {
// CurrentVersion represents the current binary version.
// This is generally set at the compilation time with -ldflags "-X main.Version=42"
// See the README for additional information
CurrentVersion string
// S3Bucket represents the S3 bucket containing the different files used by s3update.
S3Bucket string
// S3Region represents the S3 region you want to work in.
S3Region string
// S3ReleaseKey represents the raw key on S3 to download new versions.
// The value can be something like `cli/releases/cli-{{OS}}-{{ARCH}}`
S3ReleaseKey string
// S3VersionKey represents the key on S3 to download the current version
S3VersionKey string
// AWSCredentials represents the config to use to connect to s3
AWSCredentials *credentials.Credentials
}
// validate ensures every required fields is correctly set. Otherwise and error is returned.
func (u Updater) validate() error {
if u.CurrentVersion == "" {
return fmt.Errorf("no version set")
}
if u.S3Bucket == "" {
return fmt.Errorf("no bucket set")
}
if u.S3Region == "" {
return fmt.Errorf("no s3 region")
}
if u.S3ReleaseKey == "" {
return fmt.Errorf("no s3ReleaseKey set")
}
if u.S3VersionKey == "" {
return fmt.Errorf("no s3VersionKey set")
}
return nil
}
// AutoUpdate runs synchronously a verification to ensure the binary is up-to-date.
// If a new version gets released, the download will happen automatically
// It's possible to bypass this mechanism by setting the S3UPDATE_DISABLED environment variable.
func AutoUpdate(u Updater) error {
if os.Getenv("S3UPDATE_DISABLED") != "" {
fmt.Println("s3update: autoupdate disabled")
return nil
}
if err := u.validate(); err != nil {
fmt.Printf("s3update: %s - skipping auto update\n", err.Error())
return err
}
return runAutoUpdate(u)
}
// generateS3ReleaseKey dynamically builds the S3 key depending on the os and architecture.
func generateS3ReleaseKey(path string) string {
path = strings.Replace(path, "{{OS}}", runtime.GOOS, -1)
path = strings.Replace(path, "{{ARCH}}", runtime.GOARCH, -1)
return path
}
func runAutoUpdate(u Updater) error {
localVersion, err := strconv.ParseInt(u.CurrentVersion, 10, 64)
if err != nil || localVersion == 0 {
return fmt.Errorf("invalid local version")
}
svc := s3.New(session.New(), &aws.Config{
Region: aws.String(u.S3Region),
Credentials: u.AWSCredentials,
})
resp, err := svc.GetObject(&s3.GetObjectInput{Bucket: aws.String(u.S3Bucket), Key: aws.String(u.S3VersionKey)})
if err != nil {
return err
}
defer resp.Body.Close()
b, err := ioutil.ReadAll(resp.Body)
if err != nil {
return err
}
remoteVersion, err := strconv.ParseInt(string(b), 10, 64)
if err != nil || remoteVersion == 0 {
return fmt.Errorf("invalid remote version")
}
fmt.Printf("s3update: Local Version %d - Remote Version: %d\n", localVersion, remoteVersion)
if localVersion < remoteVersion {
fmt.Printf("s3update: version outdated ... \n")
s3Key := generateS3ReleaseKey(u.S3ReleaseKey)
resp, err := svc.GetObject(&s3.GetObjectInput{Bucket: aws.String(u.S3Bucket), Key: aws.String(s3Key)})
if err != nil {
return err
}
defer resp.Body.Close()
progressR := &ioprogress.Reader{
Reader: resp.Body,
Size: *resp.ContentLength,
DrawInterval: 500 * time.Millisecond,
DrawFunc: ioprogress.DrawTerminalf(os.Stdout, func(progress, total int64) string {
bar := ioprogress.DrawTextFormatBar(40)
return fmt.Sprintf("%s %20s", bar(progress, total), ioprogress.DrawTextFormatBytes(progress, total))
}),
}
dest, err := os.Executable()
if err != nil {
return err
}
// Move the old version to a backup path that we can recover from
// in case the upgrade fails
destBackup := dest + ".bak"
if _, err := os.Stat(dest); err == nil {
os.Rename(dest, destBackup)
}
// Use the same flags that ioutil.WriteFile uses
f, err := os.OpenFile(dest, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0755)
if err != nil {
os.Rename(destBackup, dest)
return err
}
defer f.Close()
fmt.Printf("s3update: downloading new version to %s\n", dest)
if _, err := io.Copy(f, progressR); err != nil {
os.Rename(destBackup, dest)
return err
}
// The file must be closed already so we can execute it in the next step
f.Close()
// Removing backup
os.Remove(destBackup)
fmt.Printf("s3update: updated with success to version %d\nRestarting application\n", remoteVersion)
// The update completed, we can now restart the application without requiring any user action.
if err := syscall.Exec(dest, os.Args, os.Environ()); err != nil {
return err
}
os.Exit(0)
}
return nil
}