@@ -17,10 +17,57 @@ use crate::signature::Signature;
17
17
use chrono:: prelude:: * ;
18
18
use serde:: Deserialize ;
19
19
20
- use quick_xml:: events:: { BytesEnd , BytesStart , BytesText , Event } ;
20
+ use quick_xml:: events:: { BytesDecl , BytesEnd , BytesStart , BytesText , Event } ;
21
21
use quick_xml:: Writer ;
22
22
23
23
use std:: io:: Cursor ;
24
+ use std:: str:: FromStr ;
25
+
26
+ use thiserror:: Error ;
27
+
28
+ #[ derive( Clone , Debug , Deserialize , Hash , Eq , PartialEq , Ord , PartialOrd ) ]
29
+ pub struct NameID {
30
+ #[ serde( rename = "Format" ) ]
31
+ pub format : Option < String > ,
32
+
33
+ #[ serde( rename = "$value" ) ]
34
+ pub value : String ,
35
+ }
36
+
37
+ impl NameID {
38
+ fn name ( ) -> & ' static str {
39
+ "saml2:NameID"
40
+ }
41
+
42
+ fn schema ( ) -> & ' static [ ( & ' static str , & ' static str ) ] {
43
+ & [ ( "xmlns:saml2" , "urn:oasis:names:tc:SAML:2.0:assertion" ) ]
44
+ }
45
+ }
46
+
47
+ impl TryFrom < & NameID > for Event < ' _ > {
48
+ type Error = Box < dyn std:: error:: Error > ;
49
+
50
+ fn try_from ( value : & NameID ) -> Result < Self , Self :: Error > {
51
+ let mut write_buf = Vec :: new ( ) ;
52
+ let mut writer = Writer :: new ( Cursor :: new ( & mut write_buf) ) ;
53
+ let mut root = BytesStart :: new ( NameID :: name ( ) ) ;
54
+
55
+ for attr in NameID :: schema ( ) {
56
+ root. push_attribute ( ( attr. 0 , attr. 1 ) ) ;
57
+ }
58
+
59
+ if let Some ( format) = & value. format {
60
+ root. push_attribute ( ( "Format" , format. as_ref ( ) ) ) ;
61
+ }
62
+
63
+ writer. write_event ( Event :: Start ( root) ) ?;
64
+ writer. write_event ( Event :: Text ( BytesText :: from_escaped ( value. value . as_str ( ) ) ) ) ?;
65
+ writer. write_event ( Event :: End ( BytesEnd :: new ( NameID :: name ( ) ) ) ) ?;
66
+ Ok ( Event :: Text ( BytesText :: from_escaped ( String :: from_utf8 (
67
+ write_buf,
68
+ ) ?) ) )
69
+ }
70
+ }
24
71
25
72
#[ derive( Clone , Debug , Deserialize , Hash , Eq , PartialEq , Ord , PartialOrd ) ]
26
73
pub struct LogoutRequest {
@@ -38,6 +85,81 @@ pub struct LogoutRequest {
38
85
pub signature : Option < Signature > ,
39
86
#[ serde( rename = "@SessionIndex" ) ]
40
87
pub session_index : Option < String > ,
88
+ #[ serde( rename = "NameID" ) ]
89
+ pub name_id : Option < NameID > ,
90
+ }
91
+
92
+ #[ derive( Debug , Error ) ]
93
+ pub enum LogoutRequestError {
94
+ #[ error( "Failed to deserialize LogoutRequest: {:?}" , source) ]
95
+ ParseError {
96
+ #[ from]
97
+ source : quick_xml:: DeError ,
98
+ } ,
99
+ }
100
+
101
+ impl FromStr for LogoutRequest {
102
+ type Err = LogoutRequestError ;
103
+
104
+ fn from_str ( s : & str ) -> Result < Self , Self :: Err > {
105
+ Ok ( quick_xml:: de:: from_str ( s) ?)
106
+ }
107
+ }
108
+
109
+ const LOGOUT_REQUEST_NAME : & str = "saml2p:LogoutRequest" ;
110
+ const SESSION_INDEX_NAME : & str = "saml2p:SessionIndex" ;
111
+ const PROTOCOL_SCHEMA : ( & str , & str ) = ( "xmlns:saml2p" , "urn:oasis:names:tc:SAML:2.0:protocol" ) ;
112
+
113
+ impl LogoutRequest {
114
+ pub fn to_xml ( & self ) -> Result < String , Box < dyn std:: error:: Error > > {
115
+ let mut write_buf = Vec :: new ( ) ;
116
+ let mut writer = Writer :: new ( Cursor :: new ( & mut write_buf) ) ;
117
+ writer. write_event ( Event :: Decl ( BytesDecl :: new ( "1.0" , Some ( "UTF-8" ) , None ) ) ) ?;
118
+
119
+ let mut root = BytesStart :: new ( LOGOUT_REQUEST_NAME ) ;
120
+ root. push_attribute ( PROTOCOL_SCHEMA ) ;
121
+ if let Some ( id) = & self . id {
122
+ root. push_attribute ( ( "ID" , id. as_ref ( ) ) ) ;
123
+ }
124
+ if let Some ( version) = & self . version {
125
+ root. push_attribute ( ( "Version" , version. as_ref ( ) ) ) ;
126
+ }
127
+ if let Some ( issue_instant) = & self . issue_instant {
128
+ root. push_attribute ( (
129
+ "IssueInstant" ,
130
+ issue_instant
131
+ . to_rfc3339_opts ( SecondsFormat :: Millis , true )
132
+ . as_ref ( ) ,
133
+ ) ) ;
134
+ }
135
+ if let Some ( destination) = & self . destination {
136
+ root. push_attribute ( ( "Destination" , destination. as_ref ( ) ) ) ;
137
+ }
138
+
139
+ writer. write_event ( Event :: Start ( root) ) ?;
140
+
141
+ if let Some ( issuer) = & self . issuer {
142
+ let event: Event < ' _ > = issuer. try_into ( ) ?;
143
+ writer. write_event ( event) ?;
144
+ }
145
+ if let Some ( signature) = & self . signature {
146
+ let event: Event < ' _ > = signature. try_into ( ) ?;
147
+ writer. write_event ( event) ?;
148
+ }
149
+
150
+ if let Some ( session) = & self . session_index {
151
+ writer. write_event ( Event :: Start ( BytesStart :: new ( SESSION_INDEX_NAME ) ) ) ?;
152
+ writer. write_event ( Event :: Text ( BytesText :: new ( session) ) ) ?;
153
+ writer. write_event ( Event :: End ( BytesEnd :: new ( SESSION_INDEX_NAME ) ) ) ?;
154
+ }
155
+ if let Some ( name_id) = & self . name_id {
156
+ let event: Event < ' _ > = name_id. try_into ( ) ?;
157
+ writer. write_event ( event) ?;
158
+ }
159
+
160
+ writer. write_event ( Event :: End ( BytesEnd :: new ( LOGOUT_REQUEST_NAME ) ) ) ?;
161
+ Ok ( String :: from_utf8 ( write_buf) ?)
162
+ }
41
163
}
42
164
43
165
#[ derive( Clone , Debug , Deserialize , Hash , Eq , PartialEq , Ord , PartialOrd ) ]
@@ -475,3 +597,120 @@ pub struct LogoutResponse {
475
597
#[ serde( rename = "Status" ) ]
476
598
pub status : Option < Status > ,
477
599
}
600
+
601
+ #[ derive( Debug , Error ) ]
602
+ pub enum LogoutResponseError {
603
+ #[ error( "Failed to deserialize LogoutResponse: {:?}" , source) ]
604
+ ParseError {
605
+ #[ from]
606
+ source : quick_xml:: DeError ,
607
+ } ,
608
+ }
609
+
610
+ impl FromStr for LogoutResponse {
611
+ type Err = LogoutResponseError ;
612
+
613
+ fn from_str ( s : & str ) -> Result < Self , Self :: Err > {
614
+ Ok ( quick_xml:: de:: from_str ( s) ?)
615
+ }
616
+ }
617
+
618
+ const LOGOUT_RESPONSE_NAME : & str = "saml2p:LogoutResponse" ;
619
+
620
+ impl LogoutResponse {
621
+ pub fn to_xml ( & self ) -> Result < String , Box < dyn std:: error:: Error > > {
622
+ let mut write_buf = Vec :: new ( ) ;
623
+ let mut writer = Writer :: new ( Cursor :: new ( & mut write_buf) ) ;
624
+ writer. write_event ( Event :: Decl ( BytesDecl :: new ( "1.0" , Some ( "UTF-8" ) , None ) ) ) ?;
625
+
626
+ let mut root = BytesStart :: new ( LOGOUT_RESPONSE_NAME ) ;
627
+ root. push_attribute ( PROTOCOL_SCHEMA ) ;
628
+ if let Some ( id) = & self . id {
629
+ root. push_attribute ( ( "ID" , id. as_ref ( ) ) ) ;
630
+ }
631
+ if let Some ( resp_to) = & self . in_response_to {
632
+ root. push_attribute ( ( "InResponseTo" , resp_to. as_ref ( ) ) ) ;
633
+ }
634
+ if let Some ( version) = & self . version {
635
+ root. push_attribute ( ( "Version" , version. as_ref ( ) ) ) ;
636
+ }
637
+ if let Some ( issue_instant) = & self . issue_instant {
638
+ root. push_attribute ( (
639
+ "IssueInstant" ,
640
+ issue_instant
641
+ . to_rfc3339_opts ( SecondsFormat :: Millis , true )
642
+ . as_ref ( ) ,
643
+ ) ) ;
644
+ }
645
+ if let Some ( destination) = & self . destination {
646
+ root. push_attribute ( ( "Destination" , destination. as_ref ( ) ) ) ;
647
+ }
648
+ if let Some ( consent) = & self . consent {
649
+ root. push_attribute ( ( "Consent" , consent. as_ref ( ) ) ) ;
650
+ }
651
+
652
+ writer. write_event ( Event :: Start ( root) ) ?;
653
+
654
+ if let Some ( issuer) = & self . issuer {
655
+ let event: Event < ' _ > = issuer. try_into ( ) ?;
656
+ writer. write_event ( event) ?;
657
+ }
658
+ if let Some ( signature) = & self . signature {
659
+ let event: Event < ' _ > = signature. try_into ( ) ?;
660
+ writer. write_event ( event) ?;
661
+ }
662
+
663
+ if let Some ( status) = & self . status {
664
+ let event: Event < ' _ > = status. try_into ( ) ?;
665
+ writer. write_event ( event) ?;
666
+ }
667
+
668
+ writer. write_event ( Event :: End ( BytesEnd :: new ( LOGOUT_RESPONSE_NAME ) ) ) ?;
669
+ Ok ( String :: from_utf8 ( write_buf) ?)
670
+ }
671
+ }
672
+
673
+ #[ cfg( test) ]
674
+ mod test {
675
+ use super :: issuer:: Issuer ;
676
+ use super :: { LogoutRequest , LogoutResponse , NameID , Status , StatusCode } ;
677
+ use chrono:: TimeZone ;
678
+
679
+ #[ test]
680
+ fn test_deserialize_serialize_logout_request ( ) {
681
+ let request_xml = include_str ! ( concat!(
682
+ env!( "CARGO_MANIFEST_DIR" ) ,
683
+ "/test_vectors/logout_request.xml" ,
684
+ ) ) ;
685
+ let expected_request: LogoutRequest = request_xml
686
+ . parse ( )
687
+ . expect ( "failed to parse logout_request.xml" ) ;
688
+ let serialized_request = expected_request
689
+ . to_xml ( )
690
+ . expect ( "failed to convert request to xml" ) ;
691
+ let actual_request: LogoutRequest = serialized_request
692
+ . parse ( )
693
+ . expect ( "failed to re-parse request" ) ;
694
+
695
+ assert_eq ! ( expected_request, actual_request) ;
696
+ }
697
+
698
+ #[ test]
699
+ fn test_deserialize_serialize_logout_response ( ) {
700
+ let response_xml = include_str ! ( concat!(
701
+ env!( "CARGO_MANIFEST_DIR" ) ,
702
+ "/test_vectors/logout_response.xml" ,
703
+ ) ) ;
704
+ let expected_response: LogoutResponse = response_xml
705
+ . parse ( )
706
+ . expect ( "failed to parse logout_response.xml" ) ;
707
+ let serialized_response = expected_response
708
+ . to_xml ( )
709
+ . expect ( "failed to convert Response to xml" ) ;
710
+ let actual_response: LogoutResponse = serialized_response
711
+ . parse ( )
712
+ . expect ( "failed to re-parse Response" ) ;
713
+
714
+ assert_eq ! ( expected_response, actual_response) ;
715
+ }
716
+ }
0 commit comments