Skip to content

Commit

Permalink
Switch to AWS SDK v2 (#127)
Browse files Browse the repository at this point in the history
* Switch to AWS SDK v2

As per https://pkg.go.dev/github.com/aws/aws-sdk-go-v2/aws/retry#hdr-Standard , one of this SDKs improvements should be automatic retry in the case of throttling limit errors

* small fixes

* fix some pointer handling

* Pass-around TODO contexts correctly

* Always call cancel on failure

* Take a stab at using contexts for timeouts

Rather than using the self-implemented timeouts at all

* exported function

* small cleanups

* Use logrus exit handler
  • Loading branch information
holtwilkins authored Oct 12, 2021
1 parent a93ea93 commit 5986b76
Show file tree
Hide file tree
Showing 1,034 changed files with 371,188 additions and 206,501 deletions.
54 changes: 28 additions & 26 deletions aws/asg.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,19 @@
package aws

import (
"context"
"strings"
"time"

"github.com/aws/aws-sdk-go/service/autoscaling"
"github.com/aws/aws-sdk-go-v2/service/autoscaling"
at "github.com/aws/aws-sdk-go-v2/service/autoscaling/types"
"github.com/pkg/errors"
)

// GetAllASGs returns all ASGs visible with your client, with no filters
func (c *Clients) GetAllASGs() ([]*autoscaling.Group, error) {
func (c *Clients) GetAllASGs(ctx context.Context) ([]*at.AutoScalingGroup, error) {
var nexttoken *string
var asgs []*autoscaling.Group
var asgs []*at.AutoScalingGroup
var input *autoscaling.DescribeAutoScalingGroupsInput
var err error
var output *autoscaling.DescribeAutoScalingGroupsOutput
Expand All @@ -35,13 +37,13 @@ func (c *Clients) GetAllASGs() ([]*autoscaling.Group, error) {
NextToken: nexttoken,
}

output, err = c.ASGClient.DescribeAutoScalingGroups(input)
output, err = c.ASGClient.DescribeAutoScalingGroups(ctx, input)
if err != nil {
return nil, errors.Wrap(err, "Error describing ASGs")
}

for _, asg := range output.AutoScalingGroups {
asgs = append(asgs, asg)
asgs = append(asgs, &asg)
}
nexttoken = output.NextToken

Expand All @@ -56,34 +58,34 @@ func (c *Clients) GetAllASGs() ([]*autoscaling.Group, error) {
}

// GetASG gets the *autoscaling.Group that matches for the name given
func (c *Clients) GetASG(asgName *string) (*autoscaling.Group, error) {
var asgs []*autoscaling.Group
func (c *Clients) GetASG(ctx context.Context, asgName string) (*at.AutoScalingGroup, error) {
var asgs []*at.AutoScalingGroup

input := &autoscaling.DescribeAutoScalingGroupsInput{
AutoScalingGroupNames: []*string{asgName},
AutoScalingGroupNames: []string{asgName},
}

output, err := c.ASGClient.DescribeAutoScalingGroups(input)
output, err := c.ASGClient.DescribeAutoScalingGroups(ctx, input)
if err != nil {
return nil, errors.Wrap(err, "Error describing ASGs")
}

for _, asg := range output.AutoScalingGroups {
asgs = append(asgs, asg)
asgs = append(asgs, &asg)
}

if len(asgs) != 1 {
return nil, errors.Errorf("ASG Name '%s' matched '%v' ASGs, expecting it to match 1 (looking in region %s, have you set AWS_DEFAULT_REGION?)", *asgName, len(asgs), c.ASGClient.SigningRegion)
return nil, errors.Errorf("ASG Name '%s' matched '%v' ASGs, expecting it to match 1 (have you set AWS_DEFAULT_REGION?)", asgName, len(asgs))
}

return asgs[0], nil
}

// GetASGTagValue returns a pointer to the value for the given tag key
func GetASGTagValue(asg *autoscaling.Group, key string) *string {
func GetASGTagValue(asg *at.AutoScalingGroup, key string) *string {
for _, tag := range asg.Tags {
if tag != nil {
if strings.ToLower(*tag.Key) == strings.ToLower(key) {
if tag.Key != nil {
if strings.EqualFold(*tag.Key, key) {
return tag.Value
}
}
Expand All @@ -92,21 +94,21 @@ func GetASGTagValue(asg *autoscaling.Group, key string) *string {
}

// GetLaunchConfiguration returns the LC object of the given ASG
func (c *Clients) GetLaunchConfiguration(asg *autoscaling.Group) (*autoscaling.LaunchConfiguration, error) {
var lcs []*string
lcs = append(lcs, asg.LaunchConfigurationName)
func (c *Clients) GetLaunchConfiguration(ctx context.Context, asg *at.AutoScalingGroup) (*at.LaunchConfiguration, error) {
var lcs []string
lcs = append(lcs, *asg.LaunchConfigurationName)
input := autoscaling.DescribeLaunchConfigurationsInput{
LaunchConfigurationNames: lcs,
}
output, err := c.ASGClient.DescribeLaunchConfigurations(&input)
output, err := c.ASGClient.DescribeLaunchConfigurations(ctx, &input)
if err != nil {
return nil, errors.Wrapf(err, "Error describing launch configuration %s for ASG %s", *asg.LaunchConfigurationName, *asg.AutoScalingGroupName)
}
return output.LaunchConfigurations[0], nil
return &output.LaunchConfigurations[0], nil
}

// GetLaunchTemplateSpec returns the LT spec for a given ASG, if it has one
func (c *Clients) GetLaunchTemplateSpec(asg *autoscaling.Group) *autoscaling.LaunchTemplateSpecification {
func (c *Clients) GetLaunchTemplateSpec(asg *at.AutoScalingGroup) *at.LaunchTemplateSpecification {
// First, let's check the direct launch template spec property of the ASG
if asg.LaunchTemplate != nil {
return asg.LaunchTemplate
Expand All @@ -122,34 +124,34 @@ func (c *Clients) GetLaunchTemplateSpec(asg *autoscaling.Group) *autoscaling.Lau
}

// CompleteLifecycleAction calls https://docs.aws.amazon.com/cli/latest/reference/autoscaling/complete-lifecycle-action.html
func (c *Clients) CompleteLifecycleAction(asgName *string, instID *string, lifecycleHook *string, result *string) error {
func (c *Clients) CompleteLifecycleAction(ctx context.Context, asgName *string, instID *string, lifecycleHook *string, result *string) error {
input := autoscaling.CompleteLifecycleActionInput{
AutoScalingGroupName: asgName,
InstanceId: instID,
LifecycleActionResult: result,
LifecycleHookName: lifecycleHook,
}

_, err := c.ASGClient.CompleteLifecycleAction(&input)
_, err := c.ASGClient.CompleteLifecycleAction(ctx, &input)
return errors.Wrapf(err, "error completing lifecycle hook %s for instance %s", *lifecycleHook, *instID)
}

// TerminateInstanceInASG calls https://docs.aws.amazon.com/cli/latest/reference/autoscaling/terminate-instance-in-auto-scaling-group.html
func (c *Clients) TerminateInstanceInASG(instID *string, decrement *bool) error {
func (c *Clients) TerminateInstanceInASG(ctx context.Context, instID *string, decrement *bool) error {
input := autoscaling.TerminateInstanceInAutoScalingGroupInput{
InstanceId: instID,
ShouldDecrementDesiredCapacity: decrement,
}
_, err := c.ASGClient.TerminateInstanceInAutoScalingGroup(&input)
_, err := c.ASGClient.TerminateInstanceInAutoScalingGroup(ctx, &input)
return errors.Wrapf(err, "error terminating instance %s", *instID)
}

// SetDesiredCapacity sets the desired capacity of given ASG to given value
func (c *Clients) SetDesiredCapacity(asg *autoscaling.Group, desiredCapacity *int64) error {
func (c *Clients) SetDesiredCapacity(ctx context.Context, asg *at.AutoScalingGroup, desiredCapacity *int32) error {
input := autoscaling.SetDesiredCapacityInput{
AutoScalingGroupName: asg.AutoScalingGroupName,
DesiredCapacity: desiredCapacity,
}
_, err := c.ASGClient.SetDesiredCapacity(&input)
_, err := c.ASGClient.SetDesiredCapacity(ctx, &input)
return errors.Wrapf(err, "error setting desired capacity for %s", *asg.AutoScalingGroupName)
}
50 changes: 22 additions & 28 deletions aws/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,46 +15,42 @@
package aws

import (
"context"
"fmt"
"os"
"strconv"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/session"
"github.com/aws/aws-sdk-go/service/autoscaling"
"github.com/aws/aws-sdk-go/service/ec2"
"github.com/aws/aws-sdk-go-v2/config"
"github.com/aws/aws-sdk-go-v2/service/autoscaling"
at "github.com/aws/aws-sdk-go-v2/service/autoscaling/types"
"github.com/aws/aws-sdk-go-v2/service/ec2"
et "github.com/aws/aws-sdk-go-v2/service/ec2/types"
"github.com/pkg/errors"
)

const apiSleepTime = 200 * time.Millisecond

// Clients holds the clients for this account's invocation of the APIs we'll need
type Clients struct {
ASGClient *autoscaling.AutoScaling
EC2Client *ec2.EC2
ASGClient *autoscaling.Client
EC2Client *ec2.Client
}

// GetAWSClients returns the AWS client objects we'll need
func GetAWSClients() (*Clients, error) {
func GetAWSClients(ctx context.Context) (*Clients, error) {
region := os.Getenv("AWS_DEFAULT_REGION")
if region == "" {
region = "us-east-1"
}

awsConf := aws.Config{
Region: &region,
}

sess, err := session.NewSessionWithOptions(session.Options{
Config: awsConf,
})
cfg, err := config.LoadDefaultConfig(ctx, config.WithRegion(region))
if err != nil {
return nil, errors.Wrap(err, "Error opening AWS session")
return nil, errors.Wrap(err, "Error opening default AWS config")
}

asg := autoscaling.New(sess)
ec2 := ec2.New(sess)
asg := autoscaling.NewFromConfig(cfg)
ec2 := ec2.NewFromConfig(cfg)

ac := Clients{
ASGClient: asg,
Expand All @@ -65,13 +61,11 @@ func GetAWSClients() (*Clients, error) {
}

// ASGInstToEC2Inst converts a *autoscaling.Instance to its corresponding *ec2.Instance
func (c *Clients) ASGInstToEC2Inst(inst *autoscaling.Instance) (*ec2.Instance, error) {
var instIDs []*string
instIDs = append(instIDs, inst.InstanceId)
func (c *Clients) ASGInstToEC2Inst(ctx context.Context, inst at.Instance) (*et.Instance, error) {
input := ec2.DescribeInstancesInput{
InstanceIds: instIDs,
InstanceIds: []string{*inst.InstanceId},
}
output, err := c.EC2Client.DescribeInstances(&input)
output, err := c.EC2Client.DescribeInstances(ctx, &input)
if err != nil {
return nil, errors.Wrapf(err, "Error describing instance %s", *inst.InstanceId)
}
Expand All @@ -82,27 +76,27 @@ func (c *Clients) ASGInstToEC2Inst(inst *autoscaling.Instance) (*ec2.Instance, e
}

for _, ec2Inst := range res.Instances {
return ec2Inst, nil
return &ec2Inst, nil
}
}

return nil, errors.Errorf("No instances found for %s", *inst.InstanceId)
}

// ASGLTplVersionToEC2LTplVersion resolves ASG Template Versions to its actual *int64 ec2LaunchTemplate Version
func (c Clients) ASGLTplVersionToEC2LTplVersion(asgLaunchTemplate *autoscaling.LaunchTemplateSpecification) (*string, error) {
// ASGLTplVersionToEC2LTplVersion resolves ASG Template Versions to its actual *int32 ec2LaunchTemplate Version
func (c Clients) ASGLTplVersionToEC2LTplVersion(ctx context.Context, asgLaunchTemplate *at.LaunchTemplateSpecification) (*string, error) {
// No launch template, nothing to do here
if asgLaunchTemplate == nil {
return nil, nil
}

input := &ec2.DescribeLaunchTemplatesInput{
LaunchTemplateIds: []*string{
asgLaunchTemplate.LaunchTemplateId,
LaunchTemplateIds: []string{
*asgLaunchTemplate.LaunchTemplateId,
},
}

res, err := c.EC2Client.DescribeLaunchTemplates(input)
res, err := c.EC2Client.DescribeLaunchTemplates(ctx, input)
if err != nil {
return nil, errors.Wrapf(err, "Error describing LaunchTemplate %s", *asgLaunchTemplate.LaunchTemplateId)
}
Expand Down
19 changes: 10 additions & 9 deletions aws/ec2.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,17 +15,19 @@
package aws

import (
"context"
"strings"

"github.com/aws/aws-sdk-go/service/autoscaling"
"github.com/aws/aws-sdk-go/service/ec2"
at "github.com/aws/aws-sdk-go-v2/service/autoscaling/types"
"github.com/aws/aws-sdk-go-v2/service/ec2"
et "github.com/aws/aws-sdk-go-v2/service/ec2/types"
"github.com/pkg/errors"
)

// GetEC2TagValue returns a pointer to the value of the tag with the given key
func GetEC2TagValue(ec2 *ec2.Instance, key string) *string {
func GetEC2TagValue(ec2 *et.Instance, key string) *string {
for _, tag := range ec2.Tags {
if tag != nil {
if tag.Key != nil {
if strings.ToLower(*tag.Key) == strings.ToLower(key) {
return tag.Value
}
Expand All @@ -35,16 +37,15 @@ func GetEC2TagValue(ec2 *ec2.Instance, key string) *string {
}

// GetUserData returns a pointer to the value of the instance's userdata
func (c *Clients) GetUserData(inst *autoscaling.Instance) (*string, error) {
attr := "userData"
func (c *Clients) GetUserData(ctx context.Context, inst *at.Instance) (*string, error) {
input := ec2.DescribeInstanceAttributeInput{
Attribute: &attr,
Attribute: et.InstanceAttributeNameUserData,
InstanceId: inst.InstanceId,
}

output, err := c.EC2Client.DescribeInstanceAttribute(&input)
output, err := c.EC2Client.DescribeInstanceAttribute(ctx, &input)
if err != nil {
return nil, errors.Wrapf(err, "Error describing attribute %s for instance %s", attr, *inst.InstanceId)
return nil, errors.Wrapf(err, "Error describing userdata for instance %s", *inst.InstanceId)
}

return output.UserData.Value, nil
Expand Down
16 changes: 6 additions & 10 deletions bouncer/asg.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,36 +15,32 @@
package bouncer

import (
"context"
"time"

"github.com/aws/aws-sdk-go/service/autoscaling"
at "github.com/aws/aws-sdk-go-v2/service/autoscaling/types"
"github.com/palantir/bouncer/aws"
"github.com/pkg/errors"
)

// ASG object holds a pointer to an ASG and its Instances
type ASG struct {
ASG *autoscaling.Group
ASG *at.AutoScalingGroup
Instances []*Instance
DesiredASG *DesiredASG
}

// NewASG creates a new ASG object
func NewASG(ac *aws.Clients, desASG *DesiredASG, force bool, startTime time.Time) (*ASG, error) {
var awsAsg *autoscaling.Group

err := retry(apiRetryCount, apiRetrySleep, func() (err error) {
awsAsg, err = ac.GetASG(&desASG.AsgName)
return
})
func NewASG(ctx context.Context, ac *aws.Clients, desASG *DesiredASG, force bool, startTime time.Time) (*ASG, error) {
awsAsg, err := ac.GetASG(ctx, desASG.AsgName)
if err != nil {
return nil, errors.Wrap(err, "error getting AWS ASG object")
}

var instances []*Instance

for _, asgInst := range awsAsg.Instances {
inst, err := NewInstance(ac, awsAsg, asgInst, force, startTime, desASG.PreTerminateCmd)
inst, err := NewInstance(ctx, ac, awsAsg, asgInst, force, startTime, desASG.PreTerminateCmd)
if err != nil {
return nil, errors.Wrapf(err, "error generating bouncer.instance for %s", *asgInst.InstanceId)
}
Expand Down
Loading

0 comments on commit 5986b76

Please sign in to comment.