Skip to content

Commit

Permalink
feat: add support for flag groups (#2488)
Browse files Browse the repository at this point in the history
  • Loading branch information
knqyf263 committed Jul 10, 2022
1 parent 5b7e0a8 commit 736e3f1
Show file tree
Hide file tree
Showing 16 changed files with 491 additions and 405 deletions.
143 changes: 98 additions & 45 deletions pkg/commands/app.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,32 @@ type VersionInfo struct {
VulnerabilityDB *metadata.Metadata `json:",omitempty"`
}

const (
usageTemplate = `Usage:{{if .Runnable}}
{{.UseLine}}{{end}}{{if .HasAvailableSubCommands}}
{{.CommandPath}} [command]{{end}}{{if gt (len .Aliases) 0}}
Aliases:
{{.NameAndAliases}}{{end}}{{if .HasExample}}
Examples:
{{.Example}}{{end}}{{if .HasAvailableSubCommands}}
Available Commands:{{range .Commands}}{{if (or .IsAvailableCommand (eq .Name "help"))}}
{{rpad .Name .NamePadding }} {{.Short}}{{end}}{{end}}{{end}}{{if .HasAvailableLocalFlags}}
%s
Global Flags:
{{.InheritedFlags.FlagUsages | trimTrailingWhitespaces}}{{end}}{{if .HasHelpSubCommands}}
Additional help topics:{{range .Commands}}{{if .IsAdditionalHelpTopicCommand}}
{{rpad .CommandPath .CommandPathPadding}} {{.Short}}{{end}}{{end}}{{end}}{{if .HasAvailableSubCommands}}
Use "{{.CommandPath}} [command] --help" for more information about a command.{{end}}
`
)

var (
outputWriter io.Writer = os.Stdout
)
Expand Down Expand Up @@ -192,13 +218,15 @@ func NewImageCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
reportFlagGroup.ReportFormat = nil // TODO: support --format summary

imageFlags := &flag.Flags{
CacheFlagGroup: flag.NewCacheFlagGroup(),
DBFlagGroup: flag.NewDBFlagGroup(),
ImageFlagGroup: flag.NewImageFlagGroup(), // container image specific
MisconfFlagGroup: flag.NewMisconfFlagGroup(),
RemoteFlagGroup: flag.NewClientFlags(), // for client/server mode
ReportFlagGroup: reportFlagGroup,
ScanFlagGroup: flag.NewScanFlagGroup(),
CacheFlagGroup: flag.NewCacheFlagGroup(),
DBFlagGroup: flag.NewDBFlagGroup(),
ImageFlagGroup: flag.NewImageFlagGroup(), // container image specific
MisconfFlagGroup: flag.NewMisconfFlagGroup(),
RemoteFlagGroup: flag.NewClientFlags(), // for client/server mode
ReportFlagGroup: reportFlagGroup,
ScanFlagGroup: flag.NewScanFlagGroup(),
SecretFlagGroup: flag.NewSecretFlagGroup(),
VulnerabilityFlagGroup: flag.NewVulnerabilityFlagGroup(),
}

cmd := &cobra.Command{
Expand Down Expand Up @@ -250,8 +278,9 @@ func NewImageCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
SilenceUsage: true,
}

cmd.SetFlagErrorFunc(flagErrorFunc)
imageFlags.AddFlags(cmd)
cmd.SetFlagErrorFunc(flagErrorFunc)
cmd.SetUsageTemplate(fmt.Sprintf(usageTemplate, imageFlags.Usages(cmd)))

return cmd
}
Expand All @@ -261,12 +290,14 @@ func NewFilesystemCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
reportFlagGroup.ReportFormat = nil // TODO: support --format summary

fsFlags := &flag.Flags{
CacheFlagGroup: flag.NewCacheFlagGroup(),
DBFlagGroup: flag.NewDBFlagGroup(),
MisconfFlagGroup: flag.NewMisconfFlagGroup(),
RemoteFlagGroup: flag.NewClientFlags(), // for client/server mode
ReportFlagGroup: reportFlagGroup,
ScanFlagGroup: flag.NewScanFlagGroup(),
CacheFlagGroup: flag.NewCacheFlagGroup(),
DBFlagGroup: flag.NewDBFlagGroup(),
MisconfFlagGroup: flag.NewMisconfFlagGroup(),
RemoteFlagGroup: flag.NewClientFlags(), // for client/server mode
ReportFlagGroup: reportFlagGroup,
ScanFlagGroup: flag.NewScanFlagGroup(),
SecretFlagGroup: flag.NewSecretFlagGroup(),
VulnerabilityFlagGroup: flag.NewVulnerabilityFlagGroup(),
}

