From e9482026dc5dab3e89ccab6ccbe8252ea0b4bfb9 Mon Sep 17 00:00:00 2001 From: Ilia Choly Date: Tue, 2 Aug 2022 15:12:31 -0400 Subject: [PATCH 1/3] Add App.UnknownFlagHandler --- app.go | 2 ++ context.go | 11 ++++++++++- context_test.go | 26 ++++++++++++++++++++++++++ funcs.go | 3 +++ 4 files changed, 41 insertions(+), 1 deletion(-) diff --git a/app.go b/app.go index 7e64c2d9f8..8c10803295 100644 --- a/app.go +++ b/app.go @@ -78,6 +78,8 @@ type App struct { CommandNotFound CommandNotFoundFunc // Execute this function if a usage error occurs OnUsageError OnUsageErrorFunc + // Execute this function when an unknown flag is accessed from the context + UnknownFlagHandler UnknownFlagFunc // Compilation date Compiled time.Time // List of all authors who contributed diff --git a/context.go b/context.go index 6b497ed20d..07b515dd70 100644 --- a/context.go +++ b/context.go @@ -46,6 +46,9 @@ func (cCtx *Context) NumFlags() int { // Set sets a context flag to a value. func (cCtx *Context) Set(name, value string) error { + if cCtx.flagSet.Lookup(name) == nil { + cCtx.onUnknownFlag(name) + } return cCtx.flagSet.Set(name, value) } @@ -158,7 +161,7 @@ func (cCtx *Context) lookupFlagSet(name string) *flag.FlagSet { return c.flagSet } } - + cCtx.onUnknownFlag(name) return nil } @@ -192,6 +195,12 @@ func (cCtx *Context) checkRequiredFlags(flags []Flag) requiredFlagsErr { return nil } +func (cCtx *Context) onUnknownFlag(name string) { + if cCtx.App != nil && cCtx.App.UnknownFlagHandler != nil { + cCtx.App.UnknownFlagHandler(cCtx, name) + } +} + func makeFlagNameVisitor(names *[]string) func(*flag.Flag) { return func(f *flag.Flag) { nameParts := strings.Split(f.Name, ",") diff --git a/context_test.go b/context_test.go index 84757063eb..829855d2cc 100644 --- a/context_test.go +++ b/context_test.go @@ -150,6 +150,19 @@ func TestContext_Value(t *testing.T) { expect(t, c.Value("unknown-flag"), nil) } +func TestContext_Value_UnknownFlagHandler(t *testing.T) { + set := flag.NewFlagSet("test", 0) + var flagName string + app := &App{ + UnknownFlagHandler: func(_ *Context, name string) { + flagName = name + }, + } + c := NewContext(app, set, nil) + c.Value("missing") + expect(t, flagName, "missing") +} + func TestContext_Args(t *testing.T) { set := flag.NewFlagSet("test", 0) set.Bool("myflag", false, "doc") @@ -258,6 +271,19 @@ func TestContext_Set(t *testing.T) { expect(t, c.IsSet("int"), true) } +func TestContext_Set_StrictLookup(t *testing.T) { + set := flag.NewFlagSet("test", 0) + var flagName string + app := &App{ + UnknownFlagHandler: func(_ *Context, name string) { + flagName = name + }, + } + c := NewContext(app, set, nil) + c.Set("missing", "") + expect(t, flagName, "missing") +} + func TestContext_LocalFlagNames(t *testing.T) { set := flag.NewFlagSet("test", 0) set.Bool("one-flag", false, "doc") diff --git a/funcs.go b/funcs.go index 0a9b22c94a..1342bd2800 100644 --- a/funcs.go +++ b/funcs.go @@ -23,6 +23,9 @@ type CommandNotFoundFunc func(*Context, string) // is displayed and the execution is interrupted. type OnUsageErrorFunc func(cCtx *Context, err error, isSubcommand bool) error +// UnknownFlagFunc is executed when an unknown flag is accessed from the context. +type UnknownFlagFunc func(*Context, string) + // ExitErrHandlerFunc is executed if provided in order to handle exitError values // returned by Actions and Before/After functions. type ExitErrHandlerFunc func(cCtx *Context, err error) From 75cb4268bf933642798503e9d570dc3c8aad86a2 Mon Sep 17 00:00:00 2001 From: Ilia Choly Date: Mon, 8 Aug 2022 13:29:01 -0400 Subject: [PATCH 2/3] Rename App.UnknownFlagHandler to App.InvalidFlagAccessHandler --- app.go | 4 ++-- context.go | 10 +++++----- context_test.go | 8 ++++---- funcs.go | 4 ++-- 4 files changed, 13 insertions(+), 13 deletions(-) diff --git a/app.go b/app.go index 8c10803295..9df014628e 100644 --- a/app.go +++ b/app.go @@ -78,8 +78,8 @@ type App struct { CommandNotFound CommandNotFoundFunc // Execute this function if a usage error occurs OnUsageError OnUsageErrorFunc - // Execute this function when an unknown flag is accessed from the context - UnknownFlagHandler UnknownFlagFunc + // Execute this function when an invalid flag is accessed from the context + InvalidFlagAccessHandler InvalidFlagAccessFunc // Compilation date Compiled time.Time // List of all authors who contributed diff --git a/context.go b/context.go index 07b515dd70..06045540b2 100644 --- a/context.go +++ b/context.go @@ -47,7 +47,7 @@ func (cCtx *Context) NumFlags() int { // Set sets a context flag to a value. func (cCtx *Context) Set(name, value string) error { if cCtx.flagSet.Lookup(name) == nil { - cCtx.onUnknownFlag(name) + cCtx.onInvalidFlag(name) } return cCtx.flagSet.Set(name, value) } @@ -161,7 +161,7 @@ func (cCtx *Context) lookupFlagSet(name string) *flag.FlagSet { return c.flagSet } } - cCtx.onUnknownFlag(name) + cCtx.onInvalidFlag(name) return nil } @@ -195,9 +195,9 @@ func (cCtx *Context) checkRequiredFlags(flags []Flag) requiredFlagsErr { return nil } -func (cCtx *Context) onUnknownFlag(name string) { - if cCtx.App != nil && cCtx.App.UnknownFlagHandler != nil { - cCtx.App.UnknownFlagHandler(cCtx, name) +func (cCtx *Context) onInvalidFlag(name string) { + if cCtx.App != nil && cCtx.App.InvalidFlagAccessHandler != nil { + cCtx.App.InvalidFlagAccessHandler(cCtx, name) } } diff --git a/context_test.go b/context_test.go index 829855d2cc..3bfb386491 100644 --- a/context_test.go +++ b/context_test.go @@ -150,11 +150,11 @@ func TestContext_Value(t *testing.T) { expect(t, c.Value("unknown-flag"), nil) } -func TestContext_Value_UnknownFlagHandler(t *testing.T) { +func TestContext_Value_InvalidFlagAccessHandler(t *testing.T) { set := flag.NewFlagSet("test", 0) var flagName string app := &App{ - UnknownFlagHandler: func(_ *Context, name string) { + InvalidFlagAccessHandler: func(_ *Context, name string) { flagName = name }, } @@ -271,11 +271,11 @@ func TestContext_Set(t *testing.T) { expect(t, c.IsSet("int"), true) } -func TestContext_Set_StrictLookup(t *testing.T) { +func TestContext_Set_InvalidFlagAccessHandler(t *testing.T) { set := flag.NewFlagSet("test", 0) var flagName string app := &App{ - UnknownFlagHandler: func(_ *Context, name string) { + InvalidFlagAccessHandler: func(_ *Context, name string) { flagName = name }, } diff --git a/funcs.go b/funcs.go index 1342bd2800..e77b0d0a10 100644 --- a/funcs.go +++ b/funcs.go @@ -23,8 +23,8 @@ type CommandNotFoundFunc func(*Context, string) // is displayed and the execution is interrupted. type OnUsageErrorFunc func(cCtx *Context, err error, isSubcommand bool) error -// UnknownFlagFunc is executed when an unknown flag is accessed from the context. -type UnknownFlagFunc func(*Context, string) +// InvalidFlagAccessFunc is executed when an invalid flag is accessed from the context. +type InvalidFlagAccessFunc func(*Context, string) // ExitErrHandlerFunc is executed if provided in order to handle exitError values // returned by Actions and Before/After functions. From 84daa4af1772ac5866fca0ef659acc16b5286005 Mon Sep 17 00:00:00 2001 From: Ilia Choly Date: Mon, 15 Aug 2022 16:46:17 -0400 Subject: [PATCH 3/3] Traverse parent contexts --- context.go | 8 ++++++-- context_test.go | 18 +++++++++++++++--- 2 files changed, 21 insertions(+), 5 deletions(-) diff --git a/context.go b/context.go index 06045540b2..a7aec11071 100644 --- a/context.go +++ b/context.go @@ -196,8 +196,12 @@ func (cCtx *Context) checkRequiredFlags(flags []Flag) requiredFlagsErr { } func (cCtx *Context) onInvalidFlag(name string) { - if cCtx.App != nil && cCtx.App.InvalidFlagAccessHandler != nil { - cCtx.App.InvalidFlagAccessHandler(cCtx, name) + for cCtx != nil { + if cCtx.App != nil && cCtx.App.InvalidFlagAccessHandler != nil { + cCtx.App.InvalidFlagAccessHandler(cCtx, name) + break + } + cCtx = cCtx.parentContext } } diff --git a/context_test.go b/context_test.go index 3bfb386491..0fbbf51f44 100644 --- a/context_test.go +++ b/context_test.go @@ -151,15 +151,27 @@ func TestContext_Value(t *testing.T) { } func TestContext_Value_InvalidFlagAccessHandler(t *testing.T) { - set := flag.NewFlagSet("test", 0) var flagName string app := &App{ InvalidFlagAccessHandler: func(_ *Context, name string) { flagName = name }, + Commands: []*Command{ + { + Name: "command", + Subcommands: []*Command{ + { + Name: "subcommand", + Action: func(ctx *Context) error { + ctx.Value("missing") + return nil + }, + }, + }, + }, + }, } - c := NewContext(app, set, nil) - c.Value("missing") + expect(t, app.Run([]string{"run", "command", "subcommand"}), nil) expect(t, flagName, "missing") }