Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add: IAM RDS auth to sql components and refactor #168

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,14 @@ require (
github.com/aws/aws-sdk-go-v2/config v1.26.6
github.com/aws/aws-sdk-go-v2/credentials v1.16.16
github.com/aws/aws-sdk-go-v2/feature/dynamodb/expression v1.6.17
github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.4.23
github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.15.15
github.com/aws/aws-sdk-go-v2/service/cloudwatch v1.32.2
github.com/aws/aws-sdk-go-v2/service/dynamodb v1.27.1
github.com/aws/aws-sdk-go-v2/service/firehose v1.24.0
github.com/aws/aws-sdk-go-v2/service/kinesis v1.24.7
github.com/aws/aws-sdk-go-v2/service/lambda v1.50.0
github.com/aws/aws-sdk-go-v2/service/rds v1.89.2
github.com/aws/aws-sdk-go-v2/service/s3 v1.48.1
github.com/aws/aws-sdk-go-v2/service/sns v1.27.0
github.com/aws/aws-sdk-go-v2/service/sqs v1.29.7
Expand Down Expand Up @@ -200,10 +202,10 @@ require (
github.com/aws/aws-sdk-go-v2/internal/ini v1.7.3 // indirect
github.com/aws/aws-sdk-go-v2/internal/v4a v1.2.10 // indirect
github.com/aws/aws-sdk-go-v2/service/dynamodbstreams v1.18.7 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.10.4 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.0 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.2.10 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/endpoint-discovery v1.8.11 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.10.10 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.4 // indirect
github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.16.10 // indirect
github.com/aws/aws-sdk-go-v2/service/secretsmanager v1.34.5
github.com/aws/aws-sdk-go-v2/service/sso v1.18.7 // indirect
Expand Down
12 changes: 8 additions & 4 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -843,6 +843,8 @@ github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.12.3/go.mod h1:uk1vhHHERfSVCUnq
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.14.4/go.mod h1:t4i+yGHMCcUNIX1x7YVYa6bH/Do7civ5I6cG/6PMfyA=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.14.11 h1:c5I5iH+DZcH3xOIMlz3/tCKJDaHFwYEmxvlh2fAcFo8=
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.14.11/go.mod h1:cRrYDYAMUohBJUtUnOhydaMHtiK/1NZ0Otc9lIb6O0Y=
github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.4.23 h1:B2qK61ZXCQu8tkD6eG/gUiIt9Vw9tmWFD7Xo02JPdMY=
github.com/aws/aws-sdk-go-v2/feature/rds/auth v1.4.23/go.mod h1:02rz9vMZsrOX9IwUcpoGZM4jPprFNPmtD6t9Ume9ECY=
github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.11.3/go.mod h1:0dHuD2HZZSiwfJSy1FO5bX1hQ1TxVV1QXXjpn3XUE44=
github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.14.0/go.mod h1:UcgIwJ9KHquYxs6Q5skC9qXjhYMK+JASDYcXQ4X7JZE=
github.com/aws/aws-sdk-go-v2/feature/s3/manager v1.15.15 h1:2MUXyGW6dVaQz6aqycpbdLIH1NMcUI6kW6vQ0RabGYg=
Expand Down Expand Up @@ -872,8 +874,8 @@ github.com/aws/aws-sdk-go-v2/service/firehose v1.24.0 h1:U3F5oeq3Lp1jv9ebLHNr1OS
github.com/aws/aws-sdk-go-v2/service/firehose v1.24.0/go.mod h1:vHumFD15AwENJSM3SsWzcPpMK24s/7vGN1Xp5rLguz0=
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.9.1/go.mod h1:GeUru+8VzrTXV/83XyMJ80KpH8xO89VPoUileyNQ+tc=
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.10.1/go.mod h1:l9ymW25HOqymeU2m1gbUQ3rUIsTwKs8gYHXkqDQUhiI=
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.10.4 h1:/b31bi3YVNlkzkBrm9LfpaKoaYZUxIAj4sHfOTmLfqw=
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.10.4/go.mod h1:2aGXHFmbInwgP9ZfpmdIfOELL79zhdNYNmReK8qDfdQ=
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.0 h1:TToQNkvGguu209puTojY/ozlqy2d/SFNcoLIqTFi42g=
github.com/aws/aws-sdk-go-v2/service/internal/accept-encoding v1.12.0/go.mod h1:0jp+ltwkf+SwG2fm/PKo8t4y8pJSgOCO4D8Lz3k0aHQ=
github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.1.3/go.mod h1:Seb8KNmD6kVTjwRjVEgOT5hPin6sq+v4C2ycJQDwuH8=
github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.2.3/go.mod h1:R+/S1O4TYpcktbVwddeOYg+uwUfLhADP2S/x4QwsCTM=
github.com/aws/aws-sdk-go-v2/service/internal/checksum v1.2.10 h1:L0ai8WICYHozIKK+OtPzVJBugL7culcuM4E4JOpIEm8=
Expand All @@ -882,8 +884,8 @@ github.com/aws/aws-sdk-go-v2/service/internal/endpoint-discovery v1.8.11 h1:e9AV
github.com/aws/aws-sdk-go-v2/service/internal/endpoint-discovery v1.8.11/go.mod h1:B90ZQJa36xo0ph9HsoteI1+r8owgQH/U1QNfqZQkj1Q=
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.9.3/go.mod h1:wlY6SVjuwvh3TVRpTqdy4I1JpBFLX4UGeKZdWntaocw=
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.10.3/go.mod h1:Owv1I59vaghv1Ax8zz8ELY8DN7/Y0rGS+WWAmjgi950=
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.10.10 h1:DBYTXwIGQSGs9w4jKm60F5dmCQ3EEruxdc0MFh+3EY4=
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.10.10/go.mod h1:wohMUQiFdzo0NtxbBg0mSRGZ4vL3n0dKjLTINdcIino=
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.4 h1:tHxQi/XHPK0ctd/wdOw0t7Xrc2OxcRCnVzv8lwWPu0c=
github.com/aws/aws-sdk-go-v2/service/internal/presigned-url v1.12.4/go.mod h1:4GQbF1vJzG60poZqWatZlhP31y8PGCCVTvIGPdaaYJ0=
github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.13.3/go.mod h1:Bm/v2IaN6rZ+Op7zX+bOUMdL4fsrYZiD0dsjLhNKwZc=
github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.16.3/go.mod h1:KZgs2ny8HsxRIRbDwgvJcHHBZPOzQr/+NtGwnP+w2ec=
github.com/aws/aws-sdk-go-v2/service/internal/s3shared v1.16.10 h1:KOxnQeWy5sXyS37fdKEvAsGHOr9fa/qvwxfJurR/BzE=
Expand All @@ -893,6 +895,8 @@ github.com/aws/aws-sdk-go-v2/service/kinesis v1.24.7/go.mod h1:xOJOknNQF6owzT/d+
github.com/aws/aws-sdk-go-v2/service/kms v1.16.3/go.mod h1:QuiHPBqlOFCi4LqdSskYYAWpQlx3PKmohy+rE2F+o5g=
github.com/aws/aws-sdk-go-v2/service/lambda v1.50.0 h1:fBJs+X3ZOEqpmiSb7as6DBqm7K2RTkbaxYL9RBGCZyE=
github.com/aws/aws-sdk-go-v2/service/lambda v1.50.0/go.mod h1:yEO3Ejj0qBhdIDlRYQ8O9+gB5CAUKyaYYiFBkvGX8ZA=
github.com/aws/aws-sdk-go-v2/service/rds v1.89.2 h1:6Z8uAqPcfS2FkXJCAbiRv1I6ZGV9qt4U7mlkzsLHDuA=
github.com/aws/aws-sdk-go-v2/service/rds v1.89.2/go.mod h1:NVSftCz6GNgqRJrlZIlihCTih9PYcDfI1C34NImX59c=
github.com/aws/aws-sdk-go-v2/service/s3 v1.26.3/go.mod h1:g1qvDuRsJY+XghsV6zg00Z4KJ7DtFFCx8fJD2a491Ak=
github.com/aws/aws-sdk-go-v2/service/s3 v1.43.0/go.mod h1:NXRKkiRF+erX2hnybnVU660cYT5/KChRD4iUgJ97cI8=
github.com/aws/aws-sdk-go-v2/service/s3 v1.48.1 h1:5XNlsBsEvBZBMO6p82y+sqpWg8j5aBCe+5C2GBFgqBQ=
Expand Down
181 changes: 181 additions & 0 deletions internal/impl/sql/aws/aws.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
package aws

import (
"context"
"encoding/json"
"errors"
"fmt"
"net/url"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/feature/rds/auth"
"github.com/aws/aws-sdk-go-v2/service/secretsmanager"

baws "github.com/warpstreamlabs/bento/internal/impl/aws"
"github.com/warpstreamlabs/bento/internal/impl/sql"
"github.com/warpstreamlabs/bento/public/service"
)

func init() {
noop := func(dsn, driver string) (password string, err error) {
return "", nil
}

sql.AWSGetCredentialsGeneratorFn = func(conf *service.ParsedConfig) (func(dsn, driver string) (password string, err error), error) {
if !conf.Contains(sql.SqlFieldAWS) {
return noop, nil
}

aConf := conf.Namespace(sql.SqlFieldAWS)
if aConf == nil {
return noop, errors.New("field 'aws' is not present in parsed config")
}

awsConfig, err := baws.GetSession(context.TODO(), aConf)
if err != nil {
return nil, err
}

if aConf.Contains("iam_enabled") {
isIamAuth, err := aConf.FieldBool("iam_enabled")
if err != nil {
return nil, err
}

if isIamAuth {
getCredentials := func(dbEndpoint string, dbUser string) (string, error) {
return buildIamAuthToken(dbEndpoint, dbUser, awsConfig)
}

wrapDsnBuilder := func(dsn, driver string) (string, error) {
return BuildAWSDsnFromIAMCredentials(dsn, driver, getCredentials)
}

return wrapDsnBuilder, nil
}
}

if conf.Contains("secret_name") {
secret, err := aConf.FieldString("secret_name")
if err != nil {
return nil, err
}

if secret != "" {
getCredentials := func() (string, error) {
return getSecretFromAWSSecretManager(secret, awsConfig)
}

wrapDsnBuilder := func(dsn, driver string) (string, error) {
return BuildAWSDsnFromSecret(dsn, driver, getCredentials)
}

return wrapDsnBuilder, nil
}

}

return noop, nil
}
}

//------------------------------------------------------------------------------

func BuildAWSDsnFromSecret(dsn, driver string, getAWSCredentialsFromSecret func() (string, error)) (string, error) {
if driver != "postgres" {
return "", errors.New("secret_name with DSN info currently only works for postgres DSNs")
}

parsedDSN, err := url.Parse(dsn)
if err != nil {
return "", fmt.Errorf("error parsing DSN URL: %w", err)
}

username := parsedDSN.User.Username()
password, _ := parsedDSN.User.Password()
host := parsedDSN.Hostname()
port := parsedDSN.Port()
path := parsedDSN.Path
rawQuery := parsedDSN.RawQuery

secretString, err := getAWSCredentialsFromSecret()
if err != nil {
return "", fmt.Errorf("error retrieving secret: %w", err)
}

var secrets map[string]interface{}
if err := json.Unmarshal([]byte(secretString), &secrets); err != nil {
return "", fmt.Errorf("error unmarshalling secret: %w", err)
}

if val, ok := secrets["username"].(string); ok && val != "" {
username = val
}
if val, ok := secrets["password"].(string); ok && val != "" {
password = val
}

newDSN := fmt.Sprintf("%s://%s:%s@%s:%s%s", driver, url.QueryEscape(username), url.QueryEscape(password), host, port, path)
if rawQuery != "" {
newDSN = fmt.Sprintf("%s?%s", newDSN, rawQuery)
}

return newDSN, nil
}

func BuildAWSDsnFromIAMCredentials(dsn string, driver string, generateIAMAuthToken func(dbEndpoint string, dbUser string) (string, error)) (string, error) {
if driver != "postgres" && driver != "mysql" {
return "", errors.New("cannot create DSN from IAM when driver is not postgres or mysql")
}

parsedDSN, err := url.Parse(dsn)
if err != nil {
return "", fmt.Errorf("error parsing DSN URL: %w", err)
}

username := parsedDSN.User.Username()
host := parsedDSN.Hostname()
port := parsedDSN.Port()
path := parsedDSN.Path
rawQuery := parsedDSN.RawQuery
endpoint := fmt.Sprintf("%s:%s", host, port)
iamToken, err := generateIAMAuthToken(endpoint, username)
if err != nil {
return "", fmt.Errorf("error retrieving IAM token: %w", err)
}

newDSN := fmt.Sprintf("%s://%s:%s@%s:%s%s", driver, url.QueryEscape(username), url.QueryEscape(iamToken), host, port, path)
if rawQuery != "" {
newDSN = fmt.Sprintf("%s?%s", newDSN, rawQuery)
}

return newDSN, nil

}

//------------------------------------------------------------------------------

func getSecretFromAWSSecretManager(secretName string, awsConf aws.Config) (string, error) {
svc := secretsmanager.NewFromConfig(awsConf)

input := &secretsmanager.GetSecretValueInput{
SecretId: aws.String(secretName),
}
result, err := svc.GetSecretValue(context.TODO(), input)
if err != nil {
return "", err
}

return *result.SecretString, nil
}

func buildIamAuthToken(dbEndpoint, dbUser string, awsConf aws.Config) (string, error) {
authenticationToken, err := auth.BuildAuthToken(
context.TODO(), dbEndpoint, awsConf.Region, dbUser, awsConf.Credentials,
)
if err != nil {
return "", err
}

return authenticationToken, nil
}
2 changes: 1 addition & 1 deletion internal/impl/sql/cache_sql.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ func newSQLCacheFromConfig(conf *service.ParsedConfig, mgr *service.Resources) (
return nil, err
}

if s.db, err = sqlOpenWithReworks(context.Background(), s.logger, s.driver, s.dsn, connSettings, s.awsConf); err != nil {
if s.db, err = sqlOpenWithReworks(context.Background(), s.logger, s.driver, s.dsn, connSettings); err != nil {
return nil, err
}
connSettings.apply(context.Background(), s.db, s.logger)
Expand Down
23 changes: 23 additions & 0 deletions internal/impl/sql/conn_aws.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
package sql

import (
"errors"

"github.com/warpstreamlabs/bento/public/service"
)

const (
SqlFieldAWS = "aws"

Check failure on line 10 in internal/impl/sql/conn_aws.go

View workflow job for this annotation

GitHub Actions / golangci-lint

ST1003: const SqlFieldAWS should be SQLFieldAWS (stylecheck)
)

// AWSGetCredentialsGeneratorFn is populated with the child `aws` package when imported.
var AWSGetCredentialsGeneratorFn = func(c *service.ParsedConfig) (fn func(dsn, driver string) (password string, err error), err error) {
if c.Contains(SqlFieldAWS) {
return nil, errors.New("unable to configure AWS authentication as this binary does not import components/aws")
}
return
}

var BuildAwsDsn = func(dsn, driver string) (password string, err error) {
return dsn, nil
}
Loading
Loading