Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add PersistentPreRun to disable required flag for help and completion command #1992

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
28 changes: 24 additions & 4 deletions command.go
Expand Up @@ -30,7 +30,10 @@ import (
flag "github.com/spf13/pflag"
)

const FlagSetByCobraAnnotation = "cobra_annotation_flag_set_by_cobra"
const (
FlagSetByCobraAnnotation = "cobra_annotation_flag_set_by_cobra"
trueString = "true"
)

// FParseErrWhitelist configures Flag parse errors to be ignored
type FParseErrWhitelist flag.ParseErrorsWhitelist
Expand Down Expand Up @@ -1109,7 +1112,7 @@ func (c *Command) ValidateRequiredFlags() error {
if !found {
return
}
if (requiredAnnotation[0] == "true") && !pflag.Changed {
if (requiredAnnotation[0] == trueString) && !pflag.Changed {
missingFlagNames = append(missingFlagNames, pflag.Name)
}
})
Expand Down Expand Up @@ -1146,7 +1149,7 @@ func (c *Command) InitDefaultHelpFlag() {
usage += c.Name()
}
c.Flags().BoolP("help", "h", false, usage)
_ = c.Flags().SetAnnotation("help", FlagSetByCobraAnnotation, []string{"true"})
_ = c.Flags().SetAnnotation("help", FlagSetByCobraAnnotation, []string{trueString})
}
}

Expand All @@ -1172,7 +1175,7 @@ func (c *Command) InitDefaultVersionFlag() {
} else {
c.Flags().Bool("version", false, usage)
}
_ = c.Flags().SetAnnotation("version", FlagSetByCobraAnnotation, []string{"true"})
_ = c.Flags().SetAnnotation("version", FlagSetByCobraAnnotation, []string{trueString})
}
}

Expand Down Expand Up @@ -1209,6 +1212,23 @@ Simply type ` + c.Name() + ` help [path to command] for full details.`,
}
return completions, ShellCompDirectiveNoFileComp
},
PersistentPreRunE: func(cmd *Command, args []string) error {
cmd.Flags().VisitAll(func(pflag *flag.Flag) {
requiredAnnotation, found := pflag.Annotations[BashCompOneRequiredFlag]
if found && requiredAnnotation[0] == trueString {
// Disable any persistent required flags for the help command
pflag.Annotations[BashCompOneRequiredFlag] = []string{"false"}
}
})
// Adding PersistentPreRun on sub-commands prevents root's PersistentPreRun from being called.
// So it is intentionally called here.
if cmd.Root().PersistentPreRunE != nil {
return cmd.Root().PersistentPreRunE(cmd, args)
} else if cmd.Root().PersistentPreRun != nil {
cmd.Root().PersistentPreRun(cmd, args)
}
return nil
},
Run: func(c *Command, args []string) {
cmd, _, e := c.Root().Find(args)
if cmd == nil || e != nil {
Expand Down
31 changes: 31 additions & 0 deletions command_test.go
Expand Up @@ -885,6 +885,18 @@ func TestHelpCommandExecuted(t *testing.T) {
checkStringContains(t, output, rootCmd.Long)
}

func TestHelpCommandExecutedWithPersistentRequiredFlags(t *testing.T) {
rootCmd := &Command{Use: "root", Run: emptyRun}
rootCmd.PersistentFlags().Bool("foo", false, "")
childCmd := &Command{Use: "child", Run: emptyRun}
rootCmd.AddCommand(childCmd)
assertNoErr(t, rootCmd.MarkPersistentFlagRequired("foo"))

if _, err := executeCommand(rootCmd, "help"); err != nil {
t.Errorf("unexpected error: %v", err)
}
}

func TestHelpCommandExecutedOnChild(t *testing.T) {
rootCmd := &Command{Use: "root", Run: emptyRun}
childCmd := &Command{Use: "child", Long: "Long description", Run: emptyRun}
Expand Down Expand Up @@ -1599,6 +1611,25 @@ func TestPersistentHooks(t *testing.T) {
}
}

func TestPersistentPreRunHooksForHelpCommand(t *testing.T) {
executed := false

rootCmd := &Command{
Use: "root",
PersistentPreRun: func(*Command, []string) { executed = true },
Run: emptyRun,
}
childCmd := &Command{Use: "child", Run: emptyRun}
rootCmd.AddCommand(childCmd)

if _, err := executeCommand(rootCmd, "help"); err != nil {
t.Errorf("Unexpected error: %v", err)
}
if !executed {
t.Error("Root PersistentPreRun should have been executed")
}
}

// Related to https://github.com/spf13/cobra/issues/521.
func TestGlobalNormFuncPropagation(t *testing.T) {
normFunc := func(f *pflag.FlagSet, name string) pflag.NormalizedName {
Expand Down
17 changes: 17 additions & 0 deletions completions.go
Expand Up @@ -681,6 +681,23 @@ See each sub-command's help for details on how to use the generated script.
ValidArgsFunction: NoFileCompletions,
Hidden: c.CompletionOptions.HiddenDefaultCmd,
GroupID: c.completionCommandGroupID,
PersistentPreRunE: func(cmd *Command, args []string) error {
cmd.Flags().VisitAll(func(flag *pflag.Flag) {
requiredAnnotation, found := flag.Annotations[BashCompOneRequiredFlag]
if found && requiredAnnotation[0] == "true" {
// Disable any persistent required flags for the completion command
flag.Annotations[BashCompOneRequiredFlag] = []string{"false"}
}
})
// Adding PersistentPreRun on sub-commands prevents root's PersistentPreRun from being called.
// So it is intentionally called here.
if cmd.Root().PersistentPreRunE != nil {
return cmd.Root().PersistentPreRunE(cmd, args)
} else if cmd.Root().PersistentPreRun != nil {
cmd.Root().PersistentPreRun(cmd, args)
}
return nil
},
}
c.AddCommand(completionCmd)

Expand Down
36 changes: 36 additions & 0 deletions completions_test.go
Expand Up @@ -2447,6 +2447,42 @@ func TestDefaultCompletionCmd(t *testing.T) {
rootCmd.CompletionOptions.HiddenDefaultCmd = false
// Remove completion command for the next test
removeCompCmd(rootCmd)

// Test that required flag will be ignored
rootCmd.PersistentFlags().Bool("foo", false, "")
assertNoErr(t, rootCmd.MarkPersistentFlagRequired("foo"))
for _, shell := range []string{"bash", "fish", "powershell", "zsh"} {
if _, err = executeCommand(rootCmd, compCmdName, shell); err != nil {
t.Errorf("Unexpected error: %v", err)
}
}
// Remove completion command for the next test
removeCompCmd(rootCmd)
}

func TestPersistentPreRunHooksForCompletionCommand(t *testing.T) {
executed := false

rootCmd := &Command{
Use: "root",
PersistentPreRun: func(*Command, []string) { executed = true },
Run: emptyRun,
}
subCmd := &Command{
Use: "sub",
Run: emptyRun,
}
rootCmd.AddCommand(subCmd)

for _, shell := range []string{"bash", "fish", "powershell", "zsh"} {
if _, err := executeCommand(rootCmd, compCmdName, shell); err != nil {
t.Errorf("Unexpected error: %v", err)
}
if !executed {
t.Error("Root PersistentPreRun should have been executed")
}
executed = false
}
}

func TestCompleteCompletion(t *testing.T) {
Expand Down