Skip to content

Commit

Permalink
Add ScopedOauth
Browse files Browse the repository at this point in the history
  • Loading branch information
cjgajard committed Jan 22, 2024
1 parent 36ac80f commit db48975
Show file tree
Hide file tree
Showing 2 changed files with 176 additions and 69 deletions.
114 changes: 103 additions & 11 deletions pagerdutyplugin/config.go
Original file line number Diff line number Diff line change
@@ -1,9 +1,12 @@
package pagerduty

import (
"context"
"fmt"
"log"
"net/http"
"os"
"path/filepath"
"sync"

"github.com/PagerDuty/go-pagerduty"
Expand Down Expand Up @@ -40,12 +43,18 @@ type Config struct {
// Region where the server of the service is deployed
ServiceRegion string

// Parameters for fine-grained access control
AppOauthScopedToken *AppOauthScopedToken

// API wrapper
client *pagerduty.Client
}

const invalidCreds = `
type AppOauthScopedToken struct {
ClientId, ClientSecret, Subdomain string
}

const invalidCreds = `
No valid credentials found for PagerDuty provider.
Please see https://www.terraform.io/docs/providers/pagerduty/index.html
for more information on providing credentials for this provider.
Expand All @@ -61,11 +70,6 @@ func (c *Config) Client() (*pagerduty.Client, error) {
return c.client, nil
}

// Validate that the PagerDuty token is set
if c.Token == "" {
return nil, fmt.Errorf(invalidCreds)
}

httpClient := http.DefaultClient
httpClient.Transport = logging.NewTransport("PagerDuty", http.DefaultTransport)

Expand All @@ -74,13 +78,33 @@ func (c *Config) Client() (*pagerduty.Client, error) {
apiUrl = c.ApiUrlOverride
}

client := pagerduty.NewClient(c.Token, []pagerduty.ClientOptions{
pagerduty.WithAPIEndpoint(apiUrl),
clientOpts := []pagerduty.ClientOptions{
WithHTTPClient(httpClient),
pagerduty.WithAPIEndpoint(apiUrl),
pagerduty.WithTerraformProvider(c.TerraformVersion),
// TODO: c.AppOauthScopedTokenParams
// TODO: c.APITokenType
}...)
}

if c.AppOauthScopedToken != nil {
tokenFile := getTokenFilepath()
account := fmt.Sprintf("as_account-%s.%s", c.ServiceRegion, c.AppOauthScopedToken.Subdomain)
accountAndScopes := []string{account}
accountAndScopes = append(accountAndScopes, availableOauthScopes()...)
opt := pagerduty.WithScopedOAuthAppTokenSource(pagerduty.NewFileTokenSource(
context.Background(),
c.AppOauthScopedToken.ClientId,
c.AppOauthScopedToken.ClientSecret,
accountAndScopes,
tokenFile,
))
clientOpts = append(clientOpts, opt)
}

// Validate that the PagerDuty token is set
if c.Token == "" && c.AppOauthScopedToken == nil {
log.Println("[CG] Stop")
return nil, fmt.Errorf(invalidCreds)
}
client := pagerduty.NewClient(c.Token, clientOpts...)

// TODO: oauth validation
// if !c.SkipCredsValidation {
Expand All @@ -105,6 +129,74 @@ func WithHTTPClient(httpClient pagerduty.HTTPClient) pagerduty.ClientOptions {
}
}

func getTokenFilepath() string {
homeDir, err := os.UserHomeDir()
if err == nil {
homeDir = filepath.Join(homeDir, ".pagerduty")
} else {
homeDir = ""
}
return filepath.Join(homeDir, "token.json")
}

