Skip to content

Commit

Permalink
Java keystore and Base64 support (#552)
Browse files Browse the repository at this point in the history
Add support for making a TlsContextOptions using a Java keystore, and bind Base64 encoding and decoding support.
  • Loading branch information
TwistedTwigleg authored Nov 22, 2022
1 parent dc195bb commit 54af471
Show file tree
Hide file tree
Showing 4 changed files with 263 additions and 1 deletion.
45 changes: 45 additions & 0 deletions src/main/java/software/amazon/awssdk/crt/io/TlsContextOptions.java
Original file line number Diff line number Diff line change
Expand Up @@ -356,6 +356,51 @@ public static TlsContextOptions createWithMtlsWindowsCertStorePath(String certif
return options;
}

/**
* Helper which creates mutual TLS (mTLS) options using a certificate and private key
* stored in a Java keystore.
* Will throw an exception if there is no certificate and key at the given certificate alias, or there is some other
* error accessing or using the passed-in Java keystore.
*
* Note: function assumes the passed keystore has already been loaded from a file by calling "keystore.load()" or similar.
*
* @param keyStore The Java keystore to use. Assumed to be loaded with the desired certificate and key
* @param certificateAlias The alias of the certificate and key to use.
* @param certificatePassword The password of the certificate and key to use.
* @throws CrtRuntimeException if the certificate alias does not exist or the certificate/key cannot be found in the certificate alias
* @return A set of options for setting up an mTLS connection
*/
public static TlsContextOptions createWithMtlsJavaKeystore(
java.security.KeyStore keyStore, String certificateAlias, String certificatePassword) {

TlsContextOptions options = new TlsContextOptions();
String certificate;
try {
java.security.cert.Certificate certificateData = keyStore.getCertificate(certificateAlias);
if (certificateData == null) {
throw new CrtRuntimeException("Certificate at given certificate alias does not exist or does not contain a certificate");
}
String certificateString = new String(StringUtils.base64Encode(certificateData.getEncoded()));
certificate = "-----BEGIN CERTIFICATE-----\n" + certificateString + "-----END CERTIFICATE-----\n";
} catch (java.security.KeyStoreException | java.security.cert.CertificateEncodingException ex) {
throw new RuntimeException("Failed to get certificate from Java keystore", ex);
}
String privateKey;
try {
java.security.Key keyData = keyStore.getKey(certificateAlias, certificatePassword.toCharArray());
if (keyData == null) {
throw new CrtRuntimeException("Private key at given certificate alias does not exist or does not identify a key-related entity");
}
String keyString = new String(StringUtils.base64Encode(keyData.getEncoded()));
privateKey = "-----BEGIN RSA PRIVATE KEY-----\n" + keyString + "-----END RSA PRIVATE KEY-----\n";
} catch (java.security.KeyStoreException | java.security.NoSuchAlgorithmException | java.security.UnrecoverableKeyException ex) {
throw new RuntimeException("Failed to get private key from Java keystore", ex);
}
options.initMtls(certificate, privateKey);
options.verifyPeer = true;
return options;
}

/*******************************************************************************
* .with() methods
******************************************************************************/
Expand Down
21 changes: 21 additions & 0 deletions src/main/java/software/amazon/awssdk/crt/utils/StringUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,25 @@ public static String join(CharSequence delimiter, Iterable<? extends CharSequenc
}
return sb.toString();
}

/**
* Encode a byte array into a Base64 byte array.
* @param data The byte array to encode
* @return The byte array encoded as Byte64
*/
public static byte[] base64Encode(byte[] data) {
return stringUtilsBase64Encode(data);
}

/**
* Decode a Base64 byte array into a non-Base64 byte array.
* @param data The byte array to decode.
* @return Byte array decoded from Base64.
*/
public static byte[] base64Decode(byte[] data) {
return stringUtilsBase64Decode(data);
}

private static native byte[] stringUtilsBase64Encode(byte[] data_to_encode);
private static native byte[] stringUtilsBase64Decode(byte[] data_to_decode);
}
92 changes: 92 additions & 0 deletions src/native/string_utils.c
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
/**
* Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
* SPDX-License-Identifier: Apache-2.0.
*/
#include <jni.h>

#include <aws/common/encoding.h>
#include <aws/common/string.h>

