From 6bf8ef18ef08b46fcbe5e19af18335978b0ee35b Mon Sep 17 00:00:00 2001 From: Sergey Lanzman Date: Mon, 6 Mar 2017 01:17:02 +0200 Subject: [PATCH 1/3] add gorm and support sqlx --- safesql.go | 109 ++++++++++++++++++++++++++++++++++++++++++++++------- 1 file changed, 95 insertions(+), 14 deletions(-) diff --git a/safesql.go b/safesql.go index 3071879..f26d65c 100644 --- a/safesql.go +++ b/safesql.go @@ -9,6 +9,7 @@ import ( "go/build" "go/types" "os" + "path/filepath" "strings" @@ -19,6 +20,27 @@ import ( "golang.org/x/tools/go/ssa/ssautil" ) +type sqlPackage struct { + packageName string + paramNames []string + enable bool +} + +var sqlPackages = []sqlPackage{ + { + packageName: "database/sql", + paramNames: []string{"query"}, + }, + { + packageName: "github.com/jinzhu/gorm", + paramNames: []string{"sql", "query"}, + }, + { + packageName: "github.com/jmoiron/sqlx", + paramNames: []string{"query"}, + }, +} + func main() { var verbose, quiet bool flag.BoolVar(&verbose, "v", false, "Verbose mode") @@ -38,21 +60,45 @@ func main() { c := loader.Config{ FindPackage: FindPackage, } - c.Import("database/sql") for _, pkg := range pkgs { c.Import(pkg) } p, err := c.Load() + + imports := getImports(p) + existOne := false + for i := range sqlPackages { + if _, exist := imports[sqlPackages[i].packageName]; exist { + if verbose { + fmt.Printf("Enabling support for %s\n", sqlPackages[i].packageName) + } + sqlPackages[i].enable = true + existOne = true + } + } + if !existOne { + fmt.Printf("No packages in %v include a supported database driver", pkgs) + os.Exit(2) + } + if err != nil { fmt.Printf("error loading packages %v: %v\n", pkgs, err) os.Exit(2) } + s := ssautil.CreateProgram(p, 0) s.Build() - qms := FindQueryMethods(p.Package("database/sql").Pkg, s) + qms := make([]*QueryMethod, 0) + + for i := range sqlPackages { + if sqlPackages[i].enable { + qms = append(qms, FindQueryMethods(sqlPackages[i], p.Package(sqlPackages[i].packageName).Pkg, s)...) + } + } + if verbose { - fmt.Println("database/sql functions that accept queries:") + fmt.Println("database driver functions that accept queries:") for _, m := range qms { fmt.Printf("- %s (param %d)\n", m.Func, m.Param) } @@ -75,6 +121,7 @@ func main() { } bad := FindNonConstCalls(res.CallGraph, qms) + if len(bad) == 0 { if !quiet { fmt.Println(`You're safe from SQL injection! Yay \o/`) @@ -82,14 +129,19 @@ func main() { return } - fmt.Printf("Found %d potentially unsafe SQL statements:\n", len(bad)) + if verbose { + fmt.Printf("Found %d potentially unsafe SQL statements:\n", len(bad)) + } + for _, ci := range bad { pos := p.Fset.Position(ci.Pos()) fmt.Printf("- %s\n", pos) } - fmt.Println("Please ensure that all SQL queries you use are compile-time constants.") - fmt.Println("You should always use parameterized queries or prepared statements") - fmt.Println("instead of building queries from strings.") + if verbose { + fmt.Println("Please ensure that all SQL queries you use are compile-time constants.") + fmt.Println("You should always use parameterized queries or prepared statements") + fmt.Println("instead of building queries from strings.") + } os.Exit(1) } @@ -104,7 +156,7 @@ type QueryMethod struct { // FindQueryMethods locates all methods in the given package (assumed to be // package database/sql) with a string parameter named "query". -func FindQueryMethods(sql *types.Package, ssa *ssa.Program) []*QueryMethod { +func FindQueryMethods(sqlPackages sqlPackage, sql *types.Package, ssa *ssa.Program) []*QueryMethod { methods := make([]*QueryMethod, 0) scope := sql.Scope() for _, name := range scope.Names() { @@ -122,7 +174,7 @@ func FindQueryMethods(sql *types.Package, ssa *ssa.Program) []*QueryMethod { continue } s := m.Type().(*types.Signature) - if num, ok := FuncHasQuery(s); ok { + if num, ok := FuncHasQuery(sqlPackages, s); ok { methods = append(methods, &QueryMethod{ Func: m, SSA: ssa.FuncValue(m), @@ -135,16 +187,16 @@ func FindQueryMethods(sql *types.Package, ssa *ssa.Program) []*QueryMethod { return methods } -var stringType types.Type = types.Typ[types.String] - // FuncHasQuery returns the offset of the string parameter named "query", or // none if no such parameter exists. -func FuncHasQuery(s *types.Signature) (offset int, ok bool) { +func FuncHasQuery(sqlPackages sqlPackage, s *types.Signature) (offset int, ok bool) { params := s.Params() for i := 0; i < params.Len(); i++ { v := params.At(i) - if v.Name() == "query" && v.Type() == stringType { - return i, true + for _, paramName := range sqlPackages.paramNames { + if v.Name() == paramName { + return i, true + } } } return 0, false @@ -164,6 +216,16 @@ func FindMains(p *loader.Program, s *ssa.Program) []*ssa.Package { return mains } +func getImports(p *loader.Program) map[string]interface{} { + pkgs := make(map[string]interface{}) + for _, pkg := range p.AllPackages { + if pkg.Importable { + pkgs[pkg.Pkg.Path()] = nil + } + } + return pkgs +} + // FindNonConstCalls returns the set of callsites of the given set of methods // for which the "query" parameter is not a compile-time constant. func FindNonConstCalls(cg *callgraph.Graph, qms []*QueryMethod) []ssa.CallInstruction { @@ -186,6 +248,18 @@ func FindNonConstCalls(cg *callgraph.Graph, qms []*QueryMethod) []ssa.CallInstru if _, ok := okFuncs[edge.Site.Parent()]; ok { continue } + + isInternalSQLPkg := false + for _, pkg := range sqlPackages { + if pkg.packageName == edge.Caller.Func.Pkg.Pkg.Path() { + isInternalSQLPkg = true + break + } + } + if isInternalSQLPkg { + continue + } + cc := edge.Site.Common() args := cc.Args // The first parameter is occasionally the receiver. @@ -195,7 +269,14 @@ func FindNonConstCalls(cg *callgraph.Graph, qms []*QueryMethod) []ssa.CallInstru panic("arg count mismatch") } v := args[m.Param] + if _, ok := v.(*ssa.Const); !ok { + if inter, ok := v.(*ssa.MakeInterface); ok && types.IsInterface(v.(*ssa.MakeInterface).Type()) { + if inter.X.Referrers() == nil || inter.X.Type() != types.Typ[types.String] { + continue + } + } + bad = append(bad, edge.Site) } } From b1a8e6337aa1223fadef94e5fd348a51f9355c8a Mon Sep 17 00:00:00 2001 From: Sergey Lanzman Date: Mon, 6 Mar 2017 01:23:38 +0200 Subject: [PATCH 2/3] Update README.md --- README.md | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index c3e7478..b36f422 100644 --- a/README.md +++ b/README.md @@ -31,8 +31,8 @@ How does it work? ----------------- SafeSQL uses the static analysis utilities in [go/tools][tools] to search for -all call sites of each of the `query` functions in package [database/sql][sql] -(i.e., functions which accept a `string` parameter named `query`). It then makes +all call sites of each of the `query` functions in packages ([database/sql][sql],[github.com/jinzhu/gorm][gorm],[github.com/jmoiron/sqlx][sqlx]) +(i.e., functions which accept a parameter named `query`,`sql`). It then makes sure that every such call site uses a query that is a compile-time constant. The principle behind SafeSQL's safety guarantees is that queries that are @@ -44,6 +44,8 @@ will not be allowed. [tools]: https://godoc.org/golang.org/x/tools/go [sql]: http://golang.org/pkg/database/sql/ +[sqlx]: https://github.com/jmoiron/sqlx +[gorm]: https://github.com/jinzhu/gorm False positives --------------- @@ -66,8 +68,6 @@ a fundamental limitation: SafeSQL could recursively trace the `query` argument through every intervening helper function to ensure that its argument is always constant, but this code has yet to be written. -If you use a wrapper for `database/sql` (e.g., [`sqlx`][sqlx]), it's likely -SafeSQL will not work for you because of this. The second sort of false positive is based on a limitation in the sort of analysis SafeSQL performs: there are many safe SQL statements which are not @@ -76,4 +76,3 @@ static analysis techniques (such as taint analysis) or user-provided safety annotations would be able to reduce the number of false positives, but this is expected to be a significant undertaking. -[sqlx]: https://github.com/jmoiron/sqlx From a7e6848e8f2121c48c095a012370644ef5f810b9 Mon Sep 17 00:00:00 2001 From: Sergey Lanzman Date: Thu, 21 Dec 2017 20:44:53 +0200 Subject: [PATCH 3/3] check error after load --- safesql.go | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/safesql.go b/safesql.go index f26d65c..adf8bb8 100644 --- a/safesql.go +++ b/safesql.go @@ -65,6 +65,11 @@ func main() { } p, err := c.Load() + if err != nil { + fmt.Printf("error loading packages %v: %v\n", pkgs, err) + os.Exit(2) + } + imports := getImports(p) existOne := false for i := range sqlPackages { @@ -81,11 +86,6 @@ func main() { os.Exit(2) } - if err != nil { - fmt.Printf("error loading packages %v: %v\n", pkgs, err) - os.Exit(2) - } - s := ssautil.CreateProgram(p, 0) s.Build()