Skip to content

Commit

Permalink
Support simultanous name and group tags
Browse files Browse the repository at this point in the history
As per Dig issue:
uber-go#380

In order to support Fx feature requests

uber-go/fx#998
uber-go/fx#1036

We need to be able to drop the restriction, both in terms of options
dig.Name and dig.Group and dig.Out struct annotations on `name` and
`group` being mutually exclusive.

In a future PR, this can then be exploited to populate value group maps
where the 'name' tag becomes the key of a map[string][T]
  • Loading branch information
jquirke committed Mar 6, 2023
1 parent 7f9f0b8 commit 682c0ce
Show file tree
Hide file tree
Showing 5 changed files with 214 additions and 75 deletions.
4 changes: 2 additions & 2 deletions decorate.go
Expand Up @@ -12,7 +12,7 @@
//
// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
// FITNESS FOR A PARTICULAR PURPOSE AN NONINFRINGEMENT. IN NO EVENT SHALL THE
// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
Expand Down Expand Up @@ -288,7 +288,7 @@ func findResultKeys(r resultList) ([]key, error) {
keys = append(keys, key{t: innerResult.Type.Elem(), group: innerResult.Group})
case resultObject:
for _, f := range innerResult.Fields {
q = append(q, f.Result)
q = append(q, f.Results...)
}
case resultList:
q = append(q, innerResult.Results...)
Expand Down
134 changes: 119 additions & 15 deletions dig_test.go
Expand Up @@ -749,6 +749,53 @@ func TestEndToEndSuccess(t *testing.T) {
assert.ElementsMatch(t, actualStrs, expectedStrs, "list of strings provided must match")
})

t.Run("multiple As with Group", func(t *testing.T) {
c := digtest.New(t)
expectedNames := []string{"inst1", "inst2"}
expectedStrs := []string{"foo", "bar"}
for i, s := range expectedStrs {
s := s
c.RequireProvide(func() *bytes.Buffer {
return bytes.NewBufferString(s)
}, dig.Group("buffs"), dig.Name(expectedNames[i]),
dig.As(new(io.Reader), new(io.Writer)))
}

type in struct {
dig.In

Reader1 io.Reader `name:"inst1"`
Reader2 io.Reader `name:"inst2"`
Readers []io.Reader `group:"buffs"`
Writers []io.Writer `group:"buffs"`
}

var actualStrs []string
var actualStrsName []string

c.RequireInvoke(func(got in) {
require.Len(t, got.Readers, 2)
buf := make([]byte, 3)
for i, r := range got.Readers {
_, err := r.Read(buf)
require.NoError(t, err)
actualStrs = append(actualStrs, string(buf))
// put the text back
got.Writers[i].Write(buf)
}
_, err := got.Reader1.Read(buf)
require.NoError(t, err)
actualStrsName = append(actualStrsName, string(buf))
_, err = got.Reader2.Read(buf)
require.NoError(t, err)
actualStrsName = append(actualStrsName, string(buf))
require.Len(t, got.Writers, 2)
})

assert.ElementsMatch(t, actualStrs, expectedStrs, "list of strings provided must match")
assert.ElementsMatch(t, actualStrsName, expectedStrs, "names: list of strings provided must match")
})

t.Run("As same interface", func(t *testing.T) {
c := digtest.New(t)
c.RequireProvide(func() io.Reader {
Expand Down Expand Up @@ -1098,6 +1145,48 @@ func TestGroups(t *testing.T) {
})
})

t.Run("values are provided; coexist with name", func(t *testing.T) {
c := digtest.New(t, dig.SetRand(rand.New(rand.NewSource(0))))

type out struct {
dig.Out

Value int `group:"val"`
}

type out2 struct {
dig.Out

Value int `name:"inst1" group:"val"`
}

provide := func(i int) {
c.RequireProvide(func() out {
return out{Value: i}
})
}

provide(1)
provide(2)
provide(3)

c.RequireProvide(func() out2 {
return out2{Value: 4}
})

type in struct {
dig.In

SingleValue int `name:"inst1"`
Values []int `group:"val"`
}

c.RequireInvoke(func(i in) {
assert.Equal(t, []int{1, 2, 3, 4}, i.Values)
assert.Equal(t, 4, i.SingleValue)
})
})

t.Run("groups are provided via option", func(t *testing.T) {
c := digtest.New(t, dig.SetRand(rand.New(rand.NewSource(0))))

Expand All @@ -1122,6 +1211,36 @@ func TestGroups(t *testing.T) {
})
})

