@@ -19,6 +19,13 @@ import (
19
19
"golang.org/x/tools/go/ssa/ssautil"
20
20
)
21
21
22
+ type Config struct {
23
+ Sqlx bool
24
+ DatabaseSql bool
25
+ }
26
+
27
+ const SQLX string = "github.com/jmoiron/sqlx"
28
+
22
29
func main () {
23
30
var verbose , quiet bool
24
31
flag .BoolVar (& verbose , "v" , false , "Verbose mode" )
@@ -28,6 +35,11 @@ func main() {
28
35
flag .PrintDefaults ()
29
36
}
30
37
38
+ config := Config {
39
+ Sqlx : false ,
40
+ DatabaseSql : false ,
41
+ }
42
+
31
43
flag .Parse ()
32
44
pkgs := flag .Args ()
33
45
if len (pkgs ) == 0 {
@@ -38,21 +50,51 @@ func main() {
38
50
c := loader.Config {
39
51
FindPackage : FindPackage ,
40
52
}
41
- c .Import ("database/sql" )
42
53
for _ , pkg := range pkgs {
43
54
c .Import (pkg )
44
55
}
45
56
p , err := c .Load ()
57
+
58
+ imports := GetImports (p )
59
+
60
+ if _ , exist := imports ["database/sql" ]; exist {
61
+ if verbose {
62
+ fmt .Println ("Enabling support for database/sql" )
63
+ }
64
+ config .DatabaseSql = true
65
+ }
66
+
67
+ if _ , exist := imports ["github.com/jmoiron/sqlx" ]; exist {
68
+ if verbose {
69
+ fmt .Println ("Enabling support for sqlx" )
70
+ }
71
+ config .Sqlx = true
72
+ }
73
+
74
+ if ! (config .Sqlx || config .DatabaseSql ) {
75
+ fmt .Printf ("No packages in %v include a supported database driver" , pkgs )
76
+ os .Exit (2 )
77
+ }
78
+
46
79
if err != nil {
47
80
fmt .Printf ("error loading packages %v: %v\n " , pkgs , err )
48
81
os .Exit (2 )
49
82
}
83
+
50
84
s := ssautil .CreateProgram (p , 0 )
51
85
s .Build ()
52
86
53
- qms := FindQueryMethods (p .Package ("database/sql" ).Pkg , s )
87
+ qms := make ([]* QueryMethod , 0 )
88
+
89
+ if config .DatabaseSql {
90
+ qms = append (qms , FindQueryMethods (p .Package ("database/sql" ).Pkg , s )... )
91
+ }
92
+ if config .Sqlx {
93
+ qms = append (qms , FindQueryMethods (p .Package (SQLX ).Pkg , s )... )
94
+ }
95
+
54
96
if verbose {
55
- fmt .Println ("database/sql functions that accept queries:" )
97
+ fmt .Println ("database driver functions that accept queries:" )
56
98
for _ , m := range qms {
57
99
fmt .Printf ("- %s (param %d)\n " , m .Func , m .Param )
58
100
}
@@ -75,6 +117,7 @@ func main() {
75
117
}
76
118
77
119
bad := FindNonConstCalls (res .CallGraph , qms )
120
+
78
121
if len (bad ) == 0 {
79
122
if ! quiet {
80
123
fmt .Println (`You're safe from SQL injection! Yay \o/` )
@@ -164,6 +207,17 @@ func FindMains(p *loader.Program, s *ssa.Program) []*ssa.Package {
164
207
return mains
165
208
}
166
209
210
+ func GetImports (p * loader.Program ) map [string ]interface {} {
211
+ packages := make (map [string ]interface {})
212
+ for _ , info := range p .AllPackages {
213
+ // Invert the map so we can do lookups more easily
214
+ if info .Importable {
215
+ packages [info .Pkg .Path ()] = nil
216
+ }
217
+ }
218
+ return packages
219
+ }
220
+
167
221
// FindNonConstCalls returns the set of callsites of the given set of methods
168
222
// for which the "query" parameter is not a compile-time constant.
169
223
func FindNonConstCalls (cg * callgraph.Graph , qms []* QueryMethod ) []ssa.CallInstruction {
@@ -196,6 +250,13 @@ func FindNonConstCalls(cg *callgraph.Graph, qms []*QueryMethod) []ssa.CallInstru
196
250
}
197
251
v := args [m .Param ]
198
252
if _ , ok := v .(* ssa.Const ); ! ok {
253
+ // This is super lurky, but sqlx wants to hand query objects about under
254
+ // the hood. We could do clever taint analysis, but it's easier
255
+ // to just bless the innards of sqlx internally, and rely on it
256
+ // to do Reasonable Things under the hood.
257
+ if edge .Caller .Func .Pkg .Pkg .Path () == SQLX {
258
+ continue
259
+ }
199
260
bad = append (bad , edge .Site )
200
261
}
201
262
}
0 commit comments