Skip to content

Commit

Permalink
fix issue with custom domains during failover
Browse files Browse the repository at this point in the history
  • Loading branch information
sergiyvamz committed Jan 30, 2025
1 parent 2daf268 commit be6deb0
Show file tree
Hide file tree
Showing 6 changed files with 83 additions and 37 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
import software.amazon.jdbc.Driver;
import software.amazon.jdbc.HostSpec;
import software.amazon.jdbc.PluginService;
import software.amazon.jdbc.PropertyDefinition;
import software.amazon.jdbc.util.CacheMap;
import software.amazon.jdbc.util.ConnectionUrlParser;
import software.amazon.jdbc.util.Messages;
Expand Down Expand Up @@ -81,7 +82,11 @@ public class DialectManager implements DialectProvider {
private Dialect dialect = null;
private String dialectCode;

private PluginService pluginService;
private final PluginService pluginService;

static {
PropertyDefinition.registerPluginProperties(DialectManager.class);
}

public DialectManager(PluginService pluginService) {
this.pluginService = pluginService;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ public class AuroraConnectionTrackerPlugin extends AbstractConnectionPlugin impl
Collections.unmodifiableSet(new HashSet<String>() {
{
addAll(SubscribedMethodHelper.NETWORK_BOUND_METHODS);
add(METHOD_CLOSE);
add(METHOD_ABORT);
add("connect");
add("notifyNodeListChanged");
}
Expand Down Expand Up @@ -87,7 +89,7 @@ public Connection connect(final String driverProtocol, final HostSpec hostSpec,

if (conn != null) {
final RdsUrlType type = this.rdsHelper.identifyRdsType(hostSpec.getHost());
if (type.isRdsCluster()) {
if (type.isRdsCluster() || type == RdsUrlType.OTHER) {
hostSpec.resetAliases();
this.pluginService.fillAliases(conn, hostSpec);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,14 @@
import java.sql.Connection;
import java.sql.SQLException;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Properties;
import java.util.Set;
import java.util.concurrent.TimeUnit;
import java.util.logging.Logger;
import java.util.regex.Pattern;
import software.amazon.jdbc.AwsWrapperProperty;
import software.amazon.jdbc.HostListProviderService;
import software.amazon.jdbc.HostRole;
Expand All @@ -32,9 +35,11 @@
import software.amazon.jdbc.PluginService;
import software.amazon.jdbc.PropertyDefinition;
import software.amazon.jdbc.hostavailability.HostAvailability;
import software.amazon.jdbc.plugin.failover.FailoverMode;
import software.amazon.jdbc.util.Messages;
import software.amazon.jdbc.util.RdsUrlType;
import software.amazon.jdbc.util.RdsUtils;
import software.amazon.jdbc.util.StringUtils;
import software.amazon.jdbc.util.WrapperUtils;

public class AuroraInitialConnectionStrategyPlugin extends AbstractConnectionPlugin {
Expand Down Expand Up @@ -67,16 +72,47 @@ public class AuroraInitialConnectionStrategyPlugin extends AbstractConnectionPlu
"1000",
"Time between each retry of opening a connection.");

public static final AwsWrapperProperty VERIFY_OPENED_CONNECTION_TYPE =
new AwsWrapperProperty(
"verifyOpenedConnectionType",
null,
"Force to verify an opened connection to be either a writer or a reader.");

private enum VerifyOpenedConnectionType {
WRITER,
READER;

private static final Map<String, VerifyOpenedConnectionType> nameToValue =
new HashMap<String, VerifyOpenedConnectionType>() {
{
put("writer", WRITER);
put("reader", READER);
}
};

public static VerifyOpenedConnectionType fromValue(String value) {
if (value == null) {
return null;
}
return nameToValue.get(value.toLowerCase());
}
}

private final PluginService pluginService;
private HostListProviderService hostListProviderService;
private final RdsUtils rdsUtils = new RdsUtils();

private VerifyOpenedConnectionType verifyOpenedConnectionType = null;

Check warning on line 105 in wrapper/src/main/java/software/amazon/jdbc/plugin/AuroraInitialConnectionStrategyPlugin.java

View workflow job for this annotation

GitHub Actions / Qodana Community for JVM

Unused assignment

Variable `verifyOpenedConnectionType` initializer `null` is redundant


static {
PropertyDefinition.registerPluginProperties(AuroraInitialConnectionStrategyPlugin.class);
}

public AuroraInitialConnectionStrategyPlugin(final PluginService pluginService, final Properties properties) {
this.pluginService = pluginService;
this.verifyOpenedConnectionType =
VerifyOpenedConnectionType.fromValue(VERIFY_OPENED_CONNECTION_TYPE.getString(properties));
}

@Override
Expand Down Expand Up @@ -110,12 +146,8 @@ public Connection connect(

final RdsUrlType type = this.rdsUtils.identifyRdsType(hostSpec.getHost());

if (!type.isRdsCluster()) {
// It's not a cluster endpoint. Continue with a normal workflow.
return connectFunc.call();
}

if (type == RdsUrlType.RDS_WRITER_CLUSTER) {
if (type == RdsUrlType.RDS_WRITER_CLUSTER
|| this.verifyOpenedConnectionType == VerifyOpenedConnectionType.WRITER) {
Connection writerCandidateConn = this.getVerifiedWriterConnection(props, isInitialConnection, connectFunc);
if (writerCandidateConn == null) {
// Can't get writer connection. Continue with a normal workflow.
Expand All @@ -124,7 +156,8 @@ public Connection connect(
return writerCandidateConn;
}

if (type == RdsUrlType.RDS_READER_CLUSTER) {
if (type == RdsUrlType.RDS_READER_CLUSTER
|| this.verifyOpenedConnectionType == VerifyOpenedConnectionType.READER) {
Connection readerCandidateConn = this.getVerifiedReaderConnection(props, isInitialConnection, connectFunc);
if (readerCandidateConn == null) {
// Can't get a reader connection. Continue with a normal workflow.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,14 @@
import java.sql.Connection;
import java.sql.SQLException;
import java.util.Arrays;
import java.util.HashSet;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Queue;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.ConcurrentLinkedQueue;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.logging.Logger;
Expand All @@ -35,6 +36,7 @@
import software.amazon.jdbc.util.Messages;
import software.amazon.jdbc.util.RdsUtils;
import software.amazon.jdbc.util.StringUtils;
import software.amazon.jdbc.util.SynchronousExecutor;
import software.amazon.jdbc.util.telemetry.TelemetryContext;
import software.amazon.jdbc.util.telemetry.TelemetryFactory;
import software.amazon.jdbc.util.telemetry.TelemetryTraceLevel;
Expand All @@ -50,16 +52,17 @@ public class OpenedConnectionTracker {
invalidateThread.setDaemon(true);
return invalidateThread;
});
private static final ExecutorService abortConnectionExecutorService =
Executors.newCachedThreadPool(
r -> {
final Thread abortThread = new Thread(r);
abortThread.setDaemon(true);
return abortThread;
});
private static final Executor abortConnectionExecutor = new SynchronousExecutor();

private static final Logger LOGGER = Logger.getLogger(OpenedConnectionTracker.class.getName());
private static final RdsUtils rdsUtils = new RdsUtils();

private static final Set<String> safeCheckIfClosed = new HashSet<>(Arrays.asList(
"HikariProxyConnection",
"org.postgresql.jdbc.PgConnection",
"com.mysql.cj.jdbc.ConnectionImpl",
"org.mariadb.jdbc.Connection"));

private final PluginService pluginService;

public OpenedConnectionTracker(final PluginService pluginService) {
Expand All @@ -72,6 +75,7 @@ public void populateOpenedConnectionQueue(final HostSpec hostSpec, final Connect
// Check if the connection was established using an instance endpoint
if (rdsUtils.isRdsInstance(hostSpec.getHost())) {
trackConnection(hostSpec.getHostAndPort(), conn);
logOpenedConnections();
return;
}

Expand All @@ -80,14 +84,17 @@ public void populateOpenedConnectionQueue(final HostSpec hostSpec, final Connect
.max(String::compareToIgnoreCase)
.orElse(null);

if (instanceEndpoint == null) {
LOGGER.finest(
Messages.get("OpenedConnectionTracker.unableToPopulateOpenedConnectionQueue",
new Object[] {hostSpec.getHost()}));
if (instanceEndpoint != null) {
trackConnection(instanceEndpoint, conn);
logOpenedConnections();
return;
}

trackConnection(instanceEndpoint, conn);
// It seems there's no RDS instance host found. It might be a custom domain name. Let's track by all aliases
for (String alias : aliases) {
trackConnection(alias, conn);
}
logOpenedConnections();
}

/**
Expand All @@ -100,22 +107,21 @@ public void invalidateAllConnections(final HostSpec hostSpec) {
invalidateAllConnections(hostSpec.getAliases().toArray(new String[] {}));
}

public void invalidateAllConnections(final String... node) {
public void invalidateAllConnections(final String... keys) {
TelemetryFactory telemetryFactory = this.pluginService.getTelemetryFactory();
TelemetryContext telemetryContext = telemetryFactory.openTelemetryContext(
TELEMETRY_INVALIDATE_CONNECTIONS, TelemetryTraceLevel.NESTED);

try {
final Optional<String> instanceEndpoint = Arrays.stream(node)
.filter(x -> rdsUtils.isRdsInstance(rdsUtils.removePort(x)))
.findFirst();
if (!instanceEndpoint.isPresent()) {
return;
for (String key : keys) {
try {
final Queue<WeakReference<Connection>> connectionQueue = openedConnections.get(key);
logConnectionQueue(key, connectionQueue);
invalidateConnections(connectionQueue);
} catch (Exception ex) {
// ignore and continue
}
}
final Queue<WeakReference<Connection>> connectionQueue = openedConnections.get(instanceEndpoint.get());
logConnectionQueue(instanceEndpoint.get(), connectionQueue);
invalidateConnections(openedConnections.get(instanceEndpoint.get()));

} finally {
telemetryContext.closeContext();
}
Expand Down Expand Up @@ -144,7 +150,6 @@ private void trackConnection(final String instanceEndpoint, final Connection con
instanceEndpoint,
(k) -> new ConcurrentLinkedQueue<>());
connectionQueue.add(new WeakReference<>(connection));
logOpenedConnections();
}

private void invalidateConnections(final Queue<WeakReference<Connection>> connectionQueue) {
Expand All @@ -157,7 +162,7 @@ private void invalidateConnections(final Queue<WeakReference<Connection>> connec
}

try {
conn.abort(abortConnectionExecutorService);
conn.abort(abortConnectionExecutor);
} catch (final SQLException e) {
// swallow this exception, current connection should be useless anyway.
}
Expand Down Expand Up @@ -204,7 +209,10 @@ public void pruneNullConnections() {
if (conn == null) {
return true;
}
if (conn.getClass().getSimpleName().equals("HikariProxyConnection")) {
// The following classes do not check connection validity by calling a DB server
// so it's safe to check whether connection is already closed.
if (safeCheckIfClosed.contains(conn.getClass().getSimpleName())
|| safeCheckIfClosed.contains(conn.getClass().getName())) {
try {
return conn.isClosed();
} catch (SQLException ex) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@

package software.amazon.jdbc.targetdriverdialect;

import com.mysql.cj.jdbc.MysqlConnectionPoolDataSource;
import com.mysql.cj.jdbc.MysqlDataSource;
import java.sql.DriverManager;
import java.sql.SQLException;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,6 @@ AuroraStaleDnsHelper.staleDnsDetected=Stale DNS data detected. Opening a connect
AuroraStaleDnsHelper.reset=Reset stored writer host.

# Opened Connection Tracker
OpenedConnectionTracker.unableToPopulateOpenedConnectionQueue=The driver is unable to track this opened connection because the instance endpoint is unknown: ''{0}''
OpenedConnectionTracker.invalidatingConnections=Invalidating opened connections to host: ''{0}''

# Util
Expand Down

0 comments on commit be6deb0

Please sign in to comment.