Skip to content

Commit

Permalink
Check for optional/ Ptr types within Unions [go/programgen]
Browse files Browse the repository at this point in the history
  • Loading branch information
aq17 committed Dec 12, 2022
1 parent 56be165 commit fabeda6
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 6 deletions.
@@ -0,0 +1,4 @@
changes:
- type: fix
scope: programgen/go
description: Check for optional/ Ptr types within Union types. This fixes a bug in Go programgen where optional outputs are not returned as pointers.
25 changes: 25 additions & 0 deletions pkg/codegen/go/gen_program_expressions.go
Expand Up @@ -889,8 +889,18 @@ func (g *generator) argumentTypeName(expr model.Expression, destType model.Type,
return g.argumentTypeName(expr, destType.ElementType, isInput)
case *model.UnionType:
for _, ut := range destType.ElementTypes {
isOptional := false
// check if the union contains none, which indicates this is an optional value
for _, ut := range destType.ElementTypes {
if ut.Equals(model.NoneType) {
isOptional = true
}
}
switch ut := ut.(type) {
case *model.OpaqueType:
if isOptional {
return g.argumentTypeNamePtr(expr, ut, isInput)
}
return g.argumentTypeName(expr, ut, isInput)
case *model.ConstType:
return g.argumentTypeName(expr, ut.Type, isInput)
Expand All @@ -907,6 +917,11 @@ func (g *generator) argumentTypeName(expr model.Expression, destType model.Type,
return ""
}

func (g *generator) argumentTypeNamePtr(expr model.Expression, destType model.Type, isInput bool) (result string) {
res := g.argumentTypeName(expr, destType, isInput)
return "*" + res
}

func (g *generator) genRelativeTraversal(w io.Writer,
traversal hcl.Traversal, parts []model.Traversable, isRootResource bool) {

Expand Down Expand Up @@ -1012,13 +1027,23 @@ func (g *generator) genApply(w io.Writer, expr *model.FunctionCallExpression) {
if retType == "[]string" {
typeAssertion = ".(pulumi.StringArrayOutput)"
} else {
if strings.HasPrefix(retType, "*") {
retType = Title(strings.TrimPrefix(retType, "*")) + "Ptr"
switch then.Body.(type) {
case *model.ScopeTraversalExpression:
traversal := then.Body.(*model.ScopeTraversalExpression)
traversal.RootName = "&" + traversal.RootName
then.Body = traversal
}
}
typeAssertion = fmt.Sprintf(".(%sOutput)", retType)
if !strings.Contains(retType, ".") {
typeAssertion = fmt.Sprintf(".(pulumi.%sOutput)", Title(retType))
}
}

if len(applyArgs) == 1 {
then.Signature.ReturnType = model.NewOptionalType(then.Signature.ReturnType)
// If we only have a single output, just generate a normal `.Apply`
g.Fgenf(w, "%.v.ApplyT(%.v)%s", applyArgs[0], then, typeAssertion)
} else {
Expand Down
Expand Up @@ -21,9 +21,9 @@ func main() {
if err != nil {
return err
}
ctx.Export("targetBucket", bucket.Loggings.ApplyT(func(loggings []s3.BucketLogging) (string, error) {
return loggings[0].TargetBucket, nil
}).(pulumi.StringOutput))
ctx.Export("targetBucket", bucket.Loggings.ApplyT(func(loggings []s3.BucketLogging) (*string, error) {
return &loggings[0].TargetBucket, nil
}).(pulumi.StringPtrOutput))
return nil
})
}
Expand Up @@ -16,9 +16,9 @@ func main() {
ctx.Export("foo", rt.Res1.ApplyT(func(res1 *resourceproperties.Res1) (resourceproperties.Obj2, error) {
return res1.Obj1.Res2.Obj2, nil
}).(resourceproperties.Obj2Output))
ctx.Export("complex", rt.Res1.ApplyT(func(res1 *resourceproperties.Res1) (float64, error) {
return res1.Obj1.Res2.Obj2.Answer, nil
}).(pulumi.Float64Output))
ctx.Export("complex", rt.Res1.ApplyT(func(res1 *resourceproperties.Res1) (*float64, error) {
return &res1.Obj1.Res2.Obj2.Answer, nil
}).(pulumi.Float64PtrOutput))
return nil
})
}

0 comments on commit fabeda6

Please sign in to comment.