diff --git a/codegen/config/package.go b/codegen/config/package.go index faacd1496f..37692ece7e 100644 --- a/codegen/config/package.go +++ b/codegen/config/package.go @@ -10,9 +10,10 @@ import ( ) type PackageConfig struct { - Filename string `yaml:"filename,omitempty"` - Package string `yaml:"package,omitempty"` - Version int `yaml:"version,omitempty"` + Filename string `yaml:"filename,omitempty"` + Package string `yaml:"package,omitempty"` + Version int `yaml:"version,omitempty"` + ModelTemplate string `yaml:"model_template,omitempty"` } func (c *PackageConfig) ImportPath() string { diff --git a/docs/content/config.md b/docs/content/config.md index 86c91a47c3..1d167e2ef0 100644 --- a/docs/content/config.md +++ b/docs/content/config.md @@ -30,6 +30,8 @@ federation: model: filename: graph/model/models_gen.go package: model + # Optional: Pass in a path to a new gotpl template to use for generating the models + # model_template: [your/path/model.gotpl] # Where should the resolver implementations go? resolver: diff --git a/plugin/modelgen/models.go b/plugin/modelgen/models.go index 6b4f483752..23c2e334d9 100644 --- a/plugin/modelgen/models.go +++ b/plugin/modelgen/models.go @@ -4,6 +4,7 @@ import ( _ "embed" "fmt" "go/types" + "os" "sort" "strings" "text/template" @@ -282,6 +283,10 @@ func (m *Plugin) MutateConfig(cfg *config.Config) error { "getInterfaceByName": getInterfaceByName, "generateGetter": generateGetter, } + newModelTemplate := modelTemplate + if cfg.Model.ModelTemplate != "" { + newModelTemplate = readModelTemplate(cfg.Model.ModelTemplate) + } err := templates.Render(templates.Options{ PackageName: cfg.Model.Package, @@ -289,7 +294,7 @@ func (m *Plugin) MutateConfig(cfg *config.Config) error { Data: b, GeneratedHeader: true, Packages: cfg.Packages, - Template: modelTemplate, + Template: newModelTemplate, Funcs: funcMap, }) if err != nil { @@ -645,3 +650,11 @@ func findAndHandleCyclicalRelationships(b *ModelBuild) { } } } + +func readModelTemplate(customModelTemplate string) string { + contentBytes, err := os.ReadFile(customModelTemplate) + if err != nil { + panic(err) + } + return string(contentBytes) +} diff --git a/plugin/modelgen/models_test.go b/plugin/modelgen/models_test.go index f4cd43beae..fa15fa010c 100644 --- a/plugin/modelgen/models_test.go +++ b/plugin/modelgen/models_test.go @@ -637,3 +637,14 @@ func Test_splitTagsBySpace(t *testing.T) { }) } } + +func TestCustomTemplate(t *testing.T) { + cfg, err := config.LoadConfig("testdata/gqlgen_custom_model_template.yml") + require.NoError(t, err) + require.NoError(t, cfg.Init()) + p := Plugin{ + MutateHook: mutateHook, + FieldHook: DefaultFieldMutateHook, + } + require.NoError(t, p.MutateConfig(cfg)) +} diff --git a/plugin/modelgen/out/generated.go b/plugin/modelgen/out/generated.go index f90f8b974a..656913e510 100644 --- a/plugin/modelgen/out/generated.go +++ b/plugin/modelgen/out/generated.go @@ -11,6 +11,8 @@ import ( "github.com/99designs/gqlgen/plugin/modelgen/internal/extrafields" ) +// Add any new functions or any additional code/template functionality here + type A interface { IsA() GetA() string diff --git a/plugin/modelgen/testdata/customModelTemplate.gotpl b/plugin/modelgen/testdata/customModelTemplate.gotpl new file mode 100644 index 0000000000..4ff192a297 --- /dev/null +++ b/plugin/modelgen/testdata/customModelTemplate.gotpl @@ -0,0 +1,106 @@ +{{ reserveImport "context" }} +{{ reserveImport "fmt" }} +{{ reserveImport "io" }} +{{ reserveImport "strconv" }} +{{ reserveImport "time" }} +{{ reserveImport "sync" }} +{{ reserveImport "errors" }} +{{ reserveImport "bytes" }} + +{{ reserveImport "github.com/vektah/gqlparser/v2" }} +{{ reserveImport "github.com/vektah/gqlparser/v2/ast" }} +{{ reserveImport "github.com/99designs/gqlgen/graphql" }} +{{ reserveImport "github.com/99designs/gqlgen/graphql/introspection" }} + +// Add any new functions or any additional code/template functionality here + +{{- range $model := .Interfaces }} + {{ with .Description }} {{.|prefixLines "// "}} {{ end }} + type {{ goModelName .Name }} interface { + {{- if not .OmitCheck }} + {{- range $impl := .Implements }} + Is{{ goModelName $impl }}() + {{- end }} + Is{{ goModelName .Name }}() + {{- end }} + {{- range $field := .Fields }} + {{- with .Description }} + {{.|prefixLines "// "}} + {{- end}} + Get{{ $field.GoName }}() {{ $field.Type | ref }} + {{- end }} + } +{{- end }} + +{{ range $model := .Models }} + {{with .Description }} {{.|prefixLines "// "}} {{end}} + type {{ goModelName .Name }} struct { + {{- range $field := .Fields }} + {{- with .Description }} + {{.|prefixLines "// "}} + {{- end}} + {{ $field.GoName }} {{$field.Type | ref}} `{{$field.Tag}}` + {{- end }} + } + + {{ range .Implements }} + func ({{ goModelName $model.Name }}) Is{{ goModelName . }}() {} + {{- with getInterfaceByName . }} + {{- range .Fields }} + {{- with .Description }} + {{.|prefixLines "// "}} + {{- end}} + {{ generateGetter $model . }} + {{- end }} + {{- end }} + {{ end }} +{{- end}} + +{{ range $enum := .Enums }} + {{ with .Description }} {{.|prefixLines "// "}} {{end}} + type {{ goModelName .Name }} string + const ( + {{- range $value := .Values}} + {{- with .Description}} + {{.|prefixLines "// "}} + {{- end}} + {{ goModelName $enum.Name .Name }} {{ goModelName $enum.Name }} = {{ .Name|quote }} + {{- end }} + ) + + var All{{ goModelName .Name }} = []{{ goModelName .Name }}{ + {{- range $value := .Values}} + {{ goModelName $enum.Name .Name }}, + {{- end }} + } + + func (e {{ goModelName .Name }}) IsValid() bool { + switch e { + case {{ range $index, $element := .Values}}{{if $index}},{{end}}{{ goModelName $enum.Name $element.Name }}{{end}}: + return true + } + return false + } + + func (e {{ goModelName .Name }}) String() string { + return string(e) + } + + func (e *{{ goModelName .Name }}) UnmarshalGQL(v interface{}) error { + str, ok := v.(string) + if !ok { + return fmt.Errorf("enums must be strings") + } + + *e = {{ goModelName .Name }}(str) + if !e.IsValid() { + return fmt.Errorf("%s is not a valid {{ .Name }}", str) + } + return nil + } + + func (e {{ goModelName .Name }}) MarshalGQL(w io.Writer) { + fmt.Fprint(w, strconv.Quote(e.String())) + } + +{{- end }} diff --git a/plugin/modelgen/testdata/gqlgen_custom_model_template.yml b/plugin/modelgen/testdata/gqlgen_custom_model_template.yml new file mode 100644 index 0000000000..ea3d96b5b7 --- /dev/null +++ b/plugin/modelgen/testdata/gqlgen_custom_model_template.yml @@ -0,0 +1,38 @@ +schema: + - "testdata/schema.graphql" + +exec: + filename: out/ignored.go +model: + filename: out/generated.go + model_template: "testdata/customModelTemplate.gotpl" + +models: + ExistingModel: + model: github.com/99designs/gqlgen/plugin/modelgen/out.ExistingModel + ExistingInput: + model: github.com/99designs/gqlgen/plugin/modelgen/out.ExistingInput + ExistingEnum: + model: github.com/99designs/gqlgen/plugin/modelgen/out.ExistingEnum + ExistingInterface: + model: github.com/99designs/gqlgen/plugin/modelgen/out.ExistingInterface + ExistingUnion: + model: github.com/99designs/gqlgen/plugin/modelgen/out.ExistingUnion + ExistingType: + model: github.com/99designs/gqlgen/plugin/modelgen/out.ExistingType + RenameFieldTest: + fields: + badName: + fieldName: GOODnaME + ExtraFieldsTest: + extraFields: + FieldInternalType: + description: "Internal field" + type: github.com/99designs/gqlgen/plugin/modelgen/internal/extrafields.Type + FieldStringPtr: + type: "*string" + FieldInt: + type: "int64" + overrideTags: 'json:"field_int_tag"' + FieldIntSlice: + type: "[]int64"