diff --git a/internal/server/arg.go b/internal/server/arg.go index 97627c5..784a37f 100644 --- a/internal/server/arg.go +++ b/internal/server/arg.go @@ -7,10 +7,12 @@ import ( "github.com/hashicorp/go-argmapper" ) -func callResourceFactory(f *argmapper.Func, provider Provider, typeName string, converter *argmapper.Func) argmapper.Result { - opts := []argmapper.Arg{ - argmapper.Named("typeName", typeName), +// TypeName represents the resource / data source type name as passed by Terraform +type TypeName string +func callResourceFactory(f *argmapper.Func, provider Provider, typeName TypeName, converter *argmapper.Func) argmapper.Result { + opts := []argmapper.Arg{ + argmapper.Typed(typeName), argmapper.Typed(provider), argmapper.ConverterFunc(converter), diff --git a/internal/server/server.go b/internal/server/server.go index 48ab810..74611bb 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -28,8 +28,8 @@ func MustNew(providerFactoryFunc interface{}) *Server { func New(providerFactoryFunc interface{}) (*Server, error) { s := &Server{ - dsf: map[string]*argmapper.Func{}, - rf: map[string]*argmapper.Func{}, + dsf: map[TypeName]*argmapper.Func{}, + rf: map[TypeName]*argmapper.Func{}, } f, err := argmapper.NewFunc(func(p Provider) { @@ -54,8 +54,8 @@ func New(providerFactoryFunc interface{}) (*Server, error) { type Server struct { p Provider - dsf map[string]*argmapper.Func - rf map[string]*argmapper.Func + dsf map[TypeName]*argmapper.Func + rf map[TypeName]*argmapper.Func } func assertValidFactory(fn *argmapper.Func, target reflect.Type) error { @@ -72,14 +72,14 @@ func assertValidFactory(fn *argmapper.Func, target reflect.Type) error { return nil } -func (s *Server) MustRegisterDataSource(typeName string, factory interface{}) { +func (s *Server) MustRegisterDataSource(typeName TypeName, factory interface{}) { err := s.RegisterDataSource(typeName, factory) if err != nil { panic(err) } } -func (s *Server) RegisterDataSource(typeName string, factory interface{}) error { +func (s *Server) RegisterDataSource(typeName TypeName, factory interface{}) error { f, err := argmapper.NewFunc( factory, ) @@ -103,7 +103,7 @@ func (s *Server) RegisterDataSource(typeName string, factory interface{}) error return nil } -func (s *Server) dataSource(typeName string) (DataSource, error) { +func (s *Server) dataSource(typeName TypeName) (DataSource, error) { conv, ok := s.dsf[typeName] if !ok { return nil, fmt.Errorf("unable to find %q", typeName) @@ -124,14 +124,14 @@ func (s *Server) dataSource(typeName string) (DataSource, error) { return ds, nil } -func (s *Server) MustRegisterResource(typeName string, fn interface{}) { +func (s *Server) MustRegisterResource(typeName TypeName, fn interface{}) { err := s.RegisterResource(typeName, fn) if err != nil { panic(err) } } -func (s *Server) RegisterResource(typeName string, fn interface{}) error { +func (s *Server) RegisterResource(typeName TypeName, fn interface{}) error { f, err := argmapper.NewFunc(fn) if err != nil { return err @@ -153,7 +153,7 @@ func (s *Server) RegisterResource(typeName string, fn interface{}) error { return nil } -func (s *Server) resource(typeName string) (Resource, error) { +func (s *Server) resource(typeName TypeName) (Resource, error) { conv, ok := s.rf[typeName] if !ok { return nil, fmt.Errorf("unable to find %q", typeName) @@ -185,7 +185,7 @@ func (s *Server) GetProviderSchema(ctx context.Context, req *tfprotov5.GetProvid if err != nil { return nil, err } - resp.DataSourceSchemas[typeName] = ds.Schema(ctx) + resp.DataSourceSchemas[string(typeName)] = ds.Schema(ctx) } for typeName := range s.rf { @@ -193,7 +193,7 @@ func (s *Server) GetProviderSchema(ctx context.Context, req *tfprotov5.GetProvid if err != nil { return nil, err } - resp.ResourceSchemas[typeName] = r.Schema(ctx) + resp.ResourceSchemas[string(typeName)] = r.Schema(ctx) } return resp, nil @@ -265,7 +265,7 @@ func (s *Server) StopProvider(ctx context.Context, req *tfprotov5.StopProviderRe // ResourceServer methods func (s *Server) ValidateResourceTypeConfig(ctx context.Context, req *tfprotov5.ValidateResourceTypeConfigRequest) (*tfprotov5.ValidateResourceTypeConfigResponse, error) { - r, err := s.resource(req.TypeName) + r, err := s.resource(TypeName(req.TypeName)) if err != nil { return nil, err } @@ -287,7 +287,7 @@ func (s *Server) ValidateResourceTypeConfig(ctx context.Context, req *tfprotov5. } func (s *Server) UpgradeResourceState(ctx context.Context, req *tfprotov5.UpgradeResourceStateRequest) (*tfprotov5.UpgradeResourceStateResponse, error) { - r, err := s.resource(req.TypeName) + r, err := s.resource(TypeName(req.TypeName)) if err != nil { return nil, err } @@ -310,7 +310,7 @@ func (s *Server) UpgradeResourceState(ctx context.Context, req *tfprotov5.Upgrad } func (s *Server) ReadResource(ctx context.Context, req *tfprotov5.ReadResourceRequest) (*tfprotov5.ReadResourceResponse, error) { - r, err := s.resource(req.TypeName) + r, err := s.resource(TypeName(req.TypeName)) if err != nil { return nil, err } @@ -345,7 +345,7 @@ func (s *Server) ReadResource(ctx context.Context, req *tfprotov5.ReadResourceRe } func (s *Server) PlanResourceChange(ctx context.Context, req *tfprotov5.PlanResourceChangeRequest) (*tfprotov5.PlanResourceChangeResponse, error) { - r, err := s.resource(req.TypeName) + r, err := s.resource(TypeName(req.TypeName)) if err != nil { return nil, err } @@ -424,7 +424,7 @@ func (s *Server) PlanResourceChange(ctx context.Context, req *tfprotov5.PlanReso } func (s *Server) ApplyResourceChange(ctx context.Context, req *tfprotov5.ApplyResourceChangeRequest) (*tfprotov5.ApplyResourceChangeResponse, error) { - r, err := s.resource(req.TypeName) + r, err := s.resource(TypeName(req.TypeName)) if err != nil { return nil, err } @@ -522,7 +522,7 @@ func (s *Server) ImportResourceState(ctx context.Context, req *tfprotov5.ImportR // DataSourceServer methods func (s *Server) ValidateDataSourceConfig(ctx context.Context, req *tfprotov5.ValidateDataSourceConfigRequest) (*tfprotov5.ValidateDataSourceConfigResponse, error) { - ds, err := s.dataSource(req.TypeName) + ds, err := s.dataSource(TypeName(req.TypeName)) if err != nil { return nil, err } @@ -545,7 +545,7 @@ func (s *Server) ValidateDataSourceConfig(ctx context.Context, req *tfprotov5.Va } func (s *Server) ReadDataSource(ctx context.Context, req *tfprotov5.ReadDataSourceRequest) (*tfprotov5.ReadDataSourceResponse, error) { - ds, err := s.dataSource(req.TypeName) + ds, err := s.dataSource(TypeName(req.TypeName)) if err != nil { return nil, err }