diff --git a/common/lib/aws_client.ts b/common/lib/aws_client.ts index 790437ae..1f05f7bd 100644 --- a/common/lib/aws_client.ts +++ b/common/lib/aws_client.ts @@ -63,13 +63,16 @@ export abstract class AwsClient extends EventEmitter { this.telemetryFactory = new DefaultTelemetryFactory(this.properties); const container = new PluginServiceManagerContainer(); - this.pluginService = new PluginService(container, this, dbType, knownDialectsByCode, this.properties, driverDialect); - this.pluginManager = new PluginManager( + this.pluginService = new PluginService( container, + this, + dbType, + knownDialectsByCode, this.properties, - new ConnectionProviderManager(new DriverConnectionProvider(), null), - this.telemetryFactory + driverDialect, + new ConnectionProviderManager(new DriverConnectionProvider(), null) ); + this.pluginManager = new PluginManager(container, this.properties, this.telemetryFactory); } private async setup() { diff --git a/common/lib/connection_plugin_chain_builder.ts b/common/lib/connection_plugin_chain_builder.ts index 84cea62b..e0593600 100644 --- a/common/lib/connection_plugin_chain_builder.ts +++ b/common/lib/connection_plugin_chain_builder.ts @@ -33,7 +33,6 @@ import { OktaAuthPluginFactory } from "./plugins/federated_auth/okta_auth_plugin import { HostMonitoringPluginFactory } from "./plugins/efm/host_monitoring_plugin_factory"; import { AuroraInitialConnectionStrategyFactory } from "./plugins/aurora_initial_connection_strategy_plugin_factory"; import { AuroraConnectionTrackerPluginFactory } from "./plugins/connection_tracker/aurora_connection_tracker_plugin_factory"; -import { ConnectionProviderManager } from "./connection_provider_manager"; import { DeveloperConnectionPluginFactory } from "./plugins/dev/developer_connection_plugin_factory"; /* @@ -67,11 +66,7 @@ export class ConnectionPluginChainBuilder { ["executeTime", { factory: ExecuteTimePluginFactory, weight: ConnectionPluginChainBuilder.WEIGHT_RELATIVE_TO_PRIOR_PLUGIN }] ]); - static async getPlugins( - pluginService: PluginService, - props: Map, - connectionProviderManager: ConnectionProviderManager - ): Promise { + static async getPlugins(pluginService: PluginService, props: Map): Promise { const plugins: ConnectionPlugin[] = []; let pluginCodes: string = props.get(WrapperProperties.PLUGINS.name); if (pluginCodes == null) { @@ -119,7 +114,7 @@ export class ConnectionPluginChainBuilder { } } - plugins.push(new DefaultPlugin(pluginService, connectionProviderManager)); + plugins.push(new DefaultPlugin(pluginService)); return plugins; } diff --git a/common/lib/connection_provider_manager.ts b/common/lib/connection_provider_manager.ts index 9b208de6..955b7cbf 100644 --- a/common/lib/connection_provider_manager.ts +++ b/common/lib/connection_provider_manager.ts @@ -68,7 +68,7 @@ export class ConnectionProviderManager { } } - if (this.effectiveProvider?.acceptsStrategy(role, strategy)) { + if (!host && this.effectiveProvider?.acceptsStrategy(role, strategy)) { try { host = this.effectiveProvider.getHostInfoByStrategy(hosts, role, strategy, props); } catch { diff --git a/common/lib/plugin_manager.ts b/common/lib/plugin_manager.ts index e7acfb8d..4756216d 100644 --- a/common/lib/plugin_manager.ts +++ b/common/lib/plugin_manager.ts @@ -26,7 +26,6 @@ import { OldConnectionSuggestionAction } from "./old_connection_suggestion_actio import { HostRole } from "./host_role"; import { ClientWrapper } from "./client_wrapper"; import { CanReleaseResources } from "./can_release_resources"; -import { ConnectionProviderManager } from "./connection_provider_manager"; import { TelemetryFactory } from "./utils/telemetry/telemetry_factory"; import { TelemetryTraceLevel } from "./utils/telemetry/telemetry_trace_level"; import { ConnectionProvider } from "./connection_provider"; @@ -71,19 +70,12 @@ export class PluginManager { private static readonly GET_HOST_INFO_BY_STRATEGY_METHOD: string = "getHostInfoByStrategy"; private readonly props: Map; private _plugins: ConnectionPlugin[] = []; - private readonly connectionProviderManager: ConnectionProviderManager; private pluginServiceManagerContainer: PluginServiceManagerContainer; protected telemetryFactory: TelemetryFactory; - constructor( - pluginServiceManagerContainer: PluginServiceManagerContainer, - props: Map, - connectionProviderManager: ConnectionProviderManager, - telemetryFactory: TelemetryFactory - ) { + constructor(pluginServiceManagerContainer: PluginServiceManagerContainer, props: Map, telemetryFactory: TelemetryFactory) { this.pluginServiceManagerContainer = pluginServiceManagerContainer; this.pluginServiceManagerContainer.pluginManager = this; - this.connectionProviderManager = connectionProviderManager; this.props = props; this.telemetryFactory = telemetryFactory; } @@ -95,11 +87,7 @@ export class PluginManager { if (plugins) { this._plugins = plugins; } else { - this._plugins = await ConnectionPluginChainBuilder.getPlugins( - this.pluginServiceManagerContainer.pluginService, - this.props, - this.connectionProviderManager - ); + this._plugins = await ConnectionPluginChainBuilder.getPlugins(this.pluginServiceManagerContainer.pluginService, this.props); } } } @@ -274,7 +262,7 @@ export class PluginManager { return false; } - getHostInfoByStrategy(role: HostRole, strategy: string): HostInfo { + getHostInfoByStrategy(role: HostRole, strategy: string): HostInfo | undefined { for (const plugin of this._plugins) { const pluginSubscribedMethods = plugin.getSubscribedMethods(); const isSubscribed = @@ -291,8 +279,6 @@ export class PluginManager { } } } - - throw new AwsWrapperError("The driver does not support the requested host selection strategy: " + strategy); } async releaseResources() { @@ -306,10 +292,6 @@ export class PluginManager { } } - getConnectionProvider(hostInfo: HostInfo | null, props: Map): ConnectionProvider { - return this.connectionProviderManager.getConnectionProvider(hostInfo, props); - } - private implementsCanReleaseResources(plugin: any): plugin is CanReleaseResources { return plugin.releaseResources !== undefined; } diff --git a/common/lib/plugin_service.ts b/common/lib/plugin_service.ts index 5572ac08..0ca0a317 100644 --- a/common/lib/plugin_service.ts +++ b/common/lib/plugin_service.ts @@ -44,6 +44,7 @@ import { getWriter } from "./utils/utils"; import { ConnectionProvider } from "./connection_provider"; import { TelemetryFactory } from "./utils/telemetry/telemetry_factory"; import { DriverDialect } from "./driver_dialect/driver_dialect"; +import { ConnectionProviderManager } from "./connection_provider_manager"; export class PluginService implements ErrorHandler, HostListProviderService { private readonly _currentClient: AwsClient; @@ -52,6 +53,7 @@ export class PluginService implements ErrorHandler, HostListProviderService { private _initialConnectionHostInfo?: HostInfo; private _isInTransaction: boolean = false; private pluginServiceManagerContainer: PluginServiceManagerContainer; + private readonly connectionProviderManager: ConnectionProviderManager; protected hosts: HostInfo[] = []; private dbDialectProvider: DatabaseDialectProvider; private readonly initialHost: string; @@ -67,7 +69,8 @@ export class PluginService implements ErrorHandler, HostListProviderService { dbType: DatabaseType, knownDialectsByCode: Map, props: Map, - driverDialect: DriverDialect + driverDialect: DriverDialect, + connectionProviderManager: ConnectionProviderManager ) { this._currentClient = client; this.pluginServiceManagerContainer = container; @@ -76,6 +79,7 @@ export class PluginService implements ErrorHandler, HostListProviderService { this.driverDialect = driverDialect; this.initialHost = props.get(WrapperProperties.HOST.name); this.sessionStateService = new SessionStateServiceImpl(this, this.props); + this.connectionProviderManager = connectionProviderManager; container.pluginService = this; this.dialect = this.dbDialectProvider.getDialect(this.props); @@ -114,8 +118,23 @@ export class PluginService implements ErrorHandler, HostListProviderService { } getHostInfoByStrategy(role: HostRole, strategy: string): HostInfo | undefined { + if (role === HostRole.UNKNOWN) { + logger.debug("unknown role requested"); // TODO provide message using Messages.get was: DefaultConnectionPlugin.unknownRoleRequested - + return; + } + const pluginManager = this.pluginServiceManagerContainer.pluginManager; - return pluginManager?.getHostInfoByStrategy(role, strategy); + const host = pluginManager?.getHostInfoByStrategy(role, strategy); + if (host) { + return host; + } + + if (this.getHosts().length === 0) { + logger.debug("no Hosts Available"); // TODO provide message using Messages.get was: DefaultConnectionPlugin.noHostsAvailable + return; + } + + return this.connectionProviderManager.getHostInfoByStrategy(this.getHosts(), role, strategy, this.props); } getCurrentHostInfo(): HostInfo | null { @@ -152,10 +171,7 @@ export class PluginService implements ErrorHandler, HostListProviderService { } getConnectionProvider(hostInfo: HostInfo | null, props: Map): ConnectionProvider { - if (!this.pluginServiceManagerContainer.pluginManager) { - throw new AwsWrapperError("Plugin manager should not be undefined"); - } - return this.pluginServiceManagerContainer.pluginManager.getConnectionProvider(hostInfo, props); + return this.connectionProviderManager.getConnectionProvider(hostInfo, props); } getDialect(): DatabaseDialect { @@ -175,7 +191,10 @@ export class PluginService implements ErrorHandler, HostListProviderService { } acceptsStrategy(role: HostRole, strategy: string): boolean { - return this.pluginServiceManagerContainer.pluginManager?.acceptsStrategy(role, strategy) ?? false; + return ( + (this.pluginServiceManagerContainer.pluginManager?.acceptsStrategy(role, strategy) ?? false) || + this.connectionProviderManager.acceptsStrategy(role, strategy) + ); } async forceRefreshHostList(): Promise; diff --git a/common/lib/plugins/default_plugin.ts b/common/lib/plugins/default_plugin.ts index d97c49b0..12d654f5 100644 --- a/common/lib/plugins/default_plugin.ts +++ b/common/lib/plugins/default_plugin.ts @@ -22,11 +22,8 @@ import { HostInfo } from "../host_info"; import { AbstractConnectionPlugin } from "../abstract_connection_plugin"; import { HostChangeOptions } from "../host_change_options"; import { OldConnectionSuggestionAction } from "../old_connection_suggestion_action"; -import { HostRole } from "../host_role"; import { PluginService } from "../plugin_service"; -import { ConnectionProviderManager } from "../connection_provider_manager"; import { ConnectionProvider } from "../connection_provider"; -import { AwsWrapperError } from "../utils/errors"; import { HostAvailability } from "../host_availability/host_availability"; import { ClientWrapper } from "../client_wrapper"; import { TelemetryTraceLevel } from "../utils/telemetry/telemetry_trace_level"; @@ -34,16 +31,14 @@ import { TelemetryTraceLevel } from "../utils/telemetry/telemetry_trace_level"; export class DefaultPlugin extends AbstractConnectionPlugin { id: string = uniqueId("_defaultPlugin"); private readonly pluginService: PluginService; - private readonly connectionProviderManager: ConnectionProviderManager; - constructor(pluginService: PluginService, connectionProviderManager: ConnectionProviderManager) { + constructor(pluginService: PluginService) { super(); this.pluginService = pluginService; - this.connectionProviderManager = connectionProviderManager; } override getSubscribedMethods(): Set { - return new Set(["*"]); + return new Set(["*"]); // TODO verify Subscribed Methods } override async forceConnect( @@ -52,7 +47,7 @@ export class DefaultPlugin extends AbstractConnectionPlugin { isInitialConnection: boolean, forceConnectFunc: () => Promise ): Promise { - return await this.connectInternal(hostInfo, props, this.connectionProviderManager.getConnectionProvider(hostInfo, props)); + return await this.connectInternal(hostInfo, props, this.pluginService.getConnectionProvider(hostInfo, props)); } override initHostProvider( @@ -70,7 +65,7 @@ export class DefaultPlugin extends AbstractConnectionPlugin { isInitialConnection: boolean, connectFunc: () => Promise ): Promise { - return await this.connectInternal(hostInfo, props, this.connectionProviderManager.getConnectionProvider(hostInfo, props)); + return await this.connectInternal(hostInfo, props, this.pluginService.getConnectionProvider(hostInfo, props)); } private async connectInternal(hostInfo: HostInfo, props: Map, connProvider: ConnectionProvider): Promise { @@ -105,25 +100,4 @@ export class DefaultPlugin extends AbstractConnectionPlugin { override notifyHostListChanged(changes: Map>): Promise { return Promise.resolve(); } - - override acceptsStrategy(role: HostRole, strategy: string): boolean { - if (role === HostRole.UNKNOWN) { - // Users must request either a writer or a reader role. - return false; - } - return this.connectionProviderManager.acceptsStrategy(role, strategy); - } - - override getHostInfoByStrategy(role: HostRole, strategy: string): HostInfo { - if (role === HostRole.UNKNOWN) { - throw new AwsWrapperError(Messages.get("DefaultConnectionPlugin.unknownRoleRequested")); - } - - const hosts = this.pluginService.getHosts(); - if (hosts.length < 1) { - throw new AwsWrapperError(Messages.get("DefaultConnectionPlugin.noHostsAvailable")); - } - - return this.connectionProviderManager.getHostInfoByStrategy(hosts, role, strategy, this.pluginService.props); - } } diff --git a/tests/plugin_benchmarks.ts b/tests/plugin_benchmarks.ts index 099cd502..6f404a90 100644 --- a/tests/plugin_benchmarks.ts +++ b/tests/plugin_benchmarks.ts @@ -15,7 +15,6 @@ */ import { anything, instance, mock, when } from "ts-mockito"; -import { ConnectionProvider } from "../common/lib/connection_provider"; import { PluginService } from "../common/lib/plugin_service"; import { PluginServiceManagerContainer } from "../common/lib/plugin_service_manager_container"; import { WrapperProperties } from "../common/lib/wrapper_property"; @@ -26,10 +25,8 @@ import { HostInfoBuilder } from "../common/lib/host_info_builder"; import { SimpleHostAvailabilityStrategy } from "../common/lib/host_availability/simple_host_availability_strategy"; import { AwsPGClient } from "../pg/lib"; import { NullTelemetryFactory } from "../common/lib/utils/telemetry/null_telemetry_factory"; -import { ConnectionProviderManager } from "../common/lib/connection_provider_manager"; import { PgClientWrapper } from "../common/lib/pg_client_wrapper"; -const mockConnectionProvider = mock(); const mockPluginService = mock(PluginService); const mockClient = mock(AwsPGClient); @@ -59,24 +56,9 @@ WrapperProperties.HOST.set(propsExecute, connectionString); WrapperProperties.HOST.set(propsReadWrite, connectionString); WrapperProperties.HOST.set(props, connectionString); -const pluginManagerExecute = new PluginManager( - pluginServiceManagerContainer, - propsExecute, - new ConnectionProviderManager(instance(mockConnectionProvider), null), - telemetryFactory -); -const pluginManagerReadWrite = new PluginManager( - pluginServiceManagerContainer, - propsReadWrite, - new ConnectionProviderManager(instance(mockConnectionProvider), null), - telemetryFactory -); -const pluginManager = new PluginManager( - pluginServiceManagerContainer, - props, - new ConnectionProviderManager(instance(mockConnectionProvider), null), - new NullTelemetryFactory() -); +const pluginManagerExecute = new PluginManager(pluginServiceManagerContainer, propsExecute, telemetryFactory); +const pluginManagerReadWrite = new PluginManager(pluginServiceManagerContainer, propsReadWrite, telemetryFactory); +const pluginManager = new PluginManager(pluginServiceManagerContainer, props, new NullTelemetryFactory()); suite( "Plugin benchmarks", diff --git a/tests/plugin_manager_benchmarks.ts b/tests/plugin_manager_benchmarks.ts index 1372738e..bc5a8807 100644 --- a/tests/plugin_manager_benchmarks.ts +++ b/tests/plugin_manager_benchmarks.ts @@ -28,7 +28,6 @@ import { WrapperProperties } from "../common/lib/wrapper_property"; import { DefaultPlugin } from "../common/lib/plugins/default_plugin"; import { BenchmarkPluginFactory } from "./testplugin/benchmark_plugin_factory"; import { NullTelemetryFactory } from "../common/lib/utils/telemetry/null_telemetry_factory"; -import { ConnectionProviderManager } from "../common/lib/connection_provider_manager"; import { PgDatabaseDialect } from "../pg/lib/dialect/pg_database_dialect"; import { NodePostgresDriverDialect } from "../pg/lib/dialect/node_postgres_driver_dialect"; @@ -48,25 +47,15 @@ const propsWithPlugins = new Map(); WrapperProperties.PLUGINS.set(propsWithNoPlugins, ""); -const pluginManagerWithNoPlugins = new PluginManager( - pluginServiceManagerContainer, - propsWithNoPlugins, - new ConnectionProviderManager(instance(mockConnectionProvider), null), - telemetryFactory -); -const pluginManagerWithPlugins = new PluginManager( - pluginServiceManagerContainer, - propsWithPlugins, - new ConnectionProviderManager(instance(mockConnectionProvider), null), - telemetryFactory -); +const pluginManagerWithNoPlugins = new PluginManager(pluginServiceManagerContainer, propsWithNoPlugins, telemetryFactory); +const pluginManagerWithPlugins = new PluginManager(pluginServiceManagerContainer, propsWithPlugins, telemetryFactory); async function createPlugins(pluginService: PluginService, connectionProvider: ConnectionProvider, props: Map) { const plugins = new Array(); for (let i = 0; i < 10; i++) { plugins.push(await new BenchmarkPluginFactory().getInstance(pluginService, props)); } - plugins.push(new DefaultPlugin(pluginService, new ConnectionProviderManager(instance(mockConnectionProvider), null))); + plugins.push(new DefaultPlugin(pluginService)); return plugins; } @@ -80,22 +69,12 @@ suite( }), add("initPluginManagerWithPlugins", async () => { - const manager = new PluginManager( - pluginServiceManagerContainer, - propsWithPlugins, - new ConnectionProviderManager(instance(mockConnectionProvider), null), - new NullTelemetryFactory() - ); + const manager = new PluginManager(pluginServiceManagerContainer, propsWithPlugins, new NullTelemetryFactory()); await manager.init(await createPlugins(instance(mockPluginService), instance(mockConnectionProvider), propsWithPlugins)); }), add("initPluginManagerWithNoPlugins", async () => { - const manager = new PluginManager( - pluginServiceManagerContainer, - propsWithNoPlugins, - new ConnectionProviderManager(instance(mockConnectionProvider), null), - new NullTelemetryFactory() - ); + const manager = new PluginManager(pluginServiceManagerContainer, propsWithNoPlugins, new NullTelemetryFactory()); await manager.init(); }), diff --git a/tests/plugin_manager_telemetry_benchmarks.ts b/tests/plugin_manager_telemetry_benchmarks.ts index 9bd7a8b5..fb29af38 100644 --- a/tests/plugin_manager_telemetry_benchmarks.ts +++ b/tests/plugin_manager_telemetry_benchmarks.ts @@ -1,12 +1,12 @@ /* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. - + Licensed under the Apache License, Version 2.0 (the "License"). You may not use this file except in compliance with the License. You may obtain a copy of the License at - + http://www.apache.org/licenses/LICENSE-2.0 - + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. @@ -41,7 +41,6 @@ import { AWSXRayPropagator } from "@opentelemetry/propagator-aws-xray"; import { HttpInstrumentation } from "@opentelemetry/instrumentation-http"; import { AwsInstrumentation } from "@opentelemetry/instrumentation-aws-sdk"; import { AWSXRayIdGenerator } from "@opentelemetry/id-generator-aws-xray"; -import { ConnectionProviderManager } from "../common/lib/connection_provider_manager"; import { PgDatabaseDialect } from "../pg/lib/dialect/pg_database_dialect"; import { NodePostgresDriverDialect } from "../pg/lib/dialect/node_postgres_driver_dialect"; @@ -67,25 +66,15 @@ WrapperProperties.ENABLE_TELEMETRY.set(propsWithPlugins, true); WrapperProperties.TELEMETRY_METRICS_BACKEND.set(propsWithPlugins, "OTLP"); WrapperProperties.TELEMETRY_TRACES_BACKEND.set(propsWithPlugins, "OTLP"); -const pluginManagerWithNoPlugins = new PluginManager( - pluginServiceManagerContainer, - propsWithNoPlugins, - new ConnectionProviderManager(instance(mockConnectionProvider), null), - telemetryFactory -); -const pluginManagerWithPlugins = new PluginManager( - pluginServiceManagerContainer, - propsWithPlugins, - new ConnectionProviderManager(instance(mockConnectionProvider), null), - telemetryFactory -); +const pluginManagerWithNoPlugins = new PluginManager(pluginServiceManagerContainer, propsWithNoPlugins, telemetryFactory); +const pluginManagerWithPlugins = new PluginManager(pluginServiceManagerContainer, propsWithPlugins, telemetryFactory); async function createPlugins(pluginService: PluginService, connectionProvider: ConnectionProvider, props: Map) { const plugins = new Array(); for (let i = 0; i < 10; i++) { plugins.push(await new BenchmarkPluginFactory().getInstance(pluginService, props)); } - plugins.push(new DefaultPlugin(pluginService, new ConnectionProviderManager(instance(mockConnectionProvider), null))); + plugins.push(new DefaultPlugin(pluginService)); return plugins; } @@ -141,22 +130,12 @@ suite( }), add("initPluginManagerWithPlugins", async () => { - const manager = new PluginManager( - pluginServiceManagerContainer, - propsWithPlugins, - new ConnectionProviderManager(instance(mockConnectionProvider), null), - new NullTelemetryFactory() - ); + const manager = new PluginManager(pluginServiceManagerContainer, propsWithPlugins, new NullTelemetryFactory()); await manager.init(await createPlugins(instance(mockPluginService), instance(mockConnectionProvider), propsWithPlugins)); }), add("initPluginManagerWithNoPlugins", async () => { - const manager = new PluginManager( - pluginServiceManagerContainer, - propsWithNoPlugins, - new ConnectionProviderManager(instance(mockConnectionProvider), null), - new NullTelemetryFactory() - ); + const manager = new PluginManager(pluginServiceManagerContainer, propsWithNoPlugins, new NullTelemetryFactory()); await manager.init(); }), diff --git a/tests/plugin_telemetry_benchmarks.ts b/tests/plugin_telemetry_benchmarks.ts index a395ddb2..586b8de7 100644 --- a/tests/plugin_telemetry_benchmarks.ts +++ b/tests/plugin_telemetry_benchmarks.ts @@ -15,7 +15,6 @@ */ import { anything, instance, mock, when } from "ts-mockito"; -import { ConnectionProvider } from "../common/lib/connection_provider"; import { PluginService } from "../common/lib/plugin_service"; import { PluginServiceManagerContainer } from "../common/lib/plugin_service_manager_container"; import { WrapperProperties } from "../common/lib/wrapper_property"; @@ -39,10 +38,8 @@ import { AWSXRayPropagator } from "@opentelemetry/propagator-aws-xray"; import { HttpInstrumentation } from "@opentelemetry/instrumentation-http"; import { AwsInstrumentation } from "@opentelemetry/instrumentation-aws-sdk"; import { AWSXRayIdGenerator } from "@opentelemetry/id-generator-aws-xray"; -import { ConnectionProviderManager } from "../common/lib/connection_provider_manager"; import { PgClientWrapper } from "../common/lib/pg_client_wrapper"; -const mockConnectionProvider = mock(); const mockPluginService = mock(PluginService); const mockClient = mock(AwsPGClient); @@ -81,24 +78,9 @@ WrapperProperties.TELEMETRY_TRACES_BACKEND.set(propsExecute, "OTLP"); WrapperProperties.TELEMETRY_TRACES_BACKEND.set(propsReadWrite, "OTLP"); WrapperProperties.TELEMETRY_TRACES_BACKEND.set(props, "OTLP"); -const pluginManagerExecute = new PluginManager( - pluginServiceManagerContainer, - propsExecute, - new ConnectionProviderManager(instance(mockConnectionProvider), null), - telemetryFactory -); -const pluginManagerReadWrite = new PluginManager( - pluginServiceManagerContainer, - propsReadWrite, - new ConnectionProviderManager(instance(mockConnectionProvider), null), - telemetryFactory -); -const pluginManager = new PluginManager( - pluginServiceManagerContainer, - props, - new ConnectionProviderManager(instance(mockConnectionProvider), null), - new NullTelemetryFactory() -); +const pluginManagerExecute = new PluginManager(pluginServiceManagerContainer, propsExecute, telemetryFactory); +const pluginManagerReadWrite = new PluginManager(pluginServiceManagerContainer, propsReadWrite, telemetryFactory); +const pluginManager = new PluginManager(pluginServiceManagerContainer, props, new NullTelemetryFactory()); const traceExporter = new OTLPTraceExporter({ url: "http://localhost:4317" }); const resource = Resource.default().merge( diff --git a/tests/unit/connection_plugin_chain_builder.test.ts b/tests/unit/connection_plugin_chain_builder.test.ts index 33b944fe..73820951 100644 --- a/tests/unit/connection_plugin_chain_builder.test.ts +++ b/tests/unit/connection_plugin_chain_builder.test.ts @@ -18,15 +18,12 @@ import { WrapperProperties } from "../../common/lib/wrapper_property"; import { instance, mock, when } from "ts-mockito"; import { ConnectionPluginChainBuilder } from "../../common/lib/connection_plugin_chain_builder"; import { PluginService } from "../../common/lib/plugin_service"; -import { ConnectionProvider } from "../../common/lib/connection_provider"; -import { DriverConnectionProvider } from "../../common/lib/driver_connection_provider"; import { FailoverPlugin } from "../../common/lib/plugins/failover/failover_plugin"; import { IamAuthenticationPlugin } from "../../common/lib/authentication/iam_authentication_plugin"; import { DefaultPlugin } from "../../common/lib/plugins/default_plugin"; import { ExecuteTimePlugin } from "../../common/lib/plugins/execute_time_plugin"; import { ConnectTimePlugin } from "../../common/lib/plugins/connect_time_plugin"; import { StaleDnsPlugin } from "../../common/lib/plugins/stale_dns/stale_dns_plugin"; -import { ConnectionProviderManager } from "../../common/lib/connection_provider_manager"; import { NullTelemetryFactory } from "../../common/lib/utils/telemetry/null_telemetry_factory"; import { AbstractConnectionPlugin } from "../../common/lib/abstract_connection_plugin"; import { ConnectionPluginFactory } from "../../common/lib/plugin_factory"; @@ -34,8 +31,6 @@ import { PluginManager } from "../../common/lib/plugin_manager"; const mockPluginService: PluginService = mock(PluginService); const mockPluginServiceInstance: PluginService = instance(mockPluginService); -const mockDefaultConnProvider: ConnectionProvider = mock(DriverConnectionProvider); -const mockEffectiveConnProvider: ConnectionProvider = mock(DriverConnectionProvider); describe("testConnectionPluginChainBuilder", () => { beforeAll(() => { @@ -46,11 +41,7 @@ describe("testConnectionPluginChainBuilder", () => { const props = new Map(); props.set(WrapperProperties.PLUGINS.name, plugins); - const result = await ConnectionPluginChainBuilder.getPlugins( - mockPluginServiceInstance, - props, - new ConnectionProviderManager(mockDefaultConnProvider, mockEffectiveConnProvider) - ); + const result = await ConnectionPluginChainBuilder.getPlugins(mockPluginServiceInstance, props); expect(result.length).toBe(4); expect(result[0]).toBeInstanceOf(StaleDnsPlugin); @@ -64,11 +55,7 @@ describe("testConnectionPluginChainBuilder", () => { props.set(WrapperProperties.PLUGINS.name, "iam,staleDns,failover"); props.set(WrapperProperties.AUTO_SORT_PLUGIN_ORDER.name, false); - const result = await ConnectionPluginChainBuilder.getPlugins( - mockPluginServiceInstance, - props, - new ConnectionProviderManager(mockDefaultConnProvider, mockEffectiveConnProvider) - ); + const result = await ConnectionPluginChainBuilder.getPlugins(mockPluginServiceInstance, props); expect(result.length).toBe(4); expect(result[0]).toBeInstanceOf(IamAuthenticationPlugin); @@ -82,11 +69,7 @@ describe("testConnectionPluginChainBuilder", () => { props.set(WrapperProperties.PLUGINS.name, "executeTime,connectTime,iam"); - let result = await ConnectionPluginChainBuilder.getPlugins( - mockPluginServiceInstance, - props, - new ConnectionProviderManager(mockDefaultConnProvider, mockEffectiveConnProvider) - ); + let result = await ConnectionPluginChainBuilder.getPlugins(mockPluginServiceInstance, props); expect(result.length).toBe(4); expect(result[0]).toBeInstanceOf(ExecuteTimePlugin); @@ -96,11 +79,7 @@ describe("testConnectionPluginChainBuilder", () => { // Test again to make sure the previous sort does not impact future plugin chains props.set(WrapperProperties.PLUGINS.name, "iam,executeTime,connectTime,failover"); - result = await ConnectionPluginChainBuilder.getPlugins( - mockPluginServiceInstance, - props, - new ConnectionProviderManager(mockDefaultConnProvider, mockEffectiveConnProvider) - ); + result = await ConnectionPluginChainBuilder.getPlugins(mockPluginServiceInstance, props); expect(result.length).toBe(5); expect(result[0]).toBeInstanceOf(FailoverPlugin); @@ -116,11 +95,7 @@ describe("testConnectionPluginChainBuilder", () => { const props = new Map(); props.set(WrapperProperties.PLUGINS.name, "test"); - const result = await ConnectionPluginChainBuilder.getPlugins( - mockPluginServiceInstance, - props, - new ConnectionProviderManager(mockDefaultConnProvider, mockEffectiveConnProvider) - ); + const result = await ConnectionPluginChainBuilder.getPlugins(mockPluginServiceInstance, props); expect(result.length).toBe(2); expect(result[0]).toBeInstanceOf(TestPlugin); diff --git a/tests/unit/database_dialect.test.ts b/tests/unit/database_dialect.test.ts index b76dfcef..07872436 100644 --- a/tests/unit/database_dialect.test.ts +++ b/tests/unit/database_dialect.test.ts @@ -33,8 +33,10 @@ import { RdsMultiAZMySQLDatabaseDialect } from "../../mysql/lib/dialect/rds_mult import { RdsMultiAZPgDatabaseDialect } from "../../pg/lib/dialect/rds_multi_az_pg_database_dialect"; import { DatabaseDialectManager } from "../../common/lib/database_dialect/database_dialect_manager"; import { NodePostgresDriverDialect } from "../../pg/lib/dialect/node_postgres_driver_dialect"; -import { mock } from "ts-mockito"; +import { mock, instance } from "ts-mockito"; import { PgClientWrapper } from "../../common/lib/pg_client_wrapper"; +import { ConnectionProvider } from "../../common/lib/connection_provider"; +import { ConnectionProviderManager } from "../../common/lib/connection_provider_manager"; const LOCALHOST = "localhost"; const RDS_DATABASE = "database-1.xyz.us-east-2.rds.amazonaws.com"; @@ -180,6 +182,7 @@ const expectedDialectMapping: Map = ne const pluginServiceManagerContainer = new PluginServiceManagerContainer(); const mockClient = new AwsPGClient({}); const mockDriverDialect = mock(NodePostgresDriverDialect); +const mockConnectionProvider = mock(); class MockTargetClient { readonly expectedInputs: string[]; @@ -276,7 +279,8 @@ describe("test database dialects", () => { databaseType, expectedDialect!.dialects, props, - mockDriverDialect + mockDriverDialect, + new ConnectionProviderManager(instance(mockConnectionProvider), null) ); await pluginService.updateDialect(mockClientWrapper); expect(pluginService.getDialect()).toBe(expectedDialectClass); diff --git a/tests/unit/notification_pipeline.test.ts b/tests/unit/notification_pipeline.test.ts index 263c65ae..84d0d1e9 100644 --- a/tests/unit/notification_pipeline.test.ts +++ b/tests/unit/notification_pipeline.test.ts @@ -21,8 +21,6 @@ import { PluginServiceManagerContainer } from "../../common/lib/plugin_service_m import { DefaultPlugin } from "../../common/lib/plugins/default_plugin"; import { instance, mock } from "ts-mockito"; import { PluginService } from "../../common/lib/plugin_service"; -import { DriverConnectionProvider } from "../../common/lib/driver_connection_provider"; -import { ConnectionProviderManager } from "../../common/lib/connection_provider_manager"; import { NullTelemetryFactory } from "../../common/lib/utils/telemetry/null_telemetry_factory"; class TestPlugin extends DefaultPlugin { @@ -54,13 +52,8 @@ describe("notificationPipelineTest", () => { let plugin: TestPlugin; beforeEach(() => { - pluginManager = new PluginManager( - container, - props, - new ConnectionProviderManager(new DriverConnectionProvider(), null), - new NullTelemetryFactory() - ); - plugin = new TestPlugin(instance(mockPluginService), new ConnectionProviderManager(new DriverConnectionProvider(), null)); + pluginManager = new PluginManager(container, props, new NullTelemetryFactory()); + plugin = new TestPlugin(instance(mockPluginService)); pluginManager["_plugins"] = [plugin]; });