Skip to content

Commit

Permalink
fix: don't remove flag value that matches subcommand name (#1781)
Browse files Browse the repository at this point in the history
When the command searches args to find the arg matching a
particular subcommand name, it needs to ignore flag values,
as it is possible that the value for a flag might match
the name of the sub command.

This change improves argsMinusFirstX() to ignore flag values
when it searches for the X to exclude from the result.
  • Loading branch information
brianpursley committed Nov 8, 2022
1 parent cc7e235 commit 6b0bd30
Show file tree
Hide file tree
Showing 2 changed files with 121 additions and 8 deletions.
40 changes: 32 additions & 8 deletions command.go
Expand Up @@ -655,13 +655,37 @@ Loop:

// argsMinusFirstX removes only the first x from args. Otherwise, commands that look like
// openshift admin policy add-role-to-user admin my-user, lose the admin argument (arg[4]).
func argsMinusFirstX(args []string, x string) []string {
for i, y := range args {
if x == y {
ret := []string{}
ret = append(ret, args[:i]...)
ret = append(ret, args[i+1:]...)
return ret
// Special care needs to be taken not to remove a flag value.
func (c *Command) argsMinusFirstX(args []string, x string) []string {
if len(args) == 0 {
return args
}
c.mergePersistentFlags()
flags := c.Flags()

Loop:
for pos := 0; pos < len(args); pos++ {
s := args[pos]
switch {
case s == "--":
// -- means we have reached the end of the parseable args. Break out of the loop now.
break Loop
case strings.HasPrefix(s, "--") && !strings.Contains(s, "=") && !hasNoOptDefVal(s[2:], flags):
fallthrough
case strings.HasPrefix(s, "-") && !strings.Contains(s, "=") && len(s) == 2 && !shortHasNoOptDefVal(s[1:], flags):
// This is a flag without a default value, and an equal sign is not used. Increment pos in order to skip
// over the next arg, because that is the value of this flag.
pos++
continue
case !strings.HasPrefix(s, "-"):
// This is not a flag or a flag value. Check to see if it matches what we're looking for, and if so,
// return the args, excluding the one at this position.
if s == x {
ret := []string{}
ret = append(ret, args[:pos]...)
ret = append(ret, args[pos+1:]...)
return ret
}
}
}
return args
Expand All @@ -686,7 +710,7 @@ func (c *Command) Find(args []string) (*Command, []string, error) {

cmd := c.findNext(nextSubCmd)
if cmd != nil {
return innerfind(cmd, argsMinusFirstX(innerArgs, nextSubCmd))
return innerfind(cmd, c.argsMinusFirstX(innerArgs, nextSubCmd))
}
return c, innerArgs
}
Expand Down
89 changes: 89 additions & 0 deletions command_test.go
Expand Up @@ -2603,3 +2603,92 @@ func TestHelpflagCommandExecutedWithoutVersionSet(t *testing.T) {
checkStringContains(t, output, HelpFlag)
checkStringOmits(t, output, VersionFlag)
}

func TestFind(t *testing.T) {
var foo, bar string
root := &Command{
Use: "root",
}
root.PersistentFlags().StringVarP(&foo, "foo", "f", "", "")
root.PersistentFlags().StringVarP(&bar, "bar", "b", "something", "")

child := &Command{
Use: "child",
}
root.AddCommand(child)

testCases := []struct {
args []string
expectedFoundArgs []string
}{
{
[]string{"child"},
[]string{},
},
{
[]string{"child", "child"},
[]string{"child"},
},
{
[]string{"child", "foo", "child", "bar", "child", "baz", "child"},
[]string{"foo", "child", "bar", "child", "baz", "child"},
},
{
[]string{"-f", "child", "child"},
[]string{"-f", "child"},
},
{
[]string{"child", "-f", "child"},
[]string{"-f", "child"},
},
{
[]string{"-b", "child", "child"},
[]string{"-b", "child"},
},
{
[]string{"child", "-b", "child"},
[]string{"-b", "child"},
},
{
[]string{"child", "-b"},
[]string{"-b"},
},
{
[]string{"-b", "-f", "child", "child"},
[]string{"-b", "-f", "child"},
},
{
[]string{"-f", "child", "-b", "something", "child"},
[]string{"-f", "child", "-b", "something"},
},
{
[]string{"-f", "child", "child", "-b"},
[]string{"-f", "child", "-b"},
},
{
[]string{"-f=child", "-b=something", "child"},
[]string{"-f=child", "-b=something"},
},
{
[]string{"--foo", "child", "--bar", "something", "child"},
[]string{"--foo", "child", "--bar", "something"},
},
}

for _, tc := range testCases {
t.Run(fmt.Sprintf("%v", tc.args), func(t *testing.T) {
cmd, foundArgs, err := root.Find(tc.args)
if err != nil {
t.Fatal(err)
}

if cmd != child {
t.Fatal("Expected cmd to be child, but it was not")
}

if !reflect.DeepEqual(tc.expectedFoundArgs, foundArgs) {
t.Fatalf("Wrong args\nExpected: %v\nGot: %v", tc.expectedFoundArgs, foundArgs)
}
})
}
}

0 comments on commit 6b0bd30

Please sign in to comment.