Skip to content

Commit

Permalink
Error handling cleanups (#112)
Browse files Browse the repository at this point in the history
* Replace .map_err(e.to_string()) with a thiserror::Error (#109)

* Replace .map_err(e.to_string()) with a thiserror::Error

* Include more detail in errors

* Ignore jdtls-created files (#107)

* Catch additional exception types in sample app (#110)

* Clippy

---------

Co-authored-by: David Cook <[email protected]>
  • Loading branch information
inahga and divergentdave authored Mar 7, 2024
1 parent d9b5dd6 commit a1d8bfc
Show file tree
Hide file tree
Showing 5 changed files with 95 additions and 67 deletions.
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,6 @@
.externalNativeBuild
.cxx
local.properties
.project
.settings
.classpath
9 changes: 5 additions & 4 deletions divviup/rust/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions divviup/rust/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@ janus_messages = "0.6.18"
jni = "0.21.1"
prio = "0.15.4"
rand = "0.8.5"
thiserror = "1.0.57"
127 changes: 68 additions & 59 deletions divviup/rust/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@ use janus_messages::{
ReportId, ReportMetadata, Role, TaskId, Time,
};
use jni::{
descriptors::Desc,
objects::{JByteArray, JClass, JLongArray, JThrowable, ReleaseMode},
objects::{JByteArray, JClass, JLongArray, ReleaseMode},
sys::{jboolean, jbyteArray, jlong, jobject},
JNIEnv,
};
Expand Down Expand Up @@ -147,18 +146,37 @@ pub extern "system" fn Java_org_divviup_android_Client_00024Prio3HistogramReport
})
}

#[derive(Debug, thiserror::Error)]
enum Error {
#[error("message encoding failed: {0}")]
Codec(#[from] prio::codec::CodecError),
#[error(transparent)]
Hpke(#[from] janus_core::hpke::Error),
#[error("invalid parameter: {0}")]
InvalidParameter(&'static str),
#[error("JNI error: {0}")]
Jni(#[from] jni::errors::Error),
#[error("task id decoding failed: {0}")]
TaskId(&'static str),
#[error("aggregator provided empty HPKE config list")]
MissingHpkeConfigs,
#[error("generated report is too large")]
ReportTooLarge,
#[error("VDAF error: {0}")]
Vdaf(#[from] prio::vdaf::VdafError),
}

/// Runs a fallible closure that returns a jobject, and transforms an error result into a thrown
/// exception, with a message provided from the error.
fn jni_try<'local, F, E>(env: &mut JNIEnv<'local>, mut f: F) -> jobject
fn jni_try<'local, F>(env: &mut JNIEnv<'local>, mut f: F) -> jobject
where
F: FnMut(&mut JNIEnv<'local>) -> Result<jobject, E>,
E: Desc<'local, JThrowable<'local>>,
F: FnMut(&mut JNIEnv<'local>) -> Result<jobject, Error>,
{
let result = f(env);
match result {
Ok(object) => object,
Err(error) => {
let _ = env.throw(error);
let _ = env.throw(error.to_string());
ptr::null_mut()
}
}
Expand All @@ -175,8 +193,8 @@ fn prepare_report_prio3count_inner<'local, 'a>(
timestamp: jlong,
measurement: jboolean,
env: &'a mut JNIEnv<'local>,
) -> Result<Vec<u8>, String> {
let vdaf = Prio3::new_count(2).map_err(|e| e.to_string())?;
) -> Result<Vec<u8>, Error> {
let vdaf = Prio3::new_count(2)?;
let measurement = measurement as u64;
prepare_report_generic(
task_id_byte_array,
Expand All @@ -201,14 +219,14 @@ fn prepare_report_prio3sum_inner<'local, 'a>(
bits: jlong,
measurement: jlong,
env: &'a mut JNIEnv<'local>,
) -> Result<Vec<u8>, String> {
) -> Result<Vec<u8>, Error> {
let bits = bits
.try_into()
.map_err(|_| "invalid bits parameter".to_string())?;
let vdaf = Prio3::new_sum(2, bits).map_err(|e| e.to_string())?;
.map_err(|_| Error::InvalidParameter("bits"))?;
let vdaf = Prio3::new_sum(2, bits)?;
let measurement = measurement
.try_into()
.map_err(|_| "invalid measurement".to_string())?;
.map_err(|_| Error::InvalidParameter("measurement"))?;
prepare_report_generic(
task_id_byte_array,
leader_hpke_config_list_byte_array,
Expand All @@ -235,17 +253,17 @@ fn prepare_report_prio3sumvec_inner<'local, 'a>(
chunk_length: jlong,
measurement: &'a JLongArray<'local>,
env: &'a mut JNIEnv<'local>,
) -> Result<Vec<u8>, String> {
let length = length
.try_into()
.map_err(|_| "invalid length parameter".to_string())?;
) -> Result<Vec<u8>, Error> {
let bits = bits
.try_into()
.map_err(|_| "invalid bits parameter".to_string())?;
.map_err(|_| Error::InvalidParameter("bits"))?;
let length = length
.try_into()
.map_err(|_| Error::InvalidParameter("length"))?;
let chunk_length = chunk_length
.try_into()
.map_err(|_| "invalid chunk_length parameter".to_string())?;
let vdaf = Prio3::new_sum_vec(2, bits, length, chunk_length).map_err(|e| e.to_string())?;
.map_err(|_| Error::InvalidParameter("chunk_length"))?;
let vdaf = Prio3::new_sum_vec(2, bits, length, chunk_length)?;

// Safety: The copy of the measurement array is not mutated again from the Java side once it is
// passed in. Only one `AutoElements` is constructed from it, in this call.
Expand Down Expand Up @@ -276,17 +294,17 @@ fn prepare_report_prio3histogram_inner<'local, 'a>(
chunk_length: jlong,
measurement: jlong,
env: &'a mut JNIEnv<'local>,
) -> Result<Vec<u8>, String> {
) -> Result<Vec<u8>, Error> {
let length = length
.try_into()
.map_err(|_| "invalid length parameter".to_string())?;
.map_err(|_| Error::InvalidParameter("length"))?;
let chunk_length = chunk_length
.try_into()
.map_err(|_| "invalid chunk_length parameter".to_string())?;
let vdaf = Prio3::new_histogram(2, length, chunk_length).map_err(|e| e.to_string())?;
.map_err(|_| Error::InvalidParameter("chunk_length"))?;
let vdaf = Prio3::new_histogram(2, length, chunk_length)?;
let measurement = measurement
.try_into()
.map_err(|_| "invalid measurement".to_string())?;
.map_err(|_| Error::InvalidParameter("measurement"))?;
prepare_report_generic(
task_id_byte_array,
leader_hpke_config_list_byte_array,
Expand All @@ -310,14 +328,12 @@ fn prepare_report_generic<'local, 'a, V>(
vdaf: V,
measurement: &V::Measurement,
env: &'a mut JNIEnv<'local>,
) -> Result<Vec<u8>, String>
) -> Result<Vec<u8>, Error>
where
V: vdaf::Client<16>,
{
let report_id: ReportId = random();
let (public_share, input_shares) = vdaf
.shard(measurement, report_id.as_ref())
.map_err(|e| e.to_string())?;
let (public_share, input_shares) = vdaf.shard(measurement, report_id.as_ref())?;
let encoded_leader_input_share = input_shares[0].get_encoded();
let encoded_helper_input_share = input_shares[1].get_encoded();
let encoded_public_share = public_share.get_encoded();
Expand Down Expand Up @@ -349,7 +365,7 @@ fn assemble_report<'local, 'a>(
encoded_public_share: Vec<u8>,
encoded_leader_input_share: Vec<u8>,
encoded_helper_input_share: Vec<u8>,
) -> Result<Vec<u8>, String> {
) -> Result<Vec<u8>, Error> {
// Safety: These byte arrays are not mutated again from the Java side once they are passed in.
// Only one `AutoElements` is constructed from each, in these calls.
let task_id = unsafe { parse_task_id(task_id_byte_array, env)? };
Expand All @@ -361,7 +377,9 @@ fn assemble_report<'local, 'a>(
let leader_hpke_config = select_hpke_config(&leader_hpke_config_list)?;
let helper_hpke_config = select_hpke_config(&helper_hpke_config_list)?;

let time = Time::from_seconds_since_epoch(u64::try_from(timestamp).map_err(|e| e.to_string())?);
let time = Time::from_seconds_since_epoch(
u64::try_from(timestamp).map_err(|_| Error::InvalidParameter("timestamp"))?,
);
let report_metadata = ReportMetadata::new(report_id, time);

let leader_encrypted_input_share = encrypt_input_share(
Expand All @@ -371,17 +389,15 @@ fn assemble_report<'local, 'a>(
&leader_hpke_config,
encoded_leader_input_share,
encoded_public_share.clone(),
)
.map_err(|e| e.to_string())?;
)?;
let helper_encrypted_input_share = encrypt_input_share(
task_id,
&report_metadata,
&Role::Helper,
&helper_hpke_config,
encoded_helper_input_share,
encoded_public_share.clone(),
)
.map_err(|e| e.to_string())?;
)?;

let report = Report::new(
report_metadata,
Expand All @@ -406,18 +422,17 @@ fn assemble_report<'local, 'a>(
unsafe fn parse_task_id<'local, 'a>(
array: &'a JByteArray<'local>,
env: &'a mut JNIEnv<'local>,
) -> Result<TaskId, String> {
) -> Result<TaskId, Error> {
// Safety: All safety requirements of get_array_elements() are imposed on the caller.
let elements_res = unsafe { env.get_array_elements(array, ReleaseMode::NoCopyBack) };
let elements = elements_res.map_err(|e| e.to_string())?;
let elements = unsafe { env.get_array_elements(array, ReleaseMode::NoCopyBack) }?;
let signed_slice: &[i8] = &elements[..];
// Safety: The [u8] slice aliases a [i8] slice, and the two have the same memory layout. The
// backing memory is managed by the JVM. The memory is valid for long enough because it is only
// released when the `AutoElements` struct is dropped, which happens after the last use of the
// slices.
let bytes: &[u8] =
unsafe { slice::from_raw_parts(signed_slice.as_ptr() as *const u8, signed_slice.len()) };
TaskId::try_from(bytes).map_err(|e| e.to_string())
TaskId::try_from(bytes).map_err(Error::TaskId)
}

/// Read from a Java byte[] array, and parse an [`HpkeConfigList`] from it. This returns an error if
Expand All @@ -433,18 +448,17 @@ unsafe fn parse_task_id<'local, 'a>(
unsafe fn decode_hpke_config_list<'local, 'a>(
array: &'a JByteArray<'local>,
env: &'a mut JNIEnv<'local>,
) -> Result<HpkeConfigList, String> {
) -> Result<HpkeConfigList, Error> {
// Safety: All safety requirements of get_array_elements() are imposed on the caller.
let elements_res = unsafe { env.get_array_elements(array, ReleaseMode::NoCopyBack) };
let elements = elements_res.map_err(|e| e.to_string())?;
let elements = unsafe { env.get_array_elements(array, ReleaseMode::NoCopyBack) }?;
let signed_slice: &[i8] = &elements[..];
// Safety: The [u8] slice aliases a [i8] slice, and the two have the same memory layout. The
// backing memory is managed by the JVM. The memory is valid for long enough because it is only
// released when the `AutoElements` struct is dropped, which happens after the last use of the
// slices.
let bytes: &[u8] =
unsafe { slice::from_raw_parts(signed_slice.as_ptr() as *const u8, signed_slice.len()) };
HpkeConfigList::get_decoded(bytes).map_err(|e| e.to_string())
Ok(HpkeConfigList::get_decoded(bytes)?)
}

/// Read from a Java long[] array, and convert each element to a `u128`. This returns an error if
Expand All @@ -460,38 +474,33 @@ unsafe fn decode_hpke_config_list<'local, 'a>(
unsafe fn convert_sumvec_measurement<'local, 'a>(
array: &'a JLongArray<'local>,
env: &'a mut JNIEnv<'local>,
) -> Result<Vec<u128>, String> {
) -> Result<Vec<u128>, Error> {
// Safety: All safety requirements of get_array_elements() are imposed on the caller.
let elements_res = unsafe { env.get_array_elements(array, ReleaseMode::NoCopyBack) };
let elements = elements_res.map_err(|e| e.to_string())?;
let elements = unsafe { env.get_array_elements(array, ReleaseMode::NoCopyBack) }?;
elements
.iter()
.map(|value| u128::try_from(*value))
.collect::<Result<Vec<_>, _>>()
.map_err(|_| "invalid measurement".to_string())
.map_err(|_| Error::InvalidParameter("measurement: negative value not allowed in sumvec"))
}

/// Creates a new byte[] array, copies the provided data into it, and returns a raw JNI pointer to
/// the array. This pointer is intended to be returned from a JNI method.
///
/// This returns an error if the byte slice's length cannot fit in an i32, or if the JVM fails to
/// create or update the array.
fn return_new_byte_array(data: &[u8], env: &mut JNIEnv<'_>) -> Result<jbyteArray, String> {
let length = data
.len()
.try_into()
.map_err(|_| "length overflow".to_string())?;
fn return_new_byte_array(data: &[u8], env: &mut JNIEnv<'_>) -> Result<jbyteArray, Error> {
let length = data.len().try_into().map_err(|_| Error::ReportTooLarge)?;

let byte_array = env.new_byte_array(length).map_err(|e| e.to_string())?;
let byte_array = env.new_byte_array(length)?;

// Start a new scope for the AutoElements. We need to drop it before calling into_raw() on the
// byte array, as it borrows the byte array.
{
// Safety: There are no races on this array, and it will not be aliased, because it is newly
// created. The `AutoElements` will release its reference before the array is returned to
// Java code.
let mut elements = unsafe { env.get_array_elements(&byte_array, ReleaseMode::CopyBack) }
.map_err(|e| e.to_string())?;
let mut elements = unsafe { env.get_array_elements(&byte_array, ReleaseMode::CopyBack) }?;
let signed_slice: &mut [i8] = &mut elements[..];
let len = signed_slice.len();
// Safety: The [u8] mutable slice points to the same memory as the [i8] slice, and the two
Expand All @@ -502,7 +511,7 @@ fn return_new_byte_array(data: &[u8], env: &mut JNIEnv<'_>) -> Result<jbyteArray
let mut_slice: &mut [u8] =
unsafe { slice::from_raw_parts_mut(signed_slice.as_ptr() as *mut u8, len) };
mut_slice.copy_from_slice(data);
elements.commit().map_err(|e| e.to_string())?;
elements.commit()?;
}

Ok(byte_array.into_raw())
Expand All @@ -511,9 +520,9 @@ fn return_new_byte_array(data: &[u8], env: &mut JNIEnv<'_>) -> Result<jbyteArray
/// Select an [`HpkeConfig`] from an [`HpkeConfigList`] that uses a supported set of algorithms.
///
/// Returns an error if the list is empty, or if all sets of algorithms are unsupported.
fn select_hpke_config(list: &HpkeConfigList) -> Result<HpkeConfig, String> {
fn select_hpke_config(list: &HpkeConfigList) -> Result<HpkeConfig, Error> {
if list.hpke_configs().is_empty() {
return Err("aggregator provided empty HpkeConfigList".to_string());
return Err(Error::MissingHpkeConfigs);
}

// Take the first supported HpkeConfig from the list. Return the first error otherwise.
Expand All @@ -523,14 +532,14 @@ fn select_hpke_config(list: &HpkeConfigList) -> Result<HpkeConfig, String> {
Ok(()) => return Ok(config.clone()),
Err(e) => {
if first_error.is_none() {
first_error = Some(e.to_string());
first_error = Some(e);
}
}
}
}
// Unwrap safety: we checked that the list is nonempty, and if we fell through to here, we must
// have seen at least one error.
Err(first_error.unwrap())
Err(first_error.unwrap().into())
}

/// Convenience method to construct a [`PlaintextInputShare`], encode it, and encrypt it.
Expand Down
22 changes: 18 additions & 4 deletions sampleapp/src/main/java/org/divviup/sampleapp/MainActivity.java
Original file line number Diff line number Diff line change
Expand Up @@ -160,15 +160,29 @@ public void run() {
client.sendMeasurement(measurement);
handler.post(
() -> Snackbar.make(view, "Success!", Snackbar.LENGTH_LONG)
.setAnchorView(R.id.button)
.show()
.setAnchorView(R.id.button)
.show()
);
} catch (IOException e) {
Log.e(TAG, "upload failed", e);
handler.post(
() -> Snackbar.make(view, "Error uploading report", Snackbar.LENGTH_LONG)
.setAnchorView(R.id.button)
.show()
.setAnchorView(R.id.button)
.show()
);
} catch (IllegalArgumentException e) {
Log.e(TAG, "illegal argument", e);
handler.post(
() -> Snackbar.make(view, "Invalid task parameters", Snackbar.LENGTH_LONG)
.setAnchorView(R.id.button)
.show()
);
} catch (RuntimeException e) {
Log.e(TAG, "sharding error", e);
handler.post(
() -> Snackbar.make(view, "Error preparing report", Snackbar.LENGTH_LONG)
.setAnchorView(R.id.button)
.show()
);
}
}
Expand Down

0 comments on commit a1d8bfc

Please sign in to comment.