Skip to content
This repository was archived by the owner on Sep 21, 2021. It is now read-only.

Commit 7f5bb6b

Browse files
richo-stripericho
authored andcommitted
Add support for sqlx
1 parent 452e37e commit 7f5bb6b

File tree

1 file changed

+64
-3
lines changed

1 file changed

+64
-3
lines changed

safesql.go

Lines changed: 64 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,13 @@ import (
1919
"golang.org/x/tools/go/ssa/ssautil"
2020
)
2121

22+
type Config struct {
23+
Sqlx bool
24+
DatabaseSql bool
25+
}
26+
27+
const SQLX string = "github.com/jmoiron/sqlx"
28+
2229
func main() {
2330
var verbose, quiet bool
2431
flag.BoolVar(&verbose, "v", false, "Verbose mode")
@@ -28,6 +35,11 @@ func main() {
2835
flag.PrintDefaults()
2936
}
3037

38+
config := Config{
39+
Sqlx: false,
40+
DatabaseSql: false,
41+
}
42+
3143
flag.Parse()
3244
pkgs := flag.Args()
3345
if len(pkgs) == 0 {
@@ -38,21 +50,51 @@ func main() {
3850
c := loader.Config{
3951
FindPackage: FindPackage,
4052
}
41-
c.Import("database/sql")
4253
for _, pkg := range pkgs {
4354
c.Import(pkg)
4455
}
4556
p, err := c.Load()
57+
58+
imports := GetImports(p)
59+
60+
if _, exist := imports["database/sql"]; exist {
61+
if verbose {
62+
fmt.Println("Enabling support for database/sql")
63+
}
64+
config.DatabaseSql = true
65+
}
66+
67+
if _, exist := imports["github.com/jmoiron/sqlx"]; exist {
68+
if verbose {
69+
fmt.Println("Enabling support for sqlx")
70+
}
71+
config.Sqlx = true
72+
}
73+
74+
if !(config.Sqlx || config.DatabaseSql) {
75+
fmt.Printf("No packages in %v include a supported database driver", pkgs)
76+
os.Exit(2)
77+
}
78+
4679
if err != nil {
4780
fmt.Printf("error loading packages %v: %v\n", pkgs, err)
4881
os.Exit(2)
4982
}
83+
5084
s := ssautil.CreateProgram(p, 0)
5185
s.Build()
5286

53-
qms := FindQueryMethods(p.Package("database/sql").Pkg, s)
87+
qms := make([]*QueryMethod, 0)
88+
89+
if config.DatabaseSql {
90+
qms = append(qms, FindQueryMethods(p.Package("database/sql").Pkg, s)...)
91+
}
92+
if config.Sqlx {
93+
qms = append(qms, FindQueryMethods(p.Package(SQLX).Pkg, s)...)
94+
}
95+
5496
if verbose {
55-
fmt.Println("database/sql functions that accept queries:")
97+
fmt.Println("database driver functions that accept queries:")
5698
for _, m := range qms {
5799
fmt.Printf("- %s (param %d)\n", m.Func, m.Param)
58100
}
@@ -75,6 +117,7 @@ func main() {
75117
}
76118

77119
bad := FindNonConstCalls(res.CallGraph, qms)
120+
78121
if len(bad) == 0 {
79122
if !quiet {
80123
fmt.Println(`You're safe from SQL injection! Yay \o/`)
@@ -164,6 +207,17 @@ func FindMains(p *loader.Program, s *ssa.Program) []*ssa.Package {
164207
return mains
165208
}
166209

210+
func GetImports(p *loader.Program) map[string]interface{} {
211+
packages := make(map[string]interface{})
212+
for _, info := range p.AllPackages {
213+
// Invert the map so we can do lookups more easily
214+
if info.Importable {
215+
packages[info.Pkg.Path()] = nil
216+
}
217+
}
218+
return packages
219+
}
220+
167221
// FindNonConstCalls returns the set of callsites of the given set of methods
168222
// for which the "query" parameter is not a compile-time constant.
169223
func FindNonConstCalls(cg *callgraph.Graph, qms []*QueryMethod) []ssa.CallInstruction {
@@ -196,6 +250,13 @@ func FindNonConstCalls(cg *callgraph.Graph, qms []*QueryMethod) []ssa.CallInstru
196250
}
197251
v := args[m.Param]
198252
if _, ok := v.(*ssa.Const); !ok {
253+
// This is super lurky, but sqlx wants to hand query objects about under
254+
// the hood. We could do clever taint analysis, but it's easier
255+
// to just bless the innards of sqlx internally, and rely on it
256+
// to do Reasonable Things under the hood.
257+
if edge.Caller.Func.Pkg.Pkg.Path() == SQLX {
258+
continue
259+
}
199260
bad = append(bad, edge.Site)
200261
}
201262
}

0 commit comments

Comments
 (0)