cmd := &cobra.Command{
Expand Down Expand Up @@ -300,6 +331,7 @@ func NewFilesystemCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {

cmd.SetFlagErrorFunc(flagErrorFunc)
fsFlags.AddFlags(cmd)
cmd.SetUsageTemplate(fmt.Sprintf(usageTemplate, fsFlags.Usages(cmd)))

return cmd
}
Expand All @@ -309,11 +341,13 @@ func NewRootfsCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
reportFlagGroup.ReportFormat = nil // TODO: support --format summary

rootfsFlags := &flag.Flags{
CacheFlagGroup: flag.NewCacheFlagGroup(),
DBFlagGroup: flag.NewDBFlagGroup(),
MisconfFlagGroup: flag.NewMisconfFlagGroup(),
ReportFlagGroup: reportFlagGroup,
ScanFlagGroup: flag.NewScanFlagGroup(),
CacheFlagGroup: flag.NewCacheFlagGroup(),
DBFlagGroup: flag.NewDBFlagGroup(),
MisconfFlagGroup: flag.NewMisconfFlagGroup(),
ReportFlagGroup: reportFlagGroup,
ScanFlagGroup: flag.NewScanFlagGroup(),
SecretFlagGroup: flag.NewSecretFlagGroup(),
VulnerabilityFlagGroup: flag.NewVulnerabilityFlagGroup(),
}

cmd := &cobra.Command{
Expand Down Expand Up @@ -348,6 +382,7 @@ func NewRootfsCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
}
cmd.SetFlagErrorFunc(flagErrorFunc)
rootfsFlags.AddFlags(cmd)
cmd.SetUsageTemplate(fmt.Sprintf(usageTemplate, rootfsFlags.Usages(cmd)))

return cmd
}
Expand All @@ -357,12 +392,14 @@ func NewRepositoryCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
reportFlagGroup.ReportFormat = nil // TODO: support --format summary

repoFlags := &flag.Flags{
CacheFlagGroup: flag.NewCacheFlagGroup(),
DBFlagGroup: flag.NewDBFlagGroup(),
MisconfFlagGroup: flag.NewMisconfFlagGroup(),
RemoteFlagGroup: flag.NewClientFlags(), // for client/server mode
ReportFlagGroup: reportFlagGroup,
ScanFlagGroup: flag.NewScanFlagGroup(),
CacheFlagGroup: flag.NewCacheFlagGroup(),
DBFlagGroup: flag.NewDBFlagGroup(),
MisconfFlagGroup: flag.NewMisconfFlagGroup(),
RemoteFlagGroup: flag.NewClientFlags(), // for client/server mode
ReportFlagGroup: reportFlagGroup,
ScanFlagGroup: flag.NewScanFlagGroup(),
SecretFlagGroup: flag.NewSecretFlagGroup(),
VulnerabilityFlagGroup: flag.NewVulnerabilityFlagGroup(),
}

cmd := &cobra.Command{
Expand Down Expand Up @@ -392,6 +429,7 @@ func NewRepositoryCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
}
cmd.SetFlagErrorFunc(flagErrorFunc)
repoFlags.AddFlags(cmd)
cmd.SetUsageTemplate(fmt.Sprintf(usageTemplate, repoFlags.Usages(cmd)))

return cmd
}
Expand All @@ -409,12 +447,13 @@ func NewClientCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
remoteFlags.ServerAddr = &remoteAddr // disable '--server' and enable '--remote' instead.

clientFlags := &flag.Flags{
CacheFlagGroup: flag.NewCacheFlagGroup(),
DBFlagGroup: flag.NewDBFlagGroup(),
MisconfFlagGroup: flag.NewMisconfFlagGroup(),
RemoteFlagGroup: remoteFlags,
ReportFlagGroup: flag.NewReportFlagGroup(),
ScanFlagGroup: flag.NewScanFlagGroup(),
CacheFlagGroup: flag.NewCacheFlagGroup(),
DBFlagGroup: flag.NewDBFlagGroup(),
MisconfFlagGroup: flag.NewMisconfFlagGroup(),
RemoteFlagGroup: remoteFlags,
ReportFlagGroup: flag.NewReportFlagGroup(),
ScanFlagGroup: flag.NewScanFlagGroup(),
VulnerabilityFlagGroup: flag.NewVulnerabilityFlagGroup(),
}

