diff --git a/staticcheck/analysis.go b/staticcheck/analysis.go index cce61975..2965d318 100644 --- a/staticcheck/analysis.go +++ b/staticcheck/analysis.go @@ -97,6 +97,7 @@ import ( "honnef.co/go/tools/staticcheck/sa9006" "honnef.co/go/tools/staticcheck/sa9007" "honnef.co/go/tools/staticcheck/sa9008" + "honnef.co/go/tools/staticcheck/sa9009" ) var Analyzers = []*lint.Analyzer{ @@ -193,4 +194,5 @@ var Analyzers = []*lint.Analyzer{ sa9006.SCAnalyzer, sa9007.SCAnalyzer, sa9008.SCAnalyzer, + sa9009.SCAnalyzer, } diff --git a/staticcheck/sa9009/sa9009.go b/staticcheck/sa9009/sa9009.go new file mode 100644 index 00000000..b5c83454 --- /dev/null +++ b/staticcheck/sa9009/sa9009.go @@ -0,0 +1,104 @@ +package sa9009 + +import ( + "go/ast" + "go/types" + + "honnef.co/go/tools/analysis/code" + "honnef.co/go/tools/analysis/lint" + "honnef.co/go/tools/analysis/report" + + "golang.org/x/tools/go/analysis" + "golang.org/x/tools/go/analysis/passes/inspect" +) + +var SCAnalyzer = lint.InitializeAnalyzer(&lint.Analyzer{ + Analyzer: &analysis.Analyzer{ + Name: "SA9009", + Run: run, + Requires: []*analysis.Analyzer{inspect.Analyzer}, + }, + Doc: &lint.Documentation{ + Title: `Returned function should be called in defer`, + Text: ` +If you have a function such as: + + func f() func() { + // Do something. + return func() { + // Do something. + } + } + +Then calling that in defer: + + defer f() + +Is almost always a mistake, since you typically want to call the returned +function: + + defer f()() +`, + Since: "Unreleased", + Severity: lint.SeverityWarning, + MergeIf: lint.MergeIfAny, + }, +}) + +var Analyzer = SCAnalyzer.Analyzer + +func run(pass *analysis.Pass) (any, error) { + checkIdent := func(c *ast.Ident) bool { + var ( + obj = pass.TypesInfo.ObjectOf(c) + sig *types.Signature + ) + switch f := obj.(type) { + case *types.Builtin: + return false + case *types.Func: + sig = f.Type().(*types.Signature) + case *types.Var: + switch ff := f.Type().(type) { + case *types.Signature: + sig = ff + case *types.Named: + sig = ff.Underlying().(*types.Signature) + } + } + r := sig.Results() + if r != nil && r.Len() == 1 { + _, ok := r.At(0).Type().(*types.Signature) + return ok + } + return false + } + + fn := func(n ast.Node) { + var ( + returnsFunc bool + def = n.(*ast.DeferStmt) + ) + switch c := def.Call.Fun.(type) { + case *ast.FuncLit: // defer func() { }() + r := c.Type.Results + if r != nil && len(r.List) == 1 { + _, returnsFunc = r.List[0].Type.(*ast.FuncType) + } + case *ast.Ident: // defer f() + returnsFunc = checkIdent(c) + case *ast.SelectorExpr: // defer t.f() + returnsFunc = checkIdent(c.Sel) + case *ast.IndexExpr: // defer f[int](0) + if id, ok := c.X.(*ast.Ident); ok { + returnsFunc = checkIdent(id) + } + } + if returnsFunc { + report.Report(pass, def, "defered return function not called") + } + } + + code.Preorder(pass, fn, (*ast.DeferStmt)(nil)) + return nil, nil +} diff --git a/staticcheck/sa9009/sa9009_test.go b/staticcheck/sa9009/sa9009_test.go new file mode 100644 index 00000000..a755faf3 --- /dev/null +++ b/staticcheck/sa9009/sa9009_test.go @@ -0,0 +1,13 @@ +// Code generated by generate.go. DO NOT EDIT. + +package sa9009 + +import ( + "testing" + + "honnef.co/go/tools/analysis/lint/testutil" +) + +func TestTestdata(t *testing.T) { + testutil.Run(t, SCAnalyzer) +} diff --git a/staticcheck/sa9009/testdata/go1.0/example.com/deferr/deferr.go b/staticcheck/sa9009/testdata/go1.0/example.com/deferr/deferr.go new file mode 100644 index 00000000..a6ea3cb6 --- /dev/null +++ b/staticcheck/sa9009/testdata/go1.0/example.com/deferr/deferr.go @@ -0,0 +1,109 @@ +package deferr + +func x() { + var ( + t t1 + tt t2 + varReturnNothing = func() {} + varReturnInt = func() int { return 0 } + varReturnFunc = func() func() { return func() {} } + varReturnFuncInt = func(int) func(int) int { return func(int) int { return 0 } } + varReturnMulti = func() (int, func()) { return 0, func() {} } + + namedReturnNothing = named(func() {}) + namedReturnFunc = namedReturn(func() func() { return func() {} }) + ) + + // Correct. + defer returnNothing() + defer varReturnNothing() + defer namedReturnNothing() + defer t.returnNothing() + defer tt.returnNothing() + defer tt.t.returnNothing() + defer func() {}() + defer close(make(chan int)) + + defer returnInt() + defer varReturnInt() + defer t.returnInt() + defer tt.returnInt() + defer tt.t.returnInt() + defer func() int { return 0 }() + + defer returnFunc()() + defer varReturnFunc()() + defer namedReturnFunc()() + defer t.returnFunc()() + defer tt.returnFunc()() + defer tt.t.returnFunc()() + defer func() func() { return func() {} }()() + + defer returnFuncInt(0)(0) + defer varReturnFuncInt(0)(0) + defer t.returnFuncInt(0)(0) + defer tt.returnFuncInt(0)(0) + defer tt.t.returnFuncInt(0)(0) + defer func(int) func(int) int { return func(int) int { return 0 } }(0)(0) + + defer returnMulti() + defer varReturnMulti() + defer t.returnMulti() + defer tt.returnMulti() + defer tt.t.returnMulti() + defer func() (int, func()) { return 0, func() {} }() + + // Wrong. + defer returnFunc() //@ diag(`defered return function not called`) + defer varReturnFunc() //@ diag(`defered return function not called`) + defer namedReturnFunc() //@ diag(`defered return function not called`) + defer t.returnFunc() //@ diag(`defered return function not called`) + defer tt.returnFunc() //@ diag(`defered return function not called`) + defer tt.t.returnFunc() //@ diag(`defered return function not called`) + defer func() func() { return func() {} }() //@ diag(`defered return function not called`) + defer returnFuncInt(0) //@ diag(`defered return function not called`) + defer t.returnFuncInt(0) //@ diag(`defered return function not called`) + defer tt.returnFuncInt(0) //@ diag(`defered return function not called`) + defer tt.t.returnFuncInt(0) //@ diag(`defered return function not called`) + defer func(int) func(int) int { return func(int) int { return 0 } }(0) //@ diag(`defered return function not called`) + + // Function returns a function which returns another function. This is + // getting silly and is not checked. + defer silly1()() + defer func() func() func() { + return func() func() { + return func() {} + } + }()() +} + +func returnNothing() {} +func returnInt() int { return 0 } +func returnFunc() func() { return func() {} } +func returnFuncInt(int) func(int) int { return func(int) int { return 0 } } +func returnMulti() (int, func()) { return 0, func() {} } + +type ( + t1 struct{} + t2 struct{ t t1 } + named func() + namedReturn func() func() +) + +func (t1) returnNothing() {} +func (t1) returnInt() int { return 0 } +func (t1) returnFunc() func() { return func() {} } +func (t1) returnFuncInt(int) func(int) int { return func(int) int { return 0 } } +func (t1) returnMulti() (int, func()) { return 0, func() {} } + +func (*t2) returnNothing() {} +func (*t2) returnInt() int { return 0 } +func (*t2) returnFunc() func() { return func() {} } +func (*t2) returnFuncInt(int) func(int) int { return func(int) int { return 0 } } +func (*t2) returnMulti() (int, func()) { return 0, func() {} } + +func silly1() func() func() { + return func() func() { + return func() {} + } +} diff --git a/staticcheck/sa9009/testdata/go1.18/deferr/deferr.go b/staticcheck/sa9009/testdata/go1.18/deferr/deferr.go new file mode 100644 index 00000000..eff17f20 --- /dev/null +++ b/staticcheck/sa9009/testdata/go1.18/deferr/deferr.go @@ -0,0 +1,9 @@ +package deferr + +func x() { + + defer tpReturnFuncInt[int](0) //@ diag(`defered return function not called`) + defer tpReturnFuncInt(0)(0) +} + +func tpReturnFuncInt[T any](T) func(int) int { return func(int) int { return 0 } }