Skip to content

Commit

Permalink
feat: custom endpoint plugin
Browse files Browse the repository at this point in the history
  • Loading branch information
crystall-bitquill committed Feb 4, 2025
1 parent 98fcda0 commit 11c4155
Show file tree
Hide file tree
Showing 53 changed files with 2,014 additions and 203 deletions.
33 changes: 33 additions & 0 deletions common/lib/AllowedAndBlockedHosts.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
/*
Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License").
You may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/

export class AllowedAndBlockedHosts {
private readonly allowedHostIds: Set<string>;
private readonly blockedHostIds: Set<string>;

constructor(allowedHostIds: Set<string>, blockedHostIds: Set<string>) {
this.allowedHostIds = allowedHostIds;
this.blockedHostIds = blockedHostIds;
}

getAllowedHostIds() {
return this.allowedHostIds;
}

getBlockedHostIds() {
return this.blockedHostIds;
}
}
7 changes: 3 additions & 4 deletions common/lib/authentication/iam_authentication_plugin.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,24 +15,23 @@
*/

import { PluginService } from "../plugin_service";
import { RdsUtils } from "../utils/rds_utils";
import { Messages } from "../utils/messages";
import { logger } from "../../logutils";
import { AwsWrapperError } from "../utils/errors";
import { HostInfo } from "../host_info";
import { AwsCredentialsManager } from "./aws_credentials_manager";
import { AbstractConnectionPlugin } from "../abstract_connection_plugin";
import { WrapperProperties } from "../wrapper_property";
import { WrapperProperties, WrapperProperty } from "../wrapper_property";
import { IamAuthUtils, TokenInfo } from "../utils/iam_auth_utils";
import { ClientWrapper } from "../client_wrapper";
import { RegionUtils } from "../utils/region_utils";

