diff --git a/di.go b/di.go index 2fc540a..b3b3054 100644 --- a/di.go +++ b/di.go @@ -17,6 +17,7 @@ package di import ( "context" "errors" + "fmt" "io" "reflect" "strconv" @@ -66,6 +67,25 @@ var singletonInstances = make(map[string]interface{}) var userCreatedInstances = make(map[string]bool) var beanPostprocessors = make(map[reflect.Type][]func(bean interface{}) error) +var ErrContainerIsAlreadyInitialized = errors.New("container is already initialized") + +var ErrCannotRegisterBeanPostProcessor = fmt.Errorf("%w: can't register bean postprocessor", ErrContainerIsAlreadyInitialized) +var ErrCannotRegisterNewBeanFactory = fmt.Errorf("%w: can't register new bean factory", ErrContainerIsAlreadyInitialized) +var ErrCannotRegisterNewBean = fmt.Errorf("%w: can't register new bean", ErrContainerIsAlreadyInitialized) +var ErrReinitializationNotSupported = fmt.Errorf("%w: reinitialization is not supported", ErrContainerIsAlreadyInitialized) +var ErrUnsupportedDependencyType = errors.New(unsupportedDependencyType) +var ErrRequestScopedBeansCantBeInjected = errors.New(requestScopedBeansCantBeInjected) +var ErrBeanInstanceMustBePointer = errors.New("bean instance must be a pointer") +var ErrBeanTypeMustBePointer = errors.New("bean type must be a pointer") +var ErrUnsupportedScope = errors.New("unsupported scope") +var ErrNoCandidatesFound = errors.New("no candidates found for the injection") +var ErrNoDependencyFound = errors.New("no dependency found") +var ErrInvalidOptionalValue = errors.New("invalid di.optional value") +var ErrBeanFactoryMustReturnPointer = errors.New("bean factory must return pointer") +var ErrCannotFindSetContextMethod = errors.New("unexpected behavior: can't find method SetContext()") +var ErrContainerIsNotInitialized = errors.New("container is not initialized: can't lookup instances of beans yet") +var ErrMoreThenOneCandidateFound = errors.New("more then one candidate found for the injection") + // InitializingBean is an interface marking beans that need to be additionally initialized after the container is ready. type InitializingBean interface { // PostConstruct method will be called on a bean after the container is initialized. @@ -89,7 +109,7 @@ func RegisterBeanPostprocessor(beanType reflect.Type, postprocessor func(bean in initializeShutdownLock.Lock() defer initializeShutdownLock.Unlock() if atomic.CompareAndSwapInt32(&containerInitialized, 1, 1) { - return errors.New("container is already initialized: can't register bean postprocessor") + return ErrCannotRegisterBeanPostProcessor } beanPostprocessors[beanType] = append(beanPostprocessors[beanType], postprocessor) return nil @@ -100,7 +120,7 @@ func InitializeContainer() error { initializeShutdownLock.Lock() defer initializeShutdownLock.Unlock() if atomic.CompareAndSwapInt32(&containerInitialized, 1, 1) { - return errors.New("container is already initialized: reinitialization is not supported") + return ErrReinitializationNotSupported } err := createSingletonInstances() if err != nil { @@ -126,10 +146,10 @@ func RegisterBean(beanID string, beanType reflect.Type) (overwritten bool, err e initializeShutdownLock.Lock() defer initializeShutdownLock.Unlock() if atomic.CompareAndSwapInt32(&containerInitialized, 1, 1) { - return false, errors.New("container is already initialized: can't register new bean") + return false, ErrCannotRegisterNewBean } if beanType.Kind() != reflect.Ptr { - return false, errors.New("bean type must be a pointer") + return false, ErrBeanTypeMustBePointer } var existingBeanType reflect.Type var ok bool @@ -152,7 +172,7 @@ func RegisterBean(beanID string, beanType reflect.Type) (overwritten bool, err e } if field.Type.Kind() != reflect.Ptr && field.Type.Kind() != reflect.Interface && field.Type.Kind() != reflect.Slice && field.Type.Kind() != reflect.Map { - return false, errors.New(unsupportedDependencyType) + return false, ErrUnsupportedDependencyType } } beans[beanID] = beanType @@ -167,11 +187,11 @@ func RegisterBeanInstance(beanID string, beanInstance interface{}) (overwritten initializeShutdownLock.Lock() defer initializeShutdownLock.Unlock() if atomic.CompareAndSwapInt32(&containerInitialized, 1, 1) { - return false, errors.New("container is already initialized: can't register new bean") + return false, ErrCannotRegisterNewBean } beanType := reflect.TypeOf(beanInstance) if beanType.Kind() != reflect.Ptr { - return false, errors.New("bean instance must be a pointer") + return false, ErrBeanInstanceMustBePointer } var existingBeanType reflect.Type var ok bool @@ -197,7 +217,7 @@ func RegisterBeanFactory(beanID string, beanScope Scope, beanFactory func(ctx co initializeShutdownLock.Lock() defer initializeShutdownLock.Unlock() if atomic.CompareAndSwapInt32(&containerInitialized, 1, 1) { - return false, errors.New("container is already initialized: can't register new bean factory") + return false, ErrCannotRegisterNewBeanFactory } var existingBeanType reflect.Type var ok bool @@ -237,7 +257,8 @@ func getScope(bean reflect.Type) (*Scope, error) { case string(Request): return &request, nil } - return nil, errors.New("unsupported scope: " + beanScope) + + return nil, fmt.Errorf("%w:%s", ErrUnsupportedScope, beanScope) } func injectSingletonDependencies() error { @@ -280,10 +301,10 @@ func injectDependencies(beanID string, instance interface{}, chain map[string]bo if optionalDependency { continue } - return errors.New("no candidates found for the injection") + return ErrNoCandidatesFound } if len(candidates) > 1 { - return errors.New("more then one candidate found for the injection") + return ErrMoreThenOneCandidateFound } beanToInject = candidates[0] } @@ -295,7 +316,7 @@ func injectDependencies(beanID string, instance interface{}, chain map[string]bo logrus.Trace("no dependency found, injecting nil since the dependency marked as optional") continue } - return errors.New("no dependency found") + return fmt.Errorf("%w: on bean [%s] miss dependency [%s]", ErrNoDependencyFound, beanID, beanToInject) } if beanScope == Request { return errors.New(requestScopedBeansCantBeInjected) @@ -307,7 +328,7 @@ func injectDependencies(beanID string, instance interface{}, chain map[string]bo fieldToInject.Set(reflect.ValueOf(instanceToInject)) case reflect.Slice: if fieldToInject.Type().Elem().Kind() != reflect.Ptr && fieldToInject.Type().Elem().Kind() != reflect.Interface { - return errors.New(unsupportedDependencyType) + return ErrUnsupportedDependencyType } candidates := findInjectionCandidates(fieldToInject.Type().Elem()) if len(candidates) < 1 { @@ -331,7 +352,7 @@ func injectDependencies(beanID string, instance interface{}, chain map[string]bo } case reflect.Map: if fieldToInject.Type().Elem().Kind() != reflect.Ptr && fieldToInject.Type().Elem().Kind() != reflect.Interface { - return errors.New(unsupportedDependencyType) + return ErrUnsupportedDependencyType } candidates := findInjectionCandidates(fieldToInject.Type().Elem()) if len(candidates) < 1 { @@ -345,7 +366,7 @@ func injectDependencies(beanID string, instance interface{}, chain map[string]bo beanToInjectType := beans[beanToInject] logInjection(beanID, instanceElement, beanToInject, beanToInjectType) if scopes[beanToInject] == Request { - return errors.New(requestScopedBeansCantBeInjected) + return ErrRequestScopedBeansCantBeInjected } instanceToInject, err := getInstance(context.Background(), beanToInject, chain) if err != nil { @@ -354,7 +375,7 @@ func injectDependencies(beanID string, instance interface{}, chain map[string]bo fieldToInject.SetMapIndex(reflect.ValueOf(beanToInject), reflect.ValueOf(instanceToInject)) } default: - return errors.New(unsupportedDependencyType) + return ErrUnsupportedDependencyType } } return nil @@ -373,7 +394,7 @@ func isOptional(field reflect.StructField) (bool, error) { optionalTag := field.Tag.Get(string(optional)) value, err := strconv.ParseBool(optionalTag) if optionalTag != "" && err != nil { - return false, errors.New("invalid di.optional value: " + optionalTag) + return false, fmt.Errorf("%w:%s", ErrInvalidOptionalValue, optionalTag) } return value, nil } @@ -415,7 +436,7 @@ func createSingletonInstances() error { return err } if reflect.TypeOf(beanInstance).Kind() != reflect.Ptr { - return errors.New("bean factory must return pointer") + return ErrBeanFactoryMustReturnPointer } singletonInstances[beanID] = beanInstance logrus.WithFields(logrus.Fields{ @@ -435,7 +456,7 @@ func createInstance(ctx context.Context, beanID string) (interface{}, error) { return nil, err } if reflect.TypeOf(beanInstance).Kind() != reflect.Ptr { - return nil, errors.New("bean factory must return pointer") + return nil, ErrBeanFactoryMustReturnPointer } return beanInstance, nil } @@ -482,7 +503,7 @@ func setContext(ctx context.Context, beanID string, instance interface{}) error if bean.Implements(contextAwareBean) { setContextMethod, ok := bean.MethodByName(contextAwareBean.Method(0).Name) if !ok { - return errors.New("unexpected behavior: can't find method SetContext() in bean " + bean.String()) + return fmt.Errorf("%w: in bean %s", ErrCannotFindSetContextMethod, bean.String()) } logrus.WithField("beanID", beanID).WithField("context", ctx).Trace("setting context to bean") setContextMethod.Func.Call([]reflect.Value{reflect.ValueOf(instance), reflect.ValueOf(ctx)}) @@ -504,7 +525,7 @@ func GetInstance(beanID string) interface{} { // instead. func GetInstanceSafe(beanID string) (interface{}, error) { if atomic.CompareAndSwapInt32(&containerInitialized, 0, 0) { - return nil, errors.New("container is not initialized: can't lookup instances of beans yet") + return nil, ErrContainerIsNotInitialized } if scopes[beanID] == Request { return nil, errors.New("request-scoped beans can't be retrieved directly from the container: they can only be retrieved from the web-context") @@ -514,7 +535,7 @@ func GetInstanceSafe(beanID string) (interface{}, error) { func getRequestBeanInstance(ctx context.Context, beanID string) interface{} { if atomic.CompareAndSwapInt32(&containerInitialized, 0, 0) { - panic("container is not initialized: can't lookup instances of beans yet") + panic(ErrContainerIsNotInitialized) } beanInstance, err := getInstance(ctx, beanID, make(map[string]bool)) if err != nil { diff --git a/di_test.go b/di_test.go index 45f8c86..a0ff2e2 100644 --- a/di_test.go +++ b/di_test.go @@ -42,10 +42,10 @@ func TestDITestSuite(t *testing.T) { func (suite *TestSuite) TestInitializeContainerTwice() { err := InitializeContainer() assert.NoError(suite.T(), err) - expectedError := errors.New("container is already initialized: reinitialization is not supported") err = InitializeContainer() if assert.Error(suite.T(), err) { - assert.Equal(suite.T(), expectedError, err) + assert.ErrorIs(suite.T(), err, ErrReinitializationNotSupported) + assert.ErrorIs(suite.T(), err, ErrContainerIsAlreadyInitialized) } } @@ -69,16 +69,17 @@ func (suite *TestSuite) TestGetInstanceBeforeContainerInitialization() { func (suite *TestSuite) TestRegisterBeanAfterContainerInitialization() { err := InitializeContainer() assert.NoError(suite.T(), err) - expectedError := errors.New("container is already initialized: can't register new bean") overwritten, err := RegisterBean("", nil) assert.False(suite.T(), overwritten) if assert.Error(suite.T(), err) { - assert.Equal(suite.T(), expectedError, err) + assert.ErrorIs(suite.T(), err, ErrCannotRegisterNewBean) + assert.ErrorIs(suite.T(), err, ErrContainerIsAlreadyInitialized) } overwritten, err = RegisterBeanInstance("", nil) assert.False(suite.T(), overwritten) if assert.Error(suite.T(), err) { - assert.Equal(suite.T(), expectedError, err) + assert.ErrorIs(suite.T(), err, ErrCannotRegisterNewBean) + assert.ErrorIs(suite.T(), err, ErrContainerIsAlreadyInitialized) } } @@ -114,10 +115,10 @@ func (suite *TestSuite) TestBeanFactoryCalledOnce() { func (suite *TestSuite) TestRegisterBeanPostprocessorAfterContainerInitialization() { err := InitializeContainer() assert.NoError(suite.T(), err) - expectedError := errors.New("container is already initialized: can't register bean postprocessor") err = RegisterBeanPostprocessor(reflect.TypeOf((*string)(nil)), nil) if assert.Error(suite.T(), err) { - assert.Equal(suite.T(), expectedError, err) + assert.ErrorIs(suite.T(), err, ErrCannotRegisterBeanPostProcessor) + assert.ErrorIs(suite.T(), err, ErrContainerIsAlreadyInitialized) } } @@ -172,11 +173,11 @@ func (suite *TestSuite) TestRegisterSingletonBeanUnsupportedScope() { type SingletonBean struct { Scope Scope `di.scope:"invalid"` } - expectedError := errors.New("unsupported scope: invalid") overwritten, err := RegisterBean("", reflect.TypeOf((*SingletonBean)(nil))) assert.False(suite.T(), overwritten) if assert.Error(suite.T(), err) { - assert.Equal(suite.T(), expectedError, err) + assert.ErrorIs(suite.T(), err, ErrUnsupportedScope) + assert.ErrorContains(suite.T(), err, "invalid") } } @@ -199,10 +200,10 @@ func (suite *TestSuite) TestRegisterSingletonBeanWrongOptionalValue() { overwritten, err := RegisterBean("singletonBean", reflect.TypeOf((*SingletonBean)(nil))) assert.False(suite.T(), overwritten) assert.NoError(suite.T(), err) - expectedError := errors.New("invalid di.optional value: fls") err = InitializeContainer() if assert.Error(suite.T(), err) { - assert.Equal(suite.T(), expectedError, err) + assert.ErrorIs(suite.T(), err, ErrInvalidOptionalValue) + assert.ErrorContains(suite.T(), err, "fls") } } @@ -213,10 +214,10 @@ func (suite *TestSuite) TestRegisterSingletonBeanMissingImplicitlyRequiredDepend overwritten, err := RegisterBean("singletonBean", reflect.TypeOf((*SingletonBean)(nil))) assert.False(suite.T(), overwritten) assert.NoError(suite.T(), err) - expectedError := errors.New("no dependency found") err = InitializeContainer() if assert.Error(suite.T(), err) { - assert.Equal(suite.T(), expectedError, err) + assert.ErrorIs(suite.T(), err, ErrNoDependencyFound) + assert.ErrorContains(suite.T(), err, "someOtherBean") } } @@ -227,10 +228,10 @@ func (suite *TestSuite) TestRegisterSingletonBeanMissingExplicitlyRequiredDepend overwritten, err := RegisterBean("singletonBean", reflect.TypeOf((*SingletonBean)(nil))) assert.False(suite.T(), overwritten) assert.NoError(suite.T(), err) - expectedError := errors.New("no dependency found") err = InitializeContainer() if assert.Error(suite.T(), err) { - assert.Equal(suite.T(), expectedError, err) + assert.ErrorIs(suite.T(), err, ErrNoDependencyFound) + assert.ErrorContains(suite.T(), err, "someOtherBean") } } @@ -361,11 +362,11 @@ func (suite *TestSuite) TestRegisterSingletonBeanFactory() { func (suite *TestSuite) TestRegisterBeanFactoryAfterContainerInitialization() { err := InitializeContainer() assert.NoError(suite.T(), err) - expectedFactoryError := errors.New("container is already initialized: can't register new bean factory") overwritten, err := RegisterBeanFactory("", Singleton, nil) assert.False(suite.T(), overwritten) assert.Error(suite.T(), err) - assert.Equal(suite.T(), expectedFactoryError, err) + assert.ErrorIs(suite.T(), err, ErrCannotRegisterNewBeanFactory) + assert.ErrorIs(suite.T(), err, ErrContainerIsAlreadyInitialized) } func (suite *TestSuite) TestRegisterPrototypeBean() {