diff --git a/utils/src/main/java/com/cloud/utils/nio/Link.java b/utils/src/main/java/com/cloud/utils/nio/Link.java index 4e68554eb499..a4fd27ecd800 100644 --- a/utils/src/main/java/com/cloud/utils/nio/Link.java +++ b/utils/src/main/java/com/cloud/utils/nio/Link.java @@ -64,23 +64,22 @@ public class Link { private final NioConnection _connection; private SelectionKey _key; private final ConcurrentLinkedQueue _writeQueue; - private ByteBuffer _readBuffer; - private ByteBuffer _plaintextBuffer; + private final ByteBuffer headerBuffer = ByteBuffer.allocate(4); // accumulates length header inside TLS + private ByteBuffer netBuffer; + private ByteBuffer appBuffer; + private ByteBuffer plainTextBuffer; + private int frameRemaining = -1; // remaining bytes for current frame (inside TLS) private Object _attach; - private boolean _readHeader; - private boolean _gotFollowingPacket; private SSLEngine _sslEngine; public Link(InetSocketAddress addr, NioConnection connection) { _addr = addr; _connection = connection; - _readBuffer = ByteBuffer.allocate(2048); _attach = null; _key = null; _writeQueue = new ConcurrentLinkedQueue(); - _readHeader = true; - _gotFollowingPacket = false; + plainTextBuffer = null; } public Link(Link link) { @@ -103,58 +102,82 @@ public void setKey(SelectionKey key) { public void setSSLEngine(SSLEngine sslEngine) { _sslEngine = sslEngine; + if (_sslEngine == null) { + netBuffer = null; + appBuffer = null; + headerBuffer.clear(); + frameRemaining = -1; + plainTextBuffer = null; + return; + } + final SSLSession s = _sslEngine.getSession(); + netBuffer = ByteBuffer.allocate(Math.max(s.getPacketBufferSize(), 16 * 1024)); + appBuffer = ByteBuffer.allocate(Math.max(s.getApplicationBufferSize(), 16 * 1024)); + headerBuffer.clear(); + frameRemaining = -1; + plainTextBuffer = null; } private static void doWrite(SocketChannel ch, ByteBuffer[] buffers, SSLEngine sslEngine) throws IOException { - SSLSession sslSession = sslEngine.getSession(); - ByteBuffer pkgBuf = ByteBuffer.allocate(sslSession.getPacketBufferSize() + 40); - SSLEngineResult engResult; - - ByteBuffer headBuf = ByteBuffer.allocate(4); + if (sslEngine == null) { + throw new IOException("SSLEngine not set"); + } + final SSLSession session = sslEngine.getSession(); + ByteBuffer netBuf = ByteBuffer.allocate(session.getPacketBufferSize()); + // Build app sequence: 4-byte length header + payload buffers int totalLen = 0; - for (ByteBuffer buffer : buffers) { - totalLen += buffer.limit(); - } - - int processedLen = 0; - while (processedLen < totalLen) { - headBuf.clear(); - pkgBuf.clear(); - engResult = sslEngine.wrap(buffers, pkgBuf); - if (engResult.getHandshakeStatus() != HandshakeStatus.FINISHED && engResult.getHandshakeStatus() != HandshakeStatus.NOT_HANDSHAKING && - engResult.getStatus() != SSLEngineResult.Status.OK) { - throw new IOException("SSL: SSLEngine return bad result! " + engResult); + for (ByteBuffer b : buffers) totalLen += b.remaining(); + ByteBuffer header = ByteBuffer.allocate(4); + header.putInt(totalLen).flip(); + + ByteBuffer[] appSeq = new ByteBuffer[buffers.length + 1]; + appSeq[0] = header; + for (int i = 0; i < buffers.length; i++) { + appSeq[i + 1] = buffers[i].duplicate(); + } + + while (true) { + // Check if all app buffers are fully consumed + boolean allDone = true; + for (ByteBuffer b : appSeq) { + if (b.hasRemaining()) { + allDone = false; + break; + } } + if (allDone) break; - processedLen = 0; - for (ByteBuffer buffer : buffers) { - processedLen += buffer.position(); + netBuf.clear(); + SSLEngineResult res; + try { + res = sslEngine.wrap(appSeq, netBuf); + } catch (SSLException e) { + throw new IOException("SSL wrap failed: " + e.getMessage(), e); } - - int dataRemaining = pkgBuf.position(); - int header = dataRemaining; - int headRemaining = 4; - pkgBuf.flip(); - if (processedLen < totalLen) { - header = header | HEADER_FLAG_FOLLOWING; + switch (res.getStatus()) { + case OK: + netBuf.flip(); + while (netBuf.hasRemaining()) { + ch.write(netBuf); // may be partial, loop until drained + } + break; + case BUFFER_OVERFLOW: + netBuf = enlargeBuffer(netBuf, session.getPacketBufferSize()); + break; + case CLOSED: + throw new IOException("SSLEngine is CLOSED during write"); + default: + throw new IOException("Unexpected SSLEngineResult status on wrap: " + res.getStatus()); } - headBuf.putInt(header); - headBuf.flip(); - - while (headRemaining > 0) { - if (LOGGER.isTraceEnabled()) { - LOGGER.trace("Writing Header " + headRemaining); - } - long count = ch.write(headBuf); - headRemaining -= count; + // Drain delegated tasks if any + if (res.getHandshakeStatus() == HandshakeStatus.NEED_TASK) { + Runnable task; + while ((task = sslEngine.getDelegatedTask()) != null) task.run(); } - while (dataRemaining > 0) { - if (LOGGER.isTraceEnabled()) { - LOGGER.trace("Writing Data " + dataRemaining); - } - long count = ch.write(pkgBuf); - dataRemaining -= count; + if (res.getHandshakeStatus() == HandshakeStatus.NEED_UNWRAP) { + // Unusual during application writes; upper layer should drive handshake + break; } } } @@ -174,116 +197,88 @@ public static void write(SocketChannel ch, ByteBuffer[] buffers, SSLEngine sslEn } } - /* SSL has limitation of 16k, we may need to split packets. 18000 is 16k + some extra SSL informations */ - protected static final int MAX_SIZE_PER_PACKET = 18000; - protected static final int HEADER_FLAG_FOLLOWING = 0x10000; - public byte[] read(SocketChannel ch) throws IOException { - if (_readHeader) { // Start of a packet - if (_readBuffer.position() == 0) { - _readBuffer.limit(4); - } - - if (ch.read(_readBuffer) == -1) { - throw new IOException("Connection closed with -1 on reading size."); - } - - if (_readBuffer.hasRemaining()) { - LOGGER.trace("Need to read the rest of the packet length"); - return null; - } - _readBuffer.flip(); - int header = _readBuffer.getInt(); - int readSize = (short)header; - if (LOGGER.isTraceEnabled()) { - LOGGER.trace("Packet length is " + readSize); - } - - if (readSize > MAX_SIZE_PER_PACKET) { - throw new IOException("Wrong packet size: " + readSize); - } - - if (!_gotFollowingPacket) { - _plaintextBuffer = ByteBuffer.allocate(2000); - } - - if ((header & HEADER_FLAG_FOLLOWING) != 0) { - _gotFollowingPacket = true; - } else { - _gotFollowingPacket = false; - } - - _readBuffer.clear(); - _readHeader = false; - - if (_readBuffer.capacity() < readSize) { - if (LOGGER.isTraceEnabled()) { - LOGGER.trace("Resizing the byte buffer from " + _readBuffer.capacity()); - } - _readBuffer = ByteBuffer.allocate(readSize); - } - _readBuffer.limit(readSize); + if (_sslEngine == null) { + throw new IOException("SSLEngine not set"); } - - if (ch.read(_readBuffer) == -1) { + if (ch.read(netBuffer) == -1) { throw new IOException("Connection closed with -1 on read."); } - - if (_readBuffer.hasRemaining()) { // We're not done yet. - if (LOGGER.isTraceEnabled()) { - LOGGER.trace("Still has " + _readBuffer.remaining()); - } - return null; - } - - _readBuffer.flip(); - - ByteBuffer appBuf; - - SSLSession sslSession = _sslEngine.getSession(); - SSLEngineResult engResult; - int remaining = 0; - - while (_readBuffer.hasRemaining()) { - remaining = _readBuffer.remaining(); - appBuf = ByteBuffer.allocate(sslSession.getApplicationBufferSize() + 40); - engResult = _sslEngine.unwrap(_readBuffer, appBuf); - if (engResult.getHandshakeStatus() != HandshakeStatus.FINISHED && engResult.getHandshakeStatus() != HandshakeStatus.NOT_HANDSHAKING && - engResult.getStatus() != SSLEngineResult.Status.OK) { - throw new IOException("SSL: SSLEngine return bad result! " + engResult); - } - if (remaining == _readBuffer.remaining()) { - throw new IOException("SSL: Unable to unwrap received data! still remaining " + remaining + "bytes!"); - } - - appBuf.flip(); - if (_plaintextBuffer.remaining() < appBuf.limit()) { - // We need to expand _plaintextBuffer for more data - ByteBuffer newBuffer = ByteBuffer.allocate(_plaintextBuffer.capacity() + appBuf.limit() * 5); - _plaintextBuffer.flip(); - newBuffer.put(_plaintextBuffer); - _plaintextBuffer = newBuffer; + netBuffer.flip(); + while (netBuffer.hasRemaining()) { + SSLEngineResult res; + try { + res = _sslEngine.unwrap(netBuffer, appBuffer); + } catch (SSLException e) { + throw new IOException("SSL unwrap failed: " + e.getMessage(), e); } - _plaintextBuffer.put(appBuf); - if (LOGGER.isTraceEnabled()) { - LOGGER.trace("Done with packet: " + appBuf.limit()); + switch (res.getStatus()) { + case OK: + appBuffer.flip(); + while (appBuffer.hasRemaining()) { + if (frameRemaining < 0) { + int need = 4 - headerBuffer.position(); + int take = Math.min(need, appBuffer.remaining()); + int oldLimit = appBuffer.limit(); + appBuffer.limit(appBuffer.position() + take); + headerBuffer.put(appBuffer); + appBuffer.limit(oldLimit); + if (headerBuffer.position() < 4) break; + headerBuffer.flip(); + frameRemaining = headerBuffer.getInt(); + headerBuffer.clear(); + if (frameRemaining < 0) { + throw new IOException("Negative frame length"); + } + if (plainTextBuffer == null || plainTextBuffer.capacity() < frameRemaining) { + plainTextBuffer = ByteBuffer.allocate(Math.max(frameRemaining, 2048)); + } + plainTextBuffer.clear(); + } else { + int toCopy = Math.min(frameRemaining, appBuffer.remaining()); + if (plainTextBuffer.remaining() < toCopy) { + ByteBuffer newBuffer = ByteBuffer.allocate(plainTextBuffer.capacity() + Math.max(toCopy, 4096)); + plainTextBuffer.flip(); + newBuffer.put(plainTextBuffer); + plainTextBuffer = newBuffer; + } + int oldLimit = appBuffer.limit(); + appBuffer.limit(appBuffer.position() + toCopy); + plainTextBuffer.put(appBuffer); + appBuffer.limit(oldLimit); + frameRemaining -= toCopy; + if (frameRemaining == 0) { + plainTextBuffer.flip(); + byte[] result = new byte[plainTextBuffer.remaining()]; + plainTextBuffer.get(result); + appBuffer.compact(); + netBuffer.compact(); + frameRemaining = -1; + return result; + } + } + } + appBuffer.compact(); + break; + case BUFFER_OVERFLOW: + appBuffer = enlargeBuffer(appBuffer, _sslEngine.getSession().getApplicationBufferSize()); + break; + case BUFFER_UNDERFLOW: + netBuffer = handleBufferUnderflow(_sslEngine, netBuffer); + netBuffer.compact(); + return null; + case CLOSED: + throw new IOException("SSLEngine closed during read"); + default: + throw new IOException("Unexpected SSLEngineResult status on unwrap: " + res.getStatus()); } - } - - _readBuffer.clear(); - _readHeader = true; - - if (!_gotFollowingPacket) { - _plaintextBuffer.flip(); - byte[] result = new byte[_plaintextBuffer.limit()]; - _plaintextBuffer.get(result); - return result; - } else { - if (LOGGER.isTraceEnabled()) { - LOGGER.trace("Waiting for more packets"); + if (res.getHandshakeStatus() == HandshakeStatus.NEED_TASK) { + Runnable task; + while ((task = _sslEngine.getDelegatedTask()) != null) task.run(); } - return null; } + netBuffer.compact(); + return null; } public void send(byte[] data) throws ClosedChannelException { @@ -295,19 +290,14 @@ public void send(byte[] data, boolean close) throws ClosedChannelException { } public void send(ByteBuffer[] data, boolean close) throws ClosedChannelException { - ByteBuffer[] item = new ByteBuffer[data.length + 1]; + ByteBuffer[] item = new ByteBuffer[data.length]; int remaining = 0; for (int i = 0; i < data.length; i++) { remaining += data[i].remaining(); - item[i + 1] = data[i]; + item[i] = data[i]; } - - item[0] = ByteBuffer.allocate(4); - item[0].putInt(remaining); - item[0].flip(); - if (LOGGER.isTraceEnabled()) { - LOGGER.trace("Sending packet of length " + remaining); + LOGGER.trace("Sending framed message of length " + remaining); } _writeQueue.add(item); @@ -341,11 +331,7 @@ public boolean write(SocketChannel ch) throws IOException { } return true; } - - ByteBuffer[] raw_data = new ByteBuffer[data.length - 1]; - System.arraycopy(data, 1, raw_data, 0, data.length - 1); - - doWrite(ch, raw_data, _sslEngine); + doWrite(ch, data, _sslEngine); } return false; } @@ -376,7 +362,7 @@ public static KeyStore loadKeyStore(final InputStream stream, final char[] passp } public static SSLEngine initServerSSLEngine(final CAService caService, final String clientAddress) throws GeneralSecurityException, IOException { - final SSLContext sslContext = SSLUtils.getSSLContext(); + final SSLContext sslContext = SSLUtils.getSSLContextWithLatestProtocolVersion(); if (caService != null) { return caService.createSSLEngine(sslContext, clientAddress); } @@ -405,7 +391,7 @@ public static SSLContext initManagementSSLContext(final CAService caService) thr final KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509"); kmf.init(ks, passphrase); - final SSLContext sslContext = SSLUtils.getSSLContext(); + final SSLContext sslContext = SSLUtils.getSSLContextWithLatestProtocolVersion(); sslContext.init(kmf.getKeyManagers(), tms, new SecureRandom()); return sslContext; } @@ -449,7 +435,7 @@ public static SSLContext initClientSSLContext() throws GeneralSecurityException, final KeyManagerFactory kmf = KeyManagerFactory.getInstance("SunX509"); kmf.init(ks, passphrase); - final SSLContext sslContext = SSLUtils.getSSLContext(); + final SSLContext sslContext = SSLUtils.getSSLContextWithLatestProtocolVersion(); sslContext.init(kmf.getKeyManagers(), tms, new SecureRandom()); return sslContext; } diff --git a/utils/src/main/java/com/cloud/utils/nio/NioClient.java b/utils/src/main/java/com/cloud/utils/nio/NioClient.java index d274973a6584..85f65318237f 100644 --- a/utils/src/main/java/com/cloud/utils/nio/NioClient.java +++ b/utils/src/main/java/com/cloud/utils/nio/NioClient.java @@ -74,7 +74,8 @@ protected void init() throws IOException { if (!Link.doHandshake(clientConnection, sslEngine, getSslHandshakeTimeout())) { throw new IOException(String.format("SSL Handshake failed while connecting to host: %s", hostLog)); } - logger.info("SSL: Handshake done"); + logger.info("SSL: Handshake done with {} protocol: {}, cipher suite: {}", + serverAddress, sslEngine.getSession().getProtocol(), sslEngine.getSession().getCipherSuite()); final Link link = new Link(serverAddress, this); link.setSSLEngine(sslEngine); diff --git a/utils/src/main/java/com/cloud/utils/nio/NioConnection.java b/utils/src/main/java/com/cloud/utils/nio/NioConnection.java index 8e1a208a164e..7bf895f1d1f5 100644 --- a/utils/src/main/java/com/cloud/utils/nio/NioConnection.java +++ b/utils/src/main/java/com/cloud/utils/nio/NioConnection.java @@ -274,7 +274,9 @@ protected void accept(final SelectionKey key) throws IOException { if (!Link.doHandshake(socketChannel, sslEngine, getSslHandshakeTimeout())) { throw new IOException("SSL handshake timed out with " + socketAddress); } - logger.trace("SSL: Handshake done"); + logger.trace("SSL: Handshake done with {} protocol: {}, cipher suite: {}", + socketAddress, sslEngine.getSession().getProtocol(), + sslEngine.getSession().getCipherSuite()); final Link link = new Link(socketAddress, nioConnection); link.setSSLEngine(sslEngine); link.setKey(socketChannel.register(key.selector(), SelectionKey.OP_READ, link)); diff --git a/utils/src/main/java/org/apache/cloudstack/utils/security/SSLUtils.java b/utils/src/main/java/org/apache/cloudstack/utils/security/SSLUtils.java index eeebefab2219..f0741ce27184 100644 --- a/utils/src/main/java/org/apache/cloudstack/utils/security/SSLUtils.java +++ b/utils/src/main/java/org/apache/cloudstack/utils/security/SSLUtils.java @@ -70,6 +70,10 @@ public static SSLContext getSSLContext() throws NoSuchAlgorithmException { return SSLContext.getInstance("TLSv1.2"); } + public static SSLContext getSSLContextWithLatestProtocolVersion() throws NoSuchAlgorithmException { + return SSLContext.getInstance("TLSv1.3"); + } + public static SSLContext getSSLContext(String provider) throws NoSuchAlgorithmException, NoSuchProviderException { return SSLContext.getInstance("TLSv1.2", provider); }