cmd := &cobra.Command{
Expand Down Expand Up @@ -444,6 +483,7 @@ func NewClientCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
}
cmd.SetFlagErrorFunc(flagErrorFunc)
clientFlags.AddFlags(cmd)
cmd.SetUsageTemplate(fmt.Sprintf(usageTemplate, clientFlags.Usages(cmd)))

return cmd
}
Expand All @@ -459,7 +499,13 @@ func NewServerCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
Use: "server [flags]",
Aliases: []string{"s"},
Short: "Server mode",
Args: cobra.ExactArgs(0),
Example: ` # Run a server
$ trivy server
# Listen on 0.0.0.0:10000
$ trivy server --listen 0.0.0.0:10000
`,
Args: cobra.ExactArgs(0),
RunE: func(cmd *cobra.Command, args []string) error {
if err := serverFlags.Bind(cmd); err != nil {
return xerrors.Errorf("flag bind error: %w", err)
Expand All @@ -475,6 +521,7 @@ func NewServerCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
}
cmd.SetFlagErrorFunc(flagErrorFunc)
serverFlags.AddFlags(cmd)
cmd.SetUsageTemplate(fmt.Sprintf(usageTemplate, serverFlags.Usages(cmd)))

return cmd
}
Expand Down Expand Up @@ -528,6 +575,7 @@ func NewConfigCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
}
cmd.SetFlagErrorFunc(flagErrorFunc)
configFlags.AddFlags(cmd)
cmd.SetUsageTemplate(fmt.Sprintf(usageTemplate, configFlags.Usages(cmd)))

return cmd
}
Expand Down Expand Up @@ -690,12 +738,14 @@ func NewKubernetesCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
scanFlags.SecurityChecks = &securityChecks

k8sFlags := &flag.Flags{
CacheFlagGroup: flag.NewCacheFlagGroup(),
DBFlagGroup: flag.NewDBFlagGroup(),
K8sFlagGroup: flag.NewK8sFlagGroup(), // kubernetes-specific flags
MisconfFlagGroup: flag.NewMisconfFlagGroup(),
ReportFlagGroup: flag.NewReportFlagGroup(),
ScanFlagGroup: scanFlags,
CacheFlagGroup: flag.NewCacheFlagGroup(),
DBFlagGroup: flag.NewDBFlagGroup(),
K8sFlagGroup: flag.NewK8sFlagGroup(), // kubernetes-specific flags
MisconfFlagGroup: flag.NewMisconfFlagGroup(),
ReportFlagGroup: flag.NewReportFlagGroup(),
ScanFlagGroup: scanFlags,
SecretFlagGroup: flag.NewSecretFlagGroup(),
VulnerabilityFlagGroup: flag.NewVulnerabilityFlagGroup(),
}
cmd := &cobra.Command{
Use: "kubernetes [flags] { cluster | all | specific resources like kubectl. eg: pods, pod/NAME }",
Expand Down Expand Up @@ -736,6 +786,7 @@ func NewKubernetesCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
}
cmd.SetFlagErrorFunc(flagErrorFunc)
k8sFlags.AddFlags(cmd)
cmd.SetUsageTemplate(fmt.Sprintf(usageTemplate, k8sFlags.Usages(cmd)))

return cmd
}
Expand All @@ -748,12 +799,13 @@ func NewSBOMCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
scanFlags.SecurityChecks = nil // disable '--security-checks' as it always scans for vulnerabilities

sbomFlags := &flag.Flags{
CacheFlagGroup: flag.NewCacheFlagGroup(),
DBFlagGroup: flag.NewDBFlagGroup(),
RemoteFlagGroup: flag.NewClientFlags(), // for client/server mode
ReportFlagGroup: reportFlagGroup,
ScanFlagGroup: flag.NewScanFlagGroup(),
SBOMFlagGroup: flag.NewSBOMFlagGroup(),
CacheFlagGroup: flag.NewCacheFlagGroup(),
DBFlagGroup: flag.NewDBFlagGroup(),
RemoteFlagGroup: flag.NewClientFlags(), // for client/server mode
ReportFlagGroup: reportFlagGroup,
ScanFlagGroup: flag.NewScanFlagGroup(),
SBOMFlagGroup: flag.NewSBOMFlagGroup(),
VulnerabilityFlagGroup: flag.NewVulnerabilityFlagGroup(),
}