func availableOauthScopes() []string {
return []string{
"abilities.read",
"addons.read",
"addons.write",
"analytics.read",
"audit_records.read",
"change_events.read",
"change_events.write",
"custom_fields.read",
"custom_fields.write",
"escalation_policies.read",
"escalation_policies.write",
"event_orchestrations.read",
"event_orchestrations.write",
"event_rules.read",
"event_rules.write",
"extension_schemas.read",
"extensions.read",
"extensions.write",
"incident_workflows.read",
"incident_workflows.write",
"incident_workflows:instances.write",
"incidents.read",
"incidents.write",
"licenses.read",
"notifications.read",
"oncalls.read",
"priorities.read",
"response_plays.read",
"response_plays.write",
"schedules.read",
"schedules.write",
"services.read",
"services.write",
"standards.read",
"standards.write",
"status_dashboards.read",
"status_pages.read",
"status_pages.write",
"subscribers.read",
"subscribers.write",
"tags.read",
"tags.write",
"teams.read",
"teams.write",
"templates.read",
"templates.write",
"users.read",
"users.write",
"users:contact_methods.read",
"users:contact_methods.write",
"users:sessions.read",
"users:sessions.write",
"vendors.read",
}
}

// ConfigurePagerdutyClient sets a pagerduty API client in a pointer to the
// property of any data source struct from the general configuration of the
// provider.
Expand Down
131 changes: 73 additions & 58 deletions pagerdutyplugin/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,28 +15,20 @@ import (

type Provider struct{}

func (p *Provider) Configure(ctx context.Context, req provider.ConfigureRequest, resp *provider.ConfigureResponse) {
config, diags := ReadConfig(ctx, req)
if len(diags) > 0 {
resp.Diagnostics.Append(diags...)
return
}

client, err := config.Client()
if err != nil {
resp.Diagnostics.Append(diag.NewErrorDiagnostic(
"Cannot obtain plugin client",
err.Error(),
))
}
resp.DataSourceData = client
}

func (p *Provider) Metadata(ctx context.Context, req provider.MetadataRequest, resp *provider.MetadataResponse) {
resp.TypeName = "pagerduty"
}

func (p *Provider) Schema(ctx context.Context, req provider.SchemaRequest, resp *provider.SchemaResponse) {
useAppOauthScopedTokenBlock := schema.ListNestedBlock{
NestedObject: schema.NestedBlockObject{
Attributes: map[string]schema.Attribute{
"pd_client_id": schema.StringAttribute{Optional: true},
"pd_client_secret": schema.StringAttribute{Optional: true},
"pd_subdomain": schema.StringAttribute{Optional: true},
},
},
}
resp.Schema = schema.Schema{
Attributes: map[string]schema.Attribute{
"api_url_override": schema.StringAttribute{Optional: true},
Expand All @@ -45,26 +37,8 @@ func (p *Provider) Schema(ctx context.Context, req provider.SchemaRequest, resp
"token": schema.StringAttribute{Optional: true},
"user_token": schema.StringAttribute{Optional: true},
},

Blocks: map[string]schema.Block{
"use_app_oauth_scoped_token": schema.ListNestedBlock{
NestedObject: schema.NestedBlockObject{
Attributes: map[string]schema.Attribute{
"pd_client_id": schema.StringAttribute{
Required: true,
// DefaultFunc: schema.EnvDefaultFunc("PAGERDUTY_CLIENT_ID", nil),
},
"pd_client_secret": schema.StringAttribute{
Required: true,
// DefaultFunc: schema.EnvDefaultFunc("PAGERDUTY_CLIENT_SECRET", nil),
},
"pd_subdomain": schema.StringAttribute{
Required: true,
// DefaultFunc: schema.EnvDefaultFunc("PAGERDUTY_SUBDOMAIN", nil),
},
},
},
},
"use_app_oauth_scoped_token": useAppOauthScopedTokenBlock,
},
}
}
Expand All @@ -83,27 +57,20 @@ func New() provider.Provider {
return &Provider{}
}

type providerArguments struct {
Token types.String `tfsdk:"token"`
UserToken types.String `tfsdk:"user_token"`
SkipCredentialsValidation types.Bool `tfsdk:"skip_credentials_validation"`
ServiceRegion types.String `tfsdk:"service_region"`
ApiUrlOverride types.String `tfsdk:"api_url_override"`
UseAppOauthScopedToken *struct {
PdClientId types.String `tfsdk:"pd_client_id"`
PdClientSecret types.String `tfsdk:"pd_client_secret"`
PdDomain types.String `tfsdk:"pd_domain"`
} `tfsdk:"use_app_oauth_scoped_token"`
}

func ReadConfig(ctx context.Context, req provider.ConfigureRequest) (*Config, diag.Diagnostics) {
var diags diag.Diagnostics
func (p *Provider) Configure(ctx context.Context, req provider.ConfigureRequest, resp *provider.ConfigureResponse) {
var args providerArguments
diags.Append(req.Config.Get(ctx, &args)...)
resp.Diagnostics.Append(req.Config.Get(ctx, &args)...)
if resp.Diagnostics.HasError() {
return
}

serviceRegion := args.ServiceRegion.ValueString()
if serviceRegion == "" {
serviceRegion = "us"
}

var regionApiUrl string
if serviceRegion == "us" || serviceRegion == "" {
if serviceRegion == "us" {
regionApiUrl = ""
} else {
regionApiUrl = serviceRegion + "."
Expand All @@ -122,13 +89,61 @@ func ReadConfig(ctx context.Context, req provider.ConfigureRequest) (*Config, di
ServiceRegion: serviceRegion,
}

if config.Token == "" {
config.Token = os.Getenv("PAGERDUTY_TOKEN")
if !args.UseAppOauthScopedToken.IsNull() {
blockList := []UseAppOauthScopedToken{}
resp.Diagnostics.Append(args.UseAppOauthScopedToken.ElementsAs(ctx, &blockList, false)...)
if resp.Diagnostics.HasError() {
return
}
config.AppOauthScopedToken = &AppOauthScopedToken{
ClientId: blockList[0].PdClientId.ValueString(),
ClientSecret: blockList[0].PdClientSecret.ValueString(),
Subdomain: blockList[0].PdSubdomain.ValueString(),
}
}
if config.UserToken == "" {
config.UserToken = os.Getenv("PAGERDUTY_USER_TOKEN")

if args.UseAppOauthScopedToken.IsNull() {
if config.Token == "" {
config.Token = os.Getenv("PAGERDUTY_TOKEN")
}
if config.UserToken == "" {
config.UserToken = os.Getenv("PAGERDUTY_USER_TOKEN")
}
} else {
if config.AppOauthScopedToken.ClientId == "" {
config.AppOauthScopedToken.ClientId = os.Getenv("PAGERDUTY_CLIENT_ID")
}
if config.AppOauthScopedToken.ClientSecret == "" {
config.AppOauthScopedToken.ClientSecret = os.Getenv("PAGERDUTY_CLIENT_SECRET")
}
if config.AppOauthScopedToken.Subdomain == "" {
config.AppOauthScopedToken.Subdomain = os.Getenv("PAGERDUTY_SUBDOMAIN")
}
}

log.Println("[INFO] Initializing PagerDuty plugin client")
return &config, diags

client, err := config.Client()
if err != nil {
resp.Diagnostics.Append(diag.NewErrorDiagnostic(
"Cannot obtain plugin client",
err.Error(),
))
}
resp.DataSourceData = client
}

type UseAppOauthScopedToken struct {
PdClientId types.String `tfsdk:"pd_client_id"`
PdClientSecret types.String `tfsdk:"pd_client_secret"`
PdSubdomain types.String `tfsdk:"pd_subdomain"`
}

type providerArguments struct {
Token types.String `tfsdk:"token"`
UserToken types.String `tfsdk:"user_token"`
SkipCredentialsValidation types.Bool `tfsdk:"skip_credentials_validation"`
ServiceRegion types.String `tfsdk:"service_region"`
ApiUrlOverride types.String `tfsdk:"api_url_override"`
UseAppOauthScopedToken types.List `tfsdk:"use_app_oauth_scoped_token"`
}

0 comments on commit db48975

Please sign in to comment.