Skip to content
This repository has been archived by the owner on Sep 21, 2021. It is now read-only.

Commit

Permalink
Merge pull request #10 from gettaxi/master
Browse files Browse the repository at this point in the history
Add gorm and sqlx support, make easy add new other ORM
  • Loading branch information
clundquist-stripe authored Dec 21, 2017
2 parents 452e37e + a7e6848 commit cddf355
Show file tree
Hide file tree
Showing 2 changed files with 99 additions and 19 deletions.
9 changes: 4 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,8 +31,8 @@ How does it work?
-----------------

SafeSQL uses the static analysis utilities in [go/tools][tools] to search for
all call sites of each of the `query` functions in package [database/sql][sql]
(i.e., functions which accept a `string` parameter named `query`). It then makes
all call sites of each of the `query` functions in packages ([database/sql][sql],[github.com/jinzhu/gorm][gorm],[github.com/jmoiron/sqlx][sqlx])
(i.e., functions which accept a parameter named `query`,`sql`). It then makes
sure that every such call site uses a query that is a compile-time constant.

The principle behind SafeSQL's safety guarantees is that queries that are
Expand All @@ -44,6 +44,8 @@ will not be allowed.

[tools]: https://godoc.org/golang.org/x/tools/go
[sql]: http://golang.org/pkg/database/sql/
[sqlx]: https://github.com/jmoiron/sqlx
[gorm]: https://github.com/jinzhu/gorm

False positives
---------------
Expand All @@ -66,8 +68,6 @@ a fundamental limitation: SafeSQL could recursively trace the `query` argument
through every intervening helper function to ensure that its argument is always
constant, but this code has yet to be written.

If you use a wrapper for `database/sql` (e.g., [`sqlx`][sqlx]), it's likely
SafeSQL will not work for you because of this.

The second sort of false positive is based on a limitation in the sort of
analysis SafeSQL performs: there are many safe SQL statements which are not
Expand All @@ -76,4 +76,3 @@ static analysis techniques (such as taint analysis) or user-provided safety
annotations would be able to reduce the number of false positives, but this is
expected to be a significant undertaking.