t.Run("groups are provided via option; coexist with name", func(t *testing.T) {
c := digtest.New(t, dig.SetRand(rand.New(rand.NewSource(0))))

provide := func(i int) {
c.RequireProvide(func() int {
return i
}, dig.Group("val"))
}

provide(1)
provide(2)
provide(3)

c.RequireProvide(func() int {
return 4
}, dig.Group("val"), dig.Name("inst1"))

type in struct {
dig.In

SingleValue int `name:"inst1"`
Values []int `group:"val"`
}

c.RequireInvoke(func(i in) {
assert.Equal(t, []int{1, 2, 3, 4}, i.Values)
assert.Equal(t, 4, i.SingleValue)
})
})

t.Run("different types may be grouped", func(t *testing.T) {
c := digtest.New(t, dig.SetRand(rand.New(rand.NewSource(0))))

Expand Down Expand Up @@ -1998,21 +2117,6 @@ func TestAsExpectingOriginalType(t *testing.T) {
})
}

func TestProvideIncompatibleOptions(t *testing.T) {
t.Parallel()

t.Run("group and name", func(t *testing.T) {
c := digtest.New(t)
err := c.Provide(func() io.Reader {
t.Fatal("this function must not be called")
return nil
}, dig.Group("foo"), dig.Name("bar"))
require.Error(t, err)
assert.Contains(t, err.Error(), "cannot use named values with value groups: "+
`name:"bar" provided with group:"foo"`)
})
}

type testStruct struct{}

func (testStruct) TestMethod(x int) float64 { return float64(x) }
Expand Down
6 changes: 0 additions & 6 deletions provide.go
Expand Up @@ -46,12 +46,6 @@ type provideOptions struct {
}

