From d6eaf9a756b2a79643711de06a6cc4e8c875edbe Mon Sep 17 00:00:00 2001 From: Naveen Gogineni Date: Fri, 8 Dec 2023 12:34:20 -0500 Subject: [PATCH] Fix:(issue_1834) Add check for persistent required flags --- command.go | 83 ++++++++++++++++++++++++++++++++++++------------- command_test.go | 42 +++++++++++++++++++++++++ 2 files changed, 104 insertions(+), 21 deletions(-) diff --git a/command.go b/command.go index 08b926f997..6e5298f682 100644 --- a/command.go +++ b/command.go @@ -522,18 +522,26 @@ func (cmd *Command) Run(ctx context.Context, osArgs []string) (deferErr error) { if cmd.Action == nil { cmd.Action = helpCommandAction - } else if len(cmd.Arguments) > 0 { - rargs := cmd.Args().Slice() - tracef("calling argparse with %[1]v", rargs) - for _, arg := range cmd.Arguments { - var err error - rargs, err = arg.Parse(rargs) - if err != nil { - tracef("calling with %[1]v (cmd=%[2]q)", err, cmd.Name) - return err + } else { + if err := cmd.checkPersistentRequiredFlags(); err != nil { + cmd.isInError = true + _ = ShowSubcommandHelp(cmd) + return err + } + + if len(cmd.Arguments) > 0 { + rargs := cmd.Args().Slice() + tracef("calling argparse with %[1]v", rargs) + for _, arg := range cmd.Arguments { + var err error + rargs, err = arg.Parse(rargs) + if err != nil { + tracef("calling with %[1]v (cmd=%[2]q)", err, cmd.Name) + return err + } } + cmd.parsedArgs = &stringSliceArgs{v: rargs} } - cmd.parsedArgs = &stringSliceArgs{v: rargs} } if err := cmd.Action(ctx, cmd); err != nil { @@ -840,26 +848,59 @@ func (cmd *Command) lookupFlagSet(name string) *flag.FlagSet { return nil } +func (cmd *Command) checkRequiredFlag(f Flag) (bool, string) { + if rf, ok := f.(RequiredFlag); ok && rf.IsRequired() { + flagPresent := false + flagName := "" + + for _, key := range f.Names() { + flagName = key + + if cmd.IsSet(strings.TrimSpace(key)) { + flagPresent = true + } + } + + if !flagPresent && flagName != "" { + return false, flagName + } + } + return true, "" +} + func (cmd *Command) checkRequiredFlags() requiredFlagsErr { tracef("checking for required flags (cmd=%[1]q)", cmd.Name) missingFlags := []string{} for _, f := range cmd.Flags { - if rf, ok := f.(RequiredFlag); ok && rf.IsRequired() { - flagPresent := false - flagName := "" + if pf, ok := f.(PersistentFlag); !ok || !pf.IsPersistent() { + if ok, name := cmd.checkRequiredFlag(f); !ok { + missingFlags = append(missingFlags, name) + } + } + } - for _, key := range f.Names() { - flagName = key + if len(missingFlags) != 0 { + tracef("found missing required flags %[1]q (cmd=%[2]q)", missingFlags, cmd.Name) - if cmd.IsSet(strings.TrimSpace(key)) { - flagPresent = true - } - } + return &errRequiredFlags{missingFlags: missingFlags} + } + + tracef("all required flags set (cmd=%[1]q)", cmd.Name) + + return nil +} + +func (cmd *Command) checkPersistentRequiredFlags() requiredFlagsErr { + tracef("checking for required flags (cmd=%[1]q)", cmd.Name) + + missingFlags := []string{} - if !flagPresent && flagName != "" { - missingFlags = append(missingFlags, flagName) + for _, f := range cmd.appliedFlags { + if pf, ok := f.(PersistentFlag); ok && pf.IsPersistent() { + if ok, name := cmd.checkRequiredFlag(f); !ok { + missingFlags = append(missingFlags, name) } } } diff --git a/command_test.go b/command_test.go index 82d9a4faea..bff3cd935d 100644 --- a/command_test.go +++ b/command_test.go @@ -2926,6 +2926,7 @@ func TestFlagAction(t *testing.T) { func TestPersistentFlag(t *testing.T) { var topInt, topPersistentInt, subCommandInt, appOverrideInt int64 var appFlag string + var appRequiredFlag string var appOverrideCmdInt int64 var appSliceFloat64 []float64 var persistentCommandSliceInt []int64 @@ -2957,6 +2958,12 @@ func TestPersistentFlag(t *testing.T) { Persistent: true, Destination: &appOverrideInt, }, + &StringFlag{ + Name: "persistentRequiredCommandFlag", + Persistent: true, + Required: true, + Destination: &appRequiredFlag, + }, }, Commands: []*Command{ { @@ -3005,6 +3012,7 @@ func TestPersistentFlag(t *testing.T) { "--persistentCommandSliceFlag", "102", "--persistentCommandFloatSliceFlag", "102.455", "--paof", "105", + "--persistentRequiredCommandFlag", "hellor", "subcmd", "--cmdPersistentFlag", "20", "--cmdFlag", "11", @@ -3021,6 +3029,10 @@ func TestPersistentFlag(t *testing.T) { t.Errorf("Expected 'bar' got %s", appFlag) } + if appRequiredFlag != "hellor" { + t.Errorf("Expected 'hellor' got %s", appRequiredFlag) + } + if topInt != 12 { t.Errorf("Expected 12 got %d", topInt) } @@ -3096,6 +3108,36 @@ func TestPersistentFlagIsSet(t *testing.T) { r.True(resultIsSet) } +func TestRequiredPersistentFlag(t *testing.T) { + + app := &Command{ + Name: "root", + Flags: []Flag{ + &StringFlag{ + Name: "result", + Persistent: true, + Required: true, + }, + }, + Commands: []*Command{ + { + Name: "sub", + Action: func(ctx context.Context, c *Command) error { + return nil + }, + }, + }, + } + + r := require.New(t) + + err := app.Run(context.Background(), []string{"root", "sub"}) + r.Error(err) + + err = app.Run(context.Background(), []string{"root", "sub", "--result", "after"}) + r.NoError(err) +} + func TestFlagDuplicates(t *testing.T) { tests := []struct { name string