diff --git a/README.md b/README.md index a935b31..f015981 100644 --- a/README.md +++ b/README.md @@ -102,8 +102,9 @@ func (c *Container) MustResolve(target interface{}) - [x] Topological sorting and circular dependency detection - [x] Thread-safe container operations - [x] Interface binding support -- [ ] Error handling +- [x] Error handling - [ ] Named dependency resolution +- [ ] Dependency graph visualization ## 📊 Benefits diff --git a/container.go b/container.go index 6643a77..5819b13 100644 --- a/container.go +++ b/container.go @@ -145,16 +145,25 @@ func (c *Container) analyzeFunction(fnType reflect.Type) (fnSignature, error) { } // Analyze return values - // todo: add here error handling + // Support either: (*T) or (*T, error) + if fnType.NumOut() == 0 || fnType.NumOut() > 2 { + return fnSignature{}, fmt.Errorf("constructor must return (*T) or (*T, error)") + } - if fnType.NumOut() != 1 { - return fnSignature{}, fmt.Errorf("constructor must have one return value") + firstOut := fnType.Out(0) + if firstOut.Kind() != reflect.Pointer { + return fnSignature{}, fmt.Errorf("constructor must return pointer value as first result") } - if fnType.Out(0).Kind() != reflect.Pointer { - return fnSignature{}, fmt.Errorf("constructor must return pointer value") + + if fnType.NumOut() == 2 { + secondOut := fnType.Out(1) + errorType := reflect.TypeOf((*error)(nil)).Elem() + if !secondOut.Implements(errorType) { + return fnSignature{}, fmt.Errorf("second return value must be error") + } } - return fnSignature{args, fnType.Out(0)}, nil + return fnSignature{args, firstOut}, nil } // Resolve resolves and returns an instance of the requested type. @@ -367,10 +376,11 @@ func (c *Container) resolveInstance(typ reflect.Type) error { // Call constructor results := constructorValue.Call(args) - // Handle error return - if len(results) > 0 { + // Handle optional error return (when present and non-nil) + if len(results) > 1 { lastResult := results[len(results)-1] - if lastResult.Type().String() == "error" && !lastResult.IsNil() { + errorType := reflect.TypeOf((*error)(nil)).Elem() + if lastResult.Type().Implements(errorType) && !lastResult.IsNil() { return lastResult.Interface().(error) } } diff --git a/container_test.go b/container_test.go index 7baa90e..f58d12e 100644 --- a/container_test.go +++ b/container_test.go @@ -220,14 +220,14 @@ var _ = Describe("Container", func() { }) Describe("Error Handling", func() { - PIt("should handle constructor errors", func() { + It("should handle constructor errors", func() { Expect(container.Provide(NewErrorDatabase)).To(Succeed()) var db *Database Expect(container.Resolve(&db)).To(MatchError(ContainSubstring("database connection failed"))) }) - PIt("should handle successful constructors with error return", func() { + It("should handle successful constructors with error return", func() { Expect(container.Provide(NewSuccessfulDatabase)).To(Succeed()) var db *Database