diff --git a/completions.go b/completions.go index 4687674aa..26c350907 100644 --- a/completions.go +++ b/completions.go @@ -101,14 +101,13 @@ func (c *Command) RegisterFlagCompletionFunc(flagName string, f func(cmd *Comman return fmt.Errorf("RegisterFlagCompletionFunc: flag '%s' does not exist", flagName) } - root := c.Root() - if _, exists := root.flagCompletionFunctions[flag]; exists { + if _, exists := c.flagCompletionFunctions[flag]; exists { return fmt.Errorf("RegisterFlagCompletionFunc: flag '%s' already registered", flagName) } - if root.flagCompletionFunctions == nil { - root.flagCompletionFunctions = map[*pflag.Flag]func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective){} + if c.flagCompletionFunctions == nil { + c.flagCompletionFunctions = map[*pflag.Flag]func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective){} } - root.flagCompletionFunctions[flag] = f + c.flagCompletionFunctions[flag] = f return nil } @@ -402,7 +401,7 @@ func (c *Command) getCompletions(args []string) (*Command, []string, ShellCompDi // Find the completion function for the flag or command var completionFn func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective) if flag != nil && flagCompletion { - completionFn = c.Root().flagCompletionFunctions[flag] + completionFn = finalCmd.flagCompletionFunctions[flag] } else { completionFn = finalCmd.ValidArgsFunction } diff --git a/completions_test.go b/completions_test.go index aea06a241..9d8b073b5 100644 --- a/completions_test.go +++ b/completions_test.go @@ -1763,13 +1763,15 @@ func TestFlagCompletionWithNotInterspersedArgs(t *testing.T) { Run: emptyRun, ValidArgs: []string{"arg1", "arg2"}, } - rootCmd.AddCommand(childCmd, childCmd2) childCmd.Flags().Bool("bool", false, "test bool flag") childCmd.Flags().String("string", "", "test string flag") _ = childCmd.RegisterFlagCompletionFunc("string", func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective) { return []string{"myval"}, ShellCompDirectiveDefault }) + // Important: only add the commands after RegisterFlagCompletionFunc was called + rootCmd.AddCommand(childCmd, childCmd2) + // Test flag completion with no argument output, err := executeCommand(rootCmd, ShellCompRequestCmd, "child", "--") if err != nil { @@ -1969,6 +1971,21 @@ func TestFlagCompletionWithNotInterspersedArgs(t *testing.T) { if output != expected { t.Errorf("expected: %q, got: %q", expected, output) } + + // Test that no flag completion works on a subcmd + output, err = executeCommand(rootCmd, ShellCompRequestCmd, "child", "--string", "") + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + expected = strings.Join([]string{ + "myval", + ":0", + "Completion ended with directive: ShellCompDirectiveDefault", ""}, "\n") + + if output != expected { + t.Errorf("expected: %q, got: %q", expected, output) + } } func TestFlagCompletionInGoWithDesc(t *testing.T) {