export class IamAuthenticationPlugin extends AbstractConnectionPlugin {
private static readonly SUBSCRIBED_METHODS = new Set<string>(["connect", "forceConnect"]);
protected static readonly tokenCache = new Map<string, TokenInfo>();
private readonly telemetryFactory;
private readonly fetchTokenCounter;
private pluginService: PluginService;
rdsUtil: RdsUtils = new RdsUtils();

constructor(pluginService: PluginService) {
super();
Expand Down Expand Up @@ -75,7 +74,7 @@ export class IamAuthenticationPlugin extends AbstractConnectionPlugin {
}

const host = IamAuthUtils.getIamHost(props, hostInfo);
const region: string = IamAuthUtils.getRdsRegion(host, this.rdsUtil, props);
const region: string = RegionUtils.getRegion(props.get(WrapperProperties.IAM_REGION.name), host);
const port = IamAuthUtils.getIamPort(props, hostInfo, this.pluginService.getCurrentClient().defaultPort);
const tokenExpirationSec = WrapperProperties.IAM_TOKEN_EXPIRATION.get(props);
if (tokenExpirationSec < 0) {
Expand Down
2 changes: 2 additions & 0 deletions common/lib/connection_plugin_chain_builder.ts
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ import { DeveloperConnectionPluginFactory } from "./plugins/dev/developer_connec
import { ConnectionPluginFactory } from "./plugin_factory";
import { LimitlessConnectionPluginFactory } from "./plugins/limitless/limitless_connection_plugin_factory";
import { FastestResponseStrategyPluginFactory } from "./plugins/strategy/fastest_response/fastest_respose_strategy_plugin_factory";
import { CustomEndpointPluginFactory } from "./plugins/custom_endpoint/custom_endpoint_plugin_factory";
import { ConfigurationProfile } from "./profile/configuration_profile";

/*
Expand All @@ -54,6 +55,7 @@ export class ConnectionPluginChainBuilder {
static readonly WEIGHT_RELATIVE_TO_PRIOR_PLUGIN = -1;

static readonly PLUGIN_FACTORIES = new Map<string, PluginFactoryInfo>([
["customEndpoint", { factory: CustomEndpointPluginFactory, weight: 380 }],
["initialConnection", { factory: AuroraInitialConnectionStrategyFactory, weight: 390 }],
["auroraConnectionTracker", { factory: AuroraConnectionTrackerPluginFactory, weight: 400 }],
["staleDns", { factory: StaleDnsPluginFactory, weight: 500 }],
Expand Down
2 changes: 2 additions & 0 deletions common/lib/database_dialect/database_dialect.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import { FailoverRestriction } from "../plugins/failover/failover_restriction";
import { ErrorHandler } from "../error_handler";
import { SessionState } from "../session_state";
import { TransactionIsolationLevel } from "../utils/transaction_isolation_level";
import { HostRole } from "../host_role";

export enum DatabaseType {
MYSQL,
Expand All @@ -39,6 +40,7 @@ export interface DatabaseDialect {
getSetSchemaQuery(schema: string): string;
getDialectUpdateCandidates(): string[];
getErrorHandler(): ErrorHandler;
getHostRole(targetClient: ClientWrapper): Promise<HostRole>;
isDialect(targetClient: ClientWrapper): Promise<boolean>;
getHostListProvider(props: Map<string, any>, originalUrl: string, hostListProviderService: HostListProviderService): HostListProvider;
isClientValid(targetClient: ClientWrapper): Promise<boolean>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import { Messages } from "../../utils/messages";
import { WrapperProperties } from "../../wrapper_property";
import { BlockingHostListProvider } from "../host_list_provider";
import { logger } from "../../../logutils";
import { isDialectTopologyAware } from "../../utils/utils";

export class MonitoringRdsHostListProvider extends RdsHostListProvider implements BlockingHostListProvider {
static readonly CACHE_CLEANUP_NANOS: bigint = BigInt(60_000_000_000); // 1 minute.
Expand Down Expand Up @@ -76,7 +77,7 @@ export class MonitoringRdsHostListProvider extends RdsHostListProvider implement

async sqlQueryForTopology(targetClient: ClientWrapper): Promise<HostInfo[]> {
const dialect: DatabaseDialect = this.hostListProviderService.getDialect();
if (!this.isTopologyAwareDatabaseDialect(dialect)) {
if (!isDialectTopologyAware(dialect)) {
throw new TypeError(Messages.get("RdsHostListProvider.incorrectDialect"));
}
return await dialect.queryForTopology(targetClient, this).then((res: any) => this.processQueryResults(res));
Expand Down
15 changes: 5 additions & 10 deletions common/lib/host_list_provider/rds_host_list_provider.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,7 @@ import { WrapperProperties } from "../wrapper_property";
import { logger } from "../../logutils";
import { HostAvailability } from "../host_availability/host_availability";
import { CacheMap } from "../utils/cache_map";
import { logTopology } from "../utils/utils";
import { TopologyAwareDatabaseDialect } from "../topology_aware_database_dialect";
import { isDialectTopologyAware, logTopology } from "../utils/utils";
import { DatabaseDialect } from "../database_dialect/database_dialect";
import { ClientWrapper } from "../client_wrapper";

Expand Down Expand Up @@ -137,7 +136,7 @@ export class RdsHostListProvider implements DynamicHostListProvider {
}

async getHostRole(client: ClientWrapper, dialect: DatabaseDialect): Promise<HostRole> {
if (!this.isTopologyAwareDatabaseDialect(dialect)) {
if (!isDialectTopologyAware(dialect)) {
throw new TypeError(Messages.get("RdsHostListProvider.incorrectDialect"));
}

Expand All @@ -150,7 +149,7 @@ export class RdsHostListProvider implements DynamicHostListProvider {

async getWriterId(client: ClientWrapper): Promise<string | null> {
const dialect = this.hostListProviderService.getDialect();
if (!this.isTopologyAwareDatabaseDialect(dialect)) {
if (!isDialectTopologyAware(dialect)) {
throw new TypeError(Messages.get("RdsHostListProvider.incorrectDialect"));
}

Expand All @@ -162,7 +161,7 @@ export class RdsHostListProvider implements DynamicHostListProvider {
}

async identifyConnection(targetClient: ClientWrapper, dialect: DatabaseDialect): Promise<HostInfo | null> {
if (!this.isTopologyAwareDatabaseDialect(dialect)) {
if (!isDialectTopologyAware(dialect)) {
throw new TypeError(Messages.get("RdsHostListProvider.incorrectDialect"));
}
const instanceName = await dialect.identifyConnection(targetClient);
Expand Down Expand Up @@ -276,12 +275,8 @@ export class RdsHostListProvider implements DynamicHostListProvider {
}
}

protected isTopologyAwareDatabaseDialect(arg: any): arg is TopologyAwareDatabaseDialect {
return arg;
}

async queryForTopology(targetClient: ClientWrapper, dialect: DatabaseDialect): Promise<HostInfo[]> {
if (!this.isTopologyAwareDatabaseDialect(dialect)) {
if (!isDialectTopologyAware(dialect)) {
throw new TypeError(Messages.get("RdsHostListProvider.incorrectDialect"));
}

Expand Down
3 changes: 3 additions & 0 deletions common/lib/host_list_provider_service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import { DatabaseDialect } from "./database_dialect/database_dialect";
import { HostInfoBuilder } from "./host_info_builder";
import { ConnectionUrlParser } from "./utils/connection_url_parser";
import { TelemetryFactory } from "./utils/telemetry/telemetry_factory";
import { AllowedAndBlockedHosts } from "./AllowedAndBlockedHosts";

export interface HostListProviderService {
getHostListProvider(): HostListProvider | null;
Expand Down Expand Up @@ -50,4 +51,6 @@ export interface HostListProviderService {
isClientValid(targetClient: any): Promise<boolean>;

getTelemetryFactory(): TelemetryFactory;

setAllowedAndBlockedHosts(allowedAndBlockedHosts: AllowedAndBlockedHosts): void;
}
82 changes: 77 additions & 5 deletions common/lib/plugin_service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,11 +40,13 @@ import { ClientWrapper } from "./client_wrapper";
import { logger } from "../logutils";
import { Messages } from "./utils/messages";
import { DatabaseDialectCodes } from "./database_dialect/database_dialect_codes";
import { getWriter } from "./utils/utils";
import { getWriter, logTopology } from "./utils/utils";
import { TelemetryFactory } from "./utils/telemetry/telemetry_factory";
import { DriverDialect } from "./driver_dialect/driver_dialect";
import { AllowedAndBlockedHosts } from "./AllowedAndBlockedHosts";

export class PluginService implements ErrorHandler, HostListProviderService {
private static readonly DEFAULT_HOST_AVAILABILITY_CACHE_EXPIRE_NANO = 5 * 60_000_000_000; // 5 minutes
private readonly _currentClient: AwsClient;
private _currentHostInfo?: HostInfo;
private _hostListProvider?: HostListProvider;
Expand All @@ -59,6 +61,7 @@ export class PluginService implements ErrorHandler, HostListProviderService {
protected readonly sessionStateService: SessionStateService;
protected static readonly hostAvailabilityExpiringCache: CacheMap<string, HostAvailability> = new CacheMap<string, HostAvailability>();
readonly props: Map<string, any>;
private allowedAndBlockedHosts: AllowedAndBlockedHosts | null = null;

constructor(
container: PluginServiceManagerContainer,
Expand Down Expand Up @@ -114,17 +117,29 @@ export class PluginService implements ErrorHandler, HostListProviderService {
this._currentHostInfo = this._initialConnectionHostInfo;

if (!this._currentHostInfo) {
if (this.getHosts().length === 0) {
if (this.getAllHosts().length === 0) {
throw new AwsWrapperError(Messages.get("PluginService.hostListEmpty"));
}

const writerHost = getWriter(this.getHosts());
const writerHost = getWriter(this.getAllHosts());
if (!this.getHosts().some((hostInfo: HostInfo) => hostInfo.host === writerHost?.host)) {
throw new AwsWrapperError(
Messages.get(
"PluginService.currentHostNotAllowed",
this._currentHostInfo ? this._currentHostInfo.host : "<null>",
logTopology(this.hosts, "[PluginService.currentHostNotAllowed] ")
)
);
}

if (writerHost) {
this._currentHostInfo = writerHost;
} else {
this._currentHostInfo = this.getHosts()[0];
}
}

logger.debug(`Set current host to: ${this._currentHostInfo.host}`);
}

return this._currentHostInfo;
Expand Down Expand Up @@ -286,11 +301,64 @@ export class PluginService implements ErrorHandler, HostListProviderService {
}
}

getHosts(): HostInfo[] {
getAllHosts(): HostInfo[] {
return this.hosts;
}

setAvailability(hostAliases: Set<string>, availability: HostAvailability) {}
getHosts(): HostInfo[] {
const hostPermissions = this.allowedAndBlockedHosts;
if (!hostPermissions) {
return this.hosts;
}

let hosts = this.hosts;
const allowedHostIds = hostPermissions.getAllowedHostIds();
const blockedHostIds = hostPermissions.getBlockedHostIds();

if (allowedHostIds && allowedHostIds.size > 0) {
hosts = hosts.filter((host: HostInfo) => allowedHostIds.has(host.hostId));
}

if (blockedHostIds && blockedHostIds.size > 0) {
hosts = hosts.filter((host: HostInfo) => !blockedHostIds.has(host.hostId));
}

return hosts;
}

setAvailability(hostAliases: Set<string>, availability: HostAvailability) {
if (hostAliases.size === 0) {
return;
}

const hostsToChange = [
...new Set(
this.getAllHosts().filter(
(host: HostInfo) => hostAliases.has(host.asAlias) || [...host.aliases].some((hostAlias: string) => hostAliases.has(hostAlias))
)
)
];

if (hostsToChange.length === 0) {
logger.debug(Messages.get("PluginService.hostsChangeListEmpty"));
return;
}

const changes = new Map<string, Set<HostChangeOptions>>();
for (const host of hostsToChange) {
const currentAvailability = host.getAvailability();
PluginService.hostAvailabilityExpiringCache.put(host.url, availability, PluginService.DEFAULT_HOST_AVAILABILITY_CACHE_EXPIRE_NANO);
if (currentAvailability !== availability) {
let hostChanges = new Set<HostChangeOptions>();
if (availability === HostAvailability.AVAILABLE) {
hostChanges = new Set([HostChangeOptions.WENT_UP, HostChangeOptions.HOST_CHANGED]);
} else {
hostChanges = new Set([HostChangeOptions.WENT_DOWN, HostChangeOptions.HOST_CHANGED]);
}
changes.set(host.url, hostChanges);
}
}
}

updateConfigWithProperties(props: Map<string, any>) {
this._currentClient.config = Object.fromEntries(props.entries());
Expand Down Expand Up @@ -527,4 +595,8 @@ export class PluginService implements ErrorHandler, HostListProviderService {
attachNoOpErrorListener(clientWrapper: ClientWrapper | undefined): void {
this.getDialect().getErrorHandler().attachNoOpErrorListener(clientWrapper);
}

setAllowedAndBlockedHosts(allowedAndBlockedHosts: AllowedAndBlockedHosts) {
this.allowedAndBlockedHosts = allowedAndBlockedHosts;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ export class AuroraInitialConnectionStrategyPlugin extends AbstractConnectionPlu
}

private getWriter(): HostInfo | null {
return this.pluginService.getHosts().find((x) => x.role === HostRole.WRITER) ?? null;
return this.pluginService.getAllHosts().find((x) => x.role === HostRole.WRITER) ?? null;
}

private getReader(props: Map<string, any>): HostInfo | undefined {
Expand All @@ -278,6 +278,6 @@ export class AuroraInitialConnectionStrategyPlugin extends AbstractConnectionPlu
}

private hasNoReaders(): boolean {
return this.pluginService.getHosts().find((x) => x.role === HostRole.READER) !== undefined;
return this.pluginService.getAllHosts().find((x) => x.role === HostRole.READER) !== undefined;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ export class AuroraConnectionTrackerPlugin extends AbstractConnectionPlugin impl
}

private async checkWriterChanged(): Promise<void> {
const hostInfoAfterFailover = this.getWriter(this.pluginService.getHosts());
const hostInfoAfterFailover = this.getWriter(this.pluginService.getAllHosts());
if (this.currentWriter === null) {
this.currentWriter = hostInfoAfterFailover;
this.needUpdateCurrentWriter = false;
Expand All @@ -114,7 +114,7 @@ export class AuroraConnectionTrackerPlugin extends AbstractConnectionPlugin impl

private rememberWriter(): void {
if (this.currentWriter === null || this.needUpdateCurrentWriter) {
this.currentWriter = this.getWriter(this.pluginService.getHosts());
this.currentWriter = this.getWriter(this.pluginService.getAllHosts());
this.needUpdateCurrentWriter = false;
}
}
Expand Down
Loading

0 comments on commit 11c4155

Please sign in to comment.