diff --git a/common/lib/AllowedAndBlockedHosts.ts b/common/lib/AllowedAndBlockedHosts.ts new file mode 100644 index 00000000..3f1e4597 --- /dev/null +++ b/common/lib/AllowedAndBlockedHosts.ts @@ -0,0 +1,33 @@ +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +export class AllowedAndBlockedHosts { + private readonly allowedHostIds: Set; + private readonly blockedHostIds: Set; + + constructor(allowedHostIds: Set, blockedHostIds: Set) { + this.allowedHostIds = allowedHostIds; + this.blockedHostIds = blockedHostIds; + } + + getAllowedHostIds() { + return this.allowedHostIds; + } + + getBlockedHostIds() { + return this.blockedHostIds; + } +} diff --git a/common/lib/authentication/iam_authentication_plugin.ts b/common/lib/authentication/iam_authentication_plugin.ts index 8f4dec4e..852a69a8 100644 --- a/common/lib/authentication/iam_authentication_plugin.ts +++ b/common/lib/authentication/iam_authentication_plugin.ts @@ -15,16 +15,16 @@ */ import { PluginService } from "../plugin_service"; -import { RdsUtils } from "../utils/rds_utils"; import { Messages } from "../utils/messages"; import { logger } from "../../logutils"; import { AwsWrapperError } from "../utils/errors"; import { HostInfo } from "../host_info"; import { AwsCredentialsManager } from "./aws_credentials_manager"; import { AbstractConnectionPlugin } from "../abstract_connection_plugin"; -import { WrapperProperties } from "../wrapper_property"; +import { WrapperProperties, WrapperProperty } from "../wrapper_property"; import { IamAuthUtils, TokenInfo } from "../utils/iam_auth_utils"; import { ClientWrapper } from "../client_wrapper"; +import { RegionUtils } from "../utils/region_utils"; export class IamAuthenticationPlugin extends AbstractConnectionPlugin { private static readonly SUBSCRIBED_METHODS = new Set(["connect", "forceConnect"]); @@ -32,7 +32,6 @@ export class IamAuthenticationPlugin extends AbstractConnectionPlugin { private readonly telemetryFactory; private readonly fetchTokenCounter; private pluginService: PluginService; - rdsUtil: RdsUtils = new RdsUtils(); constructor(pluginService: PluginService) { super(); @@ -75,7 +74,7 @@ export class IamAuthenticationPlugin extends AbstractConnectionPlugin { } const host = IamAuthUtils.getIamHost(props, hostInfo); - const region: string = IamAuthUtils.getRdsRegion(host, this.rdsUtil, props); + const region: string = RegionUtils.getRegion(props.get(WrapperProperties.IAM_REGION.name), host); const port = IamAuthUtils.getIamPort(props, hostInfo, this.pluginService.getCurrentClient().defaultPort); const tokenExpirationSec = WrapperProperties.IAM_TOKEN_EXPIRATION.get(props); if (tokenExpirationSec < 0) { diff --git a/common/lib/connection_plugin_chain_builder.ts b/common/lib/connection_plugin_chain_builder.ts index d099797c..68cbd292 100644 --- a/common/lib/connection_plugin_chain_builder.ts +++ b/common/lib/connection_plugin_chain_builder.ts @@ -39,6 +39,7 @@ import { DeveloperConnectionPluginFactory } from "./plugins/dev/developer_connec import { ConnectionPluginFactory } from "./plugin_factory"; import { LimitlessConnectionPluginFactory } from "./plugins/limitless/limitless_connection_plugin_factory"; import { FastestResponseStrategyPluginFactory } from "./plugins/strategy/fastest_response/fastest_respose_strategy_plugin_factory"; +import { CustomEndpointPluginFactory } from "./plugins/custom_endpoint/custom_endpoint_plugin_factory"; import { ConfigurationProfile } from "./profile/configuration_profile"; /* @@ -54,6 +55,7 @@ export class ConnectionPluginChainBuilder { static readonly WEIGHT_RELATIVE_TO_PRIOR_PLUGIN = -1; static readonly PLUGIN_FACTORIES = new Map([ + ["customEndpoint", { factory: CustomEndpointPluginFactory, weight: 380 }], ["initialConnection", { factory: AuroraInitialConnectionStrategyFactory, weight: 390 }], ["auroraConnectionTracker", { factory: AuroraConnectionTrackerPluginFactory, weight: 400 }], ["staleDns", { factory: StaleDnsPluginFactory, weight: 500 }], diff --git a/common/lib/database_dialect/database_dialect.ts b/common/lib/database_dialect/database_dialect.ts index 7eae8676..8f65d878 100644 --- a/common/lib/database_dialect/database_dialect.ts +++ b/common/lib/database_dialect/database_dialect.ts @@ -21,6 +21,7 @@ import { FailoverRestriction } from "../plugins/failover/failover_restriction"; import { ErrorHandler } from "../error_handler"; import { SessionState } from "../session_state"; import { TransactionIsolationLevel } from "../utils/transaction_isolation_level"; +import { HostRole } from "../host_role"; export enum DatabaseType { MYSQL, @@ -39,6 +40,7 @@ export interface DatabaseDialect { getSetSchemaQuery(schema: string): string; getDialectUpdateCandidates(): string[]; getErrorHandler(): ErrorHandler; + getHostRole(targetClient: ClientWrapper): Promise; isDialect(targetClient: ClientWrapper): Promise; getHostListProvider(props: Map, originalUrl: string, hostListProviderService: HostListProviderService): HostListProvider; isClientValid(targetClient: ClientWrapper): Promise; diff --git a/common/lib/host_list_provider/monitoring/monitoring_host_list_provider.ts b/common/lib/host_list_provider/monitoring/monitoring_host_list_provider.ts index fb705e5f..0a6fe443 100644 --- a/common/lib/host_list_provider/monitoring/monitoring_host_list_provider.ts +++ b/common/lib/host_list_provider/monitoring/monitoring_host_list_provider.ts @@ -27,6 +27,7 @@ import { Messages } from "../../utils/messages"; import { WrapperProperties } from "../../wrapper_property"; import { BlockingHostListProvider } from "../host_list_provider"; import { logger } from "../../../logutils"; +import { isDialectTopologyAware } from "../../utils/utils"; export class MonitoringRdsHostListProvider extends RdsHostListProvider implements BlockingHostListProvider { static readonly CACHE_CLEANUP_NANOS: bigint = BigInt(60_000_000_000); // 1 minute. @@ -76,7 +77,7 @@ export class MonitoringRdsHostListProvider extends RdsHostListProvider implement async sqlQueryForTopology(targetClient: ClientWrapper): Promise { const dialect: DatabaseDialect = this.hostListProviderService.getDialect(); - if (!this.isTopologyAwareDatabaseDialect(dialect)) { + if (!isDialectTopologyAware(dialect)) { throw new TypeError(Messages.get("RdsHostListProvider.incorrectDialect")); } return await dialect.queryForTopology(targetClient, this).then((res: any) => this.processQueryResults(res)); diff --git a/common/lib/host_list_provider/rds_host_list_provider.ts b/common/lib/host_list_provider/rds_host_list_provider.ts index 63f6a323..fd7353bb 100644 --- a/common/lib/host_list_provider/rds_host_list_provider.ts +++ b/common/lib/host_list_provider/rds_host_list_provider.ts @@ -27,8 +27,7 @@ import { WrapperProperties } from "../wrapper_property"; import { logger } from "../../logutils"; import { HostAvailability } from "../host_availability/host_availability"; import { CacheMap } from "../utils/cache_map"; -import { logTopology } from "../utils/utils"; -import { TopologyAwareDatabaseDialect } from "../topology_aware_database_dialect"; +import { isDialectTopologyAware, logTopology } from "../utils/utils"; import { DatabaseDialect } from "../database_dialect/database_dialect"; import { ClientWrapper } from "../client_wrapper"; @@ -137,7 +136,7 @@ export class RdsHostListProvider implements DynamicHostListProvider { } async getHostRole(client: ClientWrapper, dialect: DatabaseDialect): Promise { - if (!this.isTopologyAwareDatabaseDialect(dialect)) { + if (!isDialectTopologyAware(dialect)) { throw new TypeError(Messages.get("RdsHostListProvider.incorrectDialect")); } @@ -150,7 +149,7 @@ export class RdsHostListProvider implements DynamicHostListProvider { async getWriterId(client: ClientWrapper): Promise { const dialect = this.hostListProviderService.getDialect(); - if (!this.isTopologyAwareDatabaseDialect(dialect)) { + if (!isDialectTopologyAware(dialect)) { throw new TypeError(Messages.get("RdsHostListProvider.incorrectDialect")); } @@ -162,7 +161,7 @@ export class RdsHostListProvider implements DynamicHostListProvider { } async identifyConnection(targetClient: ClientWrapper, dialect: DatabaseDialect): Promise { - if (!this.isTopologyAwareDatabaseDialect(dialect)) { + if (!isDialectTopologyAware(dialect)) { throw new TypeError(Messages.get("RdsHostListProvider.incorrectDialect")); } const instanceName = await dialect.identifyConnection(targetClient); @@ -276,12 +275,8 @@ export class RdsHostListProvider implements DynamicHostListProvider { } } - protected isTopologyAwareDatabaseDialect(arg: any): arg is TopologyAwareDatabaseDialect { - return arg; - } - async queryForTopology(targetClient: ClientWrapper, dialect: DatabaseDialect): Promise { - if (!this.isTopologyAwareDatabaseDialect(dialect)) { + if (!isDialectTopologyAware(dialect)) { throw new TypeError(Messages.get("RdsHostListProvider.incorrectDialect")); } diff --git a/common/lib/host_list_provider_service.ts b/common/lib/host_list_provider_service.ts index bcf12493..5c2910e4 100644 --- a/common/lib/host_list_provider_service.ts +++ b/common/lib/host_list_provider_service.ts @@ -21,6 +21,7 @@ import { DatabaseDialect } from "./database_dialect/database_dialect"; import { HostInfoBuilder } from "./host_info_builder"; import { ConnectionUrlParser } from "./utils/connection_url_parser"; import { TelemetryFactory } from "./utils/telemetry/telemetry_factory"; +import { AllowedAndBlockedHosts } from "./AllowedAndBlockedHosts"; export interface HostListProviderService { getHostListProvider(): HostListProvider | null; @@ -50,4 +51,6 @@ export interface HostListProviderService { isClientValid(targetClient: any): Promise; getTelemetryFactory(): TelemetryFactory; + + setAllowedAndBlockedHosts(allowedAndBlockedHosts: AllowedAndBlockedHosts): void; } diff --git a/common/lib/plugin_service.ts b/common/lib/plugin_service.ts index f93cdfee..80b872f4 100644 --- a/common/lib/plugin_service.ts +++ b/common/lib/plugin_service.ts @@ -40,11 +40,13 @@ import { ClientWrapper } from "./client_wrapper"; import { logger } from "../logutils"; import { Messages } from "./utils/messages"; import { DatabaseDialectCodes } from "./database_dialect/database_dialect_codes"; -import { getWriter } from "./utils/utils"; +import { getWriter, logTopology } from "./utils/utils"; import { TelemetryFactory } from "./utils/telemetry/telemetry_factory"; import { DriverDialect } from "./driver_dialect/driver_dialect"; +import { AllowedAndBlockedHosts } from "./AllowedAndBlockedHosts"; export class PluginService implements ErrorHandler, HostListProviderService { + private static readonly DEFAULT_HOST_AVAILABILITY_CACHE_EXPIRE_NANO = 5 * 60_000_000_000; // 5 minutes private readonly _currentClient: AwsClient; private _currentHostInfo?: HostInfo; private _hostListProvider?: HostListProvider; @@ -59,6 +61,7 @@ export class PluginService implements ErrorHandler, HostListProviderService { protected readonly sessionStateService: SessionStateService; protected static readonly hostAvailabilityExpiringCache: CacheMap = new CacheMap(); readonly props: Map; + private allowedAndBlockedHosts: AllowedAndBlockedHosts | null = null; constructor( container: PluginServiceManagerContainer, @@ -114,17 +117,29 @@ export class PluginService implements ErrorHandler, HostListProviderService { this._currentHostInfo = this._initialConnectionHostInfo; if (!this._currentHostInfo) { - if (this.getHosts().length === 0) { + if (this.getAllHosts().length === 0) { throw new AwsWrapperError(Messages.get("PluginService.hostListEmpty")); } - const writerHost = getWriter(this.getHosts()); + const writerHost = getWriter(this.getAllHosts()); + if (!this.getHosts().some((hostInfo: HostInfo) => hostInfo.host === writerHost?.host)) { + throw new AwsWrapperError( + Messages.get( + "PluginService.currentHostNotAllowed", + this._currentHostInfo ? this._currentHostInfo.host : "", + logTopology(this.hosts, "[PluginService.currentHostNotAllowed] ") + ) + ); + } + if (writerHost) { this._currentHostInfo = writerHost; } else { this._currentHostInfo = this.getHosts()[0]; } } + + logger.debug(`Set current host to: ${this._currentHostInfo.host}`); } return this._currentHostInfo; @@ -286,11 +301,64 @@ export class PluginService implements ErrorHandler, HostListProviderService { } } - getHosts(): HostInfo[] { + getAllHosts(): HostInfo[] { return this.hosts; } - setAvailability(hostAliases: Set, availability: HostAvailability) {} + getHosts(): HostInfo[] { + const hostPermissions = this.allowedAndBlockedHosts; + if (!hostPermissions) { + return this.hosts; + } + + let hosts = this.hosts; + const allowedHostIds = hostPermissions.getAllowedHostIds(); + const blockedHostIds = hostPermissions.getBlockedHostIds(); + + if (allowedHostIds && allowedHostIds.size > 0) { + hosts = hosts.filter((host: HostInfo) => allowedHostIds.has(host.hostId)); + } + + if (blockedHostIds && blockedHostIds.size > 0) { + hosts = hosts.filter((host: HostInfo) => !blockedHostIds.has(host.hostId)); + } + + return hosts; + } + + setAvailability(hostAliases: Set, availability: HostAvailability) { + if (hostAliases.size === 0) { + return; + } + + const hostsToChange = [ + ...new Set( + this.getAllHosts().filter( + (host: HostInfo) => hostAliases.has(host.asAlias) || [...host.aliases].some((hostAlias: string) => hostAliases.has(hostAlias)) + ) + ) + ]; + + if (hostsToChange.length === 0) { + logger.debug(Messages.get("PluginService.hostsChangeListEmpty")); + return; + } + + const changes = new Map>(); + for (const host of hostsToChange) { + const currentAvailability = host.getAvailability(); + PluginService.hostAvailabilityExpiringCache.put(host.url, availability, PluginService.DEFAULT_HOST_AVAILABILITY_CACHE_EXPIRE_NANO); + if (currentAvailability !== availability) { + let hostChanges = new Set(); + if (availability === HostAvailability.AVAILABLE) { + hostChanges = new Set([HostChangeOptions.WENT_UP, HostChangeOptions.HOST_CHANGED]); + } else { + hostChanges = new Set([HostChangeOptions.WENT_DOWN, HostChangeOptions.HOST_CHANGED]); + } + changes.set(host.url, hostChanges); + } + } + } updateConfigWithProperties(props: Map) { this._currentClient.config = Object.fromEntries(props.entries()); @@ -527,4 +595,8 @@ export class PluginService implements ErrorHandler, HostListProviderService { attachNoOpErrorListener(clientWrapper: ClientWrapper | undefined): void { this.getDialect().getErrorHandler().attachNoOpErrorListener(clientWrapper); } + + setAllowedAndBlockedHosts(allowedAndBlockedHosts: AllowedAndBlockedHosts) { + this.allowedAndBlockedHosts = allowedAndBlockedHosts; + } } diff --git a/common/lib/plugins/aurora_initial_connection_strategy_plugin.ts b/common/lib/plugins/aurora_initial_connection_strategy_plugin.ts index 4720574c..cea1f52b 100644 --- a/common/lib/plugins/aurora_initial_connection_strategy_plugin.ts +++ b/common/lib/plugins/aurora_initial_connection_strategy_plugin.ts @@ -261,7 +261,7 @@ export class AuroraInitialConnectionStrategyPlugin extends AbstractConnectionPlu } private getWriter(): HostInfo | null { - return this.pluginService.getHosts().find((x) => x.role === HostRole.WRITER) ?? null; + return this.pluginService.getAllHosts().find((x) => x.role === HostRole.WRITER) ?? null; } private getReader(props: Map): HostInfo | undefined { @@ -278,6 +278,6 @@ export class AuroraInitialConnectionStrategyPlugin extends AbstractConnectionPlu } private hasNoReaders(): boolean { - return this.pluginService.getHosts().find((x) => x.role === HostRole.READER) !== undefined; + return this.pluginService.getAllHosts().find((x) => x.role === HostRole.READER) !== undefined; } } diff --git a/common/lib/plugins/connection_tracker/aurora_connection_tracker_plugin.ts b/common/lib/plugins/connection_tracker/aurora_connection_tracker_plugin.ts index 85c42d94..f2a0965d 100644 --- a/common/lib/plugins/connection_tracker/aurora_connection_tracker_plugin.ts +++ b/common/lib/plugins/connection_tracker/aurora_connection_tracker_plugin.ts @@ -99,7 +99,7 @@ export class AuroraConnectionTrackerPlugin extends AbstractConnectionPlugin impl } private async checkWriterChanged(): Promise { - const hostInfoAfterFailover = this.getWriter(this.pluginService.getHosts()); + const hostInfoAfterFailover = this.getWriter(this.pluginService.getAllHosts()); if (this.currentWriter === null) { this.currentWriter = hostInfoAfterFailover; this.needUpdateCurrentWriter = false; @@ -114,7 +114,7 @@ export class AuroraConnectionTrackerPlugin extends AbstractConnectionPlugin impl private rememberWriter(): void { if (this.currentWriter === null || this.needUpdateCurrentWriter) { - this.currentWriter = this.getWriter(this.pluginService.getHosts()); + this.currentWriter = this.getWriter(this.pluginService.getAllHosts()); this.needUpdateCurrentWriter = false; } } diff --git a/common/lib/plugins/custom_endpoint/custom_endpoint_info.ts b/common/lib/plugins/custom_endpoint/custom_endpoint_info.ts new file mode 100644 index 00000000..2d7bbfbf --- /dev/null +++ b/common/lib/plugins/custom_endpoint/custom_endpoint_info.ts @@ -0,0 +1,106 @@ +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +import { CustomEndpointRoleType, customEndpointRoleTypeFromValue } from "./custom_endpoint_role_type"; +import { MemberListType } from "./member_list_type"; +import { DBClusterEndpoint } from "@aws-sdk/client-rds"; + +export class CustomEndpointInfo { + private readonly endpointIdentifier: string; // ID portion of the custom endpoint URL. + private readonly clusterIdentifier: string; // ID of the cluster that the custom endpoint belongs to. + private readonly url: string; + private readonly roleType: CustomEndpointRoleType; + + // A given custom endpoint will either specify a static list or an exclusion list, as indicated by `memberListType`. + // If the list is a static list, 'members' specifies instances included in the custom endpoint, and new cluster + // instances will not be automatically added to the custom endpoint. If it is an exclusion list, 'members' specifies + // instances excluded by the custom endpoint, and new cluster instances will be added to the custom endpoint. + private readonly memberListType: MemberListType; + private readonly members: Set; + + constructor( + endpointIdentifier: string, + clusterIdentifier: string, + url: string, + roleType: CustomEndpointRoleType, + members: Set, + memberListType: MemberListType + ) { + this.endpointIdentifier = endpointIdentifier; + this.clusterIdentifier = clusterIdentifier; + this.url = url; + this.roleType = roleType; + this.members = members; + this.memberListType = memberListType; + } + + getMemberListType(): MemberListType { + return this.memberListType; + } + + static fromDbClusterEndpoint(responseEndpointInfo: DBClusterEndpoint): CustomEndpointInfo { + let members: Set; + let memberListType: MemberListType; + + if (responseEndpointInfo.StaticMembers) { + members = new Set(responseEndpointInfo.StaticMembers); + memberListType = MemberListType.STATIC_LIST; + } else { + members = new Set(responseEndpointInfo.ExcludedMembers); + memberListType = MemberListType.EXCLUSION_LIST; + } + + return new CustomEndpointInfo( + responseEndpointInfo.DBClusterEndpointIdentifier, + responseEndpointInfo.DBClusterIdentifier, + responseEndpointInfo.Endpoint, + customEndpointRoleTypeFromValue(responseEndpointInfo.CustomEndpointType), + members, + memberListType + ); + } + + getStaticMembers(): Set { + return this.memberListType === MemberListType.STATIC_LIST ? this.members : new Set(); + } + + getExcludedMembers(): Set { + return this.memberListType === MemberListType.EXCLUSION_LIST ? this.members : new Set(); + } + + equals(info: CustomEndpointInfo): boolean { + if (!info) { + return false; + } + + if (info === this) { + return true; + } + + return ( + this.endpointIdentifier === info.endpointIdentifier && + this.clusterIdentifier === info.clusterIdentifier && + this.url === info.url && + this.roleType === info.roleType && + this.members === info.members && + this.memberListType === info.memberListType + ); + } + + toString(): string { + return `CustomEndpointInfo[url=${this.url}, clusterIdentifier=${this.clusterIdentifier}, customEndpointType=${CustomEndpointRoleType[this.roleType]}, memberListType=${MemberListType[this.memberListType]}, members={${[...this.members].join(", ")}}]`; + } +} diff --git a/common/lib/plugins/custom_endpoint/custom_endpoint_monitor_impl.ts b/common/lib/plugins/custom_endpoint/custom_endpoint_monitor_impl.ts new file mode 100644 index 00000000..d99840d9 --- /dev/null +++ b/common/lib/plugins/custom_endpoint/custom_endpoint_monitor_impl.ts @@ -0,0 +1,184 @@ +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +import { DescribeDBClusterEndpointsCommand, RDSClient } from "@aws-sdk/client-rds"; +import { HostInfo } from "../../host_info"; +import { PluginService } from "../../plugin_service"; +import { TelemetryCounter } from "../../utils/telemetry/telemetry_counter"; +import { logger } from "../../../logutils"; +import { CustomEndpointInfo } from "./custom_endpoint_info"; +import { AllowedAndBlockedHosts } from "../../AllowedAndBlockedHosts"; +import { CacheMap } from "../../utils/cache_map"; +import { MemberListType } from "./member_list_type"; +import { Messages } from "../../utils/messages"; +import { clearTimeout } from "node:timers"; +import { AwsWrapperError } from "../../utils/errors"; + +export interface CustomEndpointMonitor { + shouldDispose(): boolean; + hasCustomEndpointInfo(): boolean; + close(): void; +} + +export class CustomEndpointMonitorImpl implements CustomEndpointMonitor { + private static readonly TELEMETRY_ENDPOINT_INFO_CHANGED = "customEndpoint.infoChanged.counter"; + + // Keys are custom endpoint URLs, values are information objects for the associated custom endpoint. + private static readonly CUSTOM_ENDPOINT_INFO_EXPIRATION_NANO = 5 * 60_000_000_000; // 5 minutes + + protected static customEndpointInfoCache: CacheMap = new CacheMap(); + + private rdsClient: RDSClient; + private customEndpointHostInfo: HostInfo; + private readonly endpointIdentifier: string; + private readonly region: string; + private readonly refreshRateMs: number; + + private pluginService: PluginService; + private infoChangedCounter: TelemetryCounter; + + protected stop = false; + private timers: NodeJS.Timeout[] = []; + + constructor( + pluginService: PluginService, + customEndpointHostInfo: HostInfo, + endpointIdentifier: string, + region: string, + refreshRateMs: number, + rdsClientFunc: (hostInfo: HostInfo, region: string) => RDSClient + ) { + this.pluginService = pluginService; + this.customEndpointHostInfo = customEndpointHostInfo; + this.endpointIdentifier = endpointIdentifier; + this.region = region; + this.refreshRateMs = refreshRateMs; + this.rdsClient = rdsClientFunc(customEndpointHostInfo, this.region); + + const telemetryFactory = this.pluginService.getTelemetryFactory(); + this.infoChangedCounter = telemetryFactory.createCounter(CustomEndpointMonitorImpl.TELEMETRY_ENDPOINT_INFO_CHANGED); + + this.run(); + } + + async run(): Promise { + logger.verbose(Messages.get("CustomEndpointMonitorImpl.startingMonitor", this.customEndpointHostInfo.host)); + + while (!this.stop) { + try { + const start = Date.now(); + + const input = { + DBClusterEndpointIdentifier: this.endpointIdentifier, + Filters: [ + { + Name: "db-cluster-endpoint-type", + Values: ["custom"] + } + ] + }; + const command = new DescribeDBClusterEndpointsCommand(input); + const result = await this.rdsClient.send(command); + + const endpoints = result.DBClusterEndpoints; + + if (endpoints.length === 0) { + throw new AwsWrapperError(Messages.get("CustomEndpointMonitorImpl.noEndpoints")); + } + + if (endpoints.length !== 1) { + let endpointUrls = ""; + endpoints.forEach((endpoint) => { + endpointUrls += `\n\t${endpoint.Endpoint}`; + }); + logger.warn( + Messages.get( + "CustomEndpointMonitorImpl.unexpectedNumberOfEndpoints", + this.endpointIdentifier, + this.region, + String(endpoints.length), + endpointUrls + ) + ); + await new Promise((resolve) => { + this.timers.push(setTimeout(resolve, this.refreshRateMs)); + }); + continue; + } + + const endpointInfo = CustomEndpointInfo.fromDbClusterEndpoint(endpoints[0]); + const cachedEndpointInfo = CustomEndpointMonitorImpl.customEndpointInfoCache.get(this.customEndpointHostInfo.host); + + if (cachedEndpointInfo && cachedEndpointInfo.equals(endpointInfo)) { + const elapsedTime = Date.now() - start; + const sleepDuration = Math.max(0, this.refreshRateMs - elapsedTime); + await new Promise((resolve) => { + this.timers.push(setTimeout(resolve, sleepDuration)); + }); + continue; + } + + logger.verbose( + Messages.get("CustomEndpointMonitorImpl.detectedChangeInCustomEndpointInfo", this.customEndpointHostInfo.host, endpointInfo.toString()) + ); + + // The custom endpoint info has changed, so we need to update the set of allowed/blocked hosts. + let allowedAndBlockedHosts: AllowedAndBlockedHosts; + if (endpointInfo.getMemberListType() === MemberListType.STATIC_LIST) { + allowedAndBlockedHosts = new AllowedAndBlockedHosts(endpointInfo.getStaticMembers(), null); + } else { + allowedAndBlockedHosts = new AllowedAndBlockedHosts(null, endpointInfo.getExcludedMembers()); + } + + this.pluginService.setAllowedAndBlockedHosts(allowedAndBlockedHosts); + CustomEndpointMonitorImpl.customEndpointInfoCache.put( + this.customEndpointHostInfo.host, + endpointInfo, + CustomEndpointMonitorImpl.CUSTOM_ENDPOINT_INFO_EXPIRATION_NANO + ); + this.infoChangedCounter.inc(); + + const elapsedTime = Date.now() - start; + const sleepDuration = Math.max(0, this.refreshRateMs - elapsedTime); + await new Promise((resolve) => { + this.timers.push(setTimeout(resolve, sleepDuration)); + }); + } catch (e: any) { + logger.error(Messages.get("CustomEndpointMonitorImpl.error", this.customEndpointHostInfo.host, e.message)); + throw e; + } + } + } + + hasCustomEndpointInfo(): boolean { + return CustomEndpointMonitorImpl.customEndpointInfoCache.get(this.customEndpointHostInfo.host) != null; + } + + shouldDispose(): boolean { + return true; + } + + close(): void { + logger.verbose(Messages.get("CustomEndpointMonitorImpl.stoppingMonitor", this.customEndpointHostInfo.host)); + this.stop = true; + for (const timer of this.timers) { + clearTimeout(timer); + } + CustomEndpointMonitorImpl.customEndpointInfoCache.delete(this.customEndpointHostInfo.host); + this.rdsClient.destroy(); + logger.verbose(Messages.get("CustomEndpointMonitorImpl.stoppedMonitor", this.customEndpointHostInfo.host)); + } +} diff --git a/common/lib/plugins/custom_endpoint/custom_endpoint_plugin.ts b/common/lib/plugins/custom_endpoint/custom_endpoint_plugin.ts new file mode 100644 index 00000000..39b12c26 --- /dev/null +++ b/common/lib/plugins/custom_endpoint/custom_endpoint_plugin.ts @@ -0,0 +1,192 @@ +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +import { AbstractConnectionPlugin } from "../../abstract_connection_plugin"; +import { PluginService } from "../../plugin_service"; +import { RDSClient } from "@aws-sdk/client-rds"; +import { HostInfo } from "../../host_info"; +import { WrapperProperties } from "../../wrapper_property"; +import { TelemetryCounter } from "../../utils/telemetry/telemetry_counter"; +import { ClientWrapper } from "../../client_wrapper"; +import { RdsUtils } from "../../utils/rds_utils"; +import { logger } from "../../../logutils"; +import { Messages } from "../../utils/messages"; +import { AwsWrapperError } from "../../utils/errors"; +import { RegionUtils } from "../../utils/region_utils"; +import { SlidingExpirationCache } from "../../utils/sliding_expiration_cache"; +import { sleep } from "../../utils/utils"; +import { CustomEndpointMonitor, CustomEndpointMonitorImpl } from "./custom_endpoint_monitor_impl"; +import { SubscribedMethodHelper } from "../../utils/subscribed_method_helper"; +import { CanReleaseResources } from "../../can_release_resources"; + +export class CustomEndpointPlugin extends AbstractConnectionPlugin implements CanReleaseResources { + private static readonly TELEMETRY_WAIT_FOR_INFO_COUNTER = "customEndpoint.waitForInfo.counter"; + private static SUBSCRIBED_METHODS: Set = new Set(SubscribedMethodHelper.NETWORK_BOUND_METHODS); + private static readonly CACHE_CLEANUP_NANOS = BigInt(60_000_000_000); + + private static readonly rdsUtils = new RdsUtils(); + protected static readonly monitors: SlidingExpirationCache = new SlidingExpirationCache( + CustomEndpointPlugin.CACHE_CLEANUP_NANOS, + (monitor: CustomEndpointMonitor) => monitor.shouldDispose(), + (monitor: CustomEndpointMonitor) => { + try { + monitor.close(); + } catch (e) { + // ignore + } + } + ); + + private readonly pluginService: PluginService; + private readonly props: Map; + private readonly rdsClientFunc: (hostInfo: HostInfo, region: string) => RDSClient; + + private readonly shouldWaitForInfo: boolean; + private readonly waitOnCachedInfoDurationMs: number; + private readonly idleMonitorExpirationMs: number; + private customEndpointHostInfo: HostInfo; + private customEndpointId: string; + private region: string; + + private waitForInfoCounter: TelemetryCounter; + + constructor(pluginService: PluginService, props: Map, rdsClientFunc?: (hostInfo: HostInfo, region: string) => RDSClient) { + super(); + this.pluginService = pluginService; + this.props = props; + + if (rdsClientFunc) { + this.rdsClientFunc = rdsClientFunc; + } else { + this.rdsClientFunc = (hostInfo: HostInfo, region: string) => { + return new RDSClient({ region: region }); + }; + } + + this.shouldWaitForInfo = WrapperProperties.WAIT_FOR_CUSTOM_ENDPOINT_INFO.get(this.props); + this.waitOnCachedInfoDurationMs = WrapperProperties.WAIT_FOR_CUSTOM_ENDPOINT_INFO_TIMEOUT_MS.get(this.props); + this.idleMonitorExpirationMs = WrapperProperties.CUSTOM_ENDPOINT_MONITOR_IDLE_EXPIRATION_MS.get(this.props); + + const telemetryFactory = this.pluginService.getTelemetryFactory(); + this.waitForInfoCounter = telemetryFactory.createCounter(CustomEndpointPlugin.TELEMETRY_WAIT_FOR_INFO_COUNTER); + } + + getSubscribedMethods(): Set { + return CustomEndpointPlugin.SUBSCRIBED_METHODS; + } + + async connect( + hostInfo: HostInfo, + props: Map, + isInitialConnection: boolean, + connectFunc: () => Promise + ): Promise { + if (!CustomEndpointPlugin.rdsUtils.isRdsCustomClusterDns(hostInfo.host)) { + return await connectFunc(); + } + + this.customEndpointHostInfo = hostInfo; + logger.debug(Messages.get("CustomEndpointPlugin.connectionRequestToCustomEndpoint", hostInfo.host)); + + this.customEndpointId = CustomEndpointPlugin.rdsUtils.getRdsInstanceId(hostInfo.host); + if (!this.customEndpointId) { + throw new AwsWrapperError(Messages.get("CustomEndpointPlugin.errorParsingEndpointIdentifier", this.customEndpointHostInfo.host)); + } + + this.region = RegionUtils.getRegion(props.get(WrapperProperties.CUSTOM_ENDPOINT_REGION.name), this.customEndpointHostInfo.host); + if (!this.region) { + throw new AwsWrapperError(Messages.get("CustomEndpointPlugin.unableToDetermineRegion", WrapperProperties.CUSTOM_ENDPOINT_REGION.name)); + } + + const monitor: CustomEndpointMonitor = this.createMonitorIfAbsent(props); + + if (this.shouldWaitForInfo) { + // If needed, wait a short time for custom endpoint info to be discovered. + await this.waitForCustomEndpointInfo(monitor); + } + + return await connectFunc(); + } + + createMonitorIfAbsent(props: Map): CustomEndpointMonitor { + return CustomEndpointPlugin.monitors.computeIfAbsent( + this.customEndpointHostInfo.host, + (customEndpoint: string) => + new CustomEndpointMonitorImpl( + this.pluginService, + this.customEndpointHostInfo, + this.customEndpointId, + this.region, + WrapperProperties.CUSTOM_ENDPOINT_INFO_REFRESH_RATE.get(this.props), + this.rdsClientFunc + ), + BigInt(this.idleMonitorExpirationMs * 1000000) + ); + } + + async waitForCustomEndpointInfo(monitor: CustomEndpointMonitor): Promise { + let hasCustomEndpointInfo = monitor.hasCustomEndpointInfo(); + + if (!hasCustomEndpointInfo) { + // Wait for the monitor to place the custom endpoint info in the cache. This ensures other plugins get accurate + // custom endpoint info. + this.waitForInfoCounter.inc(); + logger.verbose( + Messages.get("CustomEndpointPlugin.waitingForCustomEndpointInfo", this.customEndpointHostInfo.host, String(this.waitOnCachedInfoDurationMs)) + ); + + const waitForEndpointInfoTimeoutMs = Date.now() + this.waitOnCachedInfoDurationMs; + while (!hasCustomEndpointInfo && Date.now() < waitForEndpointInfoTimeoutMs) { + await sleep(100); + hasCustomEndpointInfo = monitor.hasCustomEndpointInfo(); + } + + if (!hasCustomEndpointInfo) { + throw new AwsWrapperError( + Messages.get( + "CustomEndpointPlugin.timedOutWaitingForCustomEndpointInfo", + String(this.waitOnCachedInfoDurationMs), + this.customEndpointHostInfo.host + ) + ); + } + } + } + + async execute(methodName: string, methodFunc: () => Promise, methodArgs: any[]): Promise { + if (!this.customEndpointHostInfo) { + return await methodFunc(); + } + + const monitor = this.createMonitorIfAbsent(this.props); + if (this.shouldWaitForInfo) { + // If needed, wait a short time for custom endpoint info to be discovered. + await this.waitForCustomEndpointInfo(monitor); + } + + return await methodFunc(); + } + + static closeMonitors() { + logger.info(Messages.get("CustomEndpointPlugin.closeMonitors")); + // The clear call automatically calls close() on all monitors. + CustomEndpointPlugin.monitors.clear(); + } + + async releaseResources(): Promise { + CustomEndpointPlugin.closeMonitors(); + } +} diff --git a/common/lib/plugins/custom_endpoint/custom_endpoint_plugin_factory.ts b/common/lib/plugins/custom_endpoint/custom_endpoint_plugin_factory.ts new file mode 100644 index 00000000..0e945c2e --- /dev/null +++ b/common/lib/plugins/custom_endpoint/custom_endpoint_plugin_factory.ts @@ -0,0 +1,35 @@ +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +import { ConnectionPluginFactory } from "../../plugin_factory"; +import { PluginService } from "../../plugin_service"; +import { AwsWrapperError } from "../../utils/errors"; +import { Messages } from "../../utils/messages"; + +export class CustomEndpointPluginFactory extends ConnectionPluginFactory { + private static customEndpointPlugin: any; + + async getInstance(pluginService: PluginService, props: Map) { + try { + if (!CustomEndpointPluginFactory.customEndpointPlugin) { + CustomEndpointPluginFactory.customEndpointPlugin = await import("./custom_endpoint_plugin"); + } + return new CustomEndpointPluginFactory.customEndpointPlugin.CustomEndpointPlugin(pluginService, props); + } catch (error: any) { + throw new AwsWrapperError(Messages.get("ConnectionPluginChainBuilder.errorImportingPlugin", error.message, "CustomEndpointPlugin")); + } + } +} diff --git a/common/lib/plugins/custom_endpoint/custom_endpoint_role_type.ts b/common/lib/plugins/custom_endpoint/custom_endpoint_role_type.ts new file mode 100644 index 00000000..b4e2abec --- /dev/null +++ b/common/lib/plugins/custom_endpoint/custom_endpoint_role_type.ts @@ -0,0 +1,31 @@ +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +export enum CustomEndpointRoleType { + ANY, + WRITER, + READER +} + +const nameToValue = new Map([ + ["ANY", CustomEndpointRoleType.ANY], + ["WRITER", CustomEndpointRoleType.WRITER], + ["READER", CustomEndpointRoleType.READER] +]); + +export function customEndpointRoleTypeFromValue(name: string): CustomEndpointRoleType { + return nameToValue.get(name.toUpperCase()) ?? CustomEndpointRoleType.ANY; +} diff --git a/common/lib/plugins/custom_endpoint/member_list_type.ts b/common/lib/plugins/custom_endpoint/member_list_type.ts new file mode 100644 index 00000000..cbe49b22 --- /dev/null +++ b/common/lib/plugins/custom_endpoint/member_list_type.ts @@ -0,0 +1,20 @@ +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +export enum MemberListType { + STATIC_LIST, + EXCLUSION_LIST +} diff --git a/common/lib/plugins/efm/host_monitoring_connection_plugin.ts b/common/lib/plugins/efm/host_monitoring_connection_plugin.ts index d27fb26e..97a1f444 100644 --- a/common/lib/plugins/efm/host_monitoring_connection_plugin.ts +++ b/common/lib/plugins/efm/host_monitoring_connection_plugin.ts @@ -84,7 +84,7 @@ export class HostMonitoringConnectionPlugin extends AbstractConnectionPlugin imp } async execute(methodName: string, methodFunc: () => Promise, methodArgs: any): Promise { - const isEnabled: boolean = WrapperProperties.FAILURE_DETECTION_ENABLED.get(this.properties) as boolean; + const isEnabled: boolean = WrapperProperties.FAILURE_DETECTION_ENABLED.get(this.properties); if (!isEnabled || !SubscribedMethodHelper.NETWORK_BOUND_METHODS.includes(methodName)) { return methodFunc(); diff --git a/common/lib/plugins/failover/failover_plugin.ts b/common/lib/plugins/failover/failover_plugin.ts index cc734472..964fcf96 100644 --- a/common/lib/plugins/failover/failover_plugin.ts +++ b/common/lib/plugins/failover/failover_plugin.ts @@ -40,7 +40,7 @@ import { RdsUrlType } from "../../utils/rds_url_type"; import { RdsUtils } from "../../utils/rds_utils"; import { Messages } from "../../utils/messages"; import { ClientWrapper } from "../../client_wrapper"; -import { getWriter } from "../../utils/utils"; +import { getWriter, logTopology } from "../../utils/utils"; import { TelemetryCounter } from "../../utils/telemetry/telemetry_counter"; import { TelemetryTraceLevel } from "../../utils/telemetry/telemetry_trace_level"; @@ -148,15 +148,6 @@ export class FailoverPlugin extends AbstractConnectionPlugin { } initHostProviderFunc(); - - this.failoverMode = failoverModeFromValue(WrapperProperties.FAILOVER_MODE.get(this._properties)); - this._rdsUrlType = this._rdsHelper.identifyRdsType(hostInfo.host); - - if (this.failoverMode === FailoverMode.UNKNOWN) { - this.failoverMode = this._rdsUrlType === RdsUrlType.RDS_READER_CLUSTER ? FailoverMode.READER_OR_WRITER : FailoverMode.STRICT_WRITER; - } - - logger.debug(Messages.get("Failover.parameterValue", "failoverMode", FailoverMode[this.failoverMode])); } override notifyConnectionChanged(changes: Set): Promise { @@ -213,8 +204,8 @@ export class FailoverPlugin extends AbstractConnectionPlugin { return ( this.enableFailoverSetting && this._rdsUrlType !== RdsUrlType.RDS_PROXY && - this.pluginService.getHosts() && - this.pluginService.getHosts().length > 0 + this.pluginService.getAllHosts() && + this.pluginService.getAllHosts().length > 0 ); } @@ -228,7 +219,7 @@ export class FailoverPlugin extends AbstractConnectionPlugin { } private getCurrentWriter(): HostInfo | null { - const topology = this.pluginService.getHosts(); + const topology = this.pluginService.getAllHosts(); if (topology.length == 0) { return null; } @@ -282,6 +273,7 @@ export class FailoverPlugin extends AbstractConnectionPlugin { isInitialConnection: boolean, connectFunc: () => Promise ): Promise { + this.initFailoverMode(); return await this._staleDnsHelper.getVerifiedConnection(hostInfo.host, isInitialConnection, this.hostListProviderService!, props, connectFunc); } @@ -403,7 +395,7 @@ export class FailoverPlugin extends AbstractConnectionPlugin { try { await telemetryContext.start(async () => { try { - const result = await this._writerFailoverHandler.failover(this.pluginService.getHosts()); + const result = await this._writerFailoverHandler.failover(this.pluginService.getAllHosts()); if (result) { const error = result.error; @@ -423,13 +415,27 @@ export class FailoverPlugin extends AbstractConnectionPlugin { throw new AwsWrapperError(Messages.get("Failover.unableToDetermineWriter")); } + await this.pluginService.refreshHostList(); + const allowedHosts = this.pluginService.getHosts(); + if (!allowedHosts.some((hostInfo: HostInfo) => hostInfo.host === writerHostInfo?.host)) { + const failoverErrorMessage = Messages.get( + "Failover.newWriterNotAllowed", + writerHostInfo ? writerHostInfo.host : "", + logTopology(allowedHosts, "[Failover.newWriterNotAllowed] ") + ); + logger.error(failoverErrorMessage); + await this.pluginService.abortTargetClient(result.client); + throw new FailoverFailedError(failoverErrorMessage); + } + await this.pluginService.abortCurrentClient(); await this.pluginService.setCurrentClient(result.client, writerHostInfo); logger.debug(Messages.get("Failover.establishedConnection", this.pluginService.getCurrentHostInfo()?.host ?? "")); - await this.pluginService.refreshHostList(); this.throwFailoverSuccessError(); - this.failoverWriterSuccessCounter.inc(); } catch (error: any) { + if (error instanceof FailoverSuccessError) { + this.failoverWriterSuccessCounter.inc(); + } this.failoverWriterFailedCounter.inc(); throw error; } @@ -548,4 +554,20 @@ export class FailoverPlugin extends AbstractConnectionPlugin { return false; } + + initFailoverMode(): void { + if (!this._rdsUrlType) { + this.failoverMode = failoverModeFromValue(WrapperProperties.FAILOVER_MODE.get(this._properties)); + const initialHostInfo = this.hostListProviderService.getInitialConnectionHostInfo(); + this._rdsUrlType = this._rdsHelper.identifyRdsType(initialHostInfo.host); + + if (this.failoverMode === FailoverMode.UNKNOWN) { + this.failoverMode = this._rdsUrlType === RdsUrlType.RDS_READER_CLUSTER ? FailoverMode.READER_OR_WRITER : FailoverMode.STRICT_WRITER; + } + + this._readerFailoverHandler.setEnableFailoverStrictReader(this.failoverMode === FailoverMode.STRICT_READER); + + logger.debug(Messages.get("Failover.parameterValue", "failoverMode", FailoverMode[this.failoverMode])); + } + } } diff --git a/common/lib/plugins/failover/reader_failover_handler.ts b/common/lib/plugins/failover/reader_failover_handler.ts index c751be7d..87a4b8de 100644 --- a/common/lib/plugins/failover/reader_failover_handler.ts +++ b/common/lib/plugins/failover/reader_failover_handler.ts @@ -17,7 +17,7 @@ import { HostInfo } from "../../host_info"; import { PluginService } from "../../plugin_service"; import { ReaderFailoverResult } from "./reader_failover_result"; -import { getTimeoutTask, logAndThrowError, maskProperties, shuffleList, sleep } from "../../utils/utils"; +import { getTimeoutTask, maskProperties, shuffleList, sleep } from "../../utils/utils"; import { HostRole } from "../../host_role"; import { HostAvailability } from "../../host_availability/host_availability"; import { AwsWrapperError, InternalQueryTimeoutError } from "../../utils/errors"; @@ -31,6 +31,8 @@ export interface ReaderFailoverHandler { failover(hosts: HostInfo[], currentHost: HostInfo): Promise; getReaderConnection(hostList: HostInfo[]): Promise; + + setEnableFailoverStrictReader(enableFailoverStrictReader: boolean): void; } export class ClusterAwareReaderFailoverHandler implements ReaderFailoverHandler { @@ -41,7 +43,7 @@ export class ClusterAwareReaderFailoverHandler implements ReaderFailoverHandler private readonly initialConnectionProps: Map; private readonly maxFailoverTimeoutMs: number; private readonly timeoutMs: number; - private readonly enableFailoverStrictReader: boolean; + private enableFailoverStrictReader: boolean; private readonly pluginService: PluginService; private taskHandler: ReaderTaskSelectorHandler = new ReaderTaskSelectorHandler(); @@ -59,6 +61,10 @@ export class ClusterAwareReaderFailoverHandler implements ReaderFailoverHandler this.enableFailoverStrictReader = enableFailoverStrictReader; } + setEnableFailoverStrictReader(enableFailoverStrictReader: boolean): void { + this.enableFailoverStrictReader = enableFailoverStrictReader; + } + async failover(hosts: HostInfo[], currentHost: HostInfo | null): Promise { if (hosts == null || hosts.length === 0) { logger.info(Messages.get("ClusterAwareReaderFailoverHandler.invalidTopology", "failover")); @@ -113,7 +119,7 @@ export class ClusterAwareReaderFailoverHandler implements ReaderFailoverHandler // Ensure new connection is to a reader host await this.pluginService.refreshHostList(); try { - if ((await this.pluginService.getHostRole(result.client)) !== HostRole.READER) { + if ((await this.pluginService.getHostRole(result.client)) === HostRole.READER) { return result; } } catch (error) { diff --git a/common/lib/plugins/failover/writer_failover_handler.ts b/common/lib/plugins/failover/writer_failover_handler.ts index aa044795..84ce0f25 100644 --- a/common/lib/plugins/failover/writer_failover_handler.ts +++ b/common/lib/plugins/failover/writer_failover_handler.ts @@ -237,7 +237,7 @@ class ReconnectToWriterHandlerTask { props.set(WrapperProperties.HOST.name, this.originalWriterHost.host); this.currentClient = await this.pluginService.forceConnect(this.originalWriterHost, props); await this.pluginService.forceRefreshHostList(this.currentClient); - latestTopology = this.pluginService.getHosts(); + latestTopology = this.pluginService.getAllHosts(); } catch (error) { // Propagate errors that are not caused by network errors. if (error instanceof AwsWrapperError && !this.pluginService.isNetworkError(error)) { @@ -384,7 +384,7 @@ class WaitForNewWriterHandlerTask { if (this.currentReaderTargetClient) { await this.pluginService.forceRefreshHostList(this.currentReaderTargetClient); } - const topology = this.pluginService.getHosts(); + const topology = this.pluginService.getAllHosts(); if (topology && topology.length > 0) { if (topology.length === 1) { diff --git a/common/lib/plugins/failover2/failover2_plugin.ts b/common/lib/plugins/failover2/failover2_plugin.ts index ddadb8fa..51de47c3 100644 --- a/common/lib/plugins/failover2/failover2_plugin.ts +++ b/common/lib/plugins/failover2/failover2_plugin.ts @@ -41,6 +41,7 @@ import { HostRole } from "../../host_role"; import { CanReleaseResources } from "../../can_release_resources"; import { MonitoringRdsHostListProvider } from "../../host_list_provider/monitoring/monitoring_host_list_provider"; import { ReaderFailoverResult } from "../failover/reader_failover_result"; +import { logTopology } from "../../utils/utils"; export class Failover2Plugin extends AbstractConnectionPlugin implements CanReleaseResources { private static readonly TELEMETRY_WRITER_FAILOVER = "failover to writer instance"; @@ -119,8 +120,8 @@ export class Failover2Plugin extends AbstractConnectionPlugin implements CanRele return ( this.enableFailoverSetting && this._rdsUrlType !== RdsUrlType.RDS_PROXY && - this.pluginService.getHosts() && - this.pluginService.getHosts().length > 0 + this.pluginService.getAllHosts() && + this.pluginService.getAllHosts().length > 0 ); } @@ -382,11 +383,22 @@ export class Failover2Plugin extends AbstractConnectionPlugin implements CanRele this.logAndThrowError(Messages.get("Failover2.unableToFetchTopology")); } - const hosts: HostInfo[] = this.pluginService.getHosts(); + const hosts: HostInfo[] = this.pluginService.getAllHosts(); let writerCandidateClient: ClientWrapper = null; const writerCandidateHostInfo: HostInfo = hosts.find((x) => x.role === HostRole.WRITER); + const allowedHosts = this.pluginService.getHosts(); + if (!allowedHosts.some((hostInfo: HostInfo) => hostInfo.host === writerCandidateHostInfo?.host)) { + const failoverErrorMessage = Messages.get( + "Failover.newWriterNotAllowed", + writerCandidateHostInfo ? writerCandidateHostInfo.host : "", + logTopology(allowedHosts, "[Failover.newWriterNotAllowed] ") + ); + logger.error(failoverErrorMessage); + throw new FailoverFailedError(failoverErrorMessage); + } + if (writerCandidateHostInfo) { try { writerCandidateClient = await this.createConnectionForHost(writerCandidateHostInfo); diff --git a/common/lib/plugins/federated_auth/federated_auth_plugin.ts b/common/lib/plugins/federated_auth/federated_auth_plugin.ts index 8ce215c0..6f718f9d 100644 --- a/common/lib/plugins/federated_auth/federated_auth_plugin.ts +++ b/common/lib/plugins/federated_auth/federated_auth_plugin.ts @@ -27,6 +27,7 @@ import { CredentialsProviderFactory } from "./credentials_provider_factory"; import { SamlUtils } from "../../utils/saml_utils"; import { ClientWrapper } from "../../client_wrapper"; import { TelemetryCounter } from "../../utils/telemetry/telemetry_counter"; +import { RegionUtils } from "../../utils/region_utils"; export class FederatedAuthPlugin extends AbstractConnectionPlugin { protected static readonly tokenCache = new Map(); @@ -70,7 +71,7 @@ export class FederatedAuthPlugin extends AbstractConnectionPlugin { const host = IamAuthUtils.getIamHost(props, hostInfo); const port = IamAuthUtils.getIamPort(props, hostInfo, this.pluginService.getDialect().getDefaultPort()); - const region: string = IamAuthUtils.getRdsRegion(host, this.rdsUtils, props); + const region: string = RegionUtils.getRegion(props.get(WrapperProperties.IAM_REGION.name), host); const cacheKey = IamAuthUtils.getCacheKey(port, WrapperProperties.DB_USER.get(props), host, region); const tokenInfo = FederatedAuthPlugin.tokenCache.get(cacheKey); diff --git a/common/lib/plugins/federated_auth/okta_auth_plugin.ts b/common/lib/plugins/federated_auth/okta_auth_plugin.ts index 65be3bba..57474801 100644 --- a/common/lib/plugins/federated_auth/okta_auth_plugin.ts +++ b/common/lib/plugins/federated_auth/okta_auth_plugin.ts @@ -27,6 +27,7 @@ import { Messages } from "../../utils/messages"; import { AwsWrapperError } from "../../utils/errors"; import { ClientWrapper } from "../../client_wrapper"; import { TelemetryCounter } from "../../utils/telemetry/telemetry_counter"; +import { RegionUtils } from "../../utils/region_utils"; export class OktaAuthPlugin extends AbstractConnectionPlugin { protected static readonly tokenCache = new Map(); @@ -70,7 +71,7 @@ export class OktaAuthPlugin extends AbstractConnectionPlugin { const host = IamAuthUtils.getIamHost(props, hostInfo); const port = IamAuthUtils.getIamPort(props, hostInfo, this.pluginService.getDialect().getDefaultPort()); - const region = IamAuthUtils.getRdsRegion(host, this.rdsUtils, props); + const region = RegionUtils.getRegion(props.get(WrapperProperties.IAM_REGION.name), host); const cacheKey = IamAuthUtils.getCacheKey(port, WrapperProperties.DB_USER.get(props), host, region); const tokenInfo = OktaAuthPlugin.tokenCache.get(cacheKey); diff --git a/common/lib/plugins/read_write_splitting_plugin.ts b/common/lib/plugins/read_write_splitting_plugin.ts index 2d4148e3..c8985765 100644 --- a/common/lib/plugins/read_write_splitting_plugin.ts +++ b/common/lib/plugins/read_write_splitting_plugin.ts @@ -222,7 +222,6 @@ export class ReadWriteSplittingPlugin extends AbstractConnectionPlugin implement if (hosts == null || hosts.length === 0) { logAndThrowError(Messages.get("ReadWriteSplittingPlugin.emptyHostList")); } - const currentHost = this.pluginService.getCurrentHostInfo(); if (currentHost == null) { logAndThrowError(Messages.get("ReadWriteSplittingPlugin.unavailableHostInfo")); @@ -341,6 +340,11 @@ export class ReadWriteSplittingPlugin extends AbstractConnectionPlugin implement return; } + if (this._readerHostInfo && !hosts.some((hostInfo: HostInfo) => hostInfo.host === this._readerHostInfo?.host)) { + // The old reader cannot be used anymore because it is no longer in the list of allowed hosts. + await this.closeTargetClientIfIdle(this.readerTargetClient); + } + this._inReadWriteSplit = true; if (!(await this.isTargetClientUsable(this.readerTargetClient))) { await this.initializeReaderClient(hosts); diff --git a/common/lib/plugins/stale_dns/stale_dns_helper.ts b/common/lib/plugins/stale_dns/stale_dns_helper.ts index 1e9122c4..b4c94f8c 100644 --- a/common/lib/plugins/stale_dns/stale_dns_helper.ts +++ b/common/lib/plugins/stale_dns/stale_dns_helper.ts @@ -27,7 +27,7 @@ import { AwsWrapperError } from "../../utils/errors"; import { HostChangeOptions } from "../../host_change_options"; import { WrapperProperties } from "../../wrapper_property"; import { ClientWrapper } from "../../client_wrapper"; -import { getWriter } from "../../utils/utils"; +import { getWriter, logTopology } from "../../utils/utils"; import { TelemetryFactory } from "../../utils/telemetry/telemetry_factory"; import { TelemetryCounter } from "../../utils/telemetry/telemetry_counter"; @@ -88,6 +88,8 @@ export class StaleDnsHelper { await this.pluginService.refreshHostList(currentTargetClient); } + logger.debug(logTopology(this.pluginService.getAllHosts(), "[StaleDnsHelper.getVerifiedConnection] ")); + if (!this.writerHostInfo) { const writerCandidate = getWriter(this.pluginService.getHosts()); if (writerCandidate && this.rdsUtils.isRdsClusterDns(writerCandidate.host)) { diff --git a/common/lib/utils/connection_url_parser.ts b/common/lib/utils/connection_url_parser.ts index 77e417a8..1fdb3487 100644 --- a/common/lib/utils/connection_url_parser.ts +++ b/common/lib/utils/connection_url_parser.ts @@ -77,7 +77,7 @@ export abstract class ConnectionUrlParser { const hostsList: HostInfo[] = []; const hosts: string[] = this.getHostPortPairsFromUrl(initialConnection); hosts.forEach((pair, i) => { - let host; + let host: HostInfo; if (singleWriterConnectionString) { const role: HostRole = i > 0 ? HostRole.READER : HostRole.WRITER; host = this.parseHostPortPair(pair, fallbackPort, builderFunc, role); diff --git a/common/lib/utils/errors.ts b/common/lib/utils/errors.ts index 675ea713..14e0f2e4 100644 --- a/common/lib/utils/errors.ts +++ b/common/lib/utils/errors.ts @@ -14,8 +14,6 @@ limitations under the License. */ -import { Messages } from "./messages"; - export class AwsWrapperError extends Error { constructor(message?: string, cause?: any) { super(message); diff --git a/common/lib/utils/iam_auth_utils.ts b/common/lib/utils/iam_auth_utils.ts index 5ee415ae..9b933f43 100644 --- a/common/lib/utils/iam_auth_utils.ts +++ b/common/lib/utils/iam_auth_utils.ts @@ -16,7 +16,7 @@ import { logger } from "../../logutils"; import { HostInfo } from "../host_info"; -import { WrapperProperties } from "../wrapper_property"; +import { WrapperProperties, WrapperProperty } from "../wrapper_property"; import { AwsWrapperError } from "./errors"; import { Messages } from "./messages"; import { RdsUtils } from "./rds_utils"; @@ -49,7 +49,7 @@ export class IamAuthUtils { } } - public static getRdsRegion(hostname: string, rdsUtils: RdsUtils, props: Map): string { + public static getRdsRegion(hostname: string, rdsUtils: RdsUtils, props: Map, wrapperProperty: WrapperProperty): string { const rdsRegion = rdsUtils.getRdsRegion(hostname); if (!rdsRegion) { @@ -58,7 +58,7 @@ export class IamAuthUtils { throw new AwsWrapperError(errorMessage); } - return WrapperProperties.IAM_REGION.get(props) ? WrapperProperties.IAM_REGION.get(props) : rdsRegion; + return wrapperProperty.get(props) ? wrapperProperty.get(props) : rdsRegion; } public static getCacheKey(port: number, user?: string, hostname?: string, region?: string): string { diff --git a/common/lib/utils/locales/en.json b/common/lib/utils/locales/en.json index 7dbba53a..6ce458bd 100644 --- a/common/lib/utils/locales/en.json +++ b/common/lib/utils/locales/en.json @@ -65,6 +65,7 @@ "Failover.transactionResolutionUnknownError": "Unknown transaction resolution error occurred during failover.", "Failover.connectionExplicitlyClosed": "Unable to failover on an explicitly closed connection.", "Failover.timeoutError": "Internal failover task has timed out.", + "Failover.newWriterNotAllowed": "The failover process identified the new writer but the host is not in the list of allowed hosts. New writer host: '%s'. Allowed hosts: '%s'.", "StaleDnsHelper.clusterEndpointDns": "Cluster endpoint resolves to '%s'.", "StaleDnsHelper.writerHostInfo": "Writer host: '%s'.", "StaleDnsHelper.writerInetAddress": "Writer host address: '%s'", @@ -137,18 +138,18 @@ "MonitorService.emptyAliasSet": "Empty alias set passed for '%s'. Set should not be empty.", "PluginService.hostListEmpty": "Current host list is empty.", "PluginService.releaseResources": "Releasing resources.", - "PluginService.hostsChangelistEmpty": "There are no changes in the hosts' availability.", + "PluginService.hostsChangeListEmpty": "There are no changes in the hosts' availability.", "PluginService.failedToRetrieveHostPort": "Could not retrieve Host:Port for connection.", "PluginService.nonEmptyAliases": "fillAliases called when HostInfo already contains the following aliases: '%s'.", "PluginService.forceMonitoringRefreshTimeout": "A timeout exception occurred after waiting '%s' ms for refreshed topology.", "PluginService.requiredBlockingHostListProvider": "The detected host list provider is not a BlockingHostListProvider. A BlockingHostListProvider is required to force refresh the host list. Detected host list provider: '%s'.", "MonitoringHostListProvider.requiresMonitor": "The MonitoringRdsHostListProvider could not retrieve or initialize a ClusterTopologyMonitor for refreshing the topology.", "MonitoringHostListProvider.errorForceRefresh": "The MonitoringRdsHostListProvider could not refresh the topology, caught error: '%s'", + "PluginService.currentHostNotAllowed": "The current host is not in the list of allowed hosts. Current host: '%s'. Allowed hosts: '%s'.", "HostMonitoringConnectionPlugin.activatedMonitoring": "Executing method '%s', monitoring is activated.", "HostMonitoringConnectionPlugin.unableToIdentifyConnection": "Unable to identify the given connection: '%s', please ensure the correct host list provider is specified. The host list provider in use is: '%s'.", "HostMonitoringConnectionPlugin.errorIdentifyingConnection": "Error occurred while identifying connection: '%s'.", "HostMonitoringConnectionPlugin.unavailableHost": "Host '%s' is unavailable.", - "PluginServiceImpl.failedToRetrieveHostPort": "PluginServiceImpl.failedToRetrieveHostPort", "AuroraInitialConnectionStrategyPlugin.unsupportedStrategy": "Unsupported host selection strategy '%s'.", "AuroraInitialConnectionStrategyPlugin.requireDynamicProvider": "Dynamic host list provider is required.", "OpenedConnectionTracker.unableToPopulateOpenedConnectionQueue": "The driver is unable to track this opened connection because the instance endpoint is unknown: '%s'", @@ -224,5 +225,19 @@ "HostMonitor.startMonitoring": "Host monitor '%s' started.", "HostMonitor.detectedWriter": "Detected writer: '%s'.", "HostMonitor.endMonitoring": "Host monitor '%s' completed in '%s'.", - "HostMonitor.writerHostChanged": "Writer host has changed from '%s' to '%s'." + "HostMonitor.writerHostChanged": "Writer host has changed from '%s' to '%s'.", + "CustomEndpointPlugin.connectionRequestToCustomEndpoint": "Detected a connection request to a custom endpoint URL: '%s'.", + "CustomEndpointPlugin.errorParsingEndpointIdentifier": "Unable to parse custom endpoint identifier from URL: '%s'.", + "CustomEndpointPlugin.unableToDetermineRegion": "Unable to determine connection region. If you are using a non-standard RDS URL, please set the '%s' property.", + "CustomEndpointPlugin.waitingForCustomEndpointInfo": "Custom endpoint info for '%s' was not found. Waiting '%s' ms for the endpoint monitor to fetch info...", + "CustomEndpointPlugin.closeMonitors": "Closing custom endpoint monitors. Active custom endpoint monitors will be stopped, closed, and removed from the monitor's cache.", + "CustomEndpointPlugin.timedOutWaitingForCustomEndpointInfo": "The custom endpoint plugin timed out after '%s' ms while waiting for custom endpoint info for host '%s'.", + "CustomEndpointMonitorImpl.startingMonitor": "Starting custom endpoint monitor for '%s'.", + "CustomEndpointMonitorImpl.unexpectedNumberOfEndpoints": "Unexpected number of custom endpoints with endpoint identifier '%s' in region '%s'. Expected 1, but found '%s'. Endpoints:\n'%s'.", + "CustomEndpointMonitorImpl.detectedChangeInCustomEndpointInfo": "Detected change in custom endpoint info for '%s': %s", + "CustomEndpointMonitorImpl.error": "Encountered an error while monitoring custom endpoint '%s': '%s'", + "CustomEndpointMonitorImpl.stoppedMonitor": "Stopped custom endpoint monitor for '%s'.", + "CustomEndpointMonitorImpl.stoppingMonitor": "Stopping custom endpoint monitor for '%s'.", + "CustomEndpointMonitorImpl.noEndpoints": "Unable to find any custom endpoints. When connecting with a custom endpoint, at least one custom endpoint should be detected.", + "AwsSdk.unsupportedRegion": "Unsupported AWS region '%s'. For supported regions please read https://docs.aws.amazon.com/AmazonRDS/latest/UserGuide/Concepts.RegionsAndAvailabilityZones.html" } diff --git a/common/lib/utils/rds_utils.ts b/common/lib/utils/rds_utils.ts index 0555dc7c..ad2b5e10 100644 --- a/common/lib/utils/rds_utils.ts +++ b/common/lib/utils/rds_utils.ts @@ -165,7 +165,7 @@ export class RdsUtils { RdsUtils.AURORA_OLD_CHINA_DNS_PATTERN, RdsUtils.AURORA_GOV_DNS_PATTERN ); - if (this.getRegexGroup(matcher, RdsUtils.DNS_GROUP)) { + if (this.getRegexGroup(matcher, RdsUtils.DNS_GROUP) !== null) { return this.getRegexGroup(matcher, RdsUtils.INSTANCE_GROUP); } diff --git a/common/lib/utils/region_utils.ts b/common/lib/utils/region_utils.ts new file mode 100644 index 00000000..7b26a619 --- /dev/null +++ b/common/lib/utils/region_utils.ts @@ -0,0 +1,105 @@ +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +import { RdsUtils } from "./rds_utils"; +import { AwsWrapperError } from "./errors"; +import { Messages } from "./messages"; + +export class RegionUtils { + static readonly REGIONS: string[] = [ + "af-south-1", + "ap-east-1", + "ap-northeast-1", + "ap-northeast-2", + "ap-northeast-3", + "ap-south-1", + "ap-south-2", + "ap-southeast-1", + "ap-southeast-2", + "ap-southeast-3", + "ap-southeast-4", + "ap-southeast-5", + "aws-global", + "aws-cn-global", + "aws-us-gov-global", + "aws-iso-global", + "aws-iso-b-global", + "ca-central-1", + "ca-west-1", + "cn-north-1", + "cn-northwest-2", + "eu-central-1", + "eu-central-2", + "eu-isoe-west-1", + "eu-north-1", + "eu-south-1", + "eu-south-2", + "eu-west-1", + "eu-west-2", + "eu-west-3", + "il-central-1", + "me-central-1", + "me-south-1", + "sa-east-1", + "us-east-1", + "us-east-2", + "us-gov-east-1", + "us-gov-west-1", + "us-iso-east-1", + "us-iso-west-1", + "us-isob-east-1", + "us-west-1", + "us-west-2" + ]; + + protected static readonly rdsUtils = new RdsUtils(); + + static getRegion(regionString: string, host?: string): string | null { + const region = RegionUtils.getRegionFromRegionString(regionString); + + if (region !== null) { + return region; + } + + if (host) { + return RegionUtils.getRegionFromHost(host); + } + + return region; + } + + private static getRegionFromRegionString(regionString: string): string { + if (!regionString) { + return null; + } + + const region = regionString.toLowerCase().trim(); + if (!RegionUtils.REGIONS.includes(regionString)) { + throw new AwsWrapperError(Messages.get("AwsSdk.unsupportedRegion", regionString)); + } + + return region; + } + + private static getRegionFromHost(host: string): string | null { + const regionString = RegionUtils.rdsUtils.getRdsRegion(host); + if (!regionString) { + throw new AwsWrapperError(Messages.get("AwsSdk.unsupportedRegion", regionString)); + } + + return RegionUtils.getRegionFromRegionString(regionString); + } +} diff --git a/common/lib/utils/utils.ts b/common/lib/utils/utils.ts index 68b817a0..b5f29c0c 100644 --- a/common/lib/utils/utils.ts +++ b/common/lib/utils/utils.ts @@ -20,6 +20,7 @@ import { WrapperProperties } from "../wrapper_property"; import { HostRole } from "../host_role"; import { logger } from "../../logutils"; import { AwsWrapperError, InternalQueryTimeoutError } from "./errors"; +import { TopologyAwareDatabaseDialect } from "../topology_aware_database_dialect"; export function sleep(ms: number) { return new Promise((resolve) => setTimeout(resolve, ms)); @@ -79,3 +80,7 @@ export function logAndThrowError(message: string) { export function equalsIgnoreCase(value1: string | null, value2: string | null): boolean { return value1 != null && value2 != null && value1.localeCompare(value2, undefined, { sensitivity: "accent" }) === 0; } + +export function isDialectTopologyAware(dialect: any): dialect is TopologyAwareDatabaseDialect { + return dialect; +} diff --git a/common/lib/wrapper_property.ts b/common/lib/wrapper_property.ts index 452aead7..7ef996d4 100644 --- a/common/lib/wrapper_property.ts +++ b/common/lib/wrapper_property.ts @@ -387,6 +387,39 @@ export class WrapperProperties { static readonly PROFILE_NAME = new WrapperProperty("profileName", "Driver configuration profile name", null); + static readonly CUSTOM_ENDPOINT_INFO_REFRESH_RATE = new WrapperProperty( + "customEndpointInfoRefreshRateMs", + "Controls how frequently custom endpoint monitors fetch custom endpoint info, in milliseconds.", + 10_000 + ); + + static readonly WAIT_FOR_CUSTOM_ENDPOINT_INFO = new WrapperProperty( + "waitForCustomEndpointInfo", + "Controls whether to wait for custom endpoint info to become available before connecting or executing a " + + "method. Waiting is only necessary if a connection to a given custom endpoint has not been opened or used " + + "recently. Note that disabling this may result in occasional connections to instances outside of the " + + "custom endpoint.", + true + ); + + static readonly WAIT_FOR_CUSTOM_ENDPOINT_INFO_TIMEOUT_MS = new WrapperProperty( + "waitForCustomEndpointInfoTimeoutMs", + "Controls the maximum amount of time that the plugin will wait for custom endpoint info to be made available by the custom endpoint monitor, in milliseconds.", + 10_000 + ); + + static readonly CUSTOM_ENDPOINT_MONITOR_IDLE_EXPIRATION_MS = new WrapperProperty( + "customEndpointMonitorExpirationMs", + "Controls how long a monitor should run without use before expiring and being removed, in milliseconds.", + 900_000 // 15 min + ); + + static readonly CUSTOM_ENDPOINT_REGION = new WrapperProperty( + "customEndpointRegion", + "The region of the cluster's custom endpoints. If not specified, the region will be parsed from the URL.", + null + ); + static removeWrapperProperties(props: Map): Map { const persistingProperties = [ WrapperProperties.USER.name, diff --git a/docs/using-the-nodejs-wrapper/UsingTheNodejsWrapper.md b/docs/using-the-nodejs-wrapper/UsingTheNodejsWrapper.md index 3a78535b..d5895e0b 100644 --- a/docs/using-the-nodejs-wrapper/UsingTheNodejsWrapper.md +++ b/docs/using-the-nodejs-wrapper/UsingTheNodejsWrapper.md @@ -133,6 +133,7 @@ The AWS Advanced NodeJS Wrapper has several built-in plugins that are available | [Aurora Initial Connection Strategy Plugin](./using-plugins/UsingTheAuroraInitialConnectionStrategyPlugin.md) | `initialConnection` | Aurora | Allows users to configure their initial connection strategy to reader cluster endpoints. | None | | [Aurora Limitless Connection Plugin](./using-plugins/UsingTheAuroraInitialConnectionStrategyPlugin.md) | `limitless` | Aurora | Allows users to use Aurora Limitless Database and effectively load-balance load between available transaction routers. | None | | [Fastest Response Strategy Plugin](./using-plugins/UsingTheFastestResponseStrategyPlugin.md) | `fastestResponseStrategy` | Aurora | When read-write splitting is enabled, this plugin selects the reader to switch to based on the host with the fastest response time. The plugin achieves this by periodically monitoring the hosts' response times and storing the fastest host in a cache.

:warning:**Note:** the `readerHostSelector` strategy must be set to `fastestResponse` in the user-defined connection properties in order to enable this plugin. See [reader selection strategies](./ReaderSelectionStrategies.md) | None | +| [Custom Endpoint Plugin](./using-plugins/UsingTheCustomEndpointPlugin.md) | `customEndpoint` | Aurora | This plugin will analyse custom endpoint information to ensure instances used in connections are part of the custom endpoint being used. | See the Custom Endpoint Plugin [prerequisites](./using-plugins/UsingTheCustomEndpointPlugin.md#Prerequisites) | In addition to the built-in plugins, you can also create custom plugins more suitable for your needs. For more information, see [Custom Plugins](../development-guide/LoadablePlugins.md#using-custom-plugins). diff --git a/docs/using-the-nodejs-wrapper/using-plugins/UsingTheCustomEndpointPlugin.md b/docs/using-the-nodejs-wrapper/using-plugins/UsingTheCustomEndpointPlugin.md new file mode 100644 index 00000000..de6984a4 --- /dev/null +++ b/docs/using-the-nodejs-wrapper/using-plugins/UsingTheCustomEndpointPlugin.md @@ -0,0 +1,28 @@ +# Custom Endpoint Plugin + +The Custom Endpoint Plugin adds support for [RDS custom endpoints](https://docs.aws.amazon.com/AmazonRDS/latest/AuroraUserGuide/Aurora.Endpoints.Custom.html). When the Custom Endpoint Plugin is in use, the driver will analyse custom endpoint information to ensure instances used in connections are part of the custom endpoint being used. This includes connections used in failover and read-write splitting. + +## Prerequisites + +- This plugin requires the following packages to be installed: + - [@aws-sdk/client-rds](https://www.npmjs.com/package/@aws-sdk/client-rds) + +## How to use the Custom Endpoint Plugin with the AWS Advanced NodeJS Wrapper + +### Enabling the Custom Endpoint Plugin + +1. If needed, create a custom endpoint using the AWS RDS Console: + - If needed, review the documentation about [creating a custom endpoint](https://docs.aws.amazon.com/AmazonRDS/latest/AuroraUserGuide/aurora-custom-endpoint-creating.html). +2. Add the plugin code `customEndpoint` to the [`plugins`](../UsingTheNodejsWrapper.md#connection-plugin-manager-parameters) value, or to the current [driver profile](../UsingTheNodejsWrapper.md#connection-plugin-manager-parameters). +3. If you are using the failover plugin, set the failover parameter `failoverMode` according to the custom endpoint type. For example, if the custom endpoint you are using is of type `READER`, you can set `failoverMode` to `strict-reader`, or if it is of type `ANY`, you can set `failoverMode` to `reader-or-writer`. +4. Specify parameters that are required or specific to your case. + +### Custom Endpoint Plugin Parameters + +| Parameter | Value | Required | Description | Default Value | Example Value | +| ------------------------------------ | :-----: | :------: | :----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | --------------------- | ------------- | +| `customEndpointRegion` | String | No | The region of the cluster's custom endpoints. If not specified, the region will be parsed from the URL. | `null` | `us-west-1` | +| `customEndpointInfoRefreshRateMs` | number | No | Controls how frequently custom endpoint monitors fetch custom endpoint info, in milliseconds. | `10000` | `20000` | +| `customEndpointMonitorExpirationMs` | number | No | Controls how long a monitor should run without use before expiring and being removed, in milliseconds. | `900000` (15 minutes) | `600000` | +| `waitForCustomEndpointInfo` | boolean | No | Controls whether to wait for custom endpoint info to become available before connecting or executing a method. Waiting is only necessary if a connection to a given custom endpoint has not been opened or used recently. Note that disabling this may result in occasional connections to instances outside of the custom endpoint. | `true` | `true` | +| `waitForCustomEndpointInfoTimeoutMs` | number | No | Controls the maximum amount of time that the plugin will wait for custom endpoint info to be made available by the custom endpoint monitor, in milliseconds. | `10000` | `7000` | diff --git a/mysql/lib/client.ts b/mysql/lib/client.ts index 0f1ddc9b..00227188 100644 --- a/mysql/lib/client.ts +++ b/mysql/lib/client.ts @@ -31,6 +31,7 @@ import { ClientUtils } from "../../common/lib/utils/client_utils"; import { RdsMultiAZMySQLDatabaseDialect } from "./dialect/rds_multi_az_mysql_database_dialect"; import { TelemetryTraceLevel } from "../../common/lib/utils/telemetry/telemetry_trace_level"; import { MySQL2DriverDialect } from "./dialect/mysql2_driver_dialect"; +import { isDialectTopologyAware } from "../../common/lib/utils/utils"; export class AwsMySQLClient extends AwsClient { private static readonly knownDialectsByCode: Map = new Map([ @@ -53,6 +54,19 @@ export class AwsMySQLClient extends AwsClient { throw new AwsWrapperError(Messages.get("HostInfo.noHostParameter")); } const result: ClientWrapper = await this.pluginManager.connect(hostInfo, this.properties, true); + if (isDialectTopologyAware(this.pluginService.getDialect())) { + try { + const role = await this.pluginService.getHostRole(result); + // The current host role may be incorrect, use the created client to confirm the host role. + if (role !== result.hostInfo.role) { + result.hostInfo.role = role; + this.pluginService.setCurrentHostInfo(result.hostInfo); + this.pluginService.setInitialConnectionHostInfo(result.hostInfo); + } + } catch (error) { + // Ignore + } + } await this.pluginService.setCurrentClient(result, result.hostInfo); await this.internalPostConnect(); }); diff --git a/mysql/lib/dialect/mysql_database_dialect.ts b/mysql/lib/dialect/mysql_database_dialect.ts index 04fbb34a..bb17334c 100644 --- a/mysql/lib/dialect/mysql_database_dialect.ts +++ b/mysql/lib/dialect/mysql_database_dialect.ts @@ -28,6 +28,7 @@ import { ErrorHandler } from "../../../common/lib/error_handler"; import { MySQLErrorHandler } from "../mysql_error_handler"; import { SessionState } from "../../../common/lib/session_state"; import { Messages } from "../../../common/lib/utils/messages"; +import { HostRole } from "../../../common/lib/host_role"; export class MySQLDatabaseDialect implements DatabaseDialect { protected dialectName: string = this.constructor.name; @@ -206,4 +207,8 @@ export class MySQLDatabaseDialect implements DatabaseDialect { doesStatementSetSchema(statement: string): string | undefined { return undefined; } + + async getHostRole(targetClient: ClientWrapper): Promise { + throw new UnsupportedMethodError(`Method getHostRole not supported for dialect: ${this.dialectName}`); + } } diff --git a/pg/lib/client.ts b/pg/lib/client.ts index 4ba7e53a..80638d99 100644 --- a/pg/lib/client.ts +++ b/pg/lib/client.ts @@ -30,6 +30,7 @@ import { HostInfo } from "../../common/lib/host_info"; import { TelemetryTraceLevel } from "../../common/lib/utils/telemetry/telemetry_trace_level"; import { NodePostgresDriverDialect } from "./dialect/node_postgres_driver_dialect"; import { TransactionIsolationLevel } from "../../common/lib/utils/transaction_isolation_level"; +import { isDialectTopologyAware } from "../../common/lib/utils/utils"; export class AwsPGClient extends AwsClient { private static readonly knownDialectsByCode: Map = new Map([ @@ -52,6 +53,19 @@ export class AwsPGClient extends AwsClient { throw new AwsWrapperError(Messages.get("HostInfo.noHostParameter")); } const result: ClientWrapper = await this.pluginManager.connect(hostInfo, this.properties, true); + if (isDialectTopologyAware(this.pluginService.getDialect())) { + try { + const role = await this.pluginService.getHostRole(result); + // The current host role may be incorrect, use the created client to confirm the host role. + if (role !== result.hostInfo.role) { + result.hostInfo.role = role; + this.pluginService.setCurrentHostInfo(result.hostInfo); + this.pluginService.setInitialConnectionHostInfo(result.hostInfo); + } + } catch (error) { + // Ignore + } + } await this.pluginService.setCurrentClient(result, result.hostInfo); await this.internalPostConnect(); }); diff --git a/pg/lib/dialect/pg_database_dialect.ts b/pg/lib/dialect/pg_database_dialect.ts index 9d1d2d10..f259028d 100644 --- a/pg/lib/dialect/pg_database_dialect.ts +++ b/pg/lib/dialect/pg_database_dialect.ts @@ -25,8 +25,8 @@ import { ClientWrapper } from "../../../common/lib/client_wrapper"; import { FailoverRestriction } from "../../../common/lib/plugins/failover/failover_restriction"; import { ErrorHandler } from "../../../common/lib/error_handler"; import { PgErrorHandler } from "../pg_error_handler"; -import { SessionState } from "../../../common/lib/session_state"; import { Messages } from "../../../common/lib/utils/messages"; +import { HostRole } from "../../../common/lib/host_role"; export class PgDatabaseDialect implements DatabaseDialect { protected dialectName: string = this.constructor.name; @@ -189,4 +189,8 @@ export class PgDatabaseDialect implements DatabaseDialect { return undefined; } + + async getHostRole(targetClient: ClientWrapper): Promise { + throw new UnsupportedMethodError(`Method getHostRole not supported for dialect: ${this.dialectName}`); + } } diff --git a/tests/integration/container/tests/custom_endpoint.test.ts b/tests/integration/container/tests/custom_endpoint.test.ts new file mode 100644 index 00000000..f5e58403 --- /dev/null +++ b/tests/integration/container/tests/custom_endpoint.test.ts @@ -0,0 +1,453 @@ +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +import { features, instanceCount } from "./config"; +import { TestEnvironmentFeatures } from "./utils/test_environment_features"; +import { TestEnvironment } from "./utils/test_environment"; +import { AuroraTestUtility } from "./utils/aurora_test_utility"; +import { DriverHelper } from "./utils/driver_helper"; +import { AwsWrapperError, FailoverSuccessError } from "../../../../common/lib/utils/errors"; +import { + CreateDBClusterEndpointCommand, + DBClusterEndpoint, + DeleteDBClusterEndpointCommand, + DescribeDBClusterEndpointsCommand, + ModifyDBClusterEndpointCommand, + RDSClient +} from "@aws-sdk/client-rds"; +import { sleep } from "../../../../common/lib/utils/utils"; +import { randomUUID } from "node:crypto"; +import { TestInstanceInfo } from "./utils/test_instance_info"; +import { logger } from "../../../../common/logutils"; +import { ProxyHelper } from "./utils/proxy_helper"; +import { PluginManager } from "../../../../common/lib"; +import { TestDriver } from "./utils/test_driver"; + +const itIf = + features.includes(TestEnvironmentFeatures.FAILOVER_SUPPORTED) && + !features.includes(TestEnvironmentFeatures.PERFORMANCE) && + !features.includes(TestEnvironmentFeatures.RUN_AUTOSCALING_TESTS_ONLY) && + instanceCount >= 3 + ? it + : it.skip; + +const endpointId1 = `test-endpoint-1-${randomUUID()}`; +const endpointId2 = `test-endpoint-2-${randomUUID()}`; +let endpointId3: string; +let endpointInfo1: DBClusterEndpoint; +let endpointInfo2: DBClusterEndpoint; +let endpointInfo3: DBClusterEndpoint; +let instance1: string; +let instance2: string; + +let env: TestEnvironment; +let driver: TestDriver; +let client: any; +let rdsClient: RDSClient; +let initClientFunc: (props: any) => any; +let currentWriter: string; + +let auroraTestUtility: AuroraTestUtility; + +async function initDefaultConfig(host: string, port: number, connectToProxy: boolean, failoverMode: string, usingFailover1: boolean): Promise { + let config: any = { + user: env.databaseInfo.username, + host: host, + database: env.databaseInfo.defaultDbName, + password: env.databaseInfo.password, + port: port, + plugins: "customEndpoint,readWriteSplitting,failover", + failoverTimeoutMs: 250000, + failoverMode: failoverMode, + enableTelemetry: true, + telemetryTracesBackend: "OTLP", + telemetryMetricsBackend: "OTLP" + }; + if (usingFailover1) { + config["plugins"] = "customEndpoint,readWriteSplitting,failover"; + } else { + config["plugins"] = "customEndpoint,readWriteSplitting,failover2"; + } + if (connectToProxy) { + config["clusterInstanceHostPattern"] = "?." + env.proxyDatabaseInfo.instanceEndpointSuffix; + } + config = DriverHelper.addDriverSpecificConfiguration(config, env.engine); + return config; +} + +async function createEndpoint(clusterId: string, instances: TestInstanceInfo[], endpointId: string, endpointType: string) { + const instanceIds = instances.map((instance: TestInstanceInfo) => instance.instanceId); + const input = { + DBClusterEndpointIdentifier: endpointId, + DBClusterIdentifier: clusterId, + EndpointType: endpointType, + StaticMembers: instanceIds + }; + const createEndpointCommand = new CreateDBClusterEndpointCommand(input); + await rdsClient.send(createEndpointCommand); +} + +async function waitUntilEndpointAvailable(endpointId: string): Promise { + const timeoutEndMs = Date.now() + 300000; // 5 minutes + let available = false; + + while (Date.now() < timeoutEndMs) { + const input = { + DBClusterEndpointIdentifier: endpointId, + Filters: [ + { + Name: "db-cluster-endpoint-type", + Values: ["custom"] + } + ] + }; + const command = new DescribeDBClusterEndpointsCommand(input); + const result = await rdsClient.send(command); + const endpoints = result.DBClusterEndpoints; + if (endpoints.length !== 1) { + // Endpoint needs more time to get created + await sleep(3000); + } + + const responseEndpoint = endpoints[0]; + const endpointInfo = responseEndpoint; + + available = responseEndpoint.Status === "available"; + if (available) { + return endpointInfo; + } + + await sleep(3000); + } + + if (!available) { + throw Error(`The test setup step timed out while waiting for the custom endpoint to become available: '${endpointId}'.`); + } +} + +async function waitUntilEndpointHasMembers(endpointId: string, membersList: string[]): Promise { + const start = Date.now(); + + const timeoutEndMs = Date.now() + 1200000; // 20 minutes + let hasCorrectState = false; + while (Date.now() < timeoutEndMs) { + const input = { + DBClusterEndpointIdentifier: endpointId + }; + const command = new DescribeDBClusterEndpointsCommand(input); + const result = await rdsClient.send(command); + const endpoints = result.DBClusterEndpoints; + if (endpoints.length !== 1) { + fail( + `Unexpected number of endpoints returned while waiting for custom endpoint to have the specified list of members. Expected 1, got ${endpoints.length}.` + ); + } + + const endpoint = endpoints[0]; + membersList.sort(); + endpoint.StaticMembers.sort(); + hasCorrectState = endpoint.Status === "available" && arraysAreEqual(membersList, endpoint.StaticMembers); + if (hasCorrectState) { + break; + } + + await sleep(3000); + } + + if (!hasCorrectState) { + fail(`Timed out while waiting for the custom endpoint to stabilize: '${endpointId}'.`); + } + + logger.info(`waitUntilEndpointHasMembers took ${(Date.now() - start) / 1000} seconds`); +} + +function arraysAreEqual(array1: any[], array2: any[]): boolean { + if (array1.length !== array2.length) { + return false; + } + + array1.sort(); + array2.sort(); + + for (let i = 0; i < array1.length; i++) { + if (array1[i] !== array2[i]) { + return false; + } + } + + return true; +} + +async function deleteEndpoint(rdsClient: RDSClient, endpointId: string): Promise { + const input = { + DBClusterEndpointIdentifier: endpointId + }; + const deleteEndpointCommand = new DeleteDBClusterEndpointCommand(input); + try { + await rdsClient.send(deleteEndpointCommand); + } catch (e: any) { + // Custom endpoint already does not exist - do nothing. + } +} + +describe("custom endpoint", () => { + beforeAll(async () => { + env = await TestEnvironment.getCurrent(); + const clusterId = env.auroraClusterName; + const region = env.region; + rdsClient = new RDSClient({ region: region }); + + auroraTestUtility = new AuroraTestUtility(env.region); + driver = DriverHelper.getDriverForDatabaseEngine(env.engine); + initClientFunc = DriverHelper.getClient(driver); + await ProxyHelper.enableAllConnectivity(); + + await TestEnvironment.verifyClusterStatus(); + const instances = env.databaseInfo.instances; + instance1 = instances[0].instanceId; + instance2 = instances[1].instanceId; + await createEndpoint(clusterId, instances.slice(0, 1), endpointId1, "ANY"); + await createEndpoint(clusterId, instances.slice(0, 2), endpointId2, "ANY"); + endpointInfo1 = await waitUntilEndpointAvailable(endpointId1); + endpointInfo2 = await waitUntilEndpointAvailable(endpointId2); + }, 1000000); + + afterAll(async () => { + try { + await deleteEndpoint(rdsClient, endpointId1); + await deleteEndpoint(rdsClient, endpointId2); + } finally { + rdsClient.destroy(); + } + }); + + beforeEach(async () => { + await TestEnvironment.verifyClusterStatus(); + currentWriter = await auroraTestUtility.getClusterWriterInstanceId(env.info.auroraClusterName); + logger.info(`Test started: ${expect.getState().currentTestName}`); + }, 1000000); + + afterEach(async () => { + if (client !== null) { + try { + await client.end(); + } catch (error) { + // pass + } + } + + await PluginManager.releaseResources(); + }, 1000000); + + itIf.each([true, false])( + "test custom endpoint failover - strict reader", + async (usingFailover1: boolean) => { + endpointId3 = `test-endpoint-3-${randomUUID()}`; + await createEndpoint(env.auroraClusterName, env.instances.slice(0, 2), endpointId3, "READER"); + endpointInfo3 = await waitUntilEndpointAvailable(endpointId3); + + const config = await initDefaultConfig(endpointInfo3.Endpoint, env.databaseInfo.instanceEndpointPort, false, "strict-reader", usingFailover1); + client = initClientFunc(config); + + await client.connect(); + + const endpointMembers = endpointInfo3.StaticMembers; + const instanceId = await auroraTestUtility.queryInstanceId(client); + expect(endpointMembers.includes(instanceId)).toBeTruthy(); + expect(instanceId).not.toBe(currentWriter); + + // Use failover API to break connection. + await auroraTestUtility.failoverClusterAndWaitUntilWriterChanged( + currentWriter, + env.info.auroraClusterName, + instanceId === instance1 ? instance1 : instance2 + ); + + await expect(auroraTestUtility.queryInstanceId(client)).rejects.toThrow(FailoverSuccessError); + + endpointInfo3 = await waitUntilEndpointAvailable(endpointId3); + const newEndpointMembers = endpointInfo3.StaticMembers; + + const newInstanceId: string = await auroraTestUtility.queryInstanceId(client); + expect(newEndpointMembers.includes(newInstanceId)).toBeTruthy(); + + const newWriter = await auroraTestUtility.getClusterWriterInstanceId(env.info.auroraClusterName); + expect(newInstanceId).not.toBe(newWriter); + + await deleteEndpoint(rdsClient, endpointId3); + }, + 1000000 + ); + + itIf.each([true, false])( + "test custom endpoint read write splitting with custom endpoint changes", + async (usingFailover1: boolean) => { + const config = await initDefaultConfig( + endpointInfo1.Endpoint, + env.databaseInfo.instanceEndpointPort, + false, + "reader-or-writer", + usingFailover1 + ); + // This setting is not required for the test, but it allows us to also test re-creation of expired monitors since it + // takes more than 30 seconds to modify the cluster endpoint (usually around 140s). + config.customEndpointMonitorExpirationMs = 30000; + client = initClientFunc(config); + + await client.connect(); + + const endpointMembers = endpointInfo1.StaticMembers; + const instanceId1 = await auroraTestUtility.queryInstanceId(client); + expect(endpointMembers.includes(instanceId1)).toBeTruthy(); + + // Attempt to switch to an instance of the opposite role. This should fail since the custom endpoint consists only + // of the current host. + const newReadOnlyValue = currentWriter === instanceId1; + if (newReadOnlyValue) { + // We are connected to the writer. Attempting to switch to the reader will not work but will intentionally not + // throw an error. In this scenario we log a warning and purposefully stick with the writer. + await client.setReadOnly(newReadOnlyValue); + const newInstanceId = await auroraTestUtility.queryInstanceId(client); + expect(newInstanceId).toBe(instanceId1); + } else { + // We are connected to the reader. Attempting to switch to the writer will throw an error. + logger.info("Initial connection is to a reader. Attempting to switch to writer..."); + await expect(client.setReadOnly(newReadOnlyValue)).rejects.toThrow(AwsWrapperError); + } + + let newMember: string; + if (currentWriter === instanceId1) { + newMember = env.databaseInfo.instances[1].instanceId; + } else { + newMember = currentWriter; + } + + const modifyEndpointCommand = new ModifyDBClusterEndpointCommand({ + DBClusterEndpointIdentifier: endpointId1, + StaticMembers: [instanceId1, newMember] + }); + await rdsClient.send(modifyEndpointCommand); + + try { + await waitUntilEndpointHasMembers(endpointId1, [instanceId1, newMember]); + + // We should now be able to switch to newMember. + await client.setReadOnly(newReadOnlyValue); + const instanceId2 = await auroraTestUtility.queryInstanceId(client); + expect(instanceId2).toBe(newMember); + + // Switch back to original instance. + await client.setReadOnly(!newReadOnlyValue); + } finally { + const modifyEndpointCommand = new ModifyDBClusterEndpointCommand({ + DBClusterEndpointIdentifier: endpointId1, + StaticMembers: [instanceId1] + }); + await rdsClient.send(modifyEndpointCommand); + await waitUntilEndpointHasMembers(endpointId1, [instanceId1]); + } + + // We should not be able to switch again because newMember was removed from the custom endpoint. + if (newReadOnlyValue) { + // We are connected to the writer. Attempting to switch to the reader will not work but will intentionally not + // throw an error. In this scenario we log a warning and purposefully stick with the writer. + await client.setReadOnly(newReadOnlyValue); + const newInstanceId = await auroraTestUtility.queryInstanceId(client); + expect(newInstanceId).toBe(instanceId1); + } else { + // We are connected to the reader. Attempting to switch to the writer will throw an error. + await expect(client.setReadOnly(newReadOnlyValue)).rejects.toThrow(AwsWrapperError); + } + }, + 1000000 + ); + + itIf.each([true, false])( + "test custom endpoint failover - strict writer", + async (usingFailvoer1: boolean) => { + const config = await initDefaultConfig(endpointInfo2.Endpoint, env.databaseInfo.instanceEndpointPort, false, "strict-writer", usingFailvoer1); + client = initClientFunc(config); + + await client.connect(); + + const endpointMembers = endpointInfo2.StaticMembers; + const instanceId = await auroraTestUtility.queryInstanceId(client); + expect(endpointMembers.includes(instanceId)).toBeTruthy(); + + const connectedToWriter = instanceId === currentWriter; + let nextWriter: string; + if (connectedToWriter) { + nextWriter = instanceId === instance1 ? instance2 : instance1; + } else { + nextWriter = instanceId === instance1 ? instance1 : instance2; + } + + // Use failover API to break connection. + await auroraTestUtility.failoverClusterAndWaitUntilWriterChanged(currentWriter, env.info.auroraClusterName, nextWriter); + + await expect(auroraTestUtility.queryInstanceId(client)).rejects.toThrow(FailoverSuccessError); + + endpointInfo2 = await waitUntilEndpointAvailable(endpointId2); + const newEndpointMembers = endpointInfo2.StaticMembers; + + const newInstanceId: string = await auroraTestUtility.queryInstanceId(client); + expect(newEndpointMembers.includes(newInstanceId)).toBeTruthy(); + + const newWriter = await auroraTestUtility.getClusterWriterInstanceId(env.info.auroraClusterName); + expect(newInstanceId).toBe(newWriter); + }, + 1000000 + ); + + itIf.each([true, false])( + "test custom endpoint failover - reader or writer mode", + async (usingFailover1: boolean) => { + const config = await initDefaultConfig( + endpointInfo1.Endpoint, + env.databaseInfo.instanceEndpointPort, + false, + "reader-or-writer", + usingFailover1 + ); + client = initClientFunc(config); + + await client.connect(); + + const endpointMembers = endpointInfo1.StaticMembers; + const instanceId = await auroraTestUtility.queryInstanceId(client); + expect(endpointMembers.includes(instanceId)).toBeTruthy(); + + // Use failover API to break connection. + const connectedToWriter = instanceId === currentWriter; + let nextWriter: string; + if (connectedToWriter) { + nextWriter = instanceId === instance1 ? instance2 : instance1; + } else { + nextWriter = instanceId === instance1 ? instance1 : instance2; + } + await auroraTestUtility.failoverClusterAndWaitUntilWriterChanged(currentWriter, env.info.auroraClusterName, nextWriter); + + await expect(auroraTestUtility.queryInstanceId(client)).rejects.toThrow(FailoverSuccessError); + + endpointInfo1 = await waitUntilEndpointAvailable(endpointId1); + const newEndpointMembers = endpointInfo1.StaticMembers; + + const newInstanceId: string = await auroraTestUtility.queryInstanceId(client); + expect(newEndpointMembers.includes(newInstanceId)).toBeTruthy(); + }, + 1000000 + ); +}); diff --git a/tests/integration/container/tests/session_state.test.ts b/tests/integration/container/tests/session_state.test.ts index 3ce23be8..04d5bfa7 100644 --- a/tests/integration/container/tests/session_state.test.ts +++ b/tests/integration/container/tests/session_state.test.ts @@ -74,111 +74,115 @@ class TestAwsPGClient extends AwsPGClient { } describe("session state", () => { - it.only("test update state", async () => { - const env = await TestEnvironment.getCurrent(); - const driver = DriverHelper.getDriverForDatabaseEngine(env.engine); - let initClientFunc; - switch (driver) { - case TestDriver.MYSQL: - initClientFunc = (options: any) => new TestAwsMySQLClient(options); - break; - case TestDriver.PG: - initClientFunc = (options: any) => new TestAwsPGClient(options); - break; - default: - throw new Error("invalid driver"); - } - - let props = { - user: env.databaseInfo.username, - host: env.databaseInfo.clusterEndpoint, - database: env.databaseInfo.defaultDbName, - password: env.databaseInfo.password, - port: env.databaseInfo.clusterEndpointPort - }; - props = DriverHelper.addDriverSpecificConfiguration(props, env.engine); - client = initClientFunc(props); - - const newClient = initClientFunc(props); + itIf( + "test update state", + async () => { + const env = await TestEnvironment.getCurrent(); + const driver = DriverHelper.getDriverForDatabaseEngine(env.engine); + let initClientFunc; + switch (driver) { + case TestDriver.MYSQL: + initClientFunc = (options: any) => new TestAwsMySQLClient(options); + break; + case TestDriver.PG: + initClientFunc = (options: any) => new TestAwsPGClient(options); + break; + default: + throw new Error("invalid driver"); + } - try { - await client.connect(); - await newClient.connect(); - const targetClient = client.targetClient; - const newTargetClient = newClient.targetClient; - - expect(targetClient).not.toEqual(newTargetClient); - if (driver === TestDriver.MYSQL) { - await DriverHelper.executeQuery(env.engine, client, "CREATE DATABASE IF NOT EXISTS testSessionState"); - await client.setReadOnly(true); - await client.setCatalog("testSessionState"); - await client.setTransactionIsolation(TransactionIsolationLevel.TRANSACTION_SERIALIZABLE); - await client.setAutoCommit(false); - - // Assert new client's session states are using server default values. - let readOnly = await DriverHelper.executeQuery(env.engine, newClient, "SELECT @@SESSION.transaction_read_only AS readonly"); - let catalog = await DriverHelper.executeQuery(env.engine, newClient, "SELECT DATABASE() AS catalog"); - let autoCommit = await DriverHelper.executeQuery(env.engine, newClient, "SELECT @@SESSION.autocommit AS autocommit"); - let transactionIsolation = await DriverHelper.executeQuery(env.engine, newClient, "SELECT @@SESSION.transaction_isolation AS level"); - expect(readOnly[0][0].readonly).toEqual(0); - expect(catalog[0][0].catalog).toEqual(env.databaseInfo.defaultDbName); - expect(autoCommit[0][0].autocommit).toEqual(1); - expect(transactionIsolation[0][0].level).toEqual("REPEATABLE-READ"); - - await client.getPluginService().setCurrentClient(newClient.targetClient); - - expect(client.targetClient).not.toEqual(targetClient); - expect(client.targetClient).toEqual(newTargetClient); - - // Assert new client's session states are set. - readOnly = await DriverHelper.executeQuery(env.engine, newClient, "SELECT @@SESSION.transaction_read_only AS readonly"); - catalog = await DriverHelper.executeQuery(env.engine, newClient, "SELECT DATABASE() AS catalog"); - autoCommit = await DriverHelper.executeQuery(env.engine, newClient, "SELECT @@SESSION.autocommit AS autocommit"); - transactionIsolation = await DriverHelper.executeQuery(env.engine, newClient, "SELECT @@SESSION.transaction_isolation AS level"); - expect(readOnly[0][0].readonly).toEqual(1); - expect(catalog[0][0].catalog).toEqual("testSessionState"); - expect(autoCommit[0][0].autocommit).toEqual(0); - expect(transactionIsolation[0][0].level).toEqual("SERIALIZABLE"); - - await client.setReadOnly(false); - await client.setAutoCommit(true); - await DriverHelper.executeQuery(env.engine, client, "DROP DATABASE IF EXISTS testSessionState"); - } else if (driver === TestDriver.PG) { - // End any current transaction before we can create a new test database. - await DriverHelper.executeQuery(env.engine, client, "END TRANSACTION"); - await DriverHelper.executeQuery(env.engine, client, "DROP DATABASE IF EXISTS testSessionState"); - await DriverHelper.executeQuery(env.engine, client, "CREATE DATABASE testSessionState"); - await client.setReadOnly(true); - await client.setSchema("testSessionState"); - await client.setTransactionIsolation(TransactionIsolationLevel.TRANSACTION_SERIALIZABLE); - - // Assert new client's session states are using server default values. - let readOnly = await DriverHelper.executeQuery(env.engine, newClient, "SHOW transaction_read_only"); - let schema = await DriverHelper.executeQuery(env.engine, newClient, "SHOW search_path"); - let transactionIsolation = await DriverHelper.executeQuery(env.engine, newClient, "SHOW TRANSACTION ISOLATION LEVEL"); - expect(readOnly.rows[0]["transaction_read_only"]).toEqual("off"); - expect(schema.rows[0]["search_path"]).not.toEqual("testSessionState"); - expect(transactionIsolation.rows[0]["transaction_isolation"]).toEqual("read committed"); - - await client.getPluginService().setCurrentClient(newClient.targetClient); - expect(client.targetClient).not.toEqual(targetClient); - expect(client.targetClient).toEqual(newTargetClient); - - // Assert new client's session states are set. - readOnly = await DriverHelper.executeQuery(env.engine, newClient, "SHOW transaction_read_only"); - schema = await DriverHelper.executeQuery(env.engine, newClient, "SHOW search_path"); - transactionIsolation = await DriverHelper.executeQuery(env.engine, newClient, "SHOW TRANSACTION ISOLATION LEVEL"); - expect(readOnly.rows[0]["transaction_read_only"]).toEqual("on"); - expect(schema.rows[0]["search_path"]).toEqual("testsessionstate"); - expect(transactionIsolation.rows[0]["transaction_isolation"]).toEqual("serializable"); - - await client.setReadOnly(false); - await DriverHelper.executeQuery(env.engine, client, "DROP DATABASE IF EXISTS testSessionState"); + let props = { + user: env.databaseInfo.username, + host: env.databaseInfo.clusterEndpoint, + database: env.databaseInfo.defaultDbName, + password: env.databaseInfo.password, + port: env.databaseInfo.clusterEndpointPort + }; + props = DriverHelper.addDriverSpecificConfiguration(props, env.engine); + client = initClientFunc(props); + + const newClient = initClientFunc(props); + + try { + await client.connect(); + await newClient.connect(); + const targetClient = client.targetClient; + const newTargetClient = newClient.targetClient; + + expect(targetClient).not.toEqual(newTargetClient); + if (driver === TestDriver.MYSQL) { + await DriverHelper.executeQuery(env.engine, client, "CREATE DATABASE IF NOT EXISTS testSessionState"); + await client.setReadOnly(true); + await client.setCatalog("testSessionState"); + await client.setTransactionIsolation(TransactionIsolationLevel.TRANSACTION_SERIALIZABLE); + await client.setAutoCommit(false); + + // Assert new client's session states are using server default values. + let readOnly = await DriverHelper.executeQuery(env.engine, newClient, "SELECT @@SESSION.transaction_read_only AS readonly"); + let catalog = await DriverHelper.executeQuery(env.engine, newClient, "SELECT DATABASE() AS catalog"); + let autoCommit = await DriverHelper.executeQuery(env.engine, newClient, "SELECT @@SESSION.autocommit AS autocommit"); + let transactionIsolation = await DriverHelper.executeQuery(env.engine, newClient, "SELECT @@SESSION.transaction_isolation AS level"); + expect(readOnly[0][0].readonly).toEqual(0); + expect(catalog[0][0].catalog).toEqual(env.databaseInfo.defaultDbName); + expect(autoCommit[0][0].autocommit).toEqual(1); + expect(transactionIsolation[0][0].level).toEqual("REPEATABLE-READ"); + + await client.getPluginService().setCurrentClient(newClient.targetClient); + + expect(client.targetClient).not.toEqual(targetClient); + expect(client.targetClient).toEqual(newTargetClient); + + // Assert new client's session states are set. + readOnly = await DriverHelper.executeQuery(env.engine, newClient, "SELECT @@SESSION.transaction_read_only AS readonly"); + catalog = await DriverHelper.executeQuery(env.engine, newClient, "SELECT DATABASE() AS catalog"); + autoCommit = await DriverHelper.executeQuery(env.engine, newClient, "SELECT @@SESSION.autocommit AS autocommit"); + transactionIsolation = await DriverHelper.executeQuery(env.engine, newClient, "SELECT @@SESSION.transaction_isolation AS level"); + expect(readOnly[0][0].readonly).toEqual(1); + expect(catalog[0][0].catalog).toEqual("testSessionState"); + expect(autoCommit[0][0].autocommit).toEqual(0); + expect(transactionIsolation[0][0].level).toEqual("SERIALIZABLE"); + + await client.setReadOnly(false); + await client.setAutoCommit(true); + await DriverHelper.executeQuery(env.engine, client, "DROP DATABASE IF EXISTS testSessionState"); + } else if (driver === TestDriver.PG) { + // End any current transaction before we can create a new test database. + await DriverHelper.executeQuery(env.engine, client, "END TRANSACTION"); + await DriverHelper.executeQuery(env.engine, client, "DROP DATABASE IF EXISTS testSessionState"); + await DriverHelper.executeQuery(env.engine, client, "CREATE DATABASE testSessionState"); + await client.setReadOnly(true); + await client.setSchema("testSessionState"); + await client.setTransactionIsolation(TransactionIsolationLevel.TRANSACTION_SERIALIZABLE); + + // Assert new client's session states are using server default values. + let readOnly = await DriverHelper.executeQuery(env.engine, newClient, "SHOW transaction_read_only"); + let schema = await DriverHelper.executeQuery(env.engine, newClient, "SHOW search_path"); + let transactionIsolation = await DriverHelper.executeQuery(env.engine, newClient, "SHOW TRANSACTION ISOLATION LEVEL"); + expect(readOnly.rows[0]["transaction_read_only"]).toEqual("off"); + expect(schema.rows[0]["search_path"]).not.toEqual("testSessionState"); + expect(transactionIsolation.rows[0]["transaction_isolation"]).toEqual("read committed"); + + await client.getPluginService().setCurrentClient(newClient.targetClient); + expect(client.targetClient).not.toEqual(targetClient); + expect(client.targetClient).toEqual(newTargetClient); + + // Assert new client's session states are set. + readOnly = await DriverHelper.executeQuery(env.engine, newClient, "SHOW transaction_read_only"); + schema = await DriverHelper.executeQuery(env.engine, newClient, "SHOW search_path"); + transactionIsolation = await DriverHelper.executeQuery(env.engine, newClient, "SHOW TRANSACTION ISOLATION LEVEL"); + expect(readOnly.rows[0]["transaction_read_only"]).toEqual("on"); + expect(schema.rows[0]["search_path"]).toEqual("testsessionstate"); + expect(transactionIsolation.rows[0]["transaction_isolation"]).toEqual("serializable"); + + await client.setReadOnly(false); + await DriverHelper.executeQuery(env.engine, client, "DROP DATABASE IF EXISTS testSessionState"); + } + } catch (e) { + await client.end(); + await newClient.end(); + throw e; } - } catch (e) { - await client.end(); - await newClient.end(); - throw e; - } - }, 1320000); + }, + 1320000 + ); }); diff --git a/tests/integration/container/tests/utils/aurora_test_utility.ts b/tests/integration/container/tests/utils/aurora_test_utility.ts index b4cd0eac..c2101377 100644 --- a/tests/integration/container/tests/utils/aurora_test_utility.ts +++ b/tests/integration/container/tests/utils/aurora_test_utility.ts @@ -16,7 +16,6 @@ import { CreateDBInstanceCommand, - CreateDBInstanceCommandOutput, DBInstanceAlreadyExistsFault, DBInstanceNotFoundFault, DeleteDBInstanceCommand, @@ -135,7 +134,7 @@ export class AuroraTestUtility { return clusters[0]; } - async failoverClusterAndWaitUntilWriterChanged(initialWriter?: string, clusterId?: string) { + async failoverClusterAndWaitUntilWriterChanged(initialWriter?: string, clusterId?: string, targetWriterId?: string) { if (this.isNullOrUndefined(clusterId)) { clusterId = (await TestEnvironment.getCurrent()).info.auroraClusterName; } @@ -148,7 +147,7 @@ export class AuroraTestUtility { const clusterEndpoint = databaseInfo.clusterEndpoint; const initialClusterAddress = await dns.promises.lookup(clusterEndpoint); - await this.failoverCluster(clusterId); + await this.failoverClusterToTarget(clusterId, targetWriterId); let remainingAttempts: number = 5; while (!(await this.writerChanged(initialWriter, clusterId, 300))) { @@ -157,7 +156,7 @@ export class AuroraTestUtility { throw new Error("failover request unsuccessful"); } - await this.failoverCluster(clusterId); + await this.failoverClusterToTarget(clusterId, targetWriterId); } let clusterAddress: dns.LookupAddress = await dns.promises.lookup(clusterEndpoint); @@ -167,7 +166,7 @@ export class AuroraTestUtility { } } - async failoverCluster(clusterId?: string) { + async failoverClusterToTarget(clusterId?: string, targetInstanceId?: string): Promise { const info = (await TestEnvironment.getCurrent()).info; if (clusterId == null) { clusterId = info.auroraClusterName; @@ -175,10 +174,15 @@ export class AuroraTestUtility { await this.waitUntilClusterHasDesiredStatus(clusterId); - let remainingAttempts = 10; - const command = new FailoverDBClusterCommand({ + const input: any = { DBClusterIdentifier: clusterId - }); + }; + if (targetInstanceId) { + input.TargetDBInstanceIdentifier = targetInstanceId; + } + + let remainingAttempts = 10; + const command = new FailoverDBClusterCommand(input); const auroraUtility = new AuroraTestUtility(info.region); while (remainingAttempts-- > 0) { try { diff --git a/tests/integration/container/tests/utils/test_environment.ts b/tests/integration/container/tests/utils/test_environment.ts index a0c4243f..7d9d2718 100644 --- a/tests/integration/container/tests/utils/test_environment.ts +++ b/tests/integration/container/tests/utils/test_environment.ts @@ -388,6 +388,10 @@ export class TestEnvironment { return this.info.request.deployment; } + get auroraClusterName(): string { + return this.info.auroraClusterName; + } + private static createProxyUrl(host: string, port: number) { return `http://${host}:${port}`; } diff --git a/tests/unit/aurora_connection_tracker.test.ts b/tests/unit/aurora_connection_tracker.test.ts index a14d6cf1..fbbd83d1 100644 --- a/tests/unit/aurora_connection_tracker.test.ts +++ b/tests/unit/aurora_connection_tracker.test.ts @@ -29,6 +29,7 @@ import { ClientWrapper } from "../../common/lib/client_wrapper"; import { HostInfo } from "../../common/lib/host_info"; import { MySQLClientWrapper } from "../../common/lib/mysql_client_wrapper"; import { jest } from "@jest/globals"; +import { MySQL2DriverDialect } from "../../mysql/lib/dialect/mysql2_driver_dialect"; const props = new Map(); const SQL_ARGS = ["sql"]; @@ -49,7 +50,7 @@ const mockClient = mock(AwsClient); const mockHostInfo = mock(HostInfo); const mockClientInstance = instance(mockClient); -const mockClientWrapper: ClientWrapper = new MySQLClientWrapper(undefined, mockHostInfo, props); +const mockClientWrapper: ClientWrapper = new MySQLClientWrapper(undefined, mockHostInfo, props, new MySQL2DriverDialect()); mockClientInstance.targetClient = mockClientWrapper; @@ -90,7 +91,7 @@ describe("aurora connection tracker tests", () => { .withRole(HostRole.WRITER) .build(); new HostInfoBuilder({ hostAvailabilityStrategy: new SimpleHostAvailabilityStrategy() }).withHost("new-host").withRole(HostRole.WRITER).build(); - when(mockPluginService.getHosts()).thenReturn([originalHost]); + when(mockPluginService.getAllHosts()).thenReturn([originalHost]); const plugin = new AuroraConnectionTrackerPlugin(instance(mockPluginService), instance(mockRdsUtils), instance(mockTracker)); @@ -103,7 +104,7 @@ describe("aurora connection tracker tests", () => { const originalHost = new HostInfoBuilder({ hostAvailabilityStrategy: new SimpleHostAvailabilityStrategy() }).withHost("host").build(); const failoverTargetHost = new HostInfoBuilder({ hostAvailabilityStrategy: new SimpleHostAvailabilityStrategy() }).withHost("host2").build(); - when(mockPluginService.getHosts()).thenReturn([originalHost]).thenReturn([failoverTargetHost]); + when(mockPluginService.getAllHosts()).thenReturn([originalHost]).thenReturn([failoverTargetHost]); mockSqlFunc.mockResolvedValueOnce("1").mockRejectedValueOnce(expectedError); const plugin = new AuroraConnectionTrackerPlugin(instance(mockPluginService), instance(mockRdsUtils), instance(mockTracker)); diff --git a/tests/unit/aurora_initial_connection_strategy_plugin.test.ts b/tests/unit/aurora_initial_connection_strategy_plugin.test.ts index bc9457c2..f9dc43b7 100644 --- a/tests/unit/aurora_initial_connection_strategy_plugin.test.ts +++ b/tests/unit/aurora_initial_connection_strategy_plugin.test.ts @@ -30,6 +30,7 @@ import { AwsWrapperError } from "../../common/lib/utils/errors"; import { MySQLClientWrapper } from "../../common/lib/mysql_client_wrapper"; import { jest } from "@jest/globals"; import { PgClientWrapper } from "../../common/lib/pg_client_wrapper"; +import { MySQL2DriverDialect } from "../../mysql/lib/dialect/mysql2_driver_dialect"; const mockPluginService = mock(PluginService); const mockHostListProviderService = mock(); @@ -62,8 +63,8 @@ describe("Aurora initial connection strategy plugin", () => { plugin.initHostProvider(hostInfo, props, instance(mockHostListProviderService), mockFunc); WrapperProperties.OPEN_CONNECTION_RETRY_TIMEOUT_MS.set(props, 1000); - writerClient = new MySQLClientWrapper(undefined, writerHostInfo, new Map()); - readerClient = new MySQLClientWrapper(undefined, readerHostInfo, new Map()); + writerClient = new MySQLClientWrapper(undefined, writerHostInfo, new Map(), new MySQL2DriverDialect()); + readerClient = new MySQLClientWrapper(undefined, readerHostInfo, new Map(), new MySQL2DriverDialect()); }); afterEach(() => { @@ -100,13 +101,13 @@ describe("Aurora initial connection strategy plugin", () => { it("test writer - not found", async () => { when(mockRdsUtils.identifyRdsType(anything())).thenReturn(RdsUrlType.RDS_WRITER_CLUSTER); - when(mockPluginService.getHosts()).thenReturn([hostInfoBuilder.withRole(HostRole.READER).build()]); + when(mockPluginService.getAllHosts()).thenReturn([hostInfoBuilder.withRole(HostRole.READER).build()]); expect(await plugin.connect(hostInfo, props, true, mockFuncUndefined)).toBe(undefined); }); it("test writer - resolves to reader", async () => { when(mockRdsUtils.identifyRdsType(anything())).thenReturn(RdsUrlType.RDS_WRITER_CLUSTER); - when(mockPluginService.getHosts()).thenReturn([hostInfoBuilder.withRole(HostRole.WRITER).build()]); + when(mockPluginService.getAllHosts()).thenReturn([hostInfoBuilder.withRole(HostRole.WRITER).build()]); when(mockPluginService.connect(anything(), anything())).thenResolve(instance(readerClient)); expect(await plugin.connect(hostInfo, props, true, mockFunc)).toBe(undefined); @@ -114,7 +115,7 @@ describe("Aurora initial connection strategy plugin", () => { it("test writer - resolve to writer", async () => { when(mockRdsUtils.identifyRdsType(anything())).thenReturn(RdsUrlType.RDS_WRITER_CLUSTER); - when(mockPluginService.getHosts()).thenReturn([hostInfoBuilder.withRole(HostRole.WRITER).build()]); + when(mockPluginService.getAllHosts()).thenReturn([hostInfoBuilder.withRole(HostRole.WRITER).build()]); when(mockPluginService.getHostRole(writerClient)).thenReturn(Promise.resolve(HostRole.WRITER)); when(mockPluginService.connect(anything(), anything())).thenResolve(writerClient); @@ -152,7 +153,7 @@ describe("Aurora initial connection strategy plugin", () => { it("test reader - return writer", async () => { when(mockRdsUtils.identifyRdsType(anything())).thenReturn(RdsUrlType.RDS_READER_CLUSTER); - when(mockPluginService.getHosts()) + when(mockPluginService.getAllHosts()) .thenReturn([hostInfoBuilder.withRole(HostRole.READER).build()]) .thenReturn([hostInfoBuilder.withRole(HostRole.WRITER).build()]); when(mockPluginService.connect(anything(), anything())).thenResolve(writerClient); diff --git a/tests/unit/custom_endpoint_monitor_impl.test.ts b/tests/unit/custom_endpoint_monitor_impl.test.ts new file mode 100644 index 00000000..2aaa6203 --- /dev/null +++ b/tests/unit/custom_endpoint_monitor_impl.test.ts @@ -0,0 +1,149 @@ +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +import { anything, capture, instance, mock, spy, verify, when } from "ts-mockito"; +import { NullTelemetryFactory } from "../../common/lib/utils/telemetry/null_telemetry_factory"; +import { RDSClient } from "@aws-sdk/client-rds"; +import { PluginService } from "../../common/lib/plugin_service"; +import { CustomEndpointMonitorImpl } from "../../common/lib/plugins/custom_endpoint/custom_endpoint_monitor_impl"; +import { HostInfoBuilder } from "../../common/lib/host_info_builder"; +import { SimpleHostAvailabilityStrategy } from "../../common/lib/host_availability/simple_host_availability_strategy"; +import { sleep } from "../../common/lib/utils/utils"; +import { CustomEndpointInfo } from "../../common/lib/plugins/custom_endpoint/custom_endpoint_info"; +import { CustomEndpointRoleType } from "../../common/lib/plugins/custom_endpoint/custom_endpoint_role_type"; +import { MemberListType } from "../../common/lib/plugins/custom_endpoint/member_list_type"; + +const customEndpointUrl1 = "custom1.cluster-custom-XYZ.us-east-1.rds.amazonaws.com"; +const customEndpointUrl2 = "custom2.cluster-custom-XYZ.us-east-1.rds.amazonaws.com"; +const endpointId = "custom1"; +const clusterId = "cluster1"; +const staticMembersSet = new Set(["member1", "member2"]); + +const rdsSendResult1 = { + $metadata: { + httpStatusCode: 200, + requestId: "", + extendedRequestId: undefined, + cfId: undefined, + attempts: 1, + totalRetryDelay: 0 + }, + DBClusterEndpoints: [ + { + DBClusterEndpointIdentifier: endpointId, + DBClusterIdentifier: clusterId, + Endpoint: customEndpointUrl1, + EndpointType: "CUSTOM", + CustomEndpointType: "ANY", + StaticMembers: staticMembersSet, + ExcludedMembers: [], + DBClusterEndpointArn: "", + DBClusterEndpointResourceIdentifier: "", + Status: "available" + } + ] +}; +const rdsSendResult2 = { + $metadata: { + httpStatusCode: 200, + requestId: "", + extendedRequestId: undefined, + cfId: undefined, + attempts: 1, + totalRetryDelay: 0 + }, + DBClusterEndpoints: [ + { + DBClusterEndpointIdentifier: endpointId, + DBClusterIdentifier: clusterId, + Endpoint: customEndpointUrl1, + EndpointType: "CUSTOM", + CustomEndpointType: "ANY", + StaticMembers: staticMembersSet, + ExcludedMembers: [], + DBClusterEndpointArn: "", + DBClusterEndpointResourceIdentifier: "", + Status: "available" + }, + { + DBClusterEndpointIdentifier: endpointId, + DBClusterIdentifier: clusterId, + Endpoint: customEndpointUrl2, + EndpointType: "CUSTOM", + CustomEndpointType: "ANY", + StaticMembers: staticMembersSet, + ExcludedMembers: [], + DBClusterEndpointArn: "", + DBClusterEndpointResourceIdentifier: "", + Status: "available" + } + ] +}; + +const mockRdsClient = mock(RDSClient); +when(mockRdsClient.send(anything())).thenResolve(rdsSendResult2).thenResolve(rdsSendResult1); +const mockRdsClientFunc = () => instance(mockRdsClient); +const mockPluginService = mock(PluginService); +when(mockPluginService.getTelemetryFactory()).thenReturn(new NullTelemetryFactory()); + +const props = new Map(); +const host = new HostInfoBuilder({ + host: "custom.cluster-custom-XYZ.us-east-1.rds.amazonaws.com", + port: 1234, + hostAvailabilityStrategy: new SimpleHostAvailabilityStrategy() +}).build(); + +const expectedInfo = new CustomEndpointInfo( + endpointId, + clusterId, + customEndpointUrl1, + CustomEndpointRoleType.ANY, + staticMembersSet, + MemberListType.STATIC_LIST +); + +class TestCustomEndpointMonitorImpl extends CustomEndpointMonitorImpl { + static getCache() { + return TestCustomEndpointMonitorImpl.customEndpointInfoCache; + } + + getStop() { + return this.stop; + } +} + +describe("testCustomEndpoint", () => { + beforeEach(() => { + props.clear(); + }); + + it("testRun", async () => { + const monitor = new TestCustomEndpointMonitorImpl(instance(mockPluginService), host, endpointId, "us-east-1", 50, mockRdsClientFunc); + + // Wait for 2 run cycles. The first will return an unexpected number of endpoints in the API response, the second + // will return the expected number of endpoints (one). + await sleep(100); + expect(TestCustomEndpointMonitorImpl.getCache().get(host.host)).toStrictEqual(expectedInfo); + await monitor.close(); + + const captureResult = capture(mockPluginService.setAllowedAndBlockedHosts).last(); + expect(captureResult[0].getAllowedHostIds()).toStrictEqual(staticMembersSet); + expect(captureResult[0].getBlockedHostIds()).toBeNull(); + + expect(monitor.getStop()).toBe(true); + verify(mockRdsClient.destroy()).once(); + }, 100000); +}); diff --git a/tests/unit/custom_endpoint_plugin.test.ts b/tests/unit/custom_endpoint_plugin.test.ts new file mode 100644 index 00000000..54f2fbb2 --- /dev/null +++ b/tests/unit/custom_endpoint_plugin.test.ts @@ -0,0 +1,127 @@ +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +import { WrapperProperties } from "../../common/lib/wrapper_property"; +import { anything, instance, mock, spy, verify, when } from "ts-mockito"; +import { NullTelemetryFactory } from "../../common/lib/utils/telemetry/null_telemetry_factory"; +import { RDSClient } from "@aws-sdk/client-rds"; +import { CustomEndpointPlugin } from "../../common/lib/plugins/custom_endpoint/custom_endpoint_plugin"; +import { PluginService } from "../../common/lib/plugin_service"; +import { CustomEndpointMonitorImpl } from "../../common/lib/plugins/custom_endpoint/custom_endpoint_monitor_impl"; +import { HostInfoBuilder } from "../../common/lib/host_info_builder"; +import { SimpleHostAvailabilityStrategy } from "../../common/lib/host_availability/simple_host_availability_strategy"; +import { AwsWrapperError } from "../../common/lib/utils/errors"; + +const mockRdsClientFunc = () => instance(mock(RDSClient)); +const mockPluginService = mock(PluginService); +when(mockPluginService.getTelemetryFactory()).thenReturn(new NullTelemetryFactory()); +const mockMonitor = mock(CustomEndpointMonitorImpl); + +const props = new Map(); +const writerClusterHost = new HostInfoBuilder({ + host: "writer.cluster-XYZ.us-east-1.rds.amazonaws.com", + port: 1234, + hostAvailabilityStrategy: new SimpleHostAvailabilityStrategy() +}).build(); +const host = new HostInfoBuilder({ + host: "custom.cluster-custom-XYZ.us-east-1.rds.amazonaws.com", + port: 1234, + hostAvailabilityStrategy: new SimpleHostAvailabilityStrategy() +}).build(); + +let connectCounter = 0; +function mockConnectFunc(): Promise { + connectCounter++; + return Promise.resolve(); +} +let executeCounter = 0; +function mockExecuteFunc(): Promise { + executeCounter++; + return Promise.resolve(); +} + +function getPlugins() { + const plugin = new CustomEndpointPlugin(instance(mockPluginService), props, mockRdsClientFunc); + const spyPlugin = spy(plugin); + when(spyPlugin.createMonitorIfAbsent(anything())).thenReturn(instance(mockMonitor)); + return [plugin, spyPlugin]; +} + +class TestCustomEndpointPlugin extends CustomEndpointPlugin { + static getMonitors() { + return TestCustomEndpointPlugin.monitors; + } +} + +describe("testCustomEndpoint", () => { + beforeEach(() => { + connectCounter = 0; + executeCounter = 0; + props.clear(); + }); + + it("testConnect_monitorNotCreatedIfNotCustomEndpointHost", async () => { + const [plugin, spyPlugin] = getPlugins(); + await plugin.connect(writerClusterHost, props, true, mockConnectFunc); + + expect(connectCounter).toBe(1); + verify(spyPlugin.createMonitorIfAbsent(anything())).never(); + }); + + it("testConnect_monitorCreated", async () => { + when(mockMonitor.hasCustomEndpointInfo()).thenReturn(true); + const [plugin, spyPlugin] = getPlugins(); + + await plugin.connect(host, props, true, mockConnectFunc); + expect(connectCounter).toBe(1); + verify(spyPlugin.createMonitorIfAbsent(anything())).once(); + }); + + it("testConnect_timeoutWaitingForInfo", async () => { + props.set(WrapperProperties.WAIT_FOR_CUSTOM_ENDPOINT_INFO_TIMEOUT_MS.name, 1); + when(mockMonitor.hasCustomEndpointInfo()).thenReturn(false); + const [plugin, spyPlugin] = getPlugins(); + + await expect(plugin.connect(host, props, true, mockConnectFunc)).rejects.toThrow(AwsWrapperError); + expect(connectCounter).toBe(0); + verify(spyPlugin.createMonitorIfAbsent(anything())).once(); + }); + + it("testExecute_monitorNotCreatedIfNotCustomEndpointHost", async () => { + when(mockMonitor.hasCustomEndpointInfo()).thenReturn(false); + const [plugin, spyPlugin] = getPlugins(); + + await plugin.execute("execute", mockConnectFunc, []); + expect(connectCounter).toBe(1); + verify(spyPlugin.createMonitorIfAbsent(anything())).never(); + }); + + it("testExecute_monitorCreated", async () => { + when(mockMonitor.hasCustomEndpointInfo()).thenReturn(true); + const [plugin, spyPlugin] = getPlugins(); + + await plugin.connect(host, props, true, mockConnectFunc); + await plugin.execute("execute", mockExecuteFunc, []); + expect(executeCounter).toBe(1); + verify(spyPlugin.createMonitorIfAbsent(anything())).twice(); + }); + + it("testCloseMonitors", async () => { + TestCustomEndpointPlugin.getMonitors().computeIfAbsent("test-monitor", () => instance(mockMonitor), BigInt(30_000_000_000)); + TestCustomEndpointPlugin.closeMonitors(); + verify(mockMonitor.close()).atLeast(1); + }); +}); diff --git a/tests/unit/failover2_plugin.test.ts b/tests/unit/failover2_plugin.test.ts index 1ffa78ad..bfc8882b 100644 --- a/tests/unit/failover2_plugin.test.ts +++ b/tests/unit/failover2_plugin.test.ts @@ -133,6 +133,7 @@ describe("reader failover handler", () => { when(mockHostInfo.allAliases).thenReturn(new Set(["alias1", "alias2"])); when(mockHostInfo.getRawAvailability()).thenReturn(HostAvailability.AVAILABLE); when(mockPluginService.getHosts()).thenReturn(hosts); + when(mockPluginService.getAllHosts()).thenReturn(hosts); when(mockPluginService.forceMonitoringRefresh(true, anything())).thenResolve(true); when(mockPluginService.connect(mockHostInfo, anything())).thenReject(test); @@ -182,6 +183,7 @@ describe("reader failover handler", () => { when(mockHostInfo.allAliases).thenReturn(new Set(["alias1", "alias2"])); when(mockPluginService.getHosts()).thenReturn(hosts); + when(mockPluginService.getAllHosts()).thenReturn(hosts); when(mockPluginService.forceMonitoringRefresh(true, anything())).thenResolve(true); when(mockPluginService.connect(mockHostInfo, anything())).thenResolve(null); @@ -210,6 +212,7 @@ describe("reader failover handler", () => { when(mockHostInfo.allAliases).thenReturn(new Set(["alias1", "alias2"])); when(mockPluginService.getHosts()).thenReturn(hosts); + when(mockPluginService.getAllHosts()).thenReturn(hosts); when(mockPluginService.forceMonitoringRefresh(true, anything())).thenResolve(true); when(mockPluginService.connect(hostInfo, anything())).thenResolve(mockClientWrapper); when(mockPluginService.getHostRole(mockClientWrapper)).thenResolve(HostRole.WRITER); diff --git a/tests/unit/failover_plugin.test.ts b/tests/unit/failover_plugin.test.ts index 4ab1d414..32a57b81 100644 --- a/tests/unit/failover_plugin.test.ts +++ b/tests/unit/failover_plugin.test.ts @@ -38,6 +38,7 @@ import { Messages } from "../../common/lib/utils/messages"; import { HostChangeOptions } from "../../common/lib/host_change_options"; import { NullTelemetryFactory } from "../../common/lib/utils/telemetry/null_telemetry_factory"; import { MySQLClientWrapper } from "../../common/lib/mysql_client_wrapper"; +import { MySQL2DriverDialect } from "../../mysql/lib/dialect/mysql2_driver_dialect"; const builder = new HostInfoBuilder({ hostAvailabilityStrategy: new SimpleHostAvailabilityStrategy() }); @@ -53,7 +54,7 @@ const mockWriterFailoverHandler: ClusterAwareWriterFailoverHandler = mock(Cluste const mockReaderResult: ReaderFailoverResult = mock(ReaderFailoverResult); const mockWriterResult: WriterFailoverResult = mock(WriterFailoverResult); -const mockClientWrapper = new MySQLClientWrapper(undefined, mockHostInfo, new Map()); +const mockClientWrapper = new MySQLClientWrapper(undefined, mockHostInfo, new Map(), new MySQL2DriverDialect()); const properties: Map = new Map(); @@ -148,7 +149,7 @@ describe("reader failover handler", () => { verify(mockPluginService.refreshHostList()).never(); // Test with hosts - when(mockPluginService.getHosts()).thenReturn([builder.withHost("host").build()]); + when(mockPluginService.getAllHosts()).thenReturn([builder.withHost("host").build()]); // Test updateTopology with forceUpdate == true await plugin.updateTopology(true); @@ -255,7 +256,7 @@ describe("reader failover handler", () => { const test = new AwsWrapperError("test"); when(mockHostInfo.allAliases).thenReturn(new Set(["alias1", "alias2"])); - when(mockPluginService.getHosts()).thenReturn(hosts); + when(mockPluginService.getAllHosts()).thenReturn(hosts); when(mockWriterResult.error).thenReturn(test); when(mockWriterFailoverHandler.failover(anything())).thenResolve(instance(mockWriterResult)); @@ -276,7 +277,7 @@ describe("reader failover handler", () => { const hosts = [hostInfo]; when(mockHostInfo.allAliases).thenReturn(new Set(["alias1", "alias2"])); - when(mockPluginService.getHosts()).thenReturn(hosts); + when(mockPluginService.getAllHosts()).thenReturn(hosts); when(mockWriterResult.isConnected).thenReturn(false); when(mockWriterFailoverHandler.failover(anything())).thenResolve(instance(mockWriterResult)); @@ -306,7 +307,7 @@ describe("reader failover handler", () => { const hosts = [hostInfo]; when(mockHostInfo.allAliases).thenReturn(new Set(["alias1", "alias2"])); - when(mockPluginService.getHosts()).thenReturn(hosts); + when(mockPluginService.getAllHosts()).thenReturn(hosts); when(mockWriterResult.isConnected).thenReturn(false); when(mockWriterResult.topology).thenReturn(hosts); when(mockWriterFailoverHandler.failover(anything())).thenResolve(instance(mockWriterResult)); diff --git a/tests/unit/plugin_service.test.ts b/tests/unit/plugin_service.test.ts new file mode 100644 index 00000000..6b878cc4 --- /dev/null +++ b/tests/unit/plugin_service.test.ts @@ -0,0 +1,111 @@ +/* + Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"). + You may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +*/ + +import { PluginService } from "../../common/lib/plugin_service"; +import { PluginServiceManagerContainer } from "../../common/lib/plugin_service_manager_container"; +import { mock } from "ts-mockito"; +import { AwsClient } from "../../common/lib/aws_client"; +import { DatabaseDialect, DatabaseType } from "../../common/lib/database_dialect/database_dialect"; +import { DatabaseDialectCodes } from "../../common/lib/database_dialect/database_dialect_codes"; +import { MySQLDatabaseDialect } from "../../mysql/lib/dialect/mysql_database_dialect"; +import { RdsMySQLDatabaseDialect } from "../../mysql/lib/dialect/rds_mysql_database_dialect"; +import { AuroraMySQLDatabaseDialect } from "../../mysql/lib/dialect/aurora_mysql_database_dialect"; +import { RdsMultiAZMySQLDatabaseDialect } from "../../mysql/lib/dialect/rds_multi_az_mysql_database_dialect"; +import { MySQL2DriverDialect } from "../../mysql/lib/dialect/mysql2_driver_dialect"; +import { AllowedAndBlockedHosts } from "../../common/lib/AllowedAndBlockedHosts"; +import { HostInfoBuilder } from "../../common/lib/host_info_builder"; +import { SimpleHostAvailabilityStrategy } from "../../common/lib/host_availability/simple_host_availability_strategy"; +import { HostInfo } from "../../common/lib/host_info"; + +function createHost(host: string) { + return new HostInfoBuilder({ + host: host, + hostId: host, + port: 1234, + hostAvailabilityStrategy: new SimpleHostAvailabilityStrategy() + }).build(); +} + +const host1 = createHost("host-1"); +const host2 = createHost("host-2"); +const host3 = createHost("host-3"); +const host4 = createHost("host-4"); +const allHosts = [host1, host2, host3, host4]; + +class TestPluginService extends PluginService { + setHosts(hosts: HostInfo[]) { + this.hosts = hosts; + } +} + +const knownDialectsByCode: Map = new Map([ + [DatabaseDialectCodes.MYSQL, new MySQLDatabaseDialect()], + [DatabaseDialectCodes.RDS_MYSQL, new RdsMySQLDatabaseDialect()], + [DatabaseDialectCodes.AURORA_MYSQL, new AuroraMySQLDatabaseDialect()], + [DatabaseDialectCodes.RDS_MULTI_AZ_MYSQL, new RdsMultiAZMySQLDatabaseDialect()] +]); + +const mockAwsClient: AwsClient = mock(AwsClient); +let pluginService: TestPluginService; + +describe("testCustomEndpoint", () => { + beforeEach(() => { + pluginService = new TestPluginService( + new PluginServiceManagerContainer(), + mockAwsClient, + DatabaseType.MYSQL, + knownDialectsByCode, + new Map(), + new MySQL2DriverDialect() + ); + }); + + it("test get hosts - blocked hosts empty", async () => { + pluginService.setHosts(allHosts); + const allowedHosts = new Set(["host-1", "host-2"]); + const blockedHosts = new Set(); + const allowedAndBlockedHosts = new AllowedAndBlockedHosts(allowedHosts, blockedHosts); + pluginService.setAllowedAndBlockedHosts(allowedAndBlockedHosts); + const hosts = pluginService.getHosts(); + expect(hosts.length).toBe(2); + expect(hosts.includes(host1)).toBeTruthy(); + expect(hosts.includes(host2)).toBeTruthy(); + }); + + it("test get hosts - allowed hosts empty", async () => { + pluginService.setHosts(allHosts); + const allowedHosts = new Set(); + const blockedHosts = new Set(["host-1", "host-2"]); + const allowedAndBlockedHosts = new AllowedAndBlockedHosts(allowedHosts, blockedHosts); + pluginService.setAllowedAndBlockedHosts(allowedAndBlockedHosts); + const hosts = pluginService.getHosts(); + expect(hosts.length).toBe(2); + expect(hosts.includes(host3)).toBeTruthy(); + expect(hosts.includes(host4)).toBeTruthy(); + }); + + it("test get hosts - allowed and blocked hosts not empty", async () => { + pluginService.setHosts(allHosts); + const allowedHosts = new Set(["host-1", "host-2"]); + const blockedHosts = new Set(["host-3", "host-4"]); + const allowedAndBlockedHosts = new AllowedAndBlockedHosts(allowedHosts, blockedHosts); + pluginService.setAllowedAndBlockedHosts(allowedAndBlockedHosts); + const hosts = pluginService.getHosts(); + expect(hosts.length).toBe(2); + expect(hosts.includes(host1)).toBeTruthy(); + expect(hosts.includes(host2)).toBeTruthy(); + }); +}); diff --git a/tests/unit/read_write_splitting.test.ts b/tests/unit/read_write_splitting.test.ts index 2e86031d..27ffa2b1 100644 --- a/tests/unit/read_write_splitting.test.ts +++ b/tests/unit/read_write_splitting.test.ts @@ -83,6 +83,7 @@ const mockExecuteFuncThrowsFailoverSuccessError = jest.fn(() => { describe("reader write splitting test", () => { beforeEach(() => { when(mockPluginService.getHostListProvider()).thenReturn(instance(mockHostListProvider)); + when(mockPluginService.getAllHosts()).thenReturn(defaultHosts); when(mockPluginService.getHosts()).thenReturn(defaultHosts); when(mockPluginService.isInTransaction()).thenReturn(false); when(mockPluginService.getDialect()).thenReturn(mockDialect); @@ -105,7 +106,7 @@ describe("reader write splitting test", () => { it("test set read only true", async () => { const mockPluginServiceInstance = instance(mockPluginService); - when(mockPluginService.getHosts()).thenReturn(singleReaderTopology); + when(mockPluginService.getAllHosts()).thenReturn(singleReaderTopology); when(mockPluginService.getHostInfoByStrategy(anything(), anything())).thenReturn(readerHost1); when(mockPluginService.getCurrentClient()).thenReturn(instance(mockWriterClient)); when(await mockWriterClient.isValid()).thenReturn(true); @@ -129,7 +130,7 @@ describe("reader write splitting test", () => { it("test set read only false", async () => { const mockPluginServiceInstance = instance(mockPluginService); - when(mockPluginService.getHosts()).thenReturn(singleReaderTopology); + when(mockPluginService.getAllHosts()).thenReturn(singleReaderTopology); when(mockPluginService.getHostInfoByStrategy(anything(), anything())).thenReturn(writerHost); when(mockPluginService.getCurrentClient()).thenReturn(instance(mockReaderClient)); when(await mockReaderClient.isValid()).thenReturn(true); @@ -153,7 +154,7 @@ describe("reader write splitting test", () => { const mockPluginServiceInstance = instance(mockPluginService); const mockHostListProviderServiceInstance = instance(mockHostListProviderService); - when(mockPluginService.getHosts()).thenReturn(singleReaderTopology); + when(mockPluginService.getAllHosts()).thenReturn(singleReaderTopology); when(mockPluginService.getHostInfoByStrategy(anything(), anything())).thenReturn(readerHost1); when(mockPluginService.getCurrentClient()).thenReturn(instance(mockReaderClient)); when(await mockReaderClient.isValid()).thenReturn(true); @@ -177,7 +178,7 @@ describe("reader write splitting test", () => { it("test set read only false already on reader", async () => { const mockPluginServiceInstance = instance(mockPluginService); const mockHostListProviderServiceInstance = instance(mockHostListProviderService); - when(mockPluginService.getHosts()).thenReturn(singleReaderTopology); + when(mockPluginService.getAllHosts()).thenReturn(singleReaderTopology); when(mockPluginService.getHostInfoByStrategy(anything(), anything())).thenReturn(readerHost1); when(mockPluginService.getCurrentClient()).thenReturn(instance(mockWriterClient)); when(await mockWriterClient.isValid()).thenReturn(true); @@ -254,7 +255,7 @@ describe("reader write splitting test", () => { it("test set read only false writer connection failed", async () => { const mockPluginServiceInstance = instance(mockPluginService); - when(mockPluginService.getHosts()).thenReturn(singleReaderTopology); + when(mockPluginService.getAllHosts()).thenReturn(singleReaderTopology); when(mockPluginService.getHostInfoByStrategy(anything(), anything())).thenReturn(readerHost1); when(mockPluginService.getCurrentClient()).thenReturn(instance(mockReaderClient)); when(mockPluginService.getCurrentHostInfo()).thenReturn(readerHost1); diff --git a/tests/unit/stale_dns_helper.test.ts b/tests/unit/stale_dns_helper.test.ts index dc46656a..462bf968 100644 --- a/tests/unit/stale_dns_helper.test.ts +++ b/tests/unit/stale_dns_helper.test.ts @@ -113,6 +113,7 @@ describe("test_stale_dns_helper", () => { const mockHostListProviderServiceInstance = instance(mockHostListProviderService); when(mockPluginService.getHosts()).thenReturn(readerHostList); + when(mockPluginService.getAllHosts()).thenReturn(readerHostList); when(mockPluginService.getCurrentHostInfo()).thenReturn(readerA); @@ -140,6 +141,7 @@ describe("test_stale_dns_helper", () => { const mockHostListProviderServiceInstance = instance(mockHostListProviderService); when(mockPluginService.getHosts()).thenReturn(clusterHostList); + when(mockPluginService.getAllHosts()).thenReturn(clusterHostList); const lookupAddress = { address: "5.5.5.5", family: 0 }; when(target.lookupResult(anything())).thenResolve(lookupAddress); @@ -161,6 +163,7 @@ describe("test_stale_dns_helper", () => { const target: StaleDnsHelper = spy(new StaleDnsHelper(instance(mockPluginService))); const targetInstance = instance(target); when(mockPluginService.getHosts()).thenReturn(instanceHostList); + when(mockPluginService.getAllHosts()).thenReturn(instanceHostList); const mockHostListProviderServiceInstance = instance(mockHostListProviderService); @@ -185,6 +188,7 @@ describe("test_stale_dns_helper", () => { const target: StaleDnsHelper = spy(new StaleDnsHelper(instance(mockPluginService))); const targetInstance = instance(target); when(mockPluginService.getHosts()).thenReturn(readerHostList); + when(mockPluginService.getAllHosts()).thenReturn(readerHostList); const mockHostListProviderServiceInstance = instance(mockHostListProviderService); const firstCall = { address: "5.5.5.5", family: 0 }; @@ -209,6 +213,7 @@ describe("test_stale_dns_helper", () => { const target: StaleDnsHelper = spy(new StaleDnsHelper(instance(mockPluginService))); const targetInstance = instance(target); when(mockPluginService.getHosts()).thenReturn(instanceHostList); + when(mockPluginService.getAllHosts()).thenReturn(instanceHostList); const mockHostListProviderServiceInstance = instance(mockHostListProviderService); const firstCall = { address: "5.5.5.5", family: 0 }; @@ -234,6 +239,7 @@ describe("test_stale_dns_helper", () => { const targetInstance = instance(target); when(mockPluginService.getHosts()).thenReturn(clusterHostList); + when(mockPluginService.getAllHosts()).thenReturn(clusterHostList); const mockHostListProviderServiceInstance = instance(mockHostListProviderService); targetInstance["writerHostInfo"] = writerCluster; @@ -260,6 +266,7 @@ describe("test_stale_dns_helper", () => { const targetInstance = instance(target); when(mockPluginService.getHosts()).thenReturn(clusterHostList); + when(mockPluginService.getAllHosts()).thenReturn(clusterHostList); const mockHostListProviderServiceInstance = instance(mockHostListProviderService); targetInstance["writerHostInfo"] = writerCluster; when(mockHostListProviderService.getInitialConnectionHostInfo()).thenReturn(writerCluster); diff --git a/tests/unit/writer_failover_handler.test.ts b/tests/unit/writer_failover_handler.test.ts index 54bbe12e..740fb884 100644 --- a/tests/unit/writer_failover_handler.test.ts +++ b/tests/unit/writer_failover_handler.test.ts @@ -70,7 +70,7 @@ describe("writer failover handler", () => { when(mockPluginService.forceConnect(writer, anything())).thenResolve(mockClientWrapper); when(mockPluginService.forceConnect(readerA, anything())).thenThrow(new AwsWrapperError()); when(mockPluginService.forceConnect(readerB, anything())).thenThrow(new AwsWrapperError()); - when(mockPluginService.getHosts()).thenReturn(topology); + when(mockPluginService.getAllHosts()).thenReturn(topology); when(mockReaderFailover.getReaderConnection(anything())).thenThrow(new AwsWrapperError()); const mockReaderFailoverInstance = instance(mockReaderFailover); const mockPluginServiceInstance = instance(mockPluginService); @@ -90,7 +90,7 @@ describe("writer failover handler", () => { when(mockPluginService.forceConnect(writer, anything())).thenResolve(mockClientWrapper); when(mockPluginService.forceConnect(readerA, anything())).thenResolve(mockClientWrapperB); when(mockPluginService.forceConnect(readerB, anything())).thenThrow(new AwsWrapperError()); - when(mockPluginService.getHosts()).thenReturn(topology).thenReturn(newTopology); + when(mockPluginService.getAllHosts()).thenReturn(topology).thenReturn(newTopology); when(mockReaderFailover.getReaderConnection(anything())).thenCall(async () => { await new Promise((resolve, reject) => { timeoutId = setTimeout(resolve, 5000); @@ -122,7 +122,7 @@ describe("writer failover handler", () => { when(mockPluginService.getCurrentClient()).thenReturn(mockClientInstance); when(mockPluginService.forceConnect(readerA, anything())).thenResolve(mockClientWrapperB); when(mockPluginService.forceConnect(readerB, anything())).thenThrow(new AwsWrapperError()); - when(mockPluginService.getHosts()).thenReturn(topology); + when(mockPluginService.getAllHosts()).thenReturn(topology); when(mockReaderFailover.getReaderConnection(anything())).thenResolve(new ReaderFailoverResult(mockClientWrapper, readerA, true)); const mockReaderFailoverInstance = instance(mockReaderFailover); const mockPluginServiceInstance = instance(mockPluginService); @@ -150,7 +150,7 @@ describe("writer failover handler", () => { when(mockPluginService.forceConnect(readerA, anything())).thenResolve(mockClientWrapperB); when(mockPluginService.forceConnect(readerB, anything())).thenThrow(new AwsWrapperError()); when(mockPluginService.getCurrentClient()).thenReturn(mockClientInstance); - when(mockPluginService.getHosts()).thenReturn(newTopology); + when(mockPluginService.getAllHosts()).thenReturn(newTopology); when(mockReaderFailover.getReaderConnection(anything())).thenResolve(new ReaderFailoverResult(mockClientWrapper, readerA, true)); const mockReaderFailoverInstance = instance(mockReaderFailover); const mockPluginServiceInstance = instance(mockPluginService); @@ -186,7 +186,7 @@ describe("writer failover handler", () => { const newTopology = [newWriterHost, writer, readerA, readerB]; when(mockPluginService.getCurrentClient()).thenReturn(mockClientInstance); - when(mockPluginService.getHosts()).thenReturn(newTopology); + when(mockPluginService.getAllHosts()).thenReturn(newTopology); when(mockReaderFailover.getReaderConnection(anything())).thenResolve(new ReaderFailoverResult(mockClientWrapper, readerA, true)); const mockReaderFailoverInstance = instance(mockReaderFailover); const mockPluginServiceInstance = instance(mockPluginService); @@ -222,7 +222,7 @@ describe("writer failover handler", () => { }); when(mockPluginService.getCurrentClient()).thenReturn(mockClientInstance); - when(mockPluginService.getHosts()).thenReturn(newTopology); + when(mockPluginService.getAllHosts()).thenReturn(newTopology); when(mockReaderFailover.getReaderConnection(anything())).thenResolve(new ReaderFailoverResult(mockClientWrapper, readerA, true)); const mockReaderFailoverInstance = instance(mockReaderFailover); const mockPluginServiceInstance = instance(mockPluginService); @@ -247,7 +247,7 @@ describe("writer failover handler", () => { when(mockPluginService.forceConnect(writer, anything())).thenThrow(error); when(mockPluginService.forceConnect(newWriterHost, anything())).thenThrow(error); when(mockPluginService.isNetworkError(error)).thenReturn(true); - when(mockPluginService.getHosts()).thenReturn(newTopology); + when(mockPluginService.getAllHosts()).thenReturn(newTopology); when(mockPluginService.getCurrentClient()).thenReturn(mockClientInstance); when(mockReaderFailover.getReaderConnection(anything())).thenResolve(new ReaderFailoverResult(mockClientWrapper, readerA, true)); const mockReaderFailoverInstance = instance(mockReaderFailover);