diff --git a/safesql.go b/safesql.go index 3071879..fbc91e7 100644 --- a/safesql.go +++ b/safesql.go @@ -19,8 +19,16 @@ import ( "golang.org/x/tools/go/ssa/ssautil" ) +type Config struct { + Sqlx bool + DatabaseSql bool +} + +const SQLX string = "github.com/jmoiron/sqlx" + func main() { var verbose, quiet bool + // var use_sqlx bool = false flag.BoolVar(&verbose, "v", false, "Verbose mode") flag.BoolVar(&quiet, "q", false, "Only print on failure") flag.Usage = func() { @@ -28,6 +36,11 @@ func main() { flag.PrintDefaults() } + config := Config{ + Sqlx: false, + DatabaseSql: false, + } + flag.Parse() pkgs := flag.Args() if len(pkgs) == 0 { @@ -38,21 +51,53 @@ func main() { c := loader.Config{ FindPackage: FindPackage, } - c.Import("database/sql") for _, pkg := range pkgs { c.Import(pkg) } p, err := c.Load() + + imports := GetImports(p) + + if _, exist := imports["database/sql"]; exist { + if verbose { + fmt.Println("Enabling support for database/sql") + } + config.DatabaseSql = true + } + + if _, exist := imports["github.com/jmoiron/sqlx"]; exist { + if verbose { + fmt.Println("Enabling support for sqlx") + } + config.Sqlx = true + } + + if !(config.Sqlx || config.DatabaseSql) { + fmt.Printf("No packages in %v include a supported database driver", pkgs) + os.Exit(2) + } + if err != nil { fmt.Printf("error loading packages %v: %v\n", pkgs, err) os.Exit(2) } + + GetImports(p) + s := ssautil.CreateProgram(p, 0) s.Build() - qms := FindQueryMethods(p.Package("database/sql").Pkg, s) + qms := make([]*QueryMethod, 0) + + if config.DatabaseSql { + qms = append(qms, FindQueryMethods(p.Package("database/sql").Pkg, s)...) + } + if config.Sqlx { + qms = append(qms, FindQueryMethods(p.Package(SQLX).Pkg, s)...) + } + if verbose { - fmt.Println("database/sql functions that accept queries:") + fmt.Println("database driver functions that accept queries:") for _, m := range qms { fmt.Printf("- %s (param %d)\n", m.Func, m.Param) } @@ -75,6 +120,7 @@ func main() { } bad := FindNonConstCalls(res.CallGraph, qms) + if len(bad) == 0 { if !quiet { fmt.Println(`You're safe from SQL injection! Yay \o/`) @@ -164,6 +210,17 @@ func FindMains(p *loader.Program, s *ssa.Program) []*ssa.Package { return mains } +func GetImports(p *loader.Program) map[string]interface{} { + packages := make(map[string]interface{}) + for _, info := range p.AllPackages { + // Invert the map so we can do lookups more easily + if info.Importable { + packages[info.Pkg.Path()] = nil + } + } + return packages +} + // FindNonConstCalls returns the set of callsites of the given set of methods // for which the "query" parameter is not a compile-time constant. func FindNonConstCalls(cg *callgraph.Graph, qms []*QueryMethod) []ssa.CallInstruction { @@ -196,6 +253,13 @@ func FindNonConstCalls(cg *callgraph.Graph, qms []*QueryMethod) []ssa.CallInstru } v := args[m.Param] if _, ok := v.(*ssa.Const); !ok { + // This is super lurky, but sqlx wants to hand query objects about under + // the hood. We could do clever taint analysis, but it's easier + // to just bless the innards of sqlx internally, and rely on it + // to do Reasonable Things under the hood. + if edge.Caller.Func.Pkg.Pkg.Path() == SQLX { + continue + } bad = append(bad, edge.Site) } }