[sqlx]: https://github.com/jmoiron/sqlx
109 changes: 95 additions & 14 deletions safesql.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ import (
"go/build"
"go/types"
"os"

"path/filepath"
"strings"

Expand All @@ -19,6 +20,27 @@ import (
"golang.org/x/tools/go/ssa/ssautil"
)

type sqlPackage struct {
packageName string
paramNames []string
enable bool
}

var sqlPackages = []sqlPackage{
{
packageName: "database/sql",
paramNames: []string{"query"},
},
{
packageName: "github.com/jinzhu/gorm",
paramNames: []string{"sql", "query"},
},
{
packageName: "github.com/jmoiron/sqlx",
paramNames: []string{"query"},
},
}

func main() {
var verbose, quiet bool
flag.BoolVar(&verbose, "v", false, "Verbose mode")
Expand All @@ -38,21 +60,45 @@ func main() {
c := loader.Config{
FindPackage: FindPackage,
}
c.Import("database/sql")
for _, pkg := range pkgs {
c.Import(pkg)
}
p, err := c.Load()

if err != nil {
fmt.Printf("error loading packages %v: %v\n", pkgs, err)
os.Exit(2)
}

imports := getImports(p)
existOne := false
for i := range sqlPackages {
if _, exist := imports[sqlPackages[i].packageName]; exist {
if verbose {
fmt.Printf("Enabling support for %s\n", sqlPackages[i].packageName)
}
sqlPackages[i].enable = true
existOne = true
}
}
if !existOne {
fmt.Printf("No packages in %v include a supported database driver", pkgs)
os.Exit(2)
}

s := ssautil.CreateProgram(p, 0)
s.Build()

qms := FindQueryMethods(p.Package("database/sql").Pkg, s)
qms := make([]*QueryMethod, 0)

for i := range sqlPackages {
if sqlPackages[i].enable {
qms = append(qms, FindQueryMethods(sqlPackages[i], p.Package(sqlPackages[i].packageName).Pkg, s)...)
}
}

if verbose {
fmt.Println("database/sql functions that accept queries:")
fmt.Println("database driver functions that accept queries:")
for _, m := range qms {
fmt.Printf("- %s (param %d)\n", m.Func, m.Param)
}
Expand All @@ -75,21 +121,27 @@ func main() {
}

bad := FindNonConstCalls(res.CallGraph, qms)

if len(bad) == 0 {
if !quiet {
fmt.Println(`You're safe from SQL injection! Yay \o/`)
}
return
}

fmt.Printf("Found %d potentially unsafe SQL statements:\n", len(bad))
if verbose {
fmt.Printf("Found %d potentially unsafe SQL statements:\n", len(bad))
}

for _, ci := range bad {
pos := p.Fset.Position(ci.Pos())
fmt.Printf("- %s\n", pos)
}
fmt.Println("Please ensure that all SQL queries you use are compile-time constants.")
fmt.Println("You should always use parameterized queries or prepared statements")
fmt.Println("instead of building queries from strings.")
if verbose {
fmt.Println("Please ensure that all SQL queries you use are compile-time constants.")
fmt.Println("You should always use parameterized queries or prepared statements")
fmt.Println("instead of building queries from strings.")
}
os.Exit(1)
}

Expand All @@ -104,7 +156,7 @@ type QueryMethod struct {

// FindQueryMethods locates all methods in the given package (assumed to be
// package database/sql) with a string parameter named "query".
func FindQueryMethods(sql *types.Package, ssa *ssa.Program) []*QueryMethod {
func FindQueryMethods(sqlPackages sqlPackage, sql *types.Package, ssa *ssa.Program) []*QueryMethod {
methods := make([]*QueryMethod, 0)
scope := sql.Scope()
for _, name := range scope.Names() {
Expand All @@ -122,7 +174,7 @@ func FindQueryMethods(sql *types.Package, ssa *ssa.Program) []*QueryMethod {
continue
}
s := m.Type().(*types.Signature)
if num, ok := FuncHasQuery(s); ok {
if num, ok := FuncHasQuery(sqlPackages, s); ok {
methods = append(methods, &QueryMethod{
Func: m,
SSA: ssa.FuncValue(m),
Expand All @@ -135,16 +187,16 @@ func FindQueryMethods(sql *types.Package, ssa *ssa.Program) []*QueryMethod {
return methods
}

var stringType types.Type = types.Typ[types.String]

// FuncHasQuery returns the offset of the string parameter named "query", or
// none if no such parameter exists.
func FuncHasQuery(s *types.Signature) (offset int, ok bool) {
func FuncHasQuery(sqlPackages sqlPackage, s *types.Signature) (offset int, ok bool) {
params := s.Params()
for i := 0; i < params.Len(); i++ {
v := params.At(i)
if v.Name() == "query" && v.Type() == stringType {
return i, true
for _, paramName := range sqlPackages.paramNames {
if v.Name() == paramName {
return i, true
}
}
}
return 0, false
Expand All @@ -164,6 +216,16 @@ func FindMains(p *loader.Program, s *ssa.Program) []*ssa.Package {
return mains
}

func getImports(p *loader.Program) map[string]interface{} {
pkgs := make(map[string]interface{})
for _, pkg := range p.AllPackages {
if pkg.Importable {
pkgs[pkg.Pkg.Path()] = nil
}
}
return pkgs
}

// FindNonConstCalls returns the set of callsites of the given set of methods
// for which the "query" parameter is not a compile-time constant.
func FindNonConstCalls(cg *callgraph.Graph, qms []*QueryMethod) []ssa.CallInstruction {
Expand All @@ -186,6 +248,18 @@ func FindNonConstCalls(cg *callgraph.Graph, qms []*QueryMethod) []ssa.CallInstru
if _, ok := okFuncs[edge.Site.Parent()]; ok {
continue
}

isInternalSQLPkg := false
for _, pkg := range sqlPackages {
if pkg.packageName == edge.Caller.Func.Pkg.Pkg.Path() {
isInternalSQLPkg = true
break
}
}
if isInternalSQLPkg {
continue
}

cc := edge.Site.Common()
args := cc.Args
// The first parameter is occasionally the receiver.
Expand All @@ -195,7 +269,14 @@ func FindNonConstCalls(cg *callgraph.Graph, qms []*QueryMethod) []ssa.CallInstru
panic("arg count mismatch")
}
v := args[m.Param]

if _, ok := v.(*ssa.Const); !ok {
if inter, ok := v.(*ssa.MakeInterface); ok && types.IsInterface(v.(*ssa.MakeInterface).Type()) {
if inter.X.Referrers() == nil || inter.X.Type() != types.Typ[types.String] {
continue
}
}

bad = append(bad, edge.Site)
}
}
Expand Down

0 comments on commit cddf355

Please sign in to comment.