diff --git a/completions.go b/completions.go index 368b92a99..b60f6b200 100644 --- a/completions.go +++ b/completions.go @@ -145,8 +145,13 @@ func (c *Command) RegisterFlagCompletionFunc(flagName string, f func(cmd *Comman return nil } -// GetFlagCompletion returns the completion function for the given flag, if available. -func GetFlagCompletion(flag *pflag.Flag) (func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective), bool) { +// GetFlagCompletionFunc returns the completion function for the given flag of the command, if available. +func (c *Command) GetFlagCompletionFunc(flagName string) (func(*Command, []string, string) ([]string, ShellCompDirective), bool) { + flag := c.Flag(flagName) + if flag == nil { + return nil, false + } + flagCompletionMutex.RLock() defer flagCompletionMutex.RUnlock() @@ -154,16 +159,6 @@ func GetFlagCompletion(flag *pflag.Flag) (func(cmd *Command, args []string, toCo return completionFunc, exists } -// GetFlagCompletionByName returns the completion function for the given flag in the command by name, if available. -func (c *Command) GetFlagCompletionByName(flagName string) (func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective), bool) { - flag := c.Flags().Lookup(flagName) - if flag == nil { - return nil, false - } - - return GetFlagCompletion(flag) -} - // Returns a string listing the different directive enabled in the specified parameter func (d ShellCompDirective) string() string { var directives []string diff --git a/completions_test.go b/completions_test.go index 7585f88cd..d5aee2501 100644 --- a/completions_test.go +++ b/completions_test.go @@ -3427,3 +3427,93 @@ Completion ended with directive: ShellCompDirectiveNoFileComp }) } } + +func TestGetFlagCompletion(t *testing.T) { + rootCmd := &Command{Use: "root", Run: emptyRun} + + rootCmd.Flags().String("rootflag", "", "root flag") + _ = rootCmd.RegisterFlagCompletionFunc("rootflag", func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective) { + return []string{"rootvalue"}, ShellCompDirectiveKeepOrder + }) + + rootCmd.PersistentFlags().String("persistentflag", "", "persistent flag") + _ = rootCmd.RegisterFlagCompletionFunc("persistentflag", func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective) { + return []string{"persistentvalue"}, ShellCompDirectiveDefault + }) + + childCmd := &Command{Use: "child", Run: emptyRun} + + childCmd.Flags().String("childflag", "", "child flag") + _ = childCmd.RegisterFlagCompletionFunc("childflag", func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective) { + return []string{"childvalue"}, ShellCompDirectiveNoFileComp | ShellCompDirectiveNoSpace + }) + + rootCmd.AddCommand(childCmd) + + testcases := []struct { + desc string + cmd *Command + flagName string + exists bool + comps []string + directive ShellCompDirective + }{ + { + desc: "get flag completion function for command", + cmd: rootCmd, + flagName: "rootflag", + exists: true, + comps: []string{"rootvalue"}, + directive: ShellCompDirectiveKeepOrder, + }, + { + desc: "get persistent flag completion function for command", + cmd: rootCmd, + flagName: "persistentflag", + exists: true, + comps: []string{"persistentvalue"}, + directive: ShellCompDirectiveDefault, + }, + { + desc: "get flag completion function for child command", + cmd: childCmd, + flagName: "childflag", + exists: true, + comps: []string{"childvalue"}, + directive: ShellCompDirectiveNoFileComp | ShellCompDirectiveNoSpace, + }, + { + desc: "get persistent flag completion function for child command", + cmd: childCmd, + flagName: "persistentflag", + exists: true, + comps: []string{"persistentvalue"}, + directive: ShellCompDirectiveDefault, + }, + { + desc: "cannot get flag completion function for local parent flag", + cmd: childCmd, + flagName: "rootflag", + exists: false, + }, + } + + for _, tc := range testcases { + t.Run(tc.desc, func(t *testing.T) { + compFunc, exists := tc.cmd.GetFlagCompletionFunc(tc.flagName) + if tc.exists != exists { + t.Errorf("Unexpected result looking for flag completion function") + } + + if exists { + comps, directive := compFunc(tc.cmd, []string{}, "") + if strings.Join(tc.comps, " ") != strings.Join(comps, " ") { + t.Errorf("Unexpected completions %q", comps) + } + if tc.directive != directive { + t.Errorf("Unexpected directive %q", directive) + } + } + }) + } +}