From a19996a59629e9dc2b32dc2fb8628040e6e38459 Mon Sep 17 00:00:00 2001 From: Glenn Lewis <6598971+gmlewis@users.noreply.github.com> Date: Fri, 7 May 2021 09:33:07 -0400 Subject: [PATCH] Support map type in go generate (#1867) Fixes: #1866. --- github/gen-accessors.go | 42 +++++++++++++++++++++++++++++---- github/gen-stringify-test.go | 5 ---- github/github-accessors.go | 8 +++++++ github/github-accessors_test.go | 10 ++++++++ github/github-stringify_test.go | 23 ++++++++---------- github/strings.go | 3 +++ github/strings_test.go | 4 ++-- 7 files changed, 70 insertions(+), 25 deletions(-) diff --git a/github/gen-accessors.go b/github/gen-accessors.go index fb0390e963..d8d5910f68 100644 --- a/github/gen-accessors.go +++ b/github/gen-accessors.go @@ -115,8 +115,7 @@ func (t *templateData) processAST(f *ast.File) error { continue } for _, field := range st.Fields.List { - se, ok := field.Type.(*ast.StarExpr) - if len(field.Names) == 0 || !ok { + if len(field.Names) == 0 { continue } @@ -132,13 +131,25 @@ func (t *templateData) processAST(f *ast.File) error { continue } + se, ok := field.Type.(*ast.StarExpr) + if !ok { + switch x := field.Type.(type) { + case *ast.MapType: + t.addMapType(x, ts.Name.String(), fieldName.String(), false) + continue + } + + logf("Skipping field type %T, fieldName=%v", field.Type, fieldName) + continue + } + switch x := se.X.(type) { case *ast.ArrayType: t.addArrayType(x, ts.Name.String(), fieldName.String()) case *ast.Ident: t.addIdent(x, ts.Name.String(), fieldName.String()) case *ast.MapType: - t.addMapType(x, ts.Name.String(), fieldName.String()) + t.addMapType(x, ts.Name.String(), fieldName.String(), true) case *ast.SelectorExpr: t.addSelectorExpr(x, ts.Name.String(), fieldName.String()) default: @@ -232,7 +243,7 @@ func (t *templateData) addIdent(x *ast.Ident, receiverType, fieldName string) { t.Getters = append(t.Getters, newGetter(receiverType, fieldName, x.String(), zeroValue, namedStruct)) } -func (t *templateData) addMapType(x *ast.MapType, receiverType, fieldName string) { +func (t *templateData) addMapType(x *ast.MapType, receiverType, fieldName string, isAPointer bool) { var keyType string switch key := x.Key.(type) { case *ast.Ident: @@ -253,7 +264,9 @@ func (t *templateData) addMapType(x *ast.MapType, receiverType, fieldName string fieldType := fmt.Sprintf("map[%v]%v", keyType, valueType) zeroValue := fmt.Sprintf("map[%v]%v{}", keyType, valueType) - t.Getters = append(t.Getters, newGetter(receiverType, fieldName, fieldType, zeroValue, false)) + ng := newGetter(receiverType, fieldName, fieldType, zeroValue, false) + ng.MapType = !isAPointer + t.Getters = append(t.Getters, ng) } func (t *templateData) addSelectorExpr(x *ast.SelectorExpr, receiverType, fieldName string) { @@ -300,6 +313,7 @@ type getter struct { FieldType string ZeroValue string NamedStruct bool // Getter for named struct. + MapType bool } type byName []*getter @@ -332,6 +346,14 @@ func ({{.ReceiverVar}} *{{.ReceiverType}}) Get{{.FieldName}}() *{{.FieldType}} { } return {{.ReceiverVar}}.{{.FieldName}} } +{{else if .MapType}} +// Get{{.FieldName}} returns the {{.FieldName}} map if it's non-nil, an empty map otherwise. +func ({{.ReceiverVar}} *{{.ReceiverType}}) Get{{.FieldName}}() {{.FieldType}} { + if {{.ReceiverVar}} == nil || {{.ReceiverVar}}.{{.FieldName}} == nil { + return {{.ZeroValue}} + } + return {{.ReceiverVar}}.{{.FieldName}} +} {{else}} // Get{{.FieldName}} returns the {{.FieldName}} field if it's non-nil, zero value otherwise. func ({{.ReceiverVar}} *{{.ReceiverType}}) Get{{.FieldName}}() {{.FieldType}} { @@ -368,6 +390,16 @@ func Test{{.ReceiverType}}_Get{{.FieldName}}(tt *testing.T) { {{.ReceiverVar}} = nil {{.ReceiverVar}}.Get{{.FieldName}}() } +{{else if .MapType}} +func Test{{.ReceiverType}}_Get{{.FieldName}}(tt *testing.T) { + zeroValue := {{.FieldType}}{} + {{.ReceiverVar}} := &{{.ReceiverType}}{ {{.FieldName}}: zeroValue } + {{.ReceiverVar}}.Get{{.FieldName}}() + {{.ReceiverVar}} = &{{.ReceiverType}}{} + {{.ReceiverVar}}.Get{{.FieldName}}() + {{.ReceiverVar}} = nil + {{.ReceiverVar}}.Get{{.FieldName}}() +} {{else}} func Test{{.ReceiverType}}_Get{{.FieldName}}(tt *testing.T) { var zeroValue {{.FieldType}} diff --git a/github/gen-stringify-test.go b/github/gen-stringify-test.go index 5474114bf9..cc1d153693 100644 --- a/github/gen-stringify-test.go +++ b/github/gen-stringify-test.go @@ -187,11 +187,6 @@ func (t *templateData) processAST(f *ast.File) error { continue } - if _, ok := field.Type.(*ast.MapType); ok { - t.addMapType(ts.Name.String(), fieldName.String()) - continue - } - se, ok := field.Type.(*ast.StarExpr) if !ok { logf("Ignoring type %T for Name=%q, FieldName=%q", field.Type, ts.Name.String(), fieldName.String()) diff --git a/github/github-accessors.go b/github/github-accessors.go index b0f0630640..e6f349436b 100644 --- a/github/github-accessors.go +++ b/github/github-accessors.go @@ -4028,6 +4028,14 @@ func (g *Gist) GetDescription() string { return *g.Description } +// GetFiles returns the Files map if it's non-nil, an empty map otherwise. +func (g *Gist) GetFiles() map[GistFilename]GistFile { + if g == nil || g.Files == nil { + return map[GistFilename]GistFile{} + } + return g.Files +} + // GetGitPullURL returns the GitPullURL field if it's non-nil, zero value otherwise. func (g *Gist) GetGitPullURL() string { if g == nil || g.GitPullURL == nil { diff --git a/github/github-accessors_test.go b/github/github-accessors_test.go index 12c38dd4e9..2c449f8b2d 100644 --- a/github/github-accessors_test.go +++ b/github/github-accessors_test.go @@ -4730,6 +4730,16 @@ func TestGist_GetDescription(tt *testing.T) { g.GetDescription() } +func TestGist_GetFiles(tt *testing.T) { + zeroValue := map[GistFilename]GistFile{} + g := &Gist{Files: zeroValue} + g.GetFiles() + g = &Gist{} + g.GetFiles() + g = nil + g.GetFiles() +} + func TestGist_GetGitPullURL(tt *testing.T) { var zeroValue string g := &Gist{GitPullURL: &zeroValue} diff --git a/github/github-stringify_test.go b/github/github-stringify_test.go index 71da5cc49b..b1dceb0c95 100644 --- a/github/github-stringify_test.go +++ b/github/github-stringify_test.go @@ -391,14 +391,13 @@ func TestGist_String(t *testing.T) { Description: String(""), Public: Bool(false), Owner: &User{}, - Files: nil, Comments: Int(0), HTMLURL: String(""), GitPullURL: String(""), GitPushURL: String(""), NodeID: String(""), } - want := `github.Gist{ID:"", Description:"", Public:false, Owner:github.User{}, Files:map[], Comments:0, HTMLURL:"", GitPullURL:"", GitPushURL:"", NodeID:""}` + want := `github.Gist{ID:"", Description:"", Public:false, Owner:github.User{}, Comments:0, HTMLURL:"", GitPullURL:"", GitPushURL:"", NodeID:""}` if got := v.String(); got != want { t.Errorf("Gist.String = %v, want %v", got, want) } @@ -531,17 +530,15 @@ func TestHeadCommit_String(t *testing.T) { func TestHook_String(t *testing.T) { v := Hook{ - URL: String(""), - ID: Int64(0), - Type: String(""), - Name: String(""), - TestURL: String(""), - PingURL: String(""), - LastResponse: nil, - Config: nil, - Active: Bool(false), - } - want := `github.Hook{URL:"", ID:0, Type:"", Name:"", TestURL:"", PingURL:"", LastResponse:map[], Config:map[], Active:false}` + URL: String(""), + ID: Int64(0), + Type: String(""), + Name: String(""), + TestURL: String(""), + PingURL: String(""), + Active: Bool(false), + } + want := `github.Hook{URL:"", ID:0, Type:"", Name:"", TestURL:"", PingURL:"", Active:false}` if got := v.String(); got != want { t.Errorf("Hook.String = %v, want %v", got, want) } diff --git a/github/strings.go b/github/strings.go index 431e1cc6c1..5611b96a88 100644 --- a/github/strings.go +++ b/github/strings.go @@ -72,6 +72,9 @@ func stringifyValue(w io.Writer, val reflect.Value) { if fv.Kind() == reflect.Slice && fv.IsNil() { continue } + if fv.Kind() == reflect.Map && fv.IsNil() { + continue + } if sep { w.Write([]byte(", ")) diff --git a/github/strings_test.go b/github/strings_test.go index 013542df51..83cb231bbd 100644 --- a/github/strings_test.go +++ b/github/strings_test.go @@ -98,10 +98,10 @@ func TestString(t *testing.T) { {Event{ID: String("1")}, `github.Event{ID:"1"}`}, {GistComment{ID: Int64(1)}, `github.GistComment{ID:1}`}, {GistFile{Size: Int(1)}, `github.GistFile{Size:1}`}, - {Gist{ID: String("1")}, `github.Gist{ID:"1", Files:map[]}`}, + {Gist{ID: String("1")}, `github.Gist{ID:"1"}`}, {GitObject{SHA: String("s")}, `github.GitObject{SHA:"s"}`}, {Gitignore{Name: String("n")}, `github.Gitignore{Name:"n"}`}, - {Hook{ID: Int64(1)}, `github.Hook{ID:1, LastResponse:map[], Config:map[]}`}, + {Hook{ID: Int64(1)}, `github.Hook{ID:1}`}, {IssueComment{ID: Int64(1)}, `github.IssueComment{ID:1}`}, {Issue{Number: Int(1)}, `github.Issue{Number:1}`}, {Key{ID: Int64(1)}, `github.Key{ID:1}`},