func (o *provideOptions) Validate() error {
if len(o.Group) > 0 {
if len(o.Name) > 0 {
return newErrInvalidInput(
fmt.Sprintf("cannot use named values with value groups: name:%q provided with group:%q", o.Name, o.Group), nil)
}
}

// Names must be representable inside a backquoted string. The only
// limitation for raw string literals as per
Expand Down
91 changes: 62 additions & 29 deletions result.go
Expand Up @@ -66,7 +66,7 @@ type resultOptions struct {
}

// newResult builds a result from the given type.
func newResult(t reflect.Type, opts resultOptions) (result, error) {
func newResult(t reflect.Type, opts resultOptions, noGroup bool) (result, error) {
switch {
case IsIn(t) || (t.Kind() == reflect.Ptr && IsIn(t.Elem())) || embedsType(t, _inPtrType):
return nil, newErrInvalidInput(fmt.Sprintf(
Expand All @@ -81,7 +81,7 @@ func newResult(t reflect.Type, opts resultOptions) (result, error) {
case t.Kind() == reflect.Ptr && IsOut(t.Elem()):
return nil, newErrInvalidInput(fmt.Sprintf(
"cannot return a pointer to a result object, use a value instead: %v is a pointer to a struct that embeds dig.Out", t), nil)
case len(opts.Group) > 0:
case len(opts.Group) > 0 && !noGroup:
g, err := parseGroupString(opts.Group)
if err != nil {
return nil, newErrInvalidInput(
Expand Down Expand Up @@ -176,7 +176,9 @@ func walkResult(r result, v resultVisitor) {
w := v
for _, f := range res.Fields {
if v := w.AnnotateWithField(f); v != nil {
walkResult(f.Result, v)
for _, r := range f.Results {
walkResult(r, v)
}
}
}
case resultList:
Expand All @@ -200,7 +202,7 @@ type resultList struct {
// For each item at index i returned by the constructor, resultIndexes[i]
// is the index in .Results for the corresponding result object.
// resultIndexes[i] is -1 for errors returned by constructors.
resultIndexes []int
resultIndexes [][]int
}

func (rl resultList) DotResult() []*dot.Result {
Expand All @@ -216,25 +218,47 @@ func newResultList(ctype reflect.Type, opts resultOptions) (resultList, error) {
rl := resultList{
ctype: ctype,
Results: make([]result, 0, numOut),
resultIndexes: make([]int, numOut),
resultIndexes: make([][]int, numOut),
}

resultIdx := 0
for i := 0; i < numOut; i++ {
t := ctype.Out(i)
if isError(t) {
rl.resultIndexes[i] = -1
rl.resultIndexes[i] = append(rl.resultIndexes[i], -1)
continue
}

r, err := newResult(t, opts)
if err != nil {
return rl, newErrInvalidInput(fmt.Sprintf("bad result %d", i+1), err)
addResult := func(nogroup bool) error {
r, err := newResult(t, opts, nogroup)
if err != nil {
return newErrInvalidInput(fmt.Sprintf("bad result %d", i+1), err)
}

rl.Results = append(rl.Results, r)
rl.resultIndexes[i] = append(rl.resultIndexes[i], resultIdx)
resultIdx++
return nil
}

// special case, its added as a group and a name using options alone
if len(opts.Name) > 0 && len(opts.Group) > 0 && !IsOut(t) {
// add as a group
if err := addResult(false); err != nil {
return rl, err
}
// add as single
if err := addResult(true); err != nil {
return rl, err
}
return rl, nil
}

// add as normal
if err := addResult(false); err != nil {
return rl, err
}

rl.Results = append(rl.Results, r)
rl.resultIndexes[i] = resultIdx
resultIdx++
}

return rl, nil
Expand All @@ -246,8 +270,10 @@ func (resultList) Extract(containerWriter, bool, reflect.Value) {

func (rl resultList) ExtractList(cw containerWriter, decorated bool, values []reflect.Value) error {
for i, v := range values {
if resultIdx := rl.resultIndexes[i]; resultIdx >= 0 {
rl.Results[resultIdx].Extract(cw, decorated, v)
for _, resultIdx := range rl.resultIndexes[i] {
if resultIdx >= 0 {
rl.Results[resultIdx].Extract(cw, decorated, v)
}
continue
}

Expand Down Expand Up @@ -384,7 +410,9 @@ func newResultObject(t reflect.Type, opts resultOptions) (resultObject, error) {

func (ro resultObject) Extract(cw containerWriter, decorated bool, v reflect.Value) {
for _, f := range ro.Fields {
f.Result.Extract(cw, decorated, v.Field(f.FieldIndex))
for _, r := range f.Results {
r.Extract(cw, decorated, v.Field(f.FieldIndex))
}
}
}

Expand All @@ -399,12 +427,16 @@ type resultObjectField struct {
// map to results.
FieldIndex int

// Result produced by this field.
Result result
// Results produced by this field.
Results []result
}

func (rof resultObjectField) DotResult() []*dot.Result {
return rof.Result.DotResult()
results := make([]*dot.Result, 0, len(rof.Results))
for _, r := range rof.Results {
results = append(results, r.DotResult()...)
}
return results
}

// newResultObjectField(i, f, opts) builds a resultObjectField from the field
Expand All @@ -414,7 +446,11 @@ func newResultObjectField(idx int, f reflect.StructField, opts resultOptions) (r
FieldName: f.Name,
FieldIndex: idx,
}

name := f.Tag.Get(_nameTag)
if len(name) > 0 {
// can modify in-place because options are passed-by-value.
opts.Name = name
}
var r result
switch {
case f.PkgPath != "":
Expand All @@ -427,20 +463,21 @@ func newResultObjectField(idx int, f reflect.StructField, opts resultOptions) (r
if err != nil {
return rof, err
}
rof.Results = append(rof.Results, r)
if len(name) == 0 {
break
}
fallthrough

default:
var err error
if name := f.Tag.Get(_nameTag); len(name) > 0 {
// can modify in-place because options are passed-by-value.
opts.Name = name
}
r, err = newResult(f.Type, opts)
r, err = newResult(f.Type, opts, false)
if err != nil {
return rof, err
}
rof.Results = append(rof.Results, r)
}

rof.Result = r
return rof, nil
}

Expand Down Expand Up @@ -493,7 +530,6 @@ func newResultGrouped(f reflect.StructField) (resultGrouped, error) {
Flatten: g.Flatten,
Type: f.Type,
}
name := f.Tag.Get(_nameTag)
optional, _ := isFieldOptional(f)
switch {
case g.Flatten && f.Type.Kind() != reflect.Slice:
Expand All @@ -502,9 +538,6 @@ func newResultGrouped(f reflect.StructField) (resultGrouped, error) {
case g.Soft:
return rg, newErrInvalidInput(fmt.Sprintf(
"cannot use soft with result value groups: soft was used with group %q", rg.Group), nil)
case name != "":
return rg, newErrInvalidInput(fmt.Sprintf(
"cannot use named values with value groups: name:%q provided with group:%q", name, rg.Group), nil)
case optional:
return rg, newErrInvalidInput("value groups cannot be optional", nil)
}
Expand Down

0 comments on commit 682c0ce

Please sign in to comment.