diff --git a/bash_completions.go b/bash_completions.go index 925e6e787b..733f4d1211 100644 --- a/bash_completions.go +++ b/bash_completions.go @@ -512,7 +512,9 @@ func writeLocalNonPersistentFlag(buf io.StringWriter, flag *pflag.Flag) { // Setup annotations for go completions for registered flags func prepareCustomAnnotationsForFlags(cmd *Command) { - for flag := range cmd.Root().flagCompletionFunctions { + flagCompletionMutex.RLock() + defer flagCompletionMutex.RUnlock() + for flag := range flagCompletionFunctions { // Make sure the completion script calls the __*_go_custom_completion function for // every registered flag. We need to do this here (and not when the flag was registered // for completion) so that we can know the root command name for the prefix diff --git a/command.go b/command.go index 9f33a461fa..2cc18891d7 100644 --- a/command.go +++ b/command.go @@ -142,9 +142,6 @@ type Command struct { // that we can use on every pflag set and children commands globNormFunc func(f *flag.FlagSet, name string) flag.NormalizedName - //flagCompletionFunctions is map of flag completion functions. - flagCompletionFunctions map[*flag.Flag]func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective) - // usageFunc is usage func defined by user. usageFunc func(*Command) error // usageTemplate is usage template defined by user. diff --git a/completions.go b/completions.go index 4687674aa3..a1c12c01eb 100644 --- a/completions.go +++ b/completions.go @@ -4,6 +4,7 @@ import ( "fmt" "os" "strings" + "sync" "github.com/spf13/pflag" ) @@ -17,6 +18,12 @@ const ( ShellCompNoDescRequestCmd = "__completeNoDesc" ) +// Global map of flag completion functions. Make sure to use flagCompletionMutex before you try to read and write from it. +var flagCompletionFunctions = map[*pflag.Flag]func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective){} + +// lock for reading and writing from flagCompletionFunctions +var flagCompletionMutex = &sync.RWMutex{} + // ShellCompDirective is a bit map representing the different behaviors the shell // can be instructed to have once completions have been provided. type ShellCompDirective int @@ -100,15 +107,16 @@ func (c *Command) RegisterFlagCompletionFunc(flagName string, f func(cmd *Comman if flag == nil { return fmt.Errorf("RegisterFlagCompletionFunc: flag '%s' does not exist", flagName) } + flagCompletionMutex.Lock() + defer flagCompletionMutex.Unlock() - root := c.Root() - if _, exists := root.flagCompletionFunctions[flag]; exists { + if _, exists := flagCompletionFunctions[flag]; exists { return fmt.Errorf("RegisterFlagCompletionFunc: flag '%s' already registered", flagName) } - if root.flagCompletionFunctions == nil { - root.flagCompletionFunctions = map[*pflag.Flag]func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective){} + if flagCompletionFunctions == nil { + flagCompletionFunctions = map[*pflag.Flag]func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective){} } - root.flagCompletionFunctions[flag] = f + flagCompletionFunctions[flag] = f return nil } @@ -402,7 +410,9 @@ func (c *Command) getCompletions(args []string) (*Command, []string, ShellCompDi // Find the completion function for the flag or command var completionFn func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective) if flag != nil && flagCompletion { - completionFn = c.Root().flagCompletionFunctions[flag] + flagCompletionMutex.RLock() + completionFn = flagCompletionFunctions[flag] + flagCompletionMutex.RUnlock() } else { completionFn = finalCmd.ValidArgsFunction } diff --git a/completions_test.go b/completions_test.go index aea06a241b..9d8b073b5d 100644 --- a/completions_test.go +++ b/completions_test.go @@ -1763,13 +1763,15 @@ func TestFlagCompletionWithNotInterspersedArgs(t *testing.T) { Run: emptyRun, ValidArgs: []string{"arg1", "arg2"}, } - rootCmd.AddCommand(childCmd, childCmd2) childCmd.Flags().Bool("bool", false, "test bool flag") childCmd.Flags().String("string", "", "test string flag") _ = childCmd.RegisterFlagCompletionFunc("string", func(cmd *Command, args []string, toComplete string) ([]string, ShellCompDirective) { return []string{"myval"}, ShellCompDirectiveDefault }) + // Important: only add the commands after RegisterFlagCompletionFunc was called + rootCmd.AddCommand(childCmd, childCmd2) + // Test flag completion with no argument output, err := executeCommand(rootCmd, ShellCompRequestCmd, "child", "--") if err != nil { @@ -1969,6 +1971,21 @@ func TestFlagCompletionWithNotInterspersedArgs(t *testing.T) { if output != expected { t.Errorf("expected: %q, got: %q", expected, output) } + + // Test that no flag completion works on a subcmd + output, err = executeCommand(rootCmd, ShellCompRequestCmd, "child", "--string", "") + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + expected = strings.Join([]string{ + "myval", + ":0", + "Completion ended with directive: ShellCompDirectiveDefault", ""}, "\n") + + if output != expected { + t.Errorf("expected: %q, got: %q", expected, output) + } } func TestFlagCompletionInGoWithDesc(t *testing.T) {