diff --git a/src/main/java/com/basho/riak/client/core/RiakNode.java b/src/main/java/com/basho/riak/client/core/RiakNode.java index 42c73a9f2..a2873faa9 100644 --- a/src/main/java/com/basho/riak/client/core/RiakNode.java +++ b/src/main/java/com/basho/riak/client/core/RiakNode.java @@ -76,6 +76,7 @@ public enum State private volatile Bootstrap bootstrap; private volatile boolean ownsBootstrap; + private volatile RiakChannelInitializer riakChannelInitializer; private volatile ScheduledExecutorService executor; private volatile boolean ownsExecutor; private volatile State state; @@ -84,6 +85,7 @@ public enum State private volatile int minConnections; private volatile long idleTimeoutInNanos; private volatile int connectionTimeout; + private volatile int readTimeout; private volatile boolean blockOnMaxConnections; private HealthCheckFactory healthCheckFactory; @@ -175,6 +177,7 @@ private RiakNode(Builder builder) throws UnknownHostException this.executor = builder.executor; this.connectionTimeout = builder.connectionTimeout; this.idleTimeoutInNanos = TimeUnit.NANOSECONDS.convert(builder.idleTimeout, TimeUnit.MILLISECONDS); + this.readTimeout = builder.readTimeout; this.minConnections = builder.minConnections; this.port = builder.port; this.remoteAddress = builder.remoteAddress; @@ -242,7 +245,9 @@ public synchronized RiakNode start() ownsBootstrap = true; } - bootstrap.handler(new RiakChannelInitializer(this)) + + riakChannelInitializer = new RiakChannelInitializer(this, readTimeout); + bootstrap.handler(riakChannelInitializer) .remoteAddress(new InetSocketAddress(remoteAddress, port)); if (connectionTimeout > 0) @@ -516,6 +521,33 @@ public int getConnectionTimeout() return connectionTimeout; } + /** + * Sets the read timeout in milliseconds. + * + * @param readTimeoutInMillis the read timeout to set + * @return a reference to this RiakNode + * @see Builder#withReadTimeout(int) + */ + public RiakNode setReadTimeout(int readTimeoutInMillis) + { + stateCheck(State.CREATED, State.RUNNING, State.HEALTH_CHECKING); + this.readTimeout = readTimeoutInMillis; + riakChannelInitializer.setReadTimeout(readTimeout); + return this; + } + + /** + * Returns the read timeout in milliseconds. + * + * @return the readTimeout + * @see Builder#withReadTimeout(int) + */ + public int getReadTimeout() + { + stateCheck(State.CREATED, State.RUNNING, State.HEALTH_CHECKING); + return readTimeout; + } + /** * Returns the number of permits currently available. * The number of available permits indicates how many additional @@ -663,6 +695,7 @@ private Channel doGetConnection() throws ConnectionFailedException try { + logger.debug("Waiting for new connection from channel future to {}:{}", remoteAddress, port); f.await(); } catch (InterruptedException ex) @@ -680,12 +713,15 @@ private Channel doGetConnection() throws ConnectionFailedException consecutiveFailedConnectionAttempts.incrementAndGet(); throw new ConnectionFailedException(f.cause()); } + + logger.debug("Connection to {}:{} successful", remoteAddress, port); consecutiveFailedConnectionAttempts.set(0); Channel c = f.channel(); if (trustStore != null) { + logger.debug("trustStore set starting TLS"); SSLContext context; try { @@ -720,11 +756,12 @@ else if (protocols.contains("TLSv1.1")) } engine.setUseClientMode(true); - RiakSecurityDecoder decoder = new RiakSecurityDecoder(engine, username, password); + RiakSecurityDecoder decoder = new RiakSecurityDecoder(remoteAddress, port, engine, username, password); c.pipeline().addFirst(decoder); try { + logger.debug("Waiting for authentication to complete with {}:{}", remoteAddress, port); DefaultPromise promise = decoder.getPromise(); promise.await(); @@ -1199,6 +1236,12 @@ public static class Builder * @see #withConnectionTimeout(int) */ public final static int DEFAULT_CONNECTION_TIMEOUT = 0; + /** + * The default so timeout in milliseconds if not specified: {@value #DEFAULT_READ_TIMEOUT} + * + * @see #withReadTimeout(int) + */ + public final static int DEFAULT_READ_TIMEOUT = 0; /** * The default HealthCheckFactory. @@ -1216,6 +1259,7 @@ public static class Builder private int maxConnections = DEFAULT_MAX_CONNECTIONS; private int idleTimeout = DEFAULT_IDLE_TIMEOUT; private int connectionTimeout = DEFAULT_CONNECTION_TIMEOUT; + private int readTimeout = DEFAULT_READ_TIMEOUT; private HealthCheckFactory healthCheckFactory = DEFAULT_HEALTHCHECK_FACTORY; private Bootstrap bootstrap; private ScheduledExecutorService executor; @@ -1331,6 +1375,19 @@ public Builder withConnectionTimeout(int connectionTimeoutInMillis) return this; } + /** + * Set the read timeout used when waiting for a response on the underlying sockets + * + * @param readTimeoutMillis + * @return this + * @see #DEFAULT_READ_TIMEOUT + */ + public Builder withReadTimeout(int readTimeoutMillis) + { + this.readTimeout = readTimeoutMillis; + return this; + } + /** * Provides an executor for this node to use for internal maintenance tasks. * If not provided one will be created via diff --git a/src/main/java/com/basho/riak/client/core/netty/RiakChannelInitializer.java b/src/main/java/com/basho/riak/client/core/netty/RiakChannelInitializer.java index 728946699..6e26fcca2 100644 --- a/src/main/java/com/basho/riak/client/core/netty/RiakChannelInitializer.java +++ b/src/main/java/com/basho/riak/client/core/netty/RiakChannelInitializer.java @@ -20,6 +20,10 @@ import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelPipeline; import io.netty.channel.socket.SocketChannel; +import io.netty.handler.timeout.IdleStateHandler; +import io.netty.handler.timeout.ReadTimeoutHandler; + +import java.util.concurrent.TimeUnit; /** * @@ -29,10 +33,13 @@ public class RiakChannelInitializer extends ChannelInitializer { private final RiakResponseListener listener; - public RiakChannelInitializer(RiakResponseListener listener) + private volatile int readTimeout; + + public RiakChannelInitializer(RiakResponseListener listener, int readTimeoutMillis) { super(); this.listener = listener; + this.readTimeout = readTimeoutMillis; } @Override @@ -42,6 +49,16 @@ public void initChannel(SocketChannel ch) throws Exception p.addLast(Constants.MESSAGE_CODEC, new RiakMessageCodec()); p.addLast(Constants.OPERATION_ENCODER, new RiakOperationEncoder()); p.addLast(Constants.RESPONSE_HANDLER, new RiakResponseHandler(listener)); + p.addLast(Constants.READ_TIMEOUT_HANDLER, new ReadTimeoutHandler(readTimeout, TimeUnit.MILLISECONDS)); + } + + public int getReadTimeout() + { + return readTimeout; + } + + public void setReadTimeout(int readTimeoutMillis) + { + readTimeout = readTimeoutMillis; } - } diff --git a/src/main/java/com/basho/riak/client/core/netty/RiakSecurityDecoder.java b/src/main/java/com/basho/riak/client/core/netty/RiakSecurityDecoder.java index 069900081..3fa6306d2 100644 --- a/src/main/java/com/basho/riak/client/core/netty/RiakSecurityDecoder.java +++ b/src/main/java/com/basho/riak/client/core/netty/RiakSecurityDecoder.java @@ -49,14 +49,18 @@ public class RiakSecurityDecoder extends ByteToMessageDecoder private final String username; private final String password; private final Logger logger = LoggerFactory.getLogger(RiakSecurityDecoder.class); - private volatile DefaultPromise promise; + private final String remoteAddr; + private final int remotePort; + private volatile DefaultPromise promise; private enum State { TLS_START, TLS_WAIT, SSL_WAIT, AUTH_WAIT } private volatile State state = State.TLS_START; - public RiakSecurityDecoder(SSLEngine engine, String username, String password) + public RiakSecurityDecoder(String remoteAddress, int port, SSLEngine engine, String username, String password) { + this.remoteAddr = remoteAddress; + this.remotePort = port; this.sslEngine = engine; this.username = username; this.password = password; @@ -88,7 +92,7 @@ protected void decode(ChannelHandlerContext chc, ByteBuf in, List out) t switch(code) { case RiakMessageCodes.MSG_StartTls: - logger.debug("Received MSG_RpbStartTls reply"); + logger.debug("Received MSG_RpbStartTls reply from {}:{}", remoteAddr, remotePort); // change state this.state = State.SSL_WAIT; // insert SSLHandler @@ -101,10 +105,11 @@ protected void decode(ChannelHandlerContext chc, ByteBuf in, List out) t chc.channel().pipeline().addFirst(Constants.SSL_HANDLER, sslHandler); break; case RiakMessageCodes.MSG_ErrorResp: - logger.debug("Received MSG_ErrorResp reply to startTls"); + logger.debug("Received MSG_ErrorResp reply to startTls from {}:{}", remoteAddr, remotePort); promise.tryFailure((riakErrorToException(protobuf))); break; default: + logger.debug("Invalid return code during StartTLS from {}:{} code", remoteAddr, remotePort, code); promise.tryFailure(new RiakResponseException(0, "Invalid return code during StartTLS; " + code)); } @@ -114,21 +119,22 @@ protected void decode(ChannelHandlerContext chc, ByteBuf in, List out) t switch(code) { case RiakMessageCodes.MSG_AuthResp: - logger.debug("Received MSG_RpbAuthResp reply"); + logger.debug("Received MSG_RpbAuthResp reply from {}:{}", remoteAddr, remotePort); promise.trySuccess(null); break; case RiakMessageCodes.MSG_ErrorResp: - logger.debug("Received MSG_ErrorResp reply to auth"); + logger.debug("Received MSG_ErrorResp reply to Auth from {}:{}", remoteAddr, remotePort); promise.tryFailure(riakErrorToException(protobuf)); break; default: + logger.debug("Invalid return code during Auth from {}:{}", remoteAddr, remotePort); promise.tryFailure(new RiakResponseException(0, "Invalid return code during Auth; " + code)); } break; default: // WTF? - logger.error("Received message while not in TLS_WAIT or AUTH_WAIT"); + logger.error("Received message while not in TLS_WAIT or AUTH_WAIT from {}:{}", remoteAddr, remotePort); promise.tryFailure(new IllegalStateException("Received message while not in TLS_WAIT or AUTH_WAIT")); } } @@ -208,6 +214,7 @@ public void operationComplete(Future future) throws Exception { if (future.isSuccess()) { + logger.debug("SSLHandshake Completed with {}:{}. Authenticating.", remoteAddr, remotePort); Channel c = future.getNow(); state = State.AUTH_WAIT; RiakPB.RpbAuthReq authReq = @@ -221,6 +228,7 @@ public void operationComplete(Future future) throws Exception } else { + logger.warn("SSLHandshake Failed with {}:{}.", remoteAddr, remotePort, future.cause()); promise.tryFailure(future.cause()); } } diff --git a/src/main/java/com/basho/riak/client/core/util/Constants.java b/src/main/java/com/basho/riak/client/core/util/Constants.java index 117fad736..b081afce7 100644 --- a/src/main/java/com/basho/riak/client/core/util/Constants.java +++ b/src/main/java/com/basho/riak/client/core/util/Constants.java @@ -95,6 +95,7 @@ public interface Constants { public static final String MESSAGE_CODEC = "codec"; public static final String OPERATION_ENCODER = "operationEncoder"; public static final String RESPONSE_HANDLER = "responseHandler"; + public static final String READ_TIMEOUT_HANDLER = "readTimeoutHandler"; public static final String SSL_HANDLER = "sslHandler"; public static final String HEALTHCHECK_CODEC = "healthCheckCodec"; diff --git a/src/test/java/com/basho/riak/client/core/RiakNodeTest.java b/src/test/java/com/basho/riak/client/core/RiakNodeTest.java index cab00cbd6..cc4e494aa 100644 --- a/src/test/java/com/basho/riak/client/core/RiakNodeTest.java +++ b/src/test/java/com/basho/riak/client/core/RiakNodeTest.java @@ -63,6 +63,7 @@ public void builderProducesDefaultNode() throws UnknownHostException assertEquals(node.getMaxConnections(), Integer.MAX_VALUE); assertEquals(node.getConnectionTimeout(), RiakNode.Builder.DEFAULT_CONNECTION_TIMEOUT); assertEquals(node.getIdleTimeout(), RiakNode.Builder.DEFAULT_IDLE_TIMEOUT); + assertEquals(node.getReadTimeout(), RiakNode.Builder.DEFAULT_READ_TIMEOUT); assertEquals(node.getMinConnections(), RiakNode.Builder.DEFAULT_MIN_CONNECTIONS); assertEquals(node.availablePermits(), Integer.MAX_VALUE); } @@ -75,7 +76,7 @@ public void builderProducesCorrectNode() throws UnknownHostException final int MIN_CONNECTIONS = 2002; final int MAX_CONNECTIONS = 2003; final int PORT = 2004; - final int READ_TIMEOUT = 2005; + final int READ_TIMEOUT = 2006; final String REMOTE_ADDRESS = "localhost"; final ScheduledExecutorService EXECUTOR = Executors.newSingleThreadScheduledExecutor(); final Bootstrap BOOTSTRAP = PowerMockito.spy(new Bootstrap()); @@ -91,6 +92,7 @@ public void builderProducesCorrectNode() throws UnknownHostException .withRemoteAddress(REMOTE_ADDRESS) .withExecutor(EXECUTOR) .withBootstrap(BOOTSTRAP) + .withReadTimeout(READ_TIMEOUT) .build(); assertEquals(node.getRemoteAddress(), REMOTE_ADDRESS); @@ -103,6 +105,7 @@ public void builderProducesCorrectNode() throws UnknownHostException assertEquals(node.getRemoteAddress(), REMOTE_ADDRESS); assertEquals(node.availablePermits(), MAX_CONNECTIONS); assertEquals(node.getPort(), PORT); + assertEquals(node.getReadTimeout(), READ_TIMEOUT); }