-
Notifications
You must be signed in to change notification settings - Fork 7
/
helper_test.go
148 lines (125 loc) · 3.04 KB
/
helper_test.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
package shift_test
import (
"database/sql"
"database/sql/driver"
"flag"
"log"
"os"
"strconv"
"testing"
"github.com/stretchr/testify/require"
)
var schemas = []string{`
create temporary table users (
id bigint not null auto_increment,
name varchar(255) not null,
dob datetime not null,
amount varchar(255),
status tinyint not null,
created_at datetime not null,
updated_at datetime not null,
primary key (id)
);`, `
create temporary table events (
id bigint not null auto_increment,
foreign_id bigint not null,
timestamp datetime not null,
type tinyint not null,
metadata blob,
primary key (id)
);`, `
create temporary table usersStr (
id varchar(255) not null,
name varchar(255) not null,
dob datetime not null,
amount varchar(255),
status tinyint not null,
created_at datetime not null,
updated_at datetime not null,
primary key (id)
);`, `
create temporary table eventsStr (
id bigint not null auto_increment,
foreign_id varchar(255) not null,
timestamp datetime not null,
type tinyint not null,
metadata blob,
primary key (id)
);`, `
create temporary table tests (
id bigint not null auto_increment,
i1 bigint not null,
i2 varchar(255) not null,
i3 datetime not null,
u1 bool,
u2 varchar(255),
u3 datetime,
u4 varchar(255),
u5 binary(64),
status tinyint not null,
created_at datetime not null,
updated_at datetime not null,
primary key (id)
);`}
// TODO: Refactor this to use sqllite.
var dbTestURI = flag.String("db_test_base", "root@unix("+getSocketFile()+")/test?", "Test database uri")
func getSocketFile() string {
sock := "/tmp/mysql.sock"
if _, err := os.Stat(sock); os.IsNotExist(err) {
// try common linux/Ubuntu socket file location
return "/var/run/mysqld/mysqld.sock"
}
return sock
}
func connect() (*sql.DB, error) {
str := *dbTestURI + "parseTime=true&collation=utf8mb4_general_ci"
dbc, err := sql.Open("mysql", str)
if err != nil {
return nil, err
}
dbc.SetMaxOpenConns(1)
if _, err := dbc.Exec("set time_zone='+00:00';"); err != nil {
log.Fatalf("error setting db time_zone: %v", err)
}
return dbc, nil
}
func setup(t *testing.T) *sql.DB {
dbc, err := connect()
require.NoError(t, err)
t.Cleanup(func() { require.NoError(t, dbc.Close()) })
for _, s := range schemas {
_, err := dbc.Exec(s)
require.NoError(t, err)
}
return dbc
}
// Currency is a custom "currency" type stored a string in the DB.
type Currency struct {
Valid bool
Amount int64
}
func (c *Currency) Scan(src interface{}) error {
var s sql.NullString
if err := s.Scan(src); err != nil {
return err
}
if !s.Valid {
*c = Currency{
Valid: false,
Amount: 0,
}
return nil
}
i, err := strconv.ParseInt(s.String, 10, 64)
if err != nil {
return err
}
*c = Currency{
Valid: true,
Amount: i,
}
return nil
}
func (c Currency) Value() (driver.Value, error) {
return strconv.FormatInt(c.Amount, 10), nil
}