Skip to content

Commit

Permalink
Fix --version behaviour
Browse files Browse the repository at this point in the history
  • Loading branch information
hoshsadiq committed Aug 24, 2022
1 parent 5c899f9 commit a4ce13a
Show file tree
Hide file tree
Showing 2 changed files with 59 additions and 23 deletions.
59 changes: 37 additions & 22 deletions command.go
Expand Up @@ -19,6 +19,7 @@ import (
"bytes"
"context"
_ "embed"
"errors"
"fmt"
"io"
"os"
Expand All @@ -36,6 +37,9 @@ var defaultUsageTemplate string
// FParseErrAllowList configures Flag parse errors to be ignored
type FParseErrAllowList zflag.ParseErrorsAllowlist

// ErrVersion is the error returned if the flag -version is invoked.
var ErrVersion = errors.New("zulu: version requested")

type HookFuncE func(cmd *Command, args []string) error
type HookFunc func(cmd *Command, args []string)

Expand All @@ -51,7 +55,7 @@ type Group struct {
// definition to ensure usability.
type Command struct {
// Use is the one-line usage message.
// Recommended syntax is as follow:
// Recommended syntax is:
// [ ] identifies an optional argument. Arguments that are not enclosed in brackets are required.
// ... indicates that you can specify multiple values for the previous argument.
// | indicates mutually exclusive information. You can use the argument to the left of the separator or the
Expand Down Expand Up @@ -170,7 +174,7 @@ type Command struct {
persistentFinalizeHooks []HookFuncE

// groups for commands
commandGroups []*Group
commandGroups []Group

// args is actual args parsed from flags.
args []string
Expand Down Expand Up @@ -847,6 +851,11 @@ func (c *Command) execute(a []string) (err error) {
}
}()

for p := c; p != nil; p = p.Parent() {
prependHooks(&hooks, p.persistentInitializeHooks, p.PersistentInitializeE)
}
prependHooks(&hooks, c.initializeHooks, c.InitializeE)

// initialize help and version flag at the last point possible to allow for user
// overriding
hooks = append(hooks, func(cmd *Command, args []string) error {
Expand All @@ -856,11 +865,6 @@ func (c *Command) execute(a []string) (err error) {
return nil
})

for p := c; p != nil; p = p.Parent() {
prependHooks(&hooks, p.persistentInitializeHooks, p.PersistentInitializeE)
}
prependHooks(&hooks, c.initializeHooks, c.InitializeE)

hooks = append(hooks, func(cmd *Command, args []string) error {
err := c.ParseFlags(a)
if err != nil {
Expand Down Expand Up @@ -889,8 +893,8 @@ func (c *Command) execute(a []string) (err error) {
})

// for back-compat, only add version flag behavior if version is defined
if c.Version != "" {
hooks = append(hooks, func(cmd *Command, args []string) error {
hooks = append(hooks, func(cmd *Command, args []string) error {
if c.Version != "" {
versionVal, err := c.Flags().GetBool("version")
if err != nil {
c.Println(`"version" flag declared as non-bool. Please correct your code`)
Expand All @@ -900,13 +904,14 @@ func (c *Command) execute(a []string) (err error) {
err := tmpl(c.OutOrStdout(), c.VersionTemplate(), c)
if err != nil {
c.Println(err)
return err
}
return err
}

return nil
})
}
return ErrVersion
}
}
return nil
})

hooks = append(hooks, func(cmd *Command, args []string) error {
argWoFlags = c.Flags().Args()
Expand Down Expand Up @@ -1105,6 +1110,12 @@ func (c *Command) ExecuteC() (cmd *Command, err error) {

err = cmd.execute(flags)
if err != nil {
// Exit without errors when version requested. At this point the
// version has already been printed.
if err == ErrVersion {
return cmd, nil
}

// Always show help if requested, even if SilenceErrors is in
// effect
if err == zflag.ErrHelp {
Expand Down Expand Up @@ -1137,9 +1148,11 @@ func (c *Command) ValidateArgs(args []string) error {
if err := validateArgs(c, args); err != nil {
return err
}

if c.Args == nil {
return nil
}

return c.Args(c, args)
}

Expand All @@ -1149,20 +1162,21 @@ func (c *Command) validateRequiredFlags() error {
}

flags := c.Flags()
missingFlagNames := []string{}
var missingFlagNames []string
flags.VisitAll(func(pflag *zflag.Flag) {
requiredAnnotation, found := pflag.Annotations[BashCompOneRequiredFlag]
if !found {
return
}
if (requiredAnnotation[0] == "true") && !pflag.Changed {
if requiredAnnotation[0] == "true" && !pflag.Changed {
missingFlagNames = append(missingFlagNames, pflag.Name)
}
})

if len(missingFlagNames) > 0 {
return fmt.Errorf(`required flag(s) "%s" not set`, strings.Join(missingFlagNames, `", "`))
}

return nil
}

Expand Down Expand Up @@ -1199,11 +1213,12 @@ func (c *Command) InitDefaultVersionFlag() {
} else {
usage += c.Name()
}

var opts []zflag.Opt
if c.Flags().ShorthandLookup('v') == nil {
c.Flags().Bool("version", false, usage, zflag.OptShorthand('v'))
} else {
c.Flags().Bool("version", false, usage)
opts = append(opts, zflag.OptShorthand('v'))
}
c.Flags().Bool("version", false, usage, opts...)
}
}

Expand Down Expand Up @@ -1293,7 +1308,7 @@ func (c *Command) AddCommand(cmds ...*Command) {
cmds[i].parent = c
// if Group is not defined generate a new one with same title
if x.Group != "" && !c.ContainsGroup(x.Group) {
c.AddGroup(&Group{Group: x.Group, Title: x.Group})
c.AddGroup(Group{Group: x.Group, Title: x.Group})
}
// update max lengths
usageLen := len(x.Use)
Expand All @@ -1318,7 +1333,7 @@ func (c *Command) AddCommand(cmds ...*Command) {
}

// Groups returns a slice of child command groups.
func (c *Command) Groups() []*Group {
func (c *Command) Groups() []Group {
return c.commandGroups
}

Expand All @@ -1333,7 +1348,7 @@ func (c *Command) ContainsGroup(group string) bool {
}

// AddGroup adds one or more command groups to this parent command.
func (c *Command) AddGroup(groups ...*Group) {
func (c *Command) AddGroup(groups ...Group) {
c.commandGroups = append(c.commandGroups, groups...)
}

Expand Down
23 changes: 22 additions & 1 deletion command_test.go
Expand Up @@ -1369,6 +1369,27 @@ func TestHooks(t *testing.T) {
}
}

func TestHooksVersionFlagAddedWhenVersionSetOnInitialize(t *testing.T) {
c := &zulu.Command{
Use: "c",
InitializeE: func(c *zulu.Command, _ []string) error {
c.Version = "(devel)"
return nil
},
RunE: func(_ *zulu.Command, _ []string) error {
return nil
},
}

output, err := executeCommand(c, "--version")
if err != nil {
t.Fatalf("Unexpected error: %v", err)
}
if output != "c version (devel)\n" {
t.Errorf("Unexpected output: %v", output)
}
}

func TestPersistentHooks(t *testing.T) {
hooksArgs := map[string]string{}

Expand Down Expand Up @@ -1693,7 +1714,7 @@ func TestUsageHelpGroup(t *testing.T) {
func TestAddGroup(t *testing.T) {
var rootCmd = &zulu.Command{Use: "root", Short: "test", RunE: emptyRun}

rootCmd.AddGroup(&zulu.Group{Group: "group", Title: "Test group"})
rootCmd.AddGroup(zulu.Group{Group: "group", Title: "Test group"})
rootCmd.AddCommand(&zulu.Command{Use: "cmd", Group: "group", RunE: emptyRun})

output, err := executeCommand(rootCmd, "--help")
Expand Down

0 comments on commit a4ce13a

Please sign in to comment.