Skip to content

Commit

Permalink
Merge pull request #1835 from dearchap/issue_1834
Browse files Browse the repository at this point in the history
Fix:(issue_1834) Add check for persistent required flags
  • Loading branch information
dearchap committed Dec 9, 2023
2 parents 2458b93 + d6eaf9a commit 5e65616
Show file tree
Hide file tree
Showing 2 changed files with 104 additions and 21 deletions.
83 changes: 62 additions & 21 deletions command.go
Expand Up @@ -611,18 +611,26 @@ func (cmd *Command) Run(ctx context.Context, osArgs []string) (deferErr error) {

if cmd.Action == nil {
cmd.Action = helpCommandAction
} else if len(cmd.Arguments) > 0 {
rargs := cmd.Args().Slice()
tracef("calling argparse with %[1]v", rargs)
for _, arg := range cmd.Arguments {
var err error
rargs, err = arg.Parse(rargs)
if err != nil {
tracef("calling with %[1]v (cmd=%[2]q)", err, cmd.Name)
return err
} else {
if err := cmd.checkPersistentRequiredFlags(); err != nil {
cmd.isInError = true
_ = ShowSubcommandHelp(cmd)
return err
}

if len(cmd.Arguments) > 0 {
rargs := cmd.Args().Slice()
tracef("calling argparse with %[1]v", rargs)
for _, arg := range cmd.Arguments {
var err error
rargs, err = arg.Parse(rargs)
if err != nil {
tracef("calling with %[1]v (cmd=%[2]q)", err, cmd.Name)
return err
}
}
cmd.parsedArgs = &stringSliceArgs{v: rargs}
}
cmd.parsedArgs = &stringSliceArgs{v: rargs}
}

if err := cmd.Action(ctx, cmd); err != nil {
Expand Down Expand Up @@ -929,26 +937,59 @@ func (cmd *Command) lookupFlagSet(name string) *flag.FlagSet {
return nil
}

func (cmd *Command) checkRequiredFlag(f Flag) (bool, string) {
if rf, ok := f.(RequiredFlag); ok && rf.IsRequired() {
flagPresent := false
flagName := ""

for _, key := range f.Names() {
flagName = key

if cmd.IsSet(strings.TrimSpace(key)) {
flagPresent = true
}
}

if !flagPresent && flagName != "" {
return false, flagName
}
}
return true, ""
}

func (cmd *Command) checkRequiredFlags() requiredFlagsErr {
tracef("checking for required flags (cmd=%[1]q)", cmd.Name)

missingFlags := []string{}

for _, f := range cmd.Flags {
if rf, ok := f.(RequiredFlag); ok && rf.IsRequired() {
flagPresent := false
flagName := ""
if pf, ok := f.(PersistentFlag); !ok || !pf.IsPersistent() {
if ok, name := cmd.checkRequiredFlag(f); !ok {
missingFlags = append(missingFlags, name)
}
}
}

for _, key := range f.Names() {
flagName = key
if len(missingFlags) != 0 {
tracef("found missing required flags %[1]q (cmd=%[2]q)", missingFlags, cmd.Name)

if cmd.IsSet(strings.TrimSpace(key)) {
flagPresent = true
}
}
return &errRequiredFlags{missingFlags: missingFlags}
}

tracef("all required flags set (cmd=%[1]q)", cmd.Name)

return nil
}

func (cmd *Command) checkPersistentRequiredFlags() requiredFlagsErr {
tracef("checking for required flags (cmd=%[1]q)", cmd.Name)

missingFlags := []string{}

if !flagPresent && flagName != "" {
missingFlags = append(missingFlags, flagName)
for _, f := range cmd.appliedFlags {
if pf, ok := f.(PersistentFlag); ok && pf.IsPersistent() {
if ok, name := cmd.checkRequiredFlag(f); !ok {
missingFlags = append(missingFlags, name)
}
}
}
Expand Down
42 changes: 42 additions & 0 deletions command_test.go
Expand Up @@ -2926,6 +2926,7 @@ func TestFlagAction(t *testing.T) {
func TestPersistentFlag(t *testing.T) {
var topInt, topPersistentInt, subCommandInt, appOverrideInt int64
var appFlag string
var appRequiredFlag string
var appOverrideCmdInt int64
var appSliceFloat64 []float64
var persistentCommandSliceInt []int64
Expand Down Expand Up @@ -2957,6 +2958,12 @@ func TestPersistentFlag(t *testing.T) {
Persistent: true,
Destination: &appOverrideInt,
},
&StringFlag{
Name: "persistentRequiredCommandFlag",
Persistent: true,
Required: true,
Destination: &appRequiredFlag,
},
},
Commands: []*Command{
{
Expand Down Expand Up @@ -3005,6 +3012,7 @@ func TestPersistentFlag(t *testing.T) {
"--persistentCommandSliceFlag", "102",
"--persistentCommandFloatSliceFlag", "102.455",
"--paof", "105",
"--persistentRequiredCommandFlag", "hellor",
"subcmd",
"--cmdPersistentFlag", "20",
"--cmdFlag", "11",
Expand All @@ -3021,6 +3029,10 @@ func TestPersistentFlag(t *testing.T) {
t.Errorf("Expected 'bar' got %s", appFlag)
}

if appRequiredFlag != "hellor" {
t.Errorf("Expected 'hellor' got %s", appRequiredFlag)
}

if topInt != 12 {
t.Errorf("Expected 12 got %d", topInt)
}
Expand Down Expand Up @@ -3096,6 +3108,36 @@ func TestPersistentFlagIsSet(t *testing.T) {
r.True(resultIsSet)
}

func TestRequiredPersistentFlag(t *testing.T) {

app := &Command{
Name: "root",
Flags: []Flag{
&StringFlag{
Name: "result",
Persistent: true,
Required: true,
},
},
Commands: []*Command{
{
Name: "sub",
Action: func(ctx context.Context, c *Command) error {
return nil
},
},
},
}

r := require.New(t)

err := app.Run(context.Background(), []string{"root", "sub"})
r.Error(err)

err = app.Run(context.Background(), []string{"root", "sub", "--result", "after"})
r.NoError(err)
}

func TestFlagDuplicates(t *testing.T) {
tests := []struct {
name string
Expand Down

0 comments on commit 5e65616

Please sign in to comment.