#include "crt.h"

JNIEXPORT
jbyteArray JNICALL Java_software_amazon_awssdk_crt_utils_StringUtils_stringUtilsBase64Encode(
JNIEnv *env,
jclass jni_class,
jbyteArray jni_data) {
(void)jni_class;

struct aws_byte_cursor data_cursor;
AWS_ZERO_STRUCT(data_cursor);
struct aws_byte_buf formatted_data;
AWS_ZERO_STRUCT(formatted_data);
jbyteArray return_data = NULL;

data_cursor = aws_jni_byte_cursor_from_jbyteArray_acquire(env, jni_data);
if (data_cursor.ptr == NULL) {
return return_data;
}

// Determine how much space we need
size_t terminated_length = 0;
if (aws_base64_compute_encoded_len(data_cursor.len, &terminated_length) != AWS_OP_SUCCESS) {
aws_jni_throw_runtime_exception(env, "StringUtils: Could not determine length for base64 encode");
goto clean_up;
}

aws_byte_buf_init(&formatted_data, aws_jni_get_allocator(), terminated_length);
int result = aws_base64_encode(&data_cursor, &formatted_data);
if (result != AWS_OP_SUCCESS) {
aws_jni_throw_runtime_exception(env, "StringUtils: Could not perform base64 encode");
goto clean_up;
}

struct aws_byte_cursor formatted_data_cursor = aws_byte_cursor_from_buf(&formatted_data);
return_data = aws_jni_byte_array_from_cursor(env, &formatted_data_cursor);

clean_up:
aws_jni_byte_cursor_from_jbyteArray_release(env, jni_data, data_cursor);
aws_byte_buf_clean_up_secure(&formatted_data);
return return_data;
}

JNIEXPORT
jbyteArray JNICALL Java_software_amazon_awssdk_crt_utils_StringUtils_stringUtilsBase64Decode(
JNIEnv *env,
jclass jni_class,
jbyteArray jni_data) {
(void)jni_class;

struct aws_byte_cursor data_cursor;
AWS_ZERO_STRUCT(data_cursor);
struct aws_byte_buf formatted_data;
AWS_ZERO_STRUCT(formatted_data);
jbyteArray return_data = NULL;

data_cursor = aws_jni_byte_cursor_from_jbyteArray_acquire(env, jni_data);
if (data_cursor.ptr == NULL) {
return NULL;
}

// Determine how much space we need
size_t terminated_length = 0;
if (aws_base64_compute_decoded_len(&data_cursor, &terminated_length) != AWS_OP_SUCCESS) {
aws_jni_throw_runtime_exception(env, "StringUtils: Could not determine length for base64 decode");
goto clean_up;
}

aws_byte_buf_init(&formatted_data, aws_jni_get_allocator(), terminated_length);
int result = aws_base64_decode(&data_cursor, &formatted_data);
if (result != AWS_OP_SUCCESS) {
aws_jni_throw_runtime_exception(env, "StringUtils: Could not perform base64 decode");
goto clean_up;
}

struct aws_byte_cursor formatted_data_cursor = aws_byte_cursor_from_buf(&formatted_data);
return_data = aws_jni_byte_array_from_cursor(env, &formatted_data_cursor);

clean_up:
aws_jni_byte_cursor_from_jbyteArray_release(env, jni_data, data_cursor);
aws_byte_buf_clean_up_secure(&formatted_data);
return return_data;
}
106 changes: 105 additions & 1 deletion src/test/java/software/amazon/awssdk/crt/test/StringUtilsTest.java
Original file line number Diff line number Diff line change
@@ -1,16 +1,19 @@
package software.amazon.awssdk.crt.test;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThrows;
import static org.junit.Assert.assertTrue;

import org.junit.Test;
import org.junit.function.ThrowingRunnable;

import java.util.ArrayList;
import java.util.List;

import software.amazon.awssdk.crt.utils.StringUtils;


