Skip to content

Commit

Permalink
ZOOKEEPER-4236 Java Client SendThread create many unnecessary Login o…
Browse files Browse the repository at this point in the history
…bjects (#2128)

(cherry picked from commit 16abb9c)
Signed-off-by: Andor Molnar <[email protected]>
  • Loading branch information
anmolnar committed Feb 6, 2024
1 parent 3a97437 commit 886ba84
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 38 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import java.util.concurrent.LinkedBlockingDeque;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.ThreadLocalRandom;
import java.util.concurrent.atomic.AtomicReference;
import javax.security.auth.login.LoginException;
import javax.security.sasl.SaslException;
import org.apache.jute.BinaryInputArchive;
Expand Down Expand Up @@ -869,6 +870,7 @@ class SendThread extends ZooKeeperThread {
private final ClientCnxnSocket clientCnxnSocket;
private boolean isFirstConnect = true;
private volatile ZooKeeperSaslClient zooKeeperSaslClient;
private final AtomicReference<Login> loginRef = new AtomicReference<>();

private String stripChroot(String serverPath) {
if (serverPath.startsWith(chrootPath)) {
Expand Down Expand Up @@ -1151,10 +1153,8 @@ private void startConnect(InetSocketAddress addr) throws IOException {
setName(getName().replaceAll("\\(.*\\)", "(" + hostPort + ")"));
if (clientConfig.isSaslClientEnabled()) {
try {
if (zooKeeperSaslClient != null) {
zooKeeperSaslClient.shutdown();
}
zooKeeperSaslClient = new ZooKeeperSaslClient(SaslServerPrincipal.getServerPrincipal(addr, clientConfig), clientConfig);
zooKeeperSaslClient = new ZooKeeperSaslClient(
SaslServerPrincipal.getServerPrincipal(addr, clientConfig), clientConfig, loginRef);
} catch (LoginException e) {
// An authentication error occurred when the SASL client tried to initialize:
// for Kerberos this means that the client failed to authenticate with the KDC.
Expand Down Expand Up @@ -1322,8 +1322,9 @@ public void run() {
}
eventThread.queueEvent(new WatchedEvent(Event.EventType.None, Event.KeeperState.Closed, null));

if (zooKeeperSaslClient != null) {
zooKeeperSaslClient.shutdown();
Login l = loginRef.getAndSet(null);
if (l != null) {
l.shutdown();
}
ZooTrace.logTraceMessage(
LOG,
Expand Down Expand Up @@ -1501,6 +1502,11 @@ public void sendPacket(Packet p) throws IOException {
public ZooKeeperSaslClient getZooKeeperSaslClient() {
return zooKeeperSaslClient;
}

// VisibleForTesting
Login getLogin() {
return loginRef.get();
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import java.io.IOException;
import java.security.PrivilegedActionException;
import java.security.PrivilegedExceptionAction;
import java.util.concurrent.atomic.AtomicReference;
import javax.security.auth.Subject;
import javax.security.auth.login.AppConfigurationEntry;
import javax.security.auth.login.Configuration;
Expand Down Expand Up @@ -64,7 +65,6 @@ public class ZooKeeperSaslClient {
*/
@Deprecated
public static final String ENABLE_CLIENT_SASL_DEFAULT = "true";
private volatile boolean initializedLogin = false;

/**
* Returns true if the SASL client is enabled. By default, the client
Expand Down Expand Up @@ -112,7 +112,7 @@ public String getLoginContext() {
return null;
}

public ZooKeeperSaslClient(final String serverPrincipal, ZKClientConfig clientConfig) throws LoginException {
public ZooKeeperSaslClient(final String serverPrincipal, ZKClientConfig clientConfig, AtomicReference<Login> loginRef) throws LoginException {
/**
* ZOOKEEPER-1373: allow system property to specify the JAAS
* configuration section that the zookeeper client should use.
Expand All @@ -136,7 +136,8 @@ public ZooKeeperSaslClient(final String serverPrincipal, ZKClientConfig clientCo
}
if (entries != null) {
this.configStatus = "Will attempt to SASL-authenticate using Login Context section '" + clientSection + "'";
this.saslClient = createSaslClient(serverPrincipal, clientSection);
this.saslClient = createSaslClient(serverPrincipal, clientSection, loginRef);
this.login = loginRef.get();
} else {
// Handle situation of clientSection's being null: it might simply because the client does not intend to
// use SASL, so not necessarily an error.
Expand Down Expand Up @@ -234,26 +235,25 @@ public void processResult(int rc, String path, Object ctx, byte[] data, Stat sta

private SaslClient createSaslClient(
final String servicePrincipal,
final String loginContext) throws LoginException {
final String loginContext,
final AtomicReference<Login> loginRef) throws LoginException {
try {
if (!initializedLogin) {
synchronized (this) {
if (login == null) {
LOG.debug("JAAS loginContext is: {}", loginContext);
// note that the login object is static: it's shared amongst all zookeeper-related connections.
// in order to ensure the login is initialized only once, it must be synchronized the code snippet.
login = new Login(loginContext, new SaslClientCallbackHandler(null, "Client"), clientConfig);
login.startThreadIfNeeded();
initializedLogin = true;
}
if (loginRef.get() == null) {
LOG.debug("JAAS loginContext is: {}", loginContext);
// note that the login object is static: it's shared amongst all zookeeper-related connections.
// in order to ensure the login is initialized only once, it must be synchronized the code snippet.
Login l = new Login(loginContext, new SaslClientCallbackHandler(null, "Client"), clientConfig);
if (loginRef.compareAndSet(null, l)) {
l.startThreadIfNeeded();
}
}
return SecurityUtils.createSaslClient(login.getSubject(), servicePrincipal, "zookeeper", "zk-sasl-md5", LOG, "Client");
return SecurityUtils.createSaslClient(loginRef.get().getSubject(),
servicePrincipal, "zookeeper", "zk-sasl-md5", LOG, "Client");
} catch (LoginException e) {
// We throw LoginExceptions...
throw e;
} catch (Exception e) {
// ..but consume (with a log message) all other types of exceptions.
// ...but consume (with a log message) all other types of exceptions.
LOG.error("Exception while trying to create SASL client.", e);
return null;
}
Expand Down Expand Up @@ -451,15 +451,4 @@ public boolean clientTunneledAuthenticationInProgress() {
return false;
}
}

/**
* close login thread if running
*/
public void shutdown() {
if (null != login) {
login.shutdown();
login = null;
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertSame;
import static org.junit.jupiter.api.Assertions.assertFalse;
import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
Expand Down Expand Up @@ -239,12 +240,9 @@ public void testThreadsShutdownOnAuthFailed() throws Exception {
assertNotNull(zooKeeperSaslClient);
sendThread.join(CONNECTION_TIMEOUT);
eventThread.join(CONNECTION_TIMEOUT);
Field loginField = zooKeeperSaslClient.getClass().getDeclaredField("login");
loginField.setAccessible(true);
Login login = (Login) loginField.get(zooKeeperSaslClient);
// If login is null, this means ZooKeeperSaslClient#shutdown method has been called which in turns
// means that Login#shutdown has been called.
assertNull(login);
assertNull(sendThread.getLogin());
assertFalse(sendThread.isAlive(), "SendThread did not shutdown after authFail");
assertFalse(eventThread.isAlive(), "EventThread did not shutdown after authFail");
} finally {
Expand All @@ -253,4 +251,37 @@ public void testThreadsShutdownOnAuthFailed() throws Exception {
}
}
}
}

@Test
public void testDisconnectNotCreatingLoginThread() throws Exception {
MyWatcher watcher = new MyWatcher();
ZooKeeper zk = null;
try {
zk = new ZooKeeper(hostPort, CONNECTION_TIMEOUT, watcher);
watcher.waitForConnected(CONNECTION_TIMEOUT);
zk.getData("/", false, null);

Field cnxnField = zk.getClass().getDeclaredField("cnxn");
cnxnField.setAccessible(true);
ClientCnxn clientCnxn = (ClientCnxn) cnxnField.get(zk);
Field sendThreadField = clientCnxn.getClass().getDeclaredField("sendThread");
sendThreadField.setAccessible(true);
SendThread sendThread = (SendThread) sendThreadField.get(clientCnxn);

Login l1 = sendThread.getLogin();
assertNotNull(l1);

stopServer();
watcher.waitForDisconnected(CONNECTION_TIMEOUT);
startServer();
watcher.waitForConnected(CONNECTION_TIMEOUT);
zk.getData("/", false, null);

assertSame("Login thread should not been recreated on disconnect", l1, sendThread.getLogin());
} finally {
if (zk != null) {
zk.close();
}
}
}
}

0 comments on commit 886ba84

Please sign in to comment.