From 9ad9b6cb1ef3191a5105bde8ae3aaa1d800d302b Mon Sep 17 00:00:00 2001 From: Nikita Vorontsov Date: Mon, 29 Jul 2024 17:35:07 +0300 Subject: [PATCH 1/2] add introspection interfaces for some fsm errors --- errors.go | 28 ++++++++++++++++++++++++++++ 1 file changed, 28 insertions(+) diff --git a/errors.go b/errors.go index add5dbe..df49c3d 100644 --- a/errors.go +++ b/errors.go @@ -16,6 +16,7 @@ package fsm import ( "context" + "errors" ) // InvalidEventError is returned by FSM.Event() when the event cannot be called @@ -69,6 +70,15 @@ func (e NoTransitionError) Error() string { return "no transition" } +func (e NoTransitionError) Is(target error) bool { + _, ok := target.(NoTransitionError) + return ok || errors.Is(e.Err, target) +} + +func (e NoTransitionError) Unwrap() error { + return e.Err +} + // CanceledError is returned by FSM.Event() when a callback have canceled a // transition. type CanceledError struct { @@ -82,6 +92,15 @@ func (e CanceledError) Error() string { return "transition canceled" } +func (e CanceledError) Is(target error) bool { + _, ok := target.(CanceledError) + return ok || errors.Is(e.Err, target) +} + +func (e CanceledError) Unwrap() error { + return e.Err +} + // AsyncError is returned by FSM.Event() when a callback have initiated an // asynchronous state transition. type AsyncError struct { @@ -98,6 +117,15 @@ func (e AsyncError) Error() string { return "async started" } +func (e AsyncError) Is(target error) bool { + _, ok := target.(AsyncError) + return ok || errors.Is(e.Err, target) +} + +func (e AsyncError) Unwrap() error { + return e.Err +} + // InternalError is returned by FSM.Event() and should never occur. It is a // probably because of a bug. type InternalError struct{} From 52d34e4c1e18f0a32573064eb5a53c0d8732a272 Mon Sep 17 00:00:00 2001 From: Nikita Vorontsov Date: Fri, 16 Aug 2024 11:12:22 +0300 Subject: [PATCH 2/2] add tests for errors Is/Unwrap --- errors_test.go | 46 +++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 43 insertions(+), 3 deletions(-) diff --git a/errors_test.go b/errors_test.go index ba384ee..06f4213 100644 --- a/errors_test.go +++ b/errors_test.go @@ -53,35 +53,71 @@ func TestNotInTransitionError(t *testing.T) { func TestNoTransitionError(t *testing.T) { e := NoTransitionError{} + innerErr := errors.New("no transition") if e.Error() != "no transition" { t.Error("NoTransitionError string mismatch") } - e.Err = errors.New("no transition") + e.Err = innerErr if e.Error() != "no transition with error: "+e.Err.Error() { t.Error("NoTransitionError string mismatch") } + + realErr := hideErrInterfaceType(e) + if !errors.Is(realErr, NoTransitionError{}) { + t.Error("NoTransitionError 'Is' broken") + } + if !errors.Is(realErr, innerErr) { + t.Error("NoTransitionError 'Is' broken") + } + if errors.Unwrap(e) != innerErr { + t.Error("NoTransitionError 'Unwrap' broken") + } } func TestCanceledError(t *testing.T) { e := CanceledError{} + innerErr := errors.New("canceled") if e.Error() != "transition canceled" { t.Error("CanceledError string mismatch") } - e.Err = errors.New("canceled") + e.Err = innerErr if e.Error() != "transition canceled with error: "+e.Err.Error() { t.Error("CanceledError string mismatch") } + + realErr := hideErrInterfaceType(e) + if !errors.Is(realErr, CanceledError{}) { + t.Error("CanceledError 'Is' broken") + } + if !errors.Is(realErr, innerErr) { + t.Error("CanceledError 'Is' broken") + } + if errors.Unwrap(e) != innerErr { + t.Error("CanceledError 'Unwrap' broken") + } } func TestAsyncError(t *testing.T) { e := AsyncError{} + innerErr := errors.New("async") if e.Error() != "async started" { t.Error("AsyncError string mismatch") } - e.Err = errors.New("async") + e.Err = innerErr if e.Error() != "async started with error: "+e.Err.Error() { t.Error("AsyncError string mismatch") } + + realErr := hideErrInterfaceType(e) + if !errors.Is(realErr, AsyncError{}) { + t.Error("AsyncError 'Is' broken") + } + if !errors.Is(realErr, innerErr) { + t.Error("AsyncError 'Is' broken") + } + if errors.Unwrap(e) != innerErr { + t.Error("AsyncError 'Unwrap' broken") + } } func TestInternalError(t *testing.T) { @@ -90,3 +126,7 @@ func TestInternalError(t *testing.T) { t.Error("InternalError string mismatch") } } + +func hideErrInterfaceType(err error) error { + return err +}