Skip to content

Commit

Permalink
SA9009: flag defer foo() where foo returns a closure
Browse files Browse the repository at this point in the history
Add a new checker to flag cases where a deferred function returns
exactly one function, and that function isn't called. For example:

	func f() func() {
		// Do stuff.
		return func() {
			// Do stuff
		}
	}

	func main() {
		defer f() // Error: should be f()()
	}

As mentioned in dominikh#466, there's a change for false positives. I ran the
check on Go stdlib, all my own code, and about 200 popular Go packages,
and haven't found any false positives (and one match:
grpc/grpc-go#7270)

Fixes dominikh#466
  • Loading branch information
arp242 committed May 25, 2024
1 parent 5275b91 commit 83f639f
Show file tree
Hide file tree
Showing 4 changed files with 231 additions and 0 deletions.
2 changes: 2 additions & 0 deletions staticcheck/analysis.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

104 changes: 104 additions & 0 deletions staticcheck/sa9009/sa9009.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
package sa9009

import (
"go/ast"
"go/types"

"honnef.co/go/tools/analysis/code"
"honnef.co/go/tools/analysis/lint"
"honnef.co/go/tools/analysis/report"

"golang.org/x/tools/go/analysis"
"golang.org/x/tools/go/analysis/passes/inspect"
)

var SCAnalyzer = lint.InitializeAnalyzer(&lint.Analyzer{
Analyzer: &analysis.Analyzer{
Name: "SA9009",
Run: run,
Requires: []*analysis.Analyzer{inspect.Analyzer},
},
Doc: &lint.Documentation{
Title: `Returned function should be called in defer`,
Text: `
If you have a function such as:
func f() func() {
// Do something.
return func() {
// Do something.
}
}
Then calling that in defer:
defer f()
Is almost always a mistake, since you typically want to call the returned
function:
defer f()()
`,
Since: "Unreleased",
Severity: lint.SeverityWarning,
MergeIf: lint.MergeIfAny,
},
})

var Analyzer = SCAnalyzer.Analyzer

func run(pass *analysis.Pass) (any, error) {
checkIdent := func(c *ast.Ident) bool {
var (
obj = pass.TypesInfo.ObjectOf(c)
sig *types.Signature
)
switch f := obj.(type) {
case *types.Builtin:
return false
case *types.Func:
sig = f.Type().(*types.Signature)
case *types.Var:
switch ff := f.Type().(type) {
case *types.Signature:
sig = ff
case *types.Named:
sig = ff.Underlying().(*types.Signature)
}
}
r := sig.Results()
if r != nil && r.Len() == 1 {
_, ok := r.At(0).Type().(*types.Signature)
return ok
}
return false
}

fn := func(n ast.Node) {
var (
returnsFunc bool
def = n.(*ast.DeferStmt)
)
switch c := def.Call.Fun.(type) {
case *ast.FuncLit: // defer func() { }()
r := c.Type.Results
if r != nil && len(r.List) == 1 {
_, returnsFunc = r.List[0].Type.(*ast.FuncType)
}
case *ast.Ident: // defer f()
returnsFunc = checkIdent(c)
case *ast.SelectorExpr: // defer t.f()
returnsFunc = checkIdent(c.Sel)
case *ast.IndexExpr: // defer f[int](0)
if id, ok := c.X.(*ast.Ident); ok {
returnsFunc = checkIdent(id)
}
}
if returnsFunc {
report.Report(pass, def, "defered return function not called")
}
}

code.Preorder(pass, fn, (*ast.DeferStmt)(nil))
return nil, nil
}
13 changes: 13 additions & 0 deletions staticcheck/sa9009/sa9009_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

112 changes: 112 additions & 0 deletions staticcheck/sa9009/testdata/src/example.com/deferr/deferr.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,112 @@
package deferr

