Skip to content

Commit 8d8ef40

Browse files
authored
Refactors (#59)
1 parent 8ccbcd2 commit 8d8ef40

File tree

6 files changed

+194
-95
lines changed

6 files changed

+194
-95
lines changed

src/crypto.rs

Lines changed: 8 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use std::ffi::CString;
55
use std::str::FromStr;
66
use thiserror::Error;
77

8+
use crate::signature::SignatureAlgorithm;
89
#[cfg(feature = "xmlsec")]
910
use crate::xmlsec::{self, XmlSecKey, XmlSecKeyFormat, XmlSecSignatureContext};
1011
#[cfg(feature = "xmlsec")]
@@ -223,6 +224,7 @@ fn get_elements_by_predicate<F: FnMut(&libxml::tree::Node) -> bool>(
223224
/// Searches for and returns the element with the given value of the `ID` attribute from the subtree
224225
/// rooted at the given node.
225226
#[cfg(feature = "xmlsec")]
227+
#[allow(unused)]
226228
fn get_element_by_id(elem: &libxml::tree::Node, id: &str) -> Option<libxml::tree::Node> {
227229
let mut elems = get_elements_by_predicate(elem, |node| {
228230
node.get_attribute("ID")
@@ -486,24 +488,6 @@ pub fn gen_saml_assertion_id() -> String {
486488
format!("_{}", uuid::Uuid::new_v4())
487489
}
488490

489-
#[derive(Debug, PartialEq)]
490-
enum SigAlg {
491-
Unimplemented,
492-
RsaSha256,
493-
EcdsaSha256,
494-
}
495-
496-
impl FromStr for SigAlg {
497-
type Err = Box<dyn std::error::Error>;
498-
fn from_str(s: &str) -> Result<Self, Self::Err> {
499-
match s {
500-
"http://www.w3.org/2001/04/xmldsig-more#rsa-sha256" => Ok(SigAlg::RsaSha256),
501-
"http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha256" => Ok(SigAlg::EcdsaSha256),
502-
_ => Ok(SigAlg::Unimplemented),
503-
}
504-
}
505-
}
506-
507491
#[derive(Debug, Error, Clone)]
508492
pub enum UrlVerifierError {
509493
#[error("Unimplemented SigAlg: {:?}", sigalg)]
@@ -621,11 +605,9 @@ impl UrlVerifier {
621605
.collect::<HashMap<String, String>>();
622606

623607
// Match against implemented SigAlg
624-
let sig_alg: SigAlg = SigAlg::from_str(&query_params["SigAlg"])?;
625-
if sig_alg == SigAlg::Unimplemented {
626-
return Err(Box::new(UrlVerifierError::SigAlgUnimplemented {
627-
sigalg: query_params["SigAlg"].clone(),
628-
}));
608+
let sig_alg = SignatureAlgorithm::from_str(&query_params["SigAlg"])?;
609+
if let SignatureAlgorithm::Unsupported(sigalg) = sig_alg {
610+
return Err(Box::new(UrlVerifierError::SigAlgUnimplemented { sigalg }));
629611
}
630612

631613
// Construct a Url so that percent encoded query can be easily
@@ -668,13 +650,13 @@ impl UrlVerifier {
668650
fn verify_signature(
669651
&self,
670652
data: &[u8],
671-
sig_alg: SigAlg,
653+
sig_alg: SignatureAlgorithm,
672654
signature: &[u8],
673655
) -> Result<bool, Box<dyn std::error::Error>> {
674656
let mut verifier = openssl::sign::Verifier::new(
675657
match sig_alg {
676-
SigAlg::RsaSha256 => openssl::hash::MessageDigest::sha256(),
677-
SigAlg::EcdsaSha256 => openssl::hash::MessageDigest::sha256(),
658+
SignatureAlgorithm::RsaSha256 => openssl::hash::MessageDigest::sha256(),
659+
SignatureAlgorithm::EcdsaSha256 => openssl::hash::MessageDigest::sha256(),
678660
_ => panic!("sig_alg is bad!"),
679661
},
680662
&self.public_key,

src/idp/mod.rs

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,10 @@ pub mod verified_request;
99
mod tests;
1010

1111
use openssl::bn::{BigNum, MsbOption};
12+
use openssl::ec::{EcGroup, EcKey};
1213
use openssl::nid::Nid;
1314
use openssl::pkey::Private;
14-
use openssl::{asn1::Asn1Time, pkey, rsa::Rsa, x509};
15+
use openssl::{asn1::Asn1Time, pkey, x509};
1516
use std::str::FromStr;
1617

1718
use crate::crypto::{self};
@@ -24,22 +25,31 @@ pub struct IdentityProvider {
2425
private_key: pkey::PKey<Private>,
2526
}
2627

27-
pub enum KeyType {
28+
pub enum Rsa {
2829
Rsa2048,
2930
Rsa3072,
3031
Rsa4096,
3132
}
3233

33-
impl KeyType {
34+
impl Rsa {
3435
fn bit_length(&self) -> u32 {
3536
match &self {
36-
KeyType::Rsa2048 => 2048,
37-
KeyType::Rsa3072 => 3072,
38-
KeyType::Rsa4096 => 4096,
37+
Rsa::Rsa2048 => 2048,
38+
Rsa::Rsa3072 => 3072,
39+
Rsa::Rsa4096 => 4096,
3940
}
4041
}
4142
}
4243

44+
pub enum Elliptic {
45+
NISTP256,
46+
}
47+
48+
pub enum KeyType {
49+
Rsa(Rsa),
50+
Elliptic(Elliptic),
51+
}
52+
4353
pub struct CertificateParams<'a> {
4454
pub common_name: &'a str,
4555
pub issuer_name: &'a str,
@@ -48,22 +58,40 @@ pub struct CertificateParams<'a> {
4858

4959
impl IdentityProvider {
5060
pub fn generate_new(key_type: KeyType) -> Result<Self, Error> {
51-
let rsa = Rsa::generate(key_type.bit_length())?;
52-
let private_key = pkey::PKey::from_rsa(rsa)?;
61+
let private_key = match key_type {
62+
KeyType::Rsa(rsa) => {
63+
let bit_length = rsa.bit_length();
64+
let rsa = openssl::rsa::Rsa::generate(bit_length)?;
65+
pkey::PKey::from_rsa(rsa)?
66+
}
67+
KeyType::Elliptic(ecc) => {
68+
let nid = match ecc {
69+
Elliptic::NISTP256 => Nid::X9_62_PRIME256V1,
70+
};
71+
let group = EcGroup::from_curve_name(nid)?;
72+
let private_key: EcKey<Private> = EcKey::generate(&group)?;
73+
pkey::PKey::from_ec_key(private_key)?
74+
}
75+
};
5376

5477
Ok(IdentityProvider { private_key })
5578
}
5679

57-
pub fn from_private_key_der(der_bytes: &[u8]) -> Result<Self, Error> {
58-
let rsa = Rsa::private_key_from_der(der_bytes)?;
80+
pub fn from_rsa_private_key_der(der_bytes: &[u8]) -> Result<Self, Error> {
81+
let rsa = openssl::rsa::Rsa::private_key_from_der(der_bytes)?;
5982
let private_key = pkey::PKey::from_rsa(rsa)?;
6083

6184
Ok(IdentityProvider { private_key })
6285
}
6386

6487
pub fn export_private_key_der(&self) -> Result<Vec<u8>, Error> {
65-
let rsa: Rsa<Private> = self.private_key.rsa()?;
66-
Ok(rsa.private_key_to_der()?)
88+
if let Ok(ec_key) = self.private_key.ec_key() {
89+
Ok(ec_key.private_key_to_der()?)
90+
} else if let Ok(rsa) = self.private_key.rsa() {
91+
Ok(rsa.private_key_to_der()?)
92+
} else {
93+
Err(Error::UnexpectedError)?
94+
}
6795
}
6896

6997
pub fn create_certificate(&self, params: &CertificateParams) -> Result<Vec<u8>, Error> {

src/idp/tests.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ fn test_extract_sp() {
3737
#[test]
3838
fn test_signed_response() {
3939
// init our IdP
40-
let idp = IdentityProvider::from_private_key_der(include_bytes!(
40+
let idp = IdentityProvider::from_rsa_private_key_der(include_bytes!(
4141
"../../test_vectors/idp_private_key.der"
4242
))
4343
.expect("failed to create idp");
@@ -135,7 +135,7 @@ fn test_signed_response_threads() {
135135

136136
#[test]
137137
fn test_signed_response_fingerprint() {
138-
let idp = IdentityProvider::from_private_key_der(include_bytes!(
138+
let idp = IdentityProvider::from_rsa_private_key_der(include_bytes!(
139139
"../../test_vectors/idp_private_key.der"
140140
))
141141
.expect("failed to create idp");

src/metadata/entity_descriptor.rs

Lines changed: 45 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ use chrono::prelude::*;
77
use quick_xml::events::{BytesDecl, BytesEnd, BytesStart, BytesText, Event};
88
use quick_xml::Writer;
99
use serde::Deserialize;
10+
use std::collections::VecDeque;
1011
use std::io::Cursor;
1112
use std::str::FromStr;
1213
use thiserror::Error;
@@ -29,18 +30,8 @@ pub enum EntityDescriptorType {
2930
}
3031

3132
impl EntityDescriptorType {
32-
pub fn take_first(self) -> Option<EntityDescriptor> {
33-
match self {
34-
EntityDescriptorType::EntitiesDescriptor(descriptor) => descriptor
35-
.descriptors
36-
.into_iter()
37-
.next()
38-
.and_then(|descriptor_type| match descriptor_type {
39-
EntityDescriptorType::EntitiesDescriptor(_) => None,
40-
EntityDescriptorType::EntityDescriptor(descriptor) => Some(descriptor),
41-
}),
42-
EntityDescriptorType::EntityDescriptor(descriptor) => Some(descriptor),
43-
}
33+
pub fn iter(&self) -> EntityDescriptorIterator {
34+
EntityDescriptorIterator::new(self)
4435
}
4536
}
4637

@@ -284,6 +275,39 @@ impl TryFrom<&EntityDescriptor> for Event<'_> {
284275
}
285276
}
286277

278+
#[derive(Clone)]
279+
pub struct EntityDescriptorIterator<'a> {
280+
queue: VecDeque<&'a EntityDescriptorType>,
281+
}
282+
283+
impl<'a> EntityDescriptorIterator<'a> {
284+
pub fn new(root: &'a EntityDescriptorType) -> Self {
285+
let mut queue = VecDeque::new();
286+
queue.push_back(root);
287+
EntityDescriptorIterator { queue }
288+
}
289+
}
290+
291+
impl<'a> Iterator for EntityDescriptorIterator<'a> {
292+
type Item = &'a EntityDescriptor;
293+
294+
fn next(&mut self) -> Option<Self::Item> {
295+
while let Some(current) = self.queue.pop_front() {
296+
match current {
297+
EntityDescriptorType::EntitiesDescriptor(entities_descriptor) => {
298+
for descriptor in &entities_descriptor.descriptors {
299+
self.queue.push_back(descriptor);
300+
}
301+
}
302+
EntityDescriptorType::EntityDescriptor(entity_descriptor) => {
303+
return Some(entity_descriptor);
304+
}
305+
}
306+
}
307+
None
308+
}
309+
}
310+
287311
#[cfg(test)]
288312
mod test {
289313
use crate::traits::ToXml;
@@ -345,6 +369,7 @@ mod test {
345369
.parse()
346370
.expect("Failed to parse EntitiesDescriptor");
347371

372+
assert_eq!(2, reparsed_entities_descriptor.descriptors.len());
348373
assert_eq!(reparsed_entities_descriptor, entities_descriptor);
349374
}
350375

@@ -369,11 +394,12 @@ mod test {
369394
let expected_entity_descriptor: EntityDescriptor = input_xml
370395
.parse()
371396
.expect("Failed to parse idp_metadata.xml into an EntityDescriptor");
372-
let entity_descriptor: EntityDescriptor = entity_descriptor_type
373-
.take_first()
397+
let entity_descriptor = entity_descriptor_type
398+
.iter()
399+
.next()
374400
.expect("Failed to take first EntityDescriptor from EntityDescriptorType");
375401

376-
assert_eq!(expected_entity_descriptor, entity_descriptor);
402+
assert_eq!(&expected_entity_descriptor, entity_descriptor);
377403
}
378404

379405
#[test]
@@ -401,11 +427,12 @@ mod test {
401427
let expected_entity_descriptor: EntityDescriptor = input_xml
402428
.parse()
403429
.expect("Failed to parse idp_metadata.xml into an EntityDescriptor");
404-
let entity_descriptor: EntityDescriptor = entity_descriptor_type
405-
.take_first()
430+
let entity_descriptor = entity_descriptor_type
431+
.iter()
432+
.next()
406433
.expect("Failed to take first EntityDescriptor from EntityDescriptorType");
407434
println!("{entity_descriptor:#?}");
408435

409-
assert_eq!(expected_entity_descriptor, entity_descriptor);
436+
assert_eq!(&expected_entity_descriptor, entity_descriptor);
410437
}
411438
}

0 commit comments

Comments
 (0)