public class StringUtilsTest {
public class StringUtilsTest extends CrtTestFixture {

@Test
public void testJoin() {
Expand All @@ -20,4 +23,105 @@ public void testJoin() {
alpns.add("two");
assertEquals("one;two", StringUtils.join(";", alpns));
}

@Test
public void testBase64EncodeEmpty() {
assertEquals("", new String(StringUtils.base64Encode("".getBytes())));
}

@Test
public void testBase64EncodeNull() {
ThrowingRunnable test_runnable = new ThrowingRunnable() {
public void run() {
StringUtils.base64Encode(null);
}
};
assertThrows(NullPointerException.class, test_runnable);
}

@Test
public void testBase64EncodeCaseFoobar() {
assertEquals("Zm9vYmFy", new String(StringUtils.base64Encode("foobar".getBytes())));
}

@Test
public void testBase64EncodeExtremelyLargeString() {
StringBuilder test_input = new StringBuilder();
for (int i = 0; i < 50000; i++) {
test_input.append('A');
}
byte[] output = StringUtils.base64Encode(test_input.toString().getBytes());
assertTrue(output != null);
}

@Test
public void testBase64EncodeCaseAllValues() {
byte[] data = new byte[255];
for (int i = 0; i < 255; i++) {
data[i] = (byte)(i);
}

String expected = "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8gISIjJCUmJygpKissLS4vMDEyMzQ1Njc4OTo7PD0+P0BBQkNERU";
expected += "ZHSElKS0xNTk9QUVJTVFVWV1hZWltcXV5fYGFiY2RlZmdoaWprbG1ub3BxcnN0dXZ3eHl6e3x9fn+AgYKDhIWGh4iJiouM";
expected += "jY6PkJGSk5SVlpeYmZqbnJ2en6ChoqOkpaanqKmqq6ytrq+wsbKztLW2t7i5uru8vb6/wMHCw8TFxsfIycrLzM3Oz9DR0t";
expected += "PU1dbX2Nna29zd3t/g4eLj5OXm5+jp6uvs7e7v8PHy8/T19vf4+fr7/P3+";

assertEquals(expected, new String(StringUtils.base64Encode(data)));
}

@Test
public void testBase64DecodeEmpty() {
assertEquals("", new String(StringUtils.base64Decode("".getBytes())));
}

@Test
public void testBase64DecodeNull() {
ThrowingRunnable test_runnable = new ThrowingRunnable() {
public void run() {
StringUtils.base64Decode(null);
}
};
assertThrows(NullPointerException.class, test_runnable);
}

@Test
public void testBase64DecodeCaseFoobar() {
assertEquals("foobar", new String(StringUtils.base64Decode("Zm9vYmFy".getBytes())));
}

@Test
public void testBase64DecodeExtremelyLargeString() {
StringBuilder test_input = new StringBuilder();
for (int i = 0; i < 50000; i++) {
test_input.append('A');
}
byte[] output = StringUtils.base64Decode(test_input.toString().getBytes());
assertTrue(output != null);
}

@Test
public void testBase64DecodeCaseAllValues() {
byte[] data = new byte[255];
for (int i = 0; i < 255; i++) {
data[i] = (byte)(i);
}

String input = "AAECAwQFBgcICQoLDA0ODxAREhMUFRYXGBkaGxwdHh8gISIjJCUmJygpKissLS4vMDEyMzQ1Njc4OTo7PD0+P0BBQkNERU";
input += "ZHSElKS0xNTk9QUVJTVFVWV1hZWltcXV5fYGFiY2RlZmdoaWprbG1ub3BxcnN0dXZ3eHl6e3x9fn+AgYKDhIWGh4iJiouM";
input += "jY6PkJGSk5SVlpeYmZqbnJ2en6ChoqOkpaanqKmqq6ytrq+wsbKztLW2t7i5uru8vb6/wMHCw8TFxsfIycrLzM3Oz9DR0t";
input += "PU1dbX2Nna29zd3t/g4eLj5OXm5+jp6uvs7e7v8PHy8/T19vf4+fr7/P3+";

String expected = new String(data);

assertEquals(expected, new String(StringUtils.base64Decode(input.getBytes())));
}

@Test
public void testBase64CaseFoobarRoundTrop() {
String data = "foobar";
data = new String(StringUtils.base64Encode(data.getBytes()));
assertEquals("Zm9vYmFy", data);
data = new String(StringUtils.base64Decode(data.getBytes()));
assertEquals("foobar", data);
}
}

0 comments on commit 54af471

Please sign in to comment.