Skip to content

Commit

Permalink
Merge pull request #30987 from hashicorp/jbardin/get-schema-diags
Browse files Browse the repository at this point in the history
plugin: diagnostics must be checked on all schema calls
  • Loading branch information
jbardin committed May 4, 2022
2 parents f98cad3 + 8943c79 commit b97c640
Show file tree
Hide file tree
Showing 2 changed files with 132 additions and 64 deletions.
98 changes: 66 additions & 32 deletions internal/plugin/grpc_provider.go
Expand Up @@ -3,6 +3,7 @@ package plugin
import (
"context"
"errors"
"fmt"
"sync"

"github.com/zclconf/go-cty/cty"
Expand All @@ -11,6 +12,7 @@ import (
"github.com/hashicorp/terraform/internal/logging"
"github.com/hashicorp/terraform/internal/plugin/convert"
"github.com/hashicorp/terraform/internal/providers"
"github.com/hashicorp/terraform/internal/tfdiags"
proto "github.com/hashicorp/terraform/internal/tfplugin5"
ctyjson "github.com/zclconf/go-cty/cty/json"
"github.com/zclconf/go-cty/cty/msgpack"
Expand Down Expand Up @@ -63,9 +65,7 @@ type GRPCProvider struct {
schemas providers.GetProviderSchemaResponse
}

// getSchema is used internally to get the saved provider schema. The schema
// should have already been fetched from the provider, but we have to
// synchronize access to avoid being called concurrently with GetSchema.
// getSchema is used internally to get the cached provider schema
func (p *GRPCProvider) getSchema() providers.GetProviderSchemaResponse {
p.mu.Lock()
// unlock inline in case GetSchema needs to be called
Expand All @@ -75,44 +75,40 @@ func (p *GRPCProvider) getSchema() providers.GetProviderSchemaResponse {
}
p.mu.Unlock()

// the schema should have been fetched already, but give it another shot
// just in case things are being called out of order. This may happen for
// tests.
schemas := p.GetProviderSchema()
if schemas.Diagnostics.HasErrors() {
panic(schemas.Diagnostics.Err())
}

return schemas
return p.GetProviderSchema()
}

// getResourceSchema is a helper to extract the schema for a resource, and
// panics if the schema is not available.
func (p *GRPCProvider) getResourceSchema(name string) providers.Schema {
func (p *GRPCProvider) getResourceSchema(name string) (providers.Schema, tfdiags.Diagnostics) {
schema := p.getSchema()
resSchema, ok := schema.ResourceTypes[name]
if !ok {
panic("unknown resource type " + name)
schema.Diagnostics = schema.Diagnostics.Append(fmt.Errorf("unknown resource type " + name))
}
return resSchema
return resSchema, schema.Diagnostics
}

// gettDatasourceSchema is a helper to extract the schema for a datasource, and
// panics if that schema is not available.
func (p *GRPCProvider) getDatasourceSchema(name string) providers.Schema {
func (p *GRPCProvider) getDatasourceSchema(name string) (providers.Schema, tfdiags.Diagnostics) {
schema := p.getSchema()
if schema.Diagnostics.HasErrors() {
return providers.Schema{}, schema.Diagnostics
}

dataSchema, ok := schema.DataSources[name]
if !ok {
panic("unknown data source " + name)
schema.Diagnostics = schema.Diagnostics.Append(fmt.Errorf("unknown data source " + name))
}
return dataSchema
return dataSchema, schema.Diagnostics
}

// getProviderMetaSchema is a helper to extract the schema for the meta info
// defined for a provider,
func (p *GRPCProvider) getProviderMetaSchema() providers.Schema {
func (p *GRPCProvider) getProviderMetaSchema() (providers.Schema, tfdiags.Diagnostics) {
schema := p.getSchema()
return schema.ProviderMeta
return schema.ProviderMeta, schema.Diagnostics
}

func (p *GRPCProvider) GetProviderSchema() (resp providers.GetProviderSchemaResponse) {
Expand Down Expand Up @@ -201,7 +197,12 @@ func (p *GRPCProvider) ValidateProviderConfig(r providers.ValidateProviderConfig

func (p *GRPCProvider) ValidateResourceConfig(r providers.ValidateResourceConfigRequest) (resp providers.ValidateResourceConfigResponse) {
logger.Trace("GRPCProvider: ValidateResourceConfig")
resourceSchema := p.getResourceSchema(r.TypeName)

resourceSchema, diags := p.getResourceSchema(r.TypeName)
if diags.HasErrors() {
resp.Diagnostics = resp.Diagnostics.Append(diags)
return resp
}

mp, err := msgpack.Marshal(r.Config, resourceSchema.Block.ImpliedType())
if err != nil {
Expand All @@ -227,7 +228,11 @@ func (p *GRPCProvider) ValidateResourceConfig(r providers.ValidateResourceConfig
func (p *GRPCProvider) ValidateDataResourceConfig(r providers.ValidateDataResourceConfigRequest) (resp providers.ValidateDataResourceConfigResponse) {
logger.Trace("GRPCProvider: ValidateDataResourceConfig")

dataSchema := p.getDatasourceSchema(r.TypeName)
dataSchema, diags := p.getDatasourceSchema(r.TypeName)
if diags.HasErrors() {
resp.Diagnostics = resp.Diagnostics.Append(diags)
return resp
}

mp, err := msgpack.Marshal(r.Config, dataSchema.Block.ImpliedType())
if err != nil {
Expand All @@ -252,7 +257,11 @@ func (p *GRPCProvider) ValidateDataResourceConfig(r providers.ValidateDataResour
func (p *GRPCProvider) UpgradeResourceState(r providers.UpgradeResourceStateRequest) (resp providers.UpgradeResourceStateResponse) {
logger.Trace("GRPCProvider: UpgradeResourceState")

resSchema := p.getResourceSchema(r.TypeName)
resSchema, diags := p.getResourceSchema(r.TypeName)
if diags.HasErrors() {
resp.Diagnostics = resp.Diagnostics.Append(diags)
return resp
}

protoReq := &proto.UpgradeResourceState_Request{
TypeName: r.TypeName,
Expand Down Expand Up @@ -333,8 +342,13 @@ func (p *GRPCProvider) Stop() error {
func (p *GRPCProvider) ReadResource(r providers.ReadResourceRequest) (resp providers.ReadResourceResponse) {
logger.Trace("GRPCProvider: ReadResource")

resSchema := p.getResourceSchema(r.TypeName)
metaSchema := p.getProviderMetaSchema()
resSchema, diags := p.getResourceSchema(r.TypeName)
metaSchema, metaDiags := p.getProviderMetaSchema()
diags = diags.Append(metaDiags)
if diags.HasErrors() {
resp.Diagnostics = resp.Diagnostics.Append(diags)
return resp
}

mp, err := msgpack.Marshal(r.PriorState, resSchema.Block.ImpliedType())
if err != nil {
Expand Down Expand Up @@ -378,8 +392,13 @@ func (p *GRPCProvider) ReadResource(r providers.ReadResourceRequest) (resp provi
func (p *GRPCProvider) PlanResourceChange(r providers.PlanResourceChangeRequest) (resp providers.PlanResourceChangeResponse) {
logger.Trace("GRPCProvider: PlanResourceChange")

resSchema := p.getResourceSchema(r.TypeName)
metaSchema := p.getProviderMetaSchema()
resSchema, diags := p.getResourceSchema(r.TypeName)
metaSchema, metaDiags := p.getProviderMetaSchema()
diags = diags.Append(metaDiags)
if diags.HasErrors() {
resp.Diagnostics = resp.Diagnostics.Append(diags)
return resp
}

priorMP, err := msgpack.Marshal(r.PriorState, resSchema.Block.ImpliedType())
if err != nil {
Expand Down Expand Up @@ -444,8 +463,13 @@ func (p *GRPCProvider) PlanResourceChange(r providers.PlanResourceChangeRequest)
func (p *GRPCProvider) ApplyResourceChange(r providers.ApplyResourceChangeRequest) (resp providers.ApplyResourceChangeResponse) {
logger.Trace("GRPCProvider: ApplyResourceChange")

resSchema := p.getResourceSchema(r.TypeName)
metaSchema := p.getProviderMetaSchema()
resSchema, diags := p.getResourceSchema(r.TypeName)
metaSchema, metaDiags := p.getProviderMetaSchema()
diags = diags.Append(metaDiags)
if diags.HasErrors() {
resp.Diagnostics = resp.Diagnostics.Append(diags)
return resp
}

priorMP, err := msgpack.Marshal(r.PriorState, resSchema.Block.ImpliedType())
if err != nil {
Expand Down Expand Up @@ -522,7 +546,12 @@ func (p *GRPCProvider) ImportResourceState(r providers.ImportResourceStateReques
Private: imported.Private,
}

resSchema := p.getResourceSchema(resource.TypeName)
resSchema, diags := p.getResourceSchema(resource.TypeName)
if diags.HasErrors() {
resp.Diagnostics = resp.Diagnostics.Append(diags)
return resp
}

state, err := decodeDynamicValue(imported.State, resSchema.Block.ImpliedType())
if err != nil {
resp.Diagnostics = resp.Diagnostics.Append(err)
Expand All @@ -538,8 +567,13 @@ func (p *GRPCProvider) ImportResourceState(r providers.ImportResourceStateReques
func (p *GRPCProvider) ReadDataSource(r providers.ReadDataSourceRequest) (resp providers.ReadDataSourceResponse) {
logger.Trace("GRPCProvider: ReadDataSource")

dataSchema := p.getDatasourceSchema(r.TypeName)
metaSchema := p.getProviderMetaSchema()
dataSchema, diags := p.getDatasourceSchema(r.TypeName)
metaSchema, metaDiags := p.getProviderMetaSchema()
diags = diags.Append(metaDiags)
if diags.HasErrors() {
resp.Diagnostics = resp.Diagnostics.Append(diags)
return resp
}

config, err := msgpack.Marshal(r.Config, dataSchema.Block.ImpliedType())
if err != nil {
Expand Down

0 comments on commit b97c640

Please sign in to comment.