diff --git a/crates/gpapi/src/auth.rs b/crates/gpapi/src/auth.rs index 2e3fb2a2..c64a8474 100644 --- a/crates/gpapi/src/auth.rs +++ b/crates/gpapi/src/auth.rs @@ -1,4 +1,4 @@ -use anyhow::bail; +use anyhow::anyhow; use regex::Regex; use serde::{Deserialize, Serialize}; @@ -35,7 +35,7 @@ impl SamlAuthData { } } - pub fn parse_html(html: &str) -> anyhow::Result { + pub fn from_html(html: &str) -> anyhow::Result { match parse_xml_tag(html, "saml-auth-status") { Some(saml_status) if saml_status == "1" => { let username = parse_xml_tag(html, "saml-username"); @@ -43,21 +43,17 @@ impl SamlAuthData { let portal_userauthcookie = parse_xml_tag(html, "portal-userauthcookie"); if SamlAuthData::check(&username, &prelogin_cookie, &portal_userauthcookie) { - return Ok(SamlAuthData::new( + Ok(SamlAuthData::new( username.unwrap(), prelogin_cookie, portal_userauthcookie, - )); + )) + } else { + Err(anyhow!("Found invalid auth data in HTML")) } - - bail!("Found invalid auth data in HTML"); - } - Some(status) => { - bail!("Found invalid SAML status {} in HTML", status); - } - None => { - bail!("No auth data found in HTML"); } + Some(status) => Err(anyhow!("Found invalid SAML status {} in HTML", status)), + None => Err(anyhow!("No auth data found in HTML")), } } diff --git a/crates/gpapi/src/credential.rs b/crates/gpapi/src/credential.rs index e1a1be56..cb47f3fb 100644 --- a/crates/gpapi/src/credential.rs +++ b/crates/gpapi/src/credential.rs @@ -1,5 +1,6 @@ use std::collections::HashMap; +use log::info; use serde::{Deserialize, Serialize}; use specta::Type; @@ -155,25 +156,52 @@ impl From for CachedCredential { } } +#[derive(Debug, Serialize, Deserialize, Type, Clone)] +pub struct TokenCredential { + #[serde(alias = "un")] + username: String, + token: String, +} + +impl TokenCredential { + pub fn username(&self) -> &str { + &self.username + } + + pub fn token(&self) -> &str { + &self.token + } +} + #[derive(Debug, Serialize, Deserialize, Type, Clone)] #[serde(tag = "type", rename_all = "camelCase")] pub enum Credential { Password(PasswordCredential), PreloginCookie(PreloginCookieCredential), AuthCookie(AuthCookieCredential), + TokenCredential(TokenCredential), CachedCredential(CachedCredential), } impl Credential { - /// Create a credential from a globalprotectcallback: - pub fn parse_gpcallback(auth_data: &str) -> anyhow::Result { + /// Create a credential from a globalprotectcallback:, + /// or globalprotectcallback:cas-as=1&un=user@xyz.com&token=very_long_string + pub fn from_gpcallback(auth_data: &str) -> anyhow::Result { // Remove the surrounding quotes let auth_data = auth_data.trim_matches('"'); let auth_data = auth_data.trim_start_matches("globalprotectcallback:"); - let auth_data = decode_to_string(auth_data)?; - let auth_data = SamlAuthData::parse_html(&auth_data)?; - Self::try_from(auth_data) + if auth_data.starts_with("cas-as") { + info!("Got token auth data: {}", auth_data); + let token_cred: TokenCredential = serde_urlencoded::from_str(auth_data)?; + Ok(Self::TokenCredential(token_cred)) + } else { + info!("Parsing SAML auth data..."); + let auth_data = decode_to_string(auth_data)?; + let auth_data = SamlAuthData::from_html(&auth_data)?; + + Self::try_from(auth_data) + } } pub fn username(&self) -> &str { @@ -181,6 +209,7 @@ impl Credential { Credential::Password(cred) => cred.username(), Credential::PreloginCookie(cred) => cred.username(), Credential::AuthCookie(cred) => cred.username(), + Credential::TokenCredential(cred) => cred.username(), Credential::CachedCredential(cred) => cred.username(), } } @@ -189,20 +218,23 @@ impl Credential { let mut params = HashMap::new(); params.insert("user", self.username()); - let (passwd, prelogin_cookie, portal_userauthcookie, portal_prelogonuserauthcookie) = match self { - Credential::Password(cred) => (Some(cred.password()), None, None, None), - Credential::PreloginCookie(cred) => (None, Some(cred.prelogin_cookie()), None, None), + let (passwd, prelogin_cookie, portal_userauthcookie, portal_prelogonuserauthcookie, token) = match self { + Credential::Password(cred) => (Some(cred.password()), None, None, None, None), + Credential::PreloginCookie(cred) => (None, Some(cred.prelogin_cookie()), None, None, None), Credential::AuthCookie(cred) => ( None, None, Some(cred.user_auth_cookie()), Some(cred.prelogon_user_auth_cookie()), + None, ), + Credential::TokenCredential(cred) => (None, None, None, None, Some(cred.token())), Credential::CachedCredential(cred) => ( cred.password(), None, Some(cred.auth_cookie.user_auth_cookie()), Some(cred.auth_cookie.prelogon_user_auth_cookie()), + None, ), }; @@ -214,6 +246,10 @@ impl Credential { portal_prelogonuserauthcookie.unwrap_or_default(), ); + if let Some(token) = token { + params.insert("token", token); + } + params } } @@ -245,3 +281,23 @@ impl From<&CachedCredential> for Credential { Self::CachedCredential(value.clone()) } } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn cred_from_gpcallback_cas() { + let auth_data = "globalprotectcallback:cas-as=1&un=xyz@email.com&token=very_long_string"; + + let cred = Credential::from_gpcallback(auth_data).unwrap(); + + match cred { + Credential::TokenCredential(token_cred) => { + assert_eq!(token_cred.username(), "xyz@email.com"); + assert_eq!(token_cred.token(), "very_long_string"); + } + _ => panic!("Expected TokenCredential"), + } + } +}