Skip to content

Commit

Permalink
entproto/cmd/protoc-gen-entgrpc - infra for gRPC field validation (#55)
Browse files Browse the repository at this point in the history
entproto/cmd/protoc-gen-entgrpc: adding infra for validation on the gRPC layer, UUID field support (normal field, edge id, id)

Fixes ent/ent#1402
  • Loading branch information
rotemtam committed Apr 4, 2021
1 parent 22fa38e commit b090bf5
Show file tree
Hide file tree
Showing 54 changed files with 5,285 additions and 196 deletions.
10 changes: 8 additions & 2 deletions entproto/cmd/protoc-gen-entgrpc/cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@ func (g *serviceGenerator) castToProtoFunc(fld *entproto.FieldMappingDescriptor)
// TODO(rotemtam): don't wrap if the ent type == the pb type
pbd := fld.PbFieldDescriptor
switch pbd.GetType() {
case dpb.FieldDescriptorProto_TYPE_BYTES:
if fld.EntField != nil && fld.EntField.IsUUID() {
return protogen.GoImportPath("entgo.io/contrib/entproto/runtime").Ident("MustExtractUUIDBytes"), nil
}
return "[]byte", nil
case dpb.FieldDescriptorProto_TYPE_BOOL:
return "bool", nil
case dpb.FieldDescriptorProto_TYPE_INT32:
Expand Down Expand Up @@ -62,13 +67,14 @@ func (g *serviceGenerator) castToEntFunc(fd *entproto.FieldMappingDescriptor) (i
case fld.IsBool(), fld.IsBytes(), fld.IsString(), fld.Type.Numeric():
return fld.Type.String(), nil
case fld.IsTime():
return protogen.GoImportPath("entgo.io/contrib/entproto").Ident("ExtractTime"), nil
return protogen.GoImportPath("entgo.io/contrib/entproto/runtime").Ident("ExtractTime"), nil
case fld.IsEnum():
ident := g.pbEnumIdent(fd)
methodName := "toEnt" + ident.GoName
return methodName, nil
case fld.IsUUID():
return protogen.GoImportPath("entgo.io/contrib/entproto/runtime").Ident("MustBytesToUUID"), nil
// case field.TypeJSON:
// case field.TypeUUID:
// case field.TypeOther:
default:
return nil, fmt.Errorf("entproto: no mapping to ent field type %q", fld.Type.ConstName())
Expand Down
5 changes: 4 additions & 1 deletion entproto/cmd/protoc-gen-entgrpc/main.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,9 @@ func (g *serviceGenerator) generate() error {
if err := g.generateToProtoFunc(); err != nil {
return err
}
if typeNeedsValidator(g.fieldMap) {
g.generateValidator()
}
g.P()

for _, method := range g.service.Methods {
Expand Down Expand Up @@ -176,7 +179,7 @@ func (g *serviceGenerator) pbEnumIdent(fld *entproto.FieldMappingDescriptor) pro
func (g *serviceGenerator) generateToProtoFunc() error {
// Mapper from the ent type to the proto type.
g.Tmpl(`
// toProto%(typeName) transforms the ent type to the pb type (TODO: complete implementation)
// toProto%(typeName) transforms the ent type to the pb type
func toProto%(typeName)(e *%(entTypeIdent)) *%(typeName){
return &%(typeName) {`, tmplValues{
"typeName": g.typeName,
Expand Down
17 changes: 17 additions & 0 deletions entproto/cmd/protoc-gen-entgrpc/methods.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
package main

import (
"strconv"

"entgo.io/ent/entc/gen"
"google.golang.org/protobuf/compiler/protogen"
)
Expand All @@ -29,6 +31,9 @@ func (g *serviceGenerator) generateGetMethod() error {
if err != nil {
return err
}
if fieldNeedsValidator(idField) {
g.generateIDFieldValidator(idField)
}
g.Tmpl(`get, err := svc.client.%(typeName).Get(ctx, %(cast)(req.Get%(pbIdField)()))
switch {
case err == nil:
Expand All @@ -50,6 +55,9 @@ func (g *serviceGenerator) generateDeleteMethod() error {
if err != nil {
return err
}
if fieldNeedsValidator(idField) {
g.generateIDFieldValidator(idField)
}
g.Tmpl(`err := svc.client.%(typeName).DeleteOneID(%(cast)(req.Get%(pbIdField)())).Exec(ctx)
switch {
case err == nil:
Expand Down Expand Up @@ -79,6 +87,14 @@ func (g *serviceGenerator) generateMutationMethod(op string) error {
g.Tmpl("%(reqVar) := req.Get%(typeName)()", g.withGlobals(tmplValues{
"reqVar": reqVar,
}))
if typeNeedsValidator(g.fieldMap) {
g.Tmpl(`if err := validate%(typeName)(%(reqVar), %(checkIDFlag)); err != nil {
return nil, %(statusErrf)(%(invalidArgument), "invalid argument: %s", err)
}`, g.withGlobals(tmplValues{
"reqVar": reqVar,
"checkIDFlag": strconv.FormatBool(op == "create"),
}))
}
switch op {
case "create":
g.Tmpl("res, err := svc.client.%(typeName).Create().", g.withGlobals())
Expand Down Expand Up @@ -154,6 +170,7 @@ func (g *serviceGenerator) withGlobals(additionals ...tmplValues) tmplValues {
"notFound": codes.Ident("NotFound"),
"internal": codes.Ident("Internal"),
"typeName": g.typeName,
"fmtErr": protogen.GoImportPath("fmt").Ident("Errorf"),
}
for _, additional := range additionals {
for k, v := range additional {
Expand Down
95 changes: 95 additions & 0 deletions entproto/cmd/protoc-gen-entgrpc/validation.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
// Copyright 2019-present Facebook
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package main

import (
"entgo.io/contrib/entproto"
"google.golang.org/protobuf/compiler/protogen"
)

func typeNeedsValidator(d entproto.FieldMap) bool {
for _, fld := range d {
if fieldNeedsValidator(fld) {
return true
}
}
return false
}

func fieldNeedsValidator(d *entproto.FieldMappingDescriptor) bool {
f := d.EntField
if d.IsEdgeField {
f = d.EntEdge.Type.ID
}
return f.IsUUID()
}

// generateValidator generates a validation function for the service entity, to verify that
// the gRPC input is safe to pass to ent. Ent has already rich validation functionality and
// this layer should *only* assert invariants that are expected by ent but cannot be guaranteed
// by gRPC. For instance, TypeUUID is serialized as a proto bytes field, must be 16-bytes long.
func (g *serviceGenerator) generateValidator() {
g.Tmpl(`
// validate%(typeName) validates that all fields are encoded properly and are safe to pass
// to the ent entity builder.
func validate%(typeName)(x *%(typeName), checkId bool) error {`, g.withGlobals())
for _, fld := range g.fieldMap.Fields() {
if fieldNeedsValidator(fld) {
var idCheckSuffix string
if fld.IsIDField {
idCheckSuffix = "&& checkId"
}

if fld.EntField.IsUUID() {
g.Tmpl(`if err := %(validateUUID)(x.Get%(pbField)()); err != nil %(suffix) {
return err
}`, g.withGlobals(tmplValues{
"pbField": fld.PbStructField(),
"validateUUID": protogen.GoImportPath("entgo.io/contrib/entproto/runtime").Ident("ValidateUUID"),
"suffix": idCheckSuffix,
}))
}
}
}
for _, edg := range g.fieldMap.Edges() {
if fieldNeedsValidator(edg) {
f := edg.EntEdge.Type.ID
if f.IsUUID() {
g.Tmpl(`if err := %(validateUUID)(x.Get%(pbField)().Get%(edgeIdField)()); err != nil {
return err
}`, g.withGlobals(tmplValues{
"pbField": edg.PbStructField(),
"edgeIdField": edg.EdgeIDPbStructField(),
"validateUUID": protogen.GoImportPath("entgo.io/contrib/entproto/runtime").Ident("ValidateUUID"),
}))
}
}
}
g.P("return nil")
g.P("}")
}

func (g *serviceGenerator) generateIDFieldValidator(idField *entproto.FieldMappingDescriptor) {
if idField.EntField.IsUUID() {
g.Tmpl(`if err := %(validateUUID)(req.Get%(pbField)()); err != nil {
return nil, %(statusErrf)(%(invalidArgument), "invalid argument: %s", err)
}`, g.withGlobals(tmplValues{
"pbField": idField.PbStructField(),
"validateUUID": protogen.GoImportPath("entgo.io/contrib/entproto/runtime").Ident("ValidateUUID"),
}))
return
}
panic("entproto: id field validation not implemented for " + idField.EntField.Type.String())
}
7 changes: 0 additions & 7 deletions entproto/fieldmap.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,10 @@ import (
"fmt"
"sort"
"strings"
"time"

"entgo.io/ent/entc/gen"
"github.com/go-openapi/inflect"
"github.com/jhump/protoreflect/desc"
"google.golang.org/protobuf/types/known/timestamppb"
)

// FieldMap returns a FieldMap containing descriptors of all of the mappings between the ent schema field
Expand Down Expand Up @@ -178,8 +176,3 @@ func extractEntEdgeByName(entType *gen.Type, name string) (*gen.Edge, error) {
}
return nil, fmt.Errorf("entproto: could not find find edge %q in %q", name, entType.Name)
}

// ExtractTime returns the time.Time from a proto WKT Timestamp
func ExtractTime(t *timestamppb.Timestamp) time.Time {
return t.AsTime()
}

0 comments on commit b090bf5

Please sign in to comment.