cmd := &cobra.Command{
Expand All @@ -766,7 +818,7 @@ func NewSBOMCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
$ trivy sbom --format cyclonedx /path/to/report.cdx
`,
PreRunE: func(cmd *cobra.Command, args []string) error {
if err := scanFlags.Bind(cmd); err != nil {
if err := sbomFlags.Bind(cmd); err != nil {
return xerrors.Errorf("flag bind error: %w", err)
}
return validateArgs(cmd, args)
Expand All @@ -790,6 +842,7 @@ func NewSBOMCommand(globalFlags *flag.GlobalFlagGroup) *cobra.Command {
}
cmd.SetFlagErrorFunc(flagErrorFunc)
sbomFlags.AddFlags(cmd)
cmd.SetUsageTemplate(fmt.Sprintf(usageTemplate, sbomFlags.Usages(cmd)))

return cmd
}
Expand Down
46 changes: 17 additions & 29 deletions pkg/flag/cache_flags.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"time"

"github.com/samber/lo"
"github.com/spf13/cobra"
"golang.org/x/xerrors"
)

Expand Down Expand Up @@ -85,40 +84,29 @@ type RedisOptions struct {
// NewCacheFlagGroup returns a default CacheFlagGroup
func NewCacheFlagGroup() *CacheFlagGroup {
return &CacheFlagGroup{
ClearCache: lo.ToPtr(ClearCacheFlag),
CacheBackend: lo.ToPtr(CacheBackendFlag),
CacheTTL: lo.ToPtr(CacheTTLFlag),
RedisCACert: lo.ToPtr(RedisCACertFlag),
RedisCert: lo.ToPtr(RedisCertFlag),
RedisKey: lo.ToPtr(RedisKeyFlag),
ClearCache: &ClearCacheFlag,
CacheBackend: &CacheBackendFlag,
CacheTTL: &CacheTTLFlag,
RedisCACert: &RedisCACertFlag,
RedisCert: &RedisCertFlag,
RedisKey: &RedisKeyFlag,
}
}

func (f *CacheFlagGroup) flags() []*Flag {
return []*Flag{f.ClearCache, f.CacheBackend, f.CacheTTL, f.RedisCACert, f.RedisCert, f.RedisKey}
func (fg *CacheFlagGroup) Name() string {
return "Cache"
}

func (f *CacheFlagGroup) AddFlags(cmd *cobra.Command) {
for _, flag := range f.flags() {
addFlag(cmd, flag)
}
}

func (f *CacheFlagGroup) Bind(cmd *cobra.Command) error {
for _, flag := range f.flags() {
if err := bind(cmd, flag); err != nil {
return err
}
}
return nil
func (fg *CacheFlagGroup) Flags() []*Flag {
return []*Flag{fg.ClearCache, fg.CacheBackend, fg.CacheTTL, fg.RedisCACert, fg.RedisCert, fg.RedisKey}
}

func (f *CacheFlagGroup) ToOptions() (CacheOptions, error) {
cacheBackend := getString(f.CacheBackend)
func (fg *CacheFlagGroup) ToOptions() (CacheOptions, error) {
cacheBackend := getString(fg.CacheBackend)
redisOptions := RedisOptions{
RedisCACert: getString(f.RedisCACert),
RedisCert: getString(f.RedisCert),
RedisKey: getString(f.RedisKey),
RedisCACert: getString(fg.RedisCACert),
RedisCert: getString(fg.RedisCert),
RedisKey: getString(fg.RedisKey),
}

// "redis://" or "fs" are allowed for now
Expand All @@ -135,9 +123,9 @@ func (f *CacheFlagGroup) ToOptions() (CacheOptions, error) {
}

return CacheOptions{
ClearCache: getBool(f.ClearCache),
ClearCache: getBool(fg.ClearCache),
CacheBackend: cacheBackend,
CacheTTL: getDuration(f.CacheTTL),
CacheTTL: getDuration(fg.CacheTTL),
RedisOptions: redisOptions,
}, nil
}
Expand Down

0 comments on commit 736e3f1

Please sign in to comment.