Skip to content

Commit

Permalink
feat: custom endpoint plugin
Browse files Browse the repository at this point in the history
  • Loading branch information
crystall-bitquill committed Jan 14, 2025
1 parent aa04779 commit 67a7625
Show file tree
Hide file tree
Showing 41 changed files with 1,611 additions and 64 deletions.
33 changes: 33 additions & 0 deletions common/lib/AllowedAndBlockedHosts.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License").
You may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

export class AllowedAndBlockedHosts {
private readonly allowedHostIds: Set<string>;
private readonly blockedHostIds: Set<string>;

constructor(allowedHostIds: Set<string>, blockedHostIds: Set<string>) {
this.allowedHostIds = allowedHostIds;
this.blockedHostIds = blockedHostIds;
}

getAllowedHostIds() {
return this.allowedHostIds;
}

getBlockedHostIds() {
return this.blockedHostIds;
}
}
7 changes: 3 additions & 4 deletions common/lib/authentication/iam_authentication_plugin.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,23 @@
*/

import { PluginService } from "../plugin_service";
import { RdsUtils } from "../utils/rds_utils";
import { Messages } from "../utils/messages";
import { logger } from "../../logutils";
import { AwsWrapperError } from "../utils/errors";
import { HostInfo } from "../host_info";
import { AwsCredentialsManager } from "./aws_credentials_manager";
import { AbstractConnectionPlugin } from "../abstract_connection_plugin";
import { WrapperProperties } from "../wrapper_property";
import { WrapperProperties, WrapperProperty } from "../wrapper_property";
import { IamAuthUtils, TokenInfo } from "../utils/iam_auth_utils";
import { ClientWrapper } from "../client_wrapper";
import { RegionUtils } from "../utils/region_utils";

