From bdbdf9524eef3b4bbe48bc6679602e8efa96f31b Mon Sep 17 00:00:00 2001 From: Brian Pursley Date: Fri, 19 Aug 2022 18:34:09 -0400 Subject: [PATCH] fix: don't remove flag value that matches subcommand name When the command searches args to find the arg matching a particular subcommand name, it needs to ignore flag values, as it is possible that the value for a flag might match the name of the sub command. This change improves argsMinusFirstX() to ignore flag values when it searches for the X to exclude from the result. --- command.go | 40 +++++++++++++++++----- command_test.go | 89 +++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 121 insertions(+), 8 deletions(-) diff --git a/command.go b/command.go index 675bb1340..212285969 100644 --- a/command.go +++ b/command.go @@ -613,13 +613,37 @@ Loop: // argsMinusFirstX removes only the first x from args. Otherwise, commands that look like // openshift admin policy add-role-to-user admin my-user, lose the admin argument (arg[4]). -func argsMinusFirstX(args []string, x string) []string { - for i, y := range args { - if x == y { - ret := []string{} - ret = append(ret, args[:i]...) - ret = append(ret, args[i+1:]...) - return ret +// Special care needs to be taken not to remove a flag value. +func argsMinusFirstX(args []string, x string, c *Command) []string { + if len(args) == 0 { + return args + } + c.mergePersistentFlags() + flags := c.Flags() + +Loop: + for pos := 0; pos < len(args); pos++ { + s := args[pos] + switch { + case s == "--": + // -- means we have reached the end of the parseable args. Break out of the loop now. + break Loop + case strings.HasPrefix(s, "--") && !strings.Contains(s, "=") && !hasNoOptDefVal(s[2:], flags): + fallthrough + case strings.HasPrefix(s, "-") && !strings.Contains(s, "=") && len(s) == 2 && !shortHasNoOptDefVal(s[1:], flags): + // This is a flag without a default value, and an equal sign is not used. Increment pos in order to skip + // over the next arg, because that is the value of this flag. + pos++ + continue + case !strings.HasPrefix(s, "-"): + // This is not a flag or a flag value. Check to see if it matches what we're looking for, and if so, + // return the args, excluding the one at this position. + if s == x { + ret := []string{} + ret = append(ret, args[:pos]...) + ret = append(ret, args[pos+1:]...) + return ret + } } } return args @@ -644,7 +668,7 @@ func (c *Command) Find(args []string) (*Command, []string, error) { cmd := c.findNext(nextSubCmd) if cmd != nil { - return innerfind(cmd, argsMinusFirstX(innerArgs, nextSubCmd)) + return innerfind(cmd, argsMinusFirstX(innerArgs, nextSubCmd, c)) } return c, innerArgs } diff --git a/command_test.go b/command_test.go index 0446e3c1d..8151f6fae 100644 --- a/command_test.go +++ b/command_test.go @@ -2193,3 +2193,92 @@ func TestSetContextPersistentPreRun(t *testing.T) { t.Error(err) } } + +func TestFind(t *testing.T) { + var foo, bar string + root := &Command{ + Use: "root", + } + root.PersistentFlags().StringVarP(&foo, "foo", "f", "", "") + root.PersistentFlags().StringVarP(&bar, "bar", "b", "something", "") + + child := &Command{ + Use: "child", + } + root.AddCommand(child) + + testCases := []struct { + args []string + expectedFlags []string + }{ + { + []string{"child"}, + []string{}, + }, + { + []string{"child", "child"}, + []string{"child"}, + }, + { + []string{"child", "foo", "child", "bar", "child", "baz", "child"}, + []string{"foo", "child", "bar", "child", "baz", "child"}, + }, + { + []string{"-f", "child", "child"}, + []string{"-f", "child"}, + }, + { + []string{"child", "-f", "child"}, + []string{"-f", "child"}, + }, + { + []string{"-b", "child", "child"}, + []string{"-b", "child"}, + }, + { + []string{"child", "-b", "child"}, + []string{"-b", "child"}, + }, + { + []string{"child", "-b"}, + []string{"-b"}, + }, + { + []string{"-b", "-f", "child", "child"}, + []string{"-b", "-f", "child"}, + }, + { + []string{"-f", "child", "-b", "something", "child"}, + []string{"-f", "child", "-b", "something"}, + }, + { + []string{"-f", "child", "child", "-b"}, + []string{"-f", "child", "-b"}, + }, + { + []string{"-f=child", "-b=something", "child"}, + []string{"-f=child", "-b=something"}, + }, + { + []string{"--foo", "child", "--bar", "something", "child"}, + []string{"--foo", "child", "--bar", "something"}, + }, + } + + for _, tc := range testCases { + t.Run(fmt.Sprintf("%v", tc.args), func(t *testing.T) { + cmd, flags, err := root.Find(tc.args) + if err != nil { + t.Fatal(err) + } + + if cmd != child { + t.Fatal("Expected cmd to be child, but it was not") + } + + if !reflect.DeepEqual(tc.expectedFlags, flags) { + t.Fatalf("Wrong flags\nExpected: %v\nGot: %v", tc.expectedFlags, flags) + } + }) + } +}