func x() {
var (
t t1
tt t2
varReturnNothing = func() {}
varReturnInt = func() int { return 0 }
varReturnFunc = func() func() { return func() {} }
varReturnFuncInt = func(int) func(int) int { return func(int) int { return 0 } }
varReturnMulti = func() (int, func()) { return 0, func() {} }

namedReturnNothing = named(func() {})
namedReturnFunc = namedReturn(func() func() { return func() {} })
)

// Correct.
defer returnNothing()
defer varReturnNothing()
defer namedReturnNothing()
defer t.returnNothing()
defer tt.returnNothing()
defer tt.t.returnNothing()
defer func() {}()
defer close(make(chan int))

defer returnInt()
defer varReturnInt()
defer t.returnInt()
defer tt.returnInt()
defer tt.t.returnInt()
defer func() int { return 0 }()

defer returnFunc()()
defer varReturnFunc()()
defer namedReturnFunc()()
defer t.returnFunc()()
defer tt.returnFunc()()
defer tt.t.returnFunc()()
defer func() func() { return func() {} }()()

defer returnFuncInt(0)(0)
defer tpReturnFuncInt(0)(0)
defer varReturnFuncInt(0)(0)
defer t.returnFuncInt(0)(0)
defer tt.returnFuncInt(0)(0)
defer tt.t.returnFuncInt(0)(0)
defer func(int) func(int) int { return func(int) int { return 0 } }(0)(0)

defer returnMulti()
defer varReturnMulti()
defer t.returnMulti()
defer tt.returnMulti()
defer tt.t.returnMulti()
defer func() (int, func()) { return 0, func() {} }()

// Wrong.
defer returnFunc() //@ diag(`defered return function not called`)
defer varReturnFunc() //@ diag(`defered return function not called`)
defer namedReturnFunc() //@ diag(`defered return function not called`)
defer t.returnFunc() //@ diag(`defered return function not called`)
defer tt.returnFunc() //@ diag(`defered return function not called`)
defer tt.t.returnFunc() //@ diag(`defered return function not called`)
defer func() func() { return func() {} }() //@ diag(`defered return function not called`)
defer returnFuncInt(0) //@ diag(`defered return function not called`)
defer tpReturnFuncInt[int](0) //@ diag(`defered return function not called`)
defer t.returnFuncInt(0) //@ diag(`defered return function not called`)
defer tt.returnFuncInt(0) //@ diag(`defered return function not called`)
defer tt.t.returnFuncInt(0) //@ diag(`defered return function not called`)
defer func(int) func(int) int { return func(int) int { return 0 } }(0) //@ diag(`defered return function not called`)

// Function returns a function which returns another function. This is
// getting silly and is not checked.
defer silly1()()
defer func() func() func() {
return func() func() {
return func() {}
}
}()()
}

func returnNothing() {}
func returnInt() int { return 0 }
func returnFunc() func() { return func() {} }
func returnFuncInt(int) func(int) int { return func(int) int { return 0 } }
func returnMulti() (int, func()) { return 0, func() {} }
func tpReturnFuncInt[T any](T) func(int) int { return func(int) int { return 0 } }

type (
t1 struct{}
t2 struct{ t t1 }
named func()
namedReturn func() func()
)

func (t1) returnNothing() {}
func (t1) returnInt() int { return 0 }
func (t1) returnFunc() func() { return func() {} }
func (t1) returnFuncInt(int) func(int) int { return func(int) int { return 0 } }
func (t1) returnMulti() (int, func()) { return 0, func() {} }

func (*t2) returnNothing() {}
func (*t2) returnInt() int { return 0 }
func (*t2) returnFunc() func() { return func() {} }
func (*t2) returnFuncInt(int) func(int) int { return func(int) int { return 0 } }
func (*t2) returnMulti() (int, func()) { return 0, func() {} }

func silly1() func() func() {
return func() func() {
return func() {}
}
}

0 comments on commit 83f639f

Please sign in to comment.