Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor(sdkv3): migrate ssm client (WIP) #6137

Draft
wants to merge 18 commits into
base: feature/sdkv3
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1,879 changes: 1,409 additions & 470 deletions package-lock.json

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions package.json
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@
"webpack-merge": "^5.10.0"
},
"dependencies": {
"@aws-sdk/client-ssm": "^3.699.0",
"@aws-sdk/protocol-http": "^3.370.0",
"@types/node": "^22.7.5",
"vscode-nls": "^5.2.0",
"vscode-nls-dev": "^4.0.4"
Expand Down
32 changes: 16 additions & 16 deletions packages/core/src/awsService/ec2/model.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,12 @@
* SPDX-License-Identifier: Apache-2.0
*/
import * as vscode from 'vscode'
import { Session } from 'aws-sdk/clients/ssm'
import { EC2, IAM, SSM } from 'aws-sdk'
import { EC2, IAM } from 'aws-sdk'
import { Ec2Selection } from './prompter'
import { getOrInstallCli } from '../../shared/utilities/cliUtils'
import { isCloud9 } from '../../shared/extensionUtilities'
import { ToolkitError } from '../../shared/errors'
import { SsmClient } from '../../shared/clients/ssmClient'
import { SSMWrapper } from '../../shared/clients/ssm'
import { Ec2Client } from '../../shared/clients/ec2Client'
import {
VscodeRemoteConnection,
Expand All @@ -29,17 +28,18 @@ import { SshConfig } from '../../shared/sshConfig'
import { SshKeyPair } from './sshKeyPair'
import { Ec2SessionTracker } from './remoteSessionManager'
import { getEc2SsmEnv } from './utils'
import { Session, StartSessionCommandOutput } from '@aws-sdk/client-ssm'

export type Ec2ConnectErrorCode = 'EC2SSMStatus' | 'EC2SSMPermission' | 'EC2SSMConnect' | 'EC2SSMAgentStatus'

export interface Ec2RemoteEnv extends VscodeRemoteConnection {
selection: Ec2Selection
keyPair: SshKeyPair
ssmSession: SSM.StartSessionResponse
ssmSession: StartSessionCommandOutput
}

export class Ec2Connecter implements vscode.Disposable {
protected ssmClient: SsmClient
protected ssm: SSMWrapper
protected ec2Client: Ec2Client
protected iamClient: DefaultIamClient
protected sessionManager: Ec2SessionTracker
Expand All @@ -53,14 +53,14 @@ export class Ec2Connecter implements vscode.Disposable {
)

public constructor(readonly regionCode: string) {
this.ssmClient = this.createSsmSdkClient()
this.ssm = this.createSsmSdkClient()
this.ec2Client = this.createEc2SdkClient()
this.iamClient = this.createIamSdkClient()
this.sessionManager = new Ec2SessionTracker(regionCode, this.ssmClient)
this.sessionManager = new Ec2SessionTracker(regionCode, this.ssm)
}

protected createSsmSdkClient(): SsmClient {
return new SsmClient(this.regionCode)
protected createSsmSdkClient(): SSMWrapper {
return new SSMWrapper(this.regionCode)
}

protected createEc2SdkClient(): Ec2Client {
Expand All @@ -71,7 +71,7 @@ export class Ec2Connecter implements vscode.Disposable {
return new DefaultIamClient(this.regionCode)
}

public async addActiveSession(sessionId: SSM.SessionId, instanceId: EC2.InstanceId): Promise<void> {
public async addActiveSession(sessionId: string, instanceId: EC2.InstanceId): Promise<void> {
await this.sessionManager.addSession(instanceId, sessionId)
}

Expand Down Expand Up @@ -139,7 +139,7 @@ export class Ec2Connecter implements vscode.Disposable {
}

private async checkForInstanceSsmError(selection: Ec2Selection): Promise<void> {
const isSsmAgentRunning = (await this.ssmClient.getInstanceAgentPingStatus(selection.instanceId)) === 'Online'
const isSsmAgentRunning = (await this.ssm.getInstanceAgentPingStatus(selection.instanceId)) === 'Online'

if (!isSsmAgentRunning) {
this.throwConnectionError('Is SSM Agent running on the target instance?', selection, {
Expand Down Expand Up @@ -173,15 +173,15 @@ export class Ec2Connecter implements vscode.Disposable {
shellArgs: shellArgs,
}

await openRemoteTerminal(terminalOptions, () => this.ssmClient.terminateSession(session)).catch((err) => {
await openRemoteTerminal(terminalOptions, () => this.ssm.terminateSession(session)).catch((err) => {
throw ToolkitError.chain(err, 'Failed to open ec2 instance.')
})
}

public async attemptToOpenEc2Terminal(selection: Ec2Selection): Promise<void> {
await this.checkForStartSessionError(selection)
try {
const response = await this.ssmClient.startSession(selection.instanceId)
const response = await this.ssm.startSession(selection.instanceId)
await this.openSessionInTerminal(response, selection)
} catch (err: unknown) {
this.throwGeneralConnectionError(selection, err as Error)
Expand Down Expand Up @@ -222,7 +222,7 @@ export class Ec2Connecter implements vscode.Disposable {

throw err
}
const ssmSession = await this.ssmClient.startSession(selection.instanceId, 'AWS-StartSSHSession')
const ssmSession = await this.ssm.startSession(selection.instanceId, 'AWS-StartSSHSession')
await this.addActiveSession(selection.instanceId, ssmSession.SessionId!)

const vars = getEc2SsmEnv(selection, ssm, ssmSession)
Expand Down Expand Up @@ -270,13 +270,13 @@ export class Ec2Connecter implements vscode.Disposable {
const command = `echo "${sshPubKey}" > ${remoteAuthorizedKeysPaths}`
const documentName = 'AWS-RunShellScript'

await this.ssmClient.sendCommandAndWait(selection.instanceId, documentName, {
await this.ssm.sendCommandAndWait(selection.instanceId, documentName, {
commands: [command],
})
}

public async getRemoteUser(instanceId: string) {
const osName = await this.ssmClient.getTargetPlatformName(instanceId)
const osName = await this.ssm.getTargetPlatformName(instanceId)
if (osName === 'Amazon Linux') {
return 'ec2-user'
}
Expand Down
14 changes: 7 additions & 7 deletions packages/core/src/awsService/ec2/remoteSessionManager.ts
Original file line number Diff line number Diff line change
Expand Up @@ -3,30 +3,30 @@
* SPDX-License-Identifier: Apache-2.0
*/

import { EC2, SSM } from 'aws-sdk'
import { SsmClient } from '../../shared/clients/ssmClient'
import { EC2 } from 'aws-sdk'
import { SSMWrapper } from '../../shared/clients/ssm'
import { Disposable } from 'vscode'

export class Ec2SessionTracker extends Map<EC2.InstanceId, SSM.SessionId> implements Disposable {
export class Ec2SessionTracker extends Map<EC2.InstanceId, string> implements Disposable {
public constructor(
readonly regionCode: string,
protected ssmClient: SsmClient
protected ssm: SSMWrapper
) {
super()
}

public async addSession(instanceId: EC2.InstanceId, sessionId: SSM.SessionId): Promise<void> {
public async addSession(instanceId: EC2.InstanceId, sessionId: string): Promise<void> {
if (this.isConnectedTo(instanceId)) {
const existingSessionId = this.get(instanceId)!
await this.ssmClient.terminateSessionFromId(existingSessionId)
await this.ssm.terminateSessionFromId(existingSessionId)
this.set(instanceId, sessionId)
} else {
this.set(instanceId, sessionId)
}
}

private async disconnectEnv(instanceId: EC2.InstanceId): Promise<void> {
await this.ssmClient.terminateSessionFromId(this.get(instanceId)!)
await this.ssm.terminateSessionFromId(this.get(instanceId)!)
this.delete(instanceId)
}

Expand Down
2 changes: 2 additions & 0 deletions packages/core/src/extension.ts
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ import { registerCommands } from './commands'
import endpoints from '../resources/endpoints.json'
import { getLogger, maybeShowMinVscodeWarning, setupUninstallHandler } from './shared'
import { showViewLogsMessage } from './shared/utilities/messages'
import { AWSClientBuilderV3 } from './shared/awsClientBuilderV3'

disableAwsSdkWarning()

Expand Down Expand Up @@ -116,6 +117,7 @@ export async function activateCommon(
globals.machineId = await getMachineId()
globals.awsContext = new DefaultAwsContext()
globals.sdkClientBuilder = new DefaultAWSClientBuilder(globals.awsContext)
globals.sdkClientBuilderV3 = new AWSClientBuilderV3(globals.awsContext)
globals.loginManager = new LoginManager(globals.awsContext, new CredentialsStore())

// order matters here
Expand Down
143 changes: 143 additions & 0 deletions packages/core/src/shared/awsClientBuilderV3.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
/*!
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0
*/

import { CredentialsShim } from '../auth/deprecated/loginManager'
import { AwsContext } from './awsContext'
import { AwsCredentialIdentityProvider, RetryStrategyV2 } from '@smithy/types'
import { getUserAgent } from './telemetry/util'
import { DevSettings } from './settings'
import {
DeserializeHandler,
DeserializeHandlerOptions,
DeserializeMiddleware,
HandlerExecutionContext,
Provider,
RetryStrategy,
UserAgent,
} from '@aws-sdk/types'
import { HttpResponse } from '@aws-sdk/protocol-http'
import { ConfiguredRetryStrategy } from '@smithy/util-retry'
import { telemetry } from './telemetry'
import { getRequestId, getTelemetryReason, getTelemetryReasonDesc, getTelemetryResult } from './errors'
import { extensionVersion } from '.'
import { getLogger } from './logger'
import { omitIfPresent } from './utilities/tsUtils'

export type AwsClientConstructor<C> = new (o: AwsClientOptions) => C

export interface AwsClient {
middlewareStack: any // Ideally this would extends MiddlewareStack<Input, Output>, but this causes issues on client construction.
send: any
destroy: () => void
}

interface AwsConfigOptions {
credentials: AwsCredentialIdentityProvider
region: string | Provider<string>
customUserAgent: UserAgent
requestHandler: any
apiVersion: string
endpoint: string
retryStrategy: RetryStrategy | RetryStrategyV2
}
export type AwsClientOptions = AwsConfigOptions

export class AWSClientBuilderV3 {
public constructor(private readonly context: AwsContext) {}

private getShim(): CredentialsShim {
const shim = this.context.credentialsShim
if (!shim) {
throw new Error('Toolkit is not logged-in.')
}
return shim
}

public async createAwsService<C extends AwsClient>(
type: AwsClientConstructor<C>,
options?: Partial<AwsClientOptions>,
region?: string,
userAgent: boolean = true,
settings?: DevSettings
): Promise<C> {
const shim = this.getShim()
const opt = (options ?? {}) as AwsClientOptions

if (!opt.region && region) {
opt.region = region
}

if (!opt.customUserAgent && userAgent) {
opt.customUserAgent = [[getUserAgent({ includePlatform: true, includeClientId: true }), extensionVersion]]
}

if (!opt.retryStrategy) {
// Simple exponential backoff strategy as default.
opt.retryStrategy = new ConfiguredRetryStrategy(5, (attempt: number) => 1000 * 2 ** attempt)
}
// TODO: add tests for refresh logic.
opt.credentials = async () => {
const creds = await shim.get()
if (creds.expiration && creds.expiration.getTime() < Date.now()) {
return shim.refresh()
}
return creds
}

const service = new type(opt)
// TODO: add middleware for logging, telemetry, endpoints.
service.middlewareStack.add(telemetryMiddleware, { step: 'deserialize' } as DeserializeHandlerOptions)
return service
}
}

export function getServiceId(context: { clientName?: string; commandName?: string }): string {
return context.clientName?.toLowerCase().replace(/client$/, '') ?? 'unknown-service'
}

/**
* Record request IDs to the current context, potentially overriding the field if
* multiple API calls are made in the same context. We only do failures as successes are generally uninteresting and noisy.
*/
export function recordErrorTelemetry(err: Error, serviceName?: string) {
telemetry.record({
requestId: getRequestId(err),
requestServiceType: serviceName,
reasonDesc: getTelemetryReasonDesc(err),
reason: getTelemetryReason(err),
result: getTelemetryResult(err),
})
}

function logAndThrow(e: any, serviceId: string, errorMessageAppend: string): never {
if (e instanceof Error) {
recordErrorTelemetry(e, serviceId)
const err = { ...e }
delete err['stack']
getLogger().error('API Response %s: %O', errorMessageAppend, err)
}
throw e
}
/**
* Telemetry logic to be added to all created clients. Adds logging and emitting metric on errors.
*/

const telemetryMiddleware: DeserializeMiddleware<any, any> =
(next: DeserializeHandler<any, any>, context: HandlerExecutionContext) => async (args: any) => {
if (!HttpResponse.isInstance(args.request)) {
return next(args)
}
const serviceId = getServiceId(context as object)
const { hostname, path } = args.request
const logTail = `(${hostname} ${path})`
const result = await next(args).catch((e: any) => logAndThrow(e, serviceId, logTail))
if (HttpResponse.isInstance(result.response)) {
// TODO: omit credentials / sensitive info from the logs / telemetry.
const output = omitIfPresent(result.output, [])
getLogger().debug('API Response %s: %O', logTail, output)
}

return result
}
Loading
Loading