Skip to content

Commit a6ac4ca

Browse files
committed
Refactors
1 parent 09c4dda commit a6ac4ca

File tree

4 files changed

+105
-73
lines changed

4 files changed

+105
-73
lines changed

src/crypto.rs

Lines changed: 7 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@ use std::ffi::CString;
55
use std::str::FromStr;
66
use thiserror::Error;
77

8+
use crate::signature::SignatureAlgorithm;
89
#[cfg(feature = "xmlsec")]
910
use crate::xmlsec::{self, XmlSecKey, XmlSecKeyFormat, XmlSecSignatureContext};
1011
#[cfg(feature = "xmlsec")]
@@ -486,24 +487,6 @@ pub fn gen_saml_assertion_id() -> String {
486487
format!("_{}", uuid::Uuid::new_v4())
487488
}
488489

489-
#[derive(Debug, PartialEq)]
490-
enum SigAlg {
491-
Unimplemented,
492-
RsaSha256,
493-
EcdsaSha256,
494-
}
495-
496-
impl FromStr for SigAlg {
497-
type Err = Box<dyn std::error::Error>;
498-
fn from_str(s: &str) -> Result<Self, Self::Err> {
499-
match s {
500-
"http://www.w3.org/2001/04/xmldsig-more#rsa-sha256" => Ok(SigAlg::RsaSha256),
501-
"http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha256" => Ok(SigAlg::EcdsaSha256),
502-
_ => Ok(SigAlg::Unimplemented),
503-
}
504-
}
505-
}
506-
507490
#[derive(Debug, Error, Clone)]
508491
pub enum UrlVerifierError {
509492
#[error("Unimplemented SigAlg: {:?}", sigalg)]
@@ -621,11 +604,9 @@ impl UrlVerifier {
621604
.collect::<HashMap<String, String>>();
622605

623606
// Match against implemented SigAlg
624-
let sig_alg: SigAlg = SigAlg::from_str(&query_params["SigAlg"])?;
625-
if sig_alg == SigAlg::Unimplemented {
626-
return Err(Box::new(UrlVerifierError::SigAlgUnimplemented {
627-
sigalg: query_params["SigAlg"].clone(),
628-
}));
607+
let sig_alg = SignatureAlgorithm::from_str(&query_params["SigAlg"])?;
608+
if let SignatureAlgorithm::Unsupported(sigalg) = sig_alg {
609+
return Err(Box::new(UrlVerifierError::SigAlgUnimplemented { sigalg }));
629610
}
630611

631612
// Construct a Url so that percent encoded query can be easily
@@ -668,13 +649,13 @@ impl UrlVerifier {
668649
fn verify_signature(
669650
&self,
670651
data: &[u8],
671-
sig_alg: SigAlg,
652+
sig_alg: SignatureAlgorithm,
672653
signature: &[u8],
673654
) -> Result<bool, Box<dyn std::error::Error>> {
674655
let mut verifier = openssl::sign::Verifier::new(
675656
match sig_alg {
676-
SigAlg::RsaSha256 => openssl::hash::MessageDigest::sha256(),
677-
SigAlg::EcdsaSha256 => openssl::hash::MessageDigest::sha256(),
657+
SignatureAlgorithm::RsaSha256 => openssl::hash::MessageDigest::sha256(),
658+
SignatureAlgorithm::EcdsaSha256 => openssl::hash::MessageDigest::sha256(),
678659
_ => panic!("sig_alg is bad!"),
679660
},
680661
&self.public_key,

src/idp/mod.rs

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -9,9 +9,10 @@ pub mod verified_request;
99
mod tests;
1010

1111
use openssl::bn::{BigNum, MsbOption};
12+
use openssl::ec::{EcGroup, EcKey};
1213
use openssl::nid::Nid;
1314
use openssl::pkey::Private;
14-
use openssl::{asn1::Asn1Time, pkey, rsa::Rsa, x509};
15+
use openssl::{asn1::Asn1Time, pkey, x509};
1516
use std::str::FromStr;
1617

1718
use crate::crypto::{self};
@@ -24,22 +25,31 @@ pub struct IdentityProvider {
2425
private_key: pkey::PKey<Private>,
2526
}
2627

27-
pub enum KeyType {
28+
pub enum Rsa {
2829
Rsa2048,
2930
Rsa3072,
3031
Rsa4096,
3132
}
3233

33-
impl KeyType {
34+
impl Rsa {
3435
fn bit_length(&self) -> u32 {
3536
match &self {
36-
KeyType::Rsa2048 => 2048,
37-
KeyType::Rsa3072 => 3072,
38-
KeyType::Rsa4096 => 4096,
37+
Rsa::Rsa2048 => 2048,
38+
Rsa::Rsa3072 => 3072,
39+
Rsa::Rsa4096 => 4096,
3940
}
4041
}
4142
}
4243

44+
pub enum Eliptic {
45+
NISTP256,
46+
}
47+
48+
pub enum KeyType {
49+
Rsa(Rsa),
50+
Eliptic(Eliptic),
51+
}
52+
4353
pub struct CertificateParams<'a> {
4454
pub common_name: &'a str,
4555
pub issuer_name: &'a str,
@@ -48,22 +58,40 @@ pub struct CertificateParams<'a> {
4858

4959
impl IdentityProvider {
5060
pub fn generate_new(key_type: KeyType) -> Result<Self, Error> {
51-
let rsa = Rsa::generate(key_type.bit_length())?;
52-
let private_key = pkey::PKey::from_rsa(rsa)?;
61+
let private_key = match key_type {
62+
KeyType::Rsa(rsa) => {
63+
let bit_length = rsa.bit_length();
64+
let rsa = openssl::rsa::Rsa::generate(bit_length)?;
65+
pkey::PKey::from_rsa(rsa)?
66+
}
67+
KeyType::Eliptic(ecc) => {
68+
let nid = match ecc {
69+
Eliptic::NISTP256 => Nid::X9_62_PRIME256V1,
70+
};
71+
let group = EcGroup::from_curve_name(nid)?;
72+
let private_key: EcKey<Private> = EcKey::generate(&group)?;
73+
pkey::PKey::from_ec_key(private_key)?
74+
}
75+
};
5376

5477
Ok(IdentityProvider { private_key })
5578
}
5679

57-
pub fn from_private_key_der(der_bytes: &[u8]) -> Result<Self, Error> {
58-
let rsa = Rsa::private_key_from_der(der_bytes)?;
80+
pub fn from_rsa_private_key_der(der_bytes: &[u8]) -> Result<Self, Error> {
81+
let rsa = openssl::rsa::Rsa::private_key_from_der(der_bytes)?;
5982
let private_key = pkey::PKey::from_rsa(rsa)?;
6083

6184
Ok(IdentityProvider { private_key })
6285
}
6386

6487
pub fn export_private_key_der(&self) -> Result<Vec<u8>, Error> {
65-
let rsa: Rsa<Private> = self.private_key.rsa()?;
66-
Ok(rsa.private_key_to_der()?)
88+
if let Ok(ec_key) = self.private_key.ec_key() {
89+
Ok(ec_key.private_key_to_der()?)
90+
} else if let Ok(rsa) = self.private_key.rsa() {
91+
Ok(rsa.private_key_to_der()?)
92+
} else {
93+
Err(Error::UnexpectedError)?
94+
}
6795
}
6896

6997
pub fn create_certificate(&self, params: &CertificateParams) -> Result<Vec<u8>, Error> {

src/idp/tests.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ fn test_extract_sp() {
3737
#[test]
3838
fn test_signed_response() {
3939
// init our IdP
40-
let idp = IdentityProvider::from_private_key_der(include_bytes!(
40+
let idp = IdentityProvider::from_rsa_private_key_der(include_bytes!(
4141
"../../test_vectors/idp_private_key.der"
4242
))
4343
.expect("failed to create idp");
@@ -135,7 +135,7 @@ fn test_signed_response_threads() {
135135

136136
#[test]
137137
fn test_signed_response_fingerprint() {
138-
let idp = IdentityProvider::from_private_key_der(include_bytes!(
138+
let idp = IdentityProvider::from_rsa_private_key_der(include_bytes!(
139139
"../../test_vectors/idp_private_key.der"
140140
))
141141
.expect("failed to create idp");

src/signature.rs

Lines changed: 56 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ use quick_xml::events::{BytesEnd, BytesStart, BytesText, Event};
44
use quick_xml::Writer;
55
use serde::Deserialize;
66
use std::io::Cursor;
7+
use std::str::FromStr;
78

89
const NAME: &str = "ds:Signature";
910
const SCHEMA: (&str, &str) = ("xmlns:ds", "http://www.w3.org/2000/09/xmldsig#");
@@ -33,32 +34,30 @@ impl Signature {
3334
algorithm: SignatureAlgorithm::RsaSha256,
3435
hmac_output_length: None,
3536
},
36-
reference: vec![
37-
Reference {
38-
transforms: Some(Transforms {
39-
transforms: vec![
40-
Transform {
41-
algorithm: "http://www.w3.org/2000/09/xmldsig#enveloped-signature"
42-
.to_string(),
43-
xpath: None,
44-
},
45-
Transform {
46-
algorithm: "http://www.w3.org/2001/10/xml-exc-c14n#".to_string(),
47-
xpath: None,
48-
},
49-
],
50-
}),
51-
digest_method: DigestMethod {
52-
algorithm: DigestAlgorithm::Sha1,
53-
},
54-
digest_value: Some(DigestValue {
55-
base64_content: Some("".to_string()),
56-
}),
57-
uri: Some(format!("#{}", ref_id)),
58-
reference_type: None,
59-
id: None,
60-
}
61-
],
37+
reference: vec![Reference {
38+
transforms: Some(Transforms {
39+
transforms: vec![
40+
Transform {
41+
algorithm: "http://www.w3.org/2000/09/xmldsig#enveloped-signature"
42+
.to_string(),
43+
xpath: None,
44+
},
45+
Transform {
46+
algorithm: "http://www.w3.org/2001/10/xml-exc-c14n#".to_string(),
47+
xpath: None,
48+
},
49+
],
50+
}),
51+
digest_method: DigestMethod {
52+
algorithm: DigestAlgorithm::Sha1,
53+
},
54+
digest_value: Some(DigestValue {
55+
base64_content: Some("".to_string()),
56+
}),
57+
uri: Some(format!("#{}", ref_id)),
58+
reference_type: None,
59+
id: None,
60+
}],
6261
},
6362
signature_value: SignatureValue {
6463
id: None,
@@ -294,22 +293,43 @@ impl TryFrom<&SignatureMethod> for Event<'_> {
294293

295294
#[derive(Clone, Debug, Deserialize, Hash, Eq, PartialEq, Ord, PartialOrd)]
296295
pub enum SignatureAlgorithm {
297-
#[serde(rename="http://www.w3.org/2001/04/xmldsig-more#rsa-sha256")]
296+
#[serde(rename = "http://www.w3.org/2001/04/xmldsig-more#rsa-sha256")]
298297
RsaSha256,
299-
#[serde(rename="http://www.w3.org/2007/05/xmldsig-more#sha256-rsa-MGF1")]
298+
#[serde(rename = "http://www.w3.org/2007/05/xmldsig-more#sha256-rsa-MGF1")]
300299
Sha256RsaMGF1,
300+
#[serde(rename = "http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha256")]
301+
EcdsaSha256,
301302
#[serde(untagged)]
302303
Unsupported(String),
303304
}
304305

306+
impl FromStr for SignatureAlgorithm {
307+
type Err = Box<dyn std::error::Error>;
308+
309+
fn from_str(s: &str) -> Result<Self, Self::Err> {
310+
Ok(match s {
311+
"http://www.w3.org/2001/04/xmldsig-more#rsa-sha256" => SignatureAlgorithm::RsaSha256,
312+
"http://www.w3.org/2007/05/xmldsig-more#sha256-rsa-MGF1" => {
313+
SignatureAlgorithm::Sha256RsaMGF1
314+
}
315+
"http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha256" => {
316+
SignatureAlgorithm::EcdsaSha256
317+
}
318+
i => SignatureAlgorithm::Unsupported(i.to_string()),
319+
})
320+
}
321+
}
322+
305323
impl SignatureAlgorithm {
306324
const RSA_SHA256: &'static str = "http://www.w3.org/2001/04/xmldsig-more#rsa-sha256";
307325
const SHA256_RSA_MGF1: &'static str = "http://www.w3.org/2007/05/xmldsig-more#sha256-rsa-MGF1";
326+
const SHA256_ECDSA: &'static str = "http://www.w3.org/2001/04/xmldsig-more#ecdsa-sha256";
308327

309328
pub fn value(&self) -> &str {
310329
match self {
311330
SignatureAlgorithm::RsaSha256 => Self::RSA_SHA256,
312331
SignatureAlgorithm::Sha256RsaMGF1 => Self::SHA256_RSA_MGF1,
332+
SignatureAlgorithm::EcdsaSha256 => Self::SHA256_ECDSA,
313333
SignatureAlgorithm::Unsupported(algo) => algo,
314334
}
315335
}
@@ -430,9 +450,9 @@ impl TryFrom<&DigestMethod> for Event<'_> {
430450

431451
#[derive(Clone, Debug, Deserialize, Hash, Eq, PartialEq, Ord, PartialOrd)]
432452
pub enum DigestAlgorithm {
433-
#[serde(rename="http://www.w3.org/2000/09/xmldsig#sha1")]
453+
#[serde(rename = "http://www.w3.org/2000/09/xmldsig#sha1")]
434454
Sha1,
435-
#[serde(rename="http://www.w3.org/2001/04/xmlenc#sha256")]
455+
#[serde(rename = "http://www.w3.org/2001/04/xmlenc#sha256")]
436456
Sha256,
437457
#[serde(untagged)]
438458
Unsupported(String),
@@ -588,8 +608,10 @@ mod test {
588608

589609
#[test]
590610
pub fn test_canonicalizationmethod_deserialization() -> Result<(), Box<dyn std::error::Error>> {
591-
let canonicalization_method = r#"<ds:CanonicalizationMethod Algorithm="http://www.w3.org/2001/10/xml-exc-c14n#"/>"#;
592-
let deserialized: CanonicalizationMethod = quick_xml::de::from_str(canonicalization_method)?;
611+
let canonicalization_method =
612+
r#"<ds:CanonicalizationMethod Algorithm="http://www.w3.org/2001/10/xml-exc-c14n#"/>"#;
613+
let deserialized: CanonicalizationMethod =
614+
quick_xml::de::from_str(canonicalization_method)?;
593615
let serialized = deserialized.to_xml()?;
594616
let re_deserialized: CanonicalizationMethod = quick_xml::de::from_str(&serialized)?;
595617
assert_eq!(deserialized, re_deserialized);
@@ -627,7 +649,8 @@ mod test {
627649

628650
#[test]
629651
pub fn test_digestmethod_deserialization() -> Result<(), Box<dyn std::error::Error>> {
630-
let digest_method = r#"<ds:DigestMethod Algorithm="http://www.w3.org/2000/09/xmldsig#sha1" />"#;
652+
let digest_method =
653+
r#"<ds:DigestMethod Algorithm="http://www.w3.org/2000/09/xmldsig#sha1" />"#;
631654
let deserialized: DigestMethod = quick_xml::de::from_str(digest_method)?;
632655
let serialized = deserialized.to_xml()?;
633656
let re_deserialized: DigestMethod = quick_xml::de::from_str(&serialized)?;

0 commit comments

Comments
 (0)