export class IamAuthenticationPlugin extends AbstractConnectionPlugin {
private static readonly SUBSCRIBED_METHODS = new Set<string>(["connect", "forceConnect"]);
protected static readonly tokenCache = new Map<string, TokenInfo>();
private readonly telemetryFactory;
private readonly fetchTokenCounter;
private pluginService: PluginService;
rdsUtil: RdsUtils = new RdsUtils();

constructor(pluginService: PluginService) {
super();
Expand Down Expand Up @@ -75,7 +74,7 @@ export class IamAuthenticationPlugin extends AbstractConnectionPlugin {
}

const host = IamAuthUtils.getIamHost(props, hostInfo);
const region: string = IamAuthUtils.getRdsRegion(host, this.rdsUtil, props);
const region: string = RegionUtils.getRegion(props, WrapperProperties.IAM_REGION.name, host);
const port = IamAuthUtils.getIamPort(props, hostInfo, this.pluginService.getCurrentClient().defaultPort);
const tokenExpirationSec = WrapperProperties.IAM_TOKEN_EXPIRATION.get(props);
if (tokenExpirationSec < 0) {
Expand Down
2 changes: 2 additions & 0 deletions common/lib/connection_plugin_chain_builder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ import { DeveloperConnectionPluginFactory } from "./plugins/dev/developer_connec
import { ConnectionPluginFactory } from "./plugin_factory";
import { LimitlessConnectionPluginFactory } from "./plugins/limitless/limitless_connection_plugin_factory";
import { FastestResponseStrategyPluginFactory } from "./plugins/strategy/fastest_response/fastest_respose_strategy_plugin_factory";
import { CustomEndpointPluginFactory } from "./plugins/custom_endpoint/custom_endpoint_plugin_factory";
import { ConfigurationProfile } from "./profile/configuration_profile";

/*
Expand All @@ -53,6 +54,7 @@ export class ConnectionPluginChainBuilder {
static readonly WEIGHT_RELATIVE_TO_PRIOR_PLUGIN = -1;

static readonly PLUGIN_FACTORIES = new Map<string, PluginFactoryInfo>([
["customEndpoint", { factory: CustomEndpointPluginFactory, weight: 380 }],
["initialConnection", { factory: AuroraInitialConnectionStrategyFactory, weight: 390 }],
["auroraConnectionTracker", { factory: AuroraConnectionTrackerPluginFactory, weight: 400 }],
["staleDns", { factory: StaleDnsPluginFactory, weight: 500 }],
Expand Down
3 changes: 3 additions & 0 deletions common/lib/host_list_provider_service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import { DatabaseDialect } from "./database_dialect/database_dialect";
import { HostInfoBuilder } from "./host_info_builder";
import { ConnectionUrlParser } from "./utils/connection_url_parser";
import { TelemetryFactory } from "./utils/telemetry/telemetry_factory";
import { AllowedAndBlockedHosts } from "./AllowedAndBlockedHosts";

export interface HostListProviderService {
getHostListProvider(): HostListProvider | null;
Expand Down Expand Up @@ -50,4 +51,6 @@ export interface HostListProviderService {
isClientValid(targetClient: any): Promise<boolean>;

getTelemetryFactory(): TelemetryFactory;

setAllowedAndBlockedHosts(allowedAndBlockedHosts: AllowedAndBlockedHosts): void;
}
76 changes: 70 additions & 6 deletions common/lib/plugin_service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -43,10 +43,10 @@ import { DatabaseDialectCodes } from "./database_dialect/database_dialect_codes"
import { getWriter, logTopology } from "./utils/utils";
import { TelemetryFactory } from "./utils/telemetry/telemetry_factory";
import { DriverDialect } from "./driver_dialect/driver_dialect";
import { ConfigurationProfile } from "./profile/configuration_profile";
import { SessionState } from "./session_state";
import { AllowedAndBlockedHosts } from "./AllowedAndBlockedHosts";

export class PluginService implements ErrorHandler, HostListProviderService {
private static readonly DEFAULT_HOST_AVAILABILITY_CACHE_EXPIRE_NANO = 5 * 60_000_000_000; // 5 minutes
private readonly _currentClient: AwsClient;
private _currentHostInfo?: HostInfo;
private _hostListProvider?: HostListProvider;
Expand All @@ -61,6 +61,7 @@ export class PluginService implements ErrorHandler, HostListProviderService {
protected readonly sessionStateService: SessionStateService;
protected static readonly hostAvailabilityExpiringCache: CacheMap<string, HostAvailability> = new CacheMap<string, HostAvailability>();
readonly props: Map<string, any>;
private allowedAndBlockedHosts: AllowedAndBlockedHosts;

constructor(
container: PluginServiceManagerContainer,
Expand Down Expand Up @@ -116,17 +117,29 @@ export class PluginService implements ErrorHandler, HostListProviderService {
this._currentHostInfo = this._initialConnectionHostInfo;

if (!this._currentHostInfo) {
if (this.getHosts().length === 0) {
if (this.getAllHosts().length === 0) {
throw new AwsWrapperError(Messages.get("PluginService.hostListEmpty"));
}

const writerHost = getWriter(this.getHosts());
const writerHost = getWriter(this.getAllHosts());
if (!this.hosts.includes(writerHost)) {
throw new AwsWrapperError(
Messages.get(
"PluginServiceImpl.currentHostNotAllowed",
this._currentHostInfo ? "<null>" : this._currentHostInfo.host,
logTopology(this.hosts, "")
)
);
}

if (writerHost) {
this._currentHostInfo = writerHost;
} else {
this._currentHostInfo = this.getHosts()[0];
}
}

logger.debug(`Set current host to: ${this._currentHostInfo.host}`);
}

return this._currentHostInfo;
Expand Down Expand Up @@ -260,11 +273,58 @@ export class PluginService implements ErrorHandler, HostListProviderService {
}
}

getHosts(): HostInfo[] {
getAllHosts(): HostInfo[] {
return this.hosts;
}

setAvailability(hostAliases: Set<string>, availability: HostAvailability) {}
getHosts(): HostInfo[] {
const hostPermissions = this.allowedAndBlockedHosts;
if (!hostPermissions) {
return this.hosts;
}

let hosts = this.hosts;
const allowedHostIds = hostPermissions.getAllowedHostIds();
const blockedHostIds = hostPermissions.getBlockedHostIds();

if (allowedHostIds && allowedHostIds.size > 0) {
hosts = hosts.filter((host: HostInfo) => allowedHostIds.has(host.hostId));
}

if (blockedHostIds && blockedHostIds.size > 0) {
hosts = hosts.filter((host: HostInfo) => blockedHostIds.has(host.hostId));
}

return hosts;
}

setAvailability(hostAliases: Set<string>, availability: HostAvailability) {
if (hostAliases.size === 0) {
return;
}

const hostsToChange = [...new Set(this.getAllHosts().filter((host: HostInfo) => hostAliases.has(host.asAlias) || host.aliases))];

if (hostsToChange.length === 0) {
logger.debug(Messages.get("PluginServiceImpl.hostsChangelistEmpty"));
return;
}

const changes = new Map<string, Set<HostChangeOptions>>();
for (const host of hostsToChange) {
const currentAvailability = host.getAvailability();
PluginService.hostAvailabilityExpiringCache.put(host.url, availability, PluginService.DEFAULT_HOST_AVAILABILITY_CACHE_EXPIRE_NANO);
if (currentAvailability !== availability) {
let hostChanges = new Set<HostChangeOptions>();
if (availability === HostAvailability.AVAILABLE) {
hostChanges = new Set([HostChangeOptions.WENT_UP, HostChangeOptions.HOST_CHANGED]);
} else {
hostChanges = new Set([HostChangeOptions.WENT_DOWN, HostChangeOptions.HOST_CHANGED]);
}
changes.set(host.url, hostChanges);
}
}
}

updateConfigWithProperties(props: Map<string, any>) {
this._currentClient.config = Object.fromEntries(props.entries());
Expand Down Expand Up @@ -501,4 +561,8 @@ export class PluginService implements ErrorHandler, HostListProviderService {
attachNoOpErrorListener(clientWrapper: ClientWrapper | undefined): void {
this.getDialect().getErrorHandler().attachNoOpErrorListener(clientWrapper);
}

setAllowedAndBlockedHosts(allowedAndBlockedHosts: AllowedAndBlockedHosts) {
this.allowedAndBlockedHosts = allowedAndBlockedHosts;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ export class AuroraInitialConnectionStrategyPlugin extends AbstractConnectionPlu
}

private getWriter(): HostInfo | null {
return this.pluginService.getHosts().find((x) => x.role === HostRole.WRITER) ?? null;
return this.pluginService.getAllHosts().find((x) => x.role === HostRole.WRITER) ?? null;
}

private getReader(props: Map<string, any>): HostInfo | undefined {
Expand All @@ -278,6 +278,6 @@ export class AuroraInitialConnectionStrategyPlugin extends AbstractConnectionPlu
}

private hasNoReaders(): boolean {
return this.pluginService.getHosts().find((x) => x.role === HostRole.READER) !== undefined;
return this.pluginService.getAllHosts().find((x) => x.role === HostRole.READER) !== undefined;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ export class AuroraConnectionTrackerPlugin extends AbstractConnectionPlugin impl
}

private async checkWriterChanged(): Promise<void> {
const hostInfoAfterFailover = this.getWriter(this.pluginService.getHosts());
const hostInfoAfterFailover = this.getWriter(this.pluginService.getAllHosts());
if (this.currentWriter === null) {
this.currentWriter = hostInfoAfterFailover;
this.needUpdateCurrentWriter = false;
Expand All @@ -114,7 +114,7 @@ export class AuroraConnectionTrackerPlugin extends AbstractConnectionPlugin impl

private rememberWriter(): void {
if (this.currentWriter === null || this.needUpdateCurrentWriter) {
this.currentWriter = this.getWriter(this.pluginService.getHosts());
this.currentWriter = this.getWriter(this.pluginService.getAllHosts());
this.needUpdateCurrentWriter = false;
}
}
Expand Down
106 changes: 106 additions & 0 deletions common/lib/plugins/custom_endpoint/custom_endpoint_info.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
/*
Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License").
You may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

import { CustomEndpointRoleType, customEndpointRoleTypeFromValue } from "./custom_endpoint_role_type";
import { MemberListType } from "./member_list_type";

export class CustomEndpointInfo {
private endpointIdentifier: string; // ID portion of the custom endpoint URL.
private readonly clusterIdentifier: string; // ID of the cluster that the custom endpoint belongs to.
private readonly url: string;
private readonly roleType: CustomEndpointRoleType;

// A given custom endpoint will either specify a static list or an exclusion list, as indicated by `memberListType`.
// If the list is a static list, 'members' specifies instances included in the custom endpoint, and new cluster
// instances will not be automatically added to the custom endpoint. If it is an exclusion list, 'members' specifies
// instances excluded by the custom endpoint, and new cluster instances will be added to the custom endpoint.
private readonly memberListType: MemberListType;
private readonly members: Set<string>;

constructor(
endpointIdentifier: string,
clusterIdentifier: string,
url: string,
roleType: CustomEndpointRoleType,
members: Set<string>,
memberListType: MemberListType
) {
this.endpointIdentifier = endpointIdentifier;
this.clusterIdentifier = clusterIdentifier;
this.url = url;
this.roleType = roleType;
this.members = members;
this.memberListType = memberListType;
}

getMemberListType(): MemberListType {
return this.memberListType;
}

static fromDbClusterEndpoint(responseEndpointInfo: any): CustomEndpointInfo {
let members: Set<string>;
let memberListType: MemberListType;

if (responseEndpointInfo.StaticMembers) {
members = responseEndpointInfo.StaticMembers;
memberListType = MemberListType.STATIC_LIST;
} else {
members = responseEndpointInfo.ExcludedMembers;
memberListType = MemberListType.EXCLUSION_LIST;
}

return new CustomEndpointInfo(
responseEndpointInfo.DBClusterEndpointIdentifier,
responseEndpointInfo.DBClusterIdentifier,
responseEndpointInfo.Endpoint,
customEndpointRoleTypeFromValue(responseEndpointInfo.CustomEndpointType),
members,
memberListType
);
}

getStaticMembers(): Set<string> {
return this.memberListType === MemberListType.STATIC_LIST ? this.members : null;
}

getExcludedMembers(): Set<string> {
return this.memberListType === MemberListType.EXCLUSION_LIST ? this.members : null;
}

equals(obj: any): boolean {
if (!obj) {
return false;
}

if (obj === this) {
return true;
}

const info = obj as CustomEndpointInfo;
return (
this.endpointIdentifier === info.endpointIdentifier &&
this.clusterIdentifier === info.clusterIdentifier &&
this.url === info.url &&
this.roleType === info.roleType &&
this.members === info.members &&
this.memberListType === info.memberListType
);
}

toString(): string {
return `CustomEndpointInfo[url=${this.url}, clusterIdentifier=${this.clusterIdentifier}, customEndpointType=${this.roleType}, memberListType=${this.memberListType}, members=${this.members}]`;
}
}
21 changes: 21 additions & 0 deletions common/lib/plugins/custom_endpoint/custom_endpoint_monitor.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
/*
Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License").
You may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

export interface CustomEndpointMonitor {
shouldDispose(): boolean;
hasCustomEndpointInfo(): boolean;
close(): Promise<void>;
}
Loading

0 comments on commit 67a7625

Please sign in to comment.