From 54af47133f21e6bfe8eb1dc1aea9adced4911e3c Mon Sep 17 00:00:00 2001 From: TwistedTwigleg Date: Tue, 22 Nov 2022 10:19:01 -0500 Subject: [PATCH] Java keystore and Base64 support (#552) Add support for making a TlsContextOptions using a Java keystore, and bind Base64 encoding and decoding support. --- .../awssdk/crt/io/TlsContextOptions.java | 45 ++++++++ .../amazon/awssdk/crt/utils/StringUtils.java | 21 ++++ src/native/string_utils.c | 92 +++++++++++++++ .../awssdk/crt/test/StringUtilsTest.java | 106 +++++++++++++++++- 4 files changed, 263 insertions(+), 1 deletion(-) create mode 100644 src/native/string_utils.c diff --git a/src/main/java/software/amazon/awssdk/crt/io/TlsContextOptions.java b/src/main/java/software/amazon/awssdk/crt/io/TlsContextOptions.java index e8eee87ef..4e6817ab3 100644 --- a/src/main/java/software/amazon/awssdk/crt/io/TlsContextOptions.java +++ b/src/main/java/software/amazon/awssdk/crt/io/TlsContextOptions.java @@ -356,6 +356,51 @@ public static TlsContextOptions createWithMtlsWindowsCertStorePath(String certif return options; } + /** + * Helper which creates mutual TLS (mTLS) options using a certificate and private key + * stored in a Java keystore. + * Will throw an exception if there is no certificate and key at the given certificate alias, or there is some other + * error accessing or using the passed-in Java keystore. + * + * Note: function assumes the passed keystore has already been loaded from a file by calling "keystore.load()" or similar. + * + * @param keyStore The Java keystore to use. Assumed to be loaded with the desired certificate and key + * @param certificateAlias The alias of the certificate and key to use. + * @param certificatePassword The password of the certificate and key to use. + * @throws CrtRuntimeException if the certificate alias does not exist or the certificate/key cannot be found in the certificate alias + * @return A set of options for setting up an mTLS connection + */ + public static TlsContextOptions createWithMtlsJavaKeystore( + java.security.KeyStore keyStore, String certificateAlias, String certificatePassword) { + + TlsContextOptions options = new TlsContextOptions(); + String certificate; + try { + java.security.cert.Certificate certificateData = keyStore.getCertificate(certificateAlias); + if (certificateData == null) { + throw new CrtRuntimeException("Certificate at given certificate alias does not exist or does not contain a certificate"); + } + String certificateString = new String(StringUtils.base64Encode(certificateData.getEncoded())); + certificate = "-----BEGIN CERTIFICATE-----\n" + certificateString + "-----END CERTIFICATE-----\n"; + } catch (java.security.KeyStoreException | java.security.cert.CertificateEncodingException ex) { + throw new RuntimeException("Failed to get certificate from Java keystore", ex); + } + String privateKey; + try { + java.security.Key keyData = keyStore.getKey(certificateAlias, certificatePassword.toCharArray()); + if (keyData == null) { + throw new CrtRuntimeException("Private key at given certificate alias does not exist or does not identify a key-related entity"); + } + String keyString = new String(StringUtils.base64Encode(keyData.getEncoded())); + privateKey = "-----BEGIN RSA PRIVATE KEY-----\n" + keyString + "-----END RSA PRIVATE KEY-----\n"; + } catch (java.security.KeyStoreException | java.security.NoSuchAlgorithmException | java.security.UnrecoverableKeyException ex) { + throw new RuntimeException("Failed to get private key from Java keystore", ex); + } + options.initMtls(certificate, privateKey); + options.verifyPeer = true; + return options; + } + /******************************************************************************* * .with() methods ******************************************************************************/ diff --git a/src/main/java/software/amazon/awssdk/crt/utils/StringUtils.java b/src/main/java/software/amazon/awssdk/crt/utils/StringUtils.java index 3f5a0cfac..2ad753e0f 100644 --- a/src/main/java/software/amazon/awssdk/crt/utils/StringUtils.java +++ b/src/main/java/software/amazon/awssdk/crt/utils/StringUtils.java @@ -23,4 +23,25 @@ public static String join(CharSequence delimiter, Iterable + +#include +#include + +#include "crt.h" + +JNIEXPORT +jbyteArray JNICALL Java_software_amazon_awssdk_crt_utils_StringUtils_stringUtilsBase64Encode( + JNIEnv *env, + jclass jni_class, + jbyteArray jni_data) { + (void)jni_class; + + struct aws_byte_cursor data_cursor; + AWS_ZERO_STRUCT(data_cursor); + struct aws_byte_buf formatted_data; + AWS_ZERO_STRUCT(formatted_data); + jbyteArray return_data = NULL; + + data_cursor = aws_jni_byte_cursor_from_jbyteArray_acquire(env, jni_data); + if (data_cursor.ptr == NULL) { + return return_data; + } + + // Determine how much space we need + size_t terminated_length = 0; + if (aws_base64_compute_encoded_len(data_cursor.len, &terminated_length) != AWS_OP_SUCCESS) { + aws_jni_throw_runtime_exception(env, "StringUtils: Could not determine length for base64 encode"); + goto clean_up; + } + + aws_byte_buf_init(&formatted_data, aws_jni_get_allocator(), terminated_length); + int result = aws_base64_encode(&data_cursor, &formatted_data); + if (result != AWS_OP_SUCCESS) { + aws_jni_throw_runtime_exception(env, "StringUtils: Could not perform base64 encode"); + goto clean_up; + } + + struct aws_byte_cursor formatted_data_cursor = aws_byte_cursor_from_buf(&formatted_data); + return_data = aws_jni_byte_array_from_cursor(env, &formatted_data_cursor); + +clean_up: + aws_jni_byte_cursor_from_jbyteArray_release(env, jni_data, data_cursor); + aws_byte_buf_clean_up_secure(&formatted_data); + return return_data; +} + +JNIEXPORT +jbyteArray JNICALL Java_software_amazon_awssdk_crt_utils_StringUtils_stringUtilsBase64Decode( + JNIEnv *env, + jclass jni_class, + jbyteArray jni_data) { + (void)jni_class; + + struct aws_byte_cursor data_cursor; + AWS_ZERO_STRUCT(data_cursor); + struct aws_byte_buf formatted_data; + AWS_ZERO_STRUCT(formatted_data); + jbyteArray return_data = NULL; + + data_cursor = aws_jni_byte_cursor_from_jbyteArray_acquire(env, jni_data); + if (data_cursor.ptr == NULL) { + return NULL; + } + + // Determine how much space we need + size_t terminated_length = 0; + if (aws_base64_compute_decoded_len(&data_cursor, &terminated_length) != AWS_OP_SUCCESS) { + aws_jni_throw_runtime_exception(env, "StringUtils: Could not determine length for base64 decode"); + goto clean_up; + } + + aws_byte_buf_init(&formatted_data, aws_jni_get_allocator(), terminated_length); + int result = aws_base64_decode(&data_cursor, &formatted_data); + if (result != AWS_OP_SUCCESS) { + aws_jni_throw_runtime_exception(env, "StringUtils: Could not perform base64 decode"); + goto clean_up; + } + + struct aws_byte_cursor formatted_data_cursor = aws_byte_cursor_from_buf(&formatted_data); + return_data = aws_jni_byte_array_from_cursor(env, &formatted_data_cursor); + +clean_up: + aws_jni_byte_cursor_from_jbyteArray_release(env, jni_data, data_cursor); + aws_byte_buf_clean_up_secure(&formatted_data); + return return_data; +} diff --git a/src/test/java/software/amazon/awssdk/crt/test/StringUtilsTest.java b/src/test/java/software/amazon/awssdk/crt/test/StringUtilsTest.java index 4cadb8952..5084590de 100644 --- a/src/test/java/software/amazon/awssdk/crt/test/StringUtilsTest.java +++ b/src/test/java/software/amazon/awssdk/crt/test/StringUtilsTest.java @@ -1,8 +1,11 @@ package software.amazon.awssdk.crt.test; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; +import static org.junit.Assert.assertTrue; import org.junit.Test; +import org.junit.function.ThrowingRunnable; import java.util.ArrayList; import java.util.List; @@ -10,7 +13,7 @@ import software.amazon.awssdk.crt.utils.StringUtils; -public class StringUtilsTest { +public class StringUtilsTest extends CrtTestFixture { @Test public void testJoin() { @@ -20,4 +23,105 @@ public void testJoin() { alpns.add("two"); assertEquals("one;two", StringUtils.join(";", alpns)); } + + @Test + public void testBase64EncodeEmpty() { + assertEquals("", new String(StringUtils.base64Encode("".getBytes()))); + } + + @Test + public void testBase64EncodeNull() { + ThrowingRunnable test_runnable = new ThrowingRunnable() { + public void run() { + StringUtils.base64Encode(null); + } + }; + assertThrows(NullPointerException.class, test_runnable); + } + + @Test + public void testBase64EncodeCaseFoobar() { + assertEquals("Zm9vYmFy", new String(StringUtils.base64Encode("foobar".getBytes()))); + } + + @Test + public void testBase64EncodeExtremelyLargeString() { + StringBuilder test_input = new StringBuilder(); + for (int i = 0; i < 50000; i++) { + test_input.append('A'); + } + byte[] output = StringUtils.base64Encode(test_input.toString().getBytes()); + assertTrue(output != null); + } + + @Test + public void testBase64EncodeCaseAllValues() { + byte[] data = new byte[255]; + for (int i = 0; i < 255; i++) { + data[i] = (byte)(i); + } + + String expected = "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8gISIjJCUmJygpKissLS4vMDEyMzQ1Njc4OTo7PD0+P0BBQkNERU"; + expected += "ZHSElKS0xNTk9QUVJTVFVWV1hZWltcXV5fYGFiY2RlZmdoaWprbG1ub3BxcnN0dXZ3eHl6e3x9fn+AgYKDhIWGh4iJiouM"; + expected += "jY6PkJGSk5SVlpeYmZqbnJ2en6ChoqOkpaanqKmqq6ytrq+wsbKztLW2t7i5uru8vb6/wMHCw8TFxsfIycrLzM3Oz9DR0t"; + expected += "PU1dbX2Nna29zd3t/g4eLj5OXm5+jp6uvs7e7v8PHy8/T19vf4+fr7/P3+"; + + assertEquals(expected, new String(StringUtils.base64Encode(data))); + } + + @Test + public void testBase64DecodeEmpty() { + assertEquals("", new String(StringUtils.base64Decode("".getBytes()))); + } + + @Test + public void testBase64DecodeNull() { + ThrowingRunnable test_runnable = new ThrowingRunnable() { + public void run() { + StringUtils.base64Decode(null); + } + }; + assertThrows(NullPointerException.class, test_runnable); + } + + @Test + public void testBase64DecodeCaseFoobar() { + assertEquals("foobar", new String(StringUtils.base64Decode("Zm9vYmFy".getBytes()))); + } + + @Test + public void testBase64DecodeExtremelyLargeString() { + StringBuilder test_input = new StringBuilder(); + for (int i = 0; i < 50000; i++) { + test_input.append('A'); + } + byte[] output = StringUtils.base64Decode(test_input.toString().getBytes()); + assertTrue(output != null); + } + + @Test + public void testBase64DecodeCaseAllValues() { + byte[] data = new byte[255]; + for (int i = 0; i < 255; i++) { + data[i] = (byte)(i); + } + + String input = "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8gISIjJCUmJygpKissLS4vMDEyMzQ1Njc4OTo7PD0+P0BBQkNERU"; + input += "ZHSElKS0xNTk9QUVJTVFVWV1hZWltcXV5fYGFiY2RlZmdoaWprbG1ub3BxcnN0dXZ3eHl6e3x9fn+AgYKDhIWGh4iJiouM"; + input += "jY6PkJGSk5SVlpeYmZqbnJ2en6ChoqOkpaanqKmqq6ytrq+wsbKztLW2t7i5uru8vb6/wMHCw8TFxsfIycrLzM3Oz9DR0t"; + input += "PU1dbX2Nna29zd3t/g4eLj5OXm5+jp6uvs7e7v8PHy8/T19vf4+fr7/P3+"; + + String expected = new String(data); + + assertEquals(expected, new String(StringUtils.base64Decode(input.getBytes()))); + } + + @Test + public void testBase64CaseFoobarRoundTrop() { + String data = "foobar"; + data = new String(StringUtils.base64Encode(data.getBytes())); + assertEquals("Zm9vYmFy", data); + data = new String(StringUtils.base64Decode(data.getBytes())); + assertEquals("foobar", data); + } }