diff --git a/dingo.go b/dingo.go index bb8323e..55ac33e 100644 --- a/dingo.go +++ b/dingo.go @@ -17,6 +17,7 @@ const ( var ErrInvalidInjectReceiver = errors.New("usage of 'Inject' method with struct receiver is not allowed") var traceCircular []circularTraceEntry +var errPointersToInterface = errors.New(" Do not use pointers to interface.") // EnableCircularTracing activates dingo's trace feature to find circular dependencies // this is super expensive (memory wise), so it should only be used for debugging purposes @@ -717,7 +718,7 @@ func (injector *Injector) requestInjection(object interface{}, circularTrace []c for fieldIndex := 0; fieldIndex < ctype.NumField(); fieldIndex++ { if tag, ok := ctype.Field(fieldIndex).Tag.Lookup("inject"); ok { field := current.Field(fieldIndex) - + currentFieldName := ctype.Field(fieldIndex).Name if field.Kind() == reflect.Struct { return fmt.Errorf("can not inject into struct %#v of %#v", field, current) } @@ -743,6 +744,9 @@ func (injector *Injector) requestInjection(object interface{}, circularTrace []c if field.Kind() != reflect.Ptr && field.Kind() != reflect.Interface && instance.Kind() == reflect.Ptr { field.Set(instance.Elem()) } else { + if field.Kind() == reflect.Ptr && field.Type().Kind() == reflect.Ptr && field.Type().Elem().Kind() == reflect.Interface { + return wrapErr(fmt.Errorf("field %#v is pointer to interface. %w", currentFieldName, errPointersToInterface)) + } field.Set(instance) } } diff --git a/dingo_test.go b/dingo_test.go index 967d0a1..4de871b 100644 --- a/dingo_test.go +++ b/dingo_test.go @@ -335,3 +335,19 @@ func TestInjectStructRec(t *testing.T) { _, err = injector.GetInstance(new(TestInjectStructRecInterface)) assert.Error(t, err) } + +type someStructWithInvalidInterfacePointer struct { + A *testInterface `inject:""` +} + +func TestInjectionOfInterfacePointer(t *testing.T) { + t.Parallel() + + injector, err := NewInjector() + assert.NoError(t, err) + + injector.Bind((*testInterface)(nil)).To(interfaceImpl1{}) + + _, err = injector.GetInstance(new(someStructWithInvalidInterfacePointer)) + assert.Error(t, err, "Expected error") +}