@@ -26,11 +26,13 @@ use crate::{
2626 clients:: cloud:: { CloudClient , CloudClientInterface } ,
2727} ;
2828
29+ use super :: IdentityProvider ;
30+
2931#[ derive( Run , Parser , Collect , Clone ) ]
3032#[ cling( run = "run_login" ) ]
3133pub struct Login { }
3234
33- pub async fn run_login ( State ( env) : State < CliEnv > , opts : & Login ) -> Result < ( ) > {
35+ pub async fn run_login ( State ( env) : State < CliEnv > , _opts : & Login ) -> Result < ( ) > {
3436 let config_data = if env. config_file . is_file ( ) {
3537 std:: fs:: read_to_string ( env. config_file . as_path ( ) ) ?
3638 } else {
@@ -41,7 +43,8 @@ pub async fn run_login(State(env): State<CliEnv>, opts: &Login) -> Result<()> {
4143 . parse :: < DocumentMut > ( )
4244 . context ( "Failed to parse config file as TOML" ) ?;
4345
44- let access_token = auth_flow ( & env, opts) . await ?;
46+ let identity_provider = env. config . cloud . resolve_identity_provider ( ) . await ?;
47+ let access_token = auth_flow ( & env, & identity_provider) . await ?;
4548
4649 write_access_token ( & mut doc, & access_token) ?;
4750
@@ -72,7 +75,17 @@ pub async fn run_login(State(env): State<CliEnv>, opts: &Login) -> Result<()> {
7275 Ok ( ( ) )
7376}
7477
75- async fn auth_flow ( env : & CliEnv , _opts : & Login ) -> Result < String > {
78+ async fn auth_flow ( env : & CliEnv , identity_provider : & IdentityProvider ) -> Result < String > {
79+ match identity_provider {
80+ IdentityProvider :: Cognito {
81+ client_id,
82+ login_base_url,
83+ } => cognito_auth_flow ( env, client_id, login_base_url) . await ,
84+ IdentityProvider :: WorkOS { client_id } => workos_auth_flow ( client_id) . await ,
85+ }
86+ }
87+
88+ async fn cognito_auth_flow ( env : & CliEnv , client_id : & str , login_base_url : & Url ) -> Result < String > {
7689 let client = reqwest:: Client :: builder ( )
7790 . user_agent ( format ! (
7891 "{}/{} {}-{}" ,
@@ -86,19 +99,18 @@ async fn auth_flow(env: &CliEnv, _opts: &Login) -> Result<String> {
8699 . build ( )
87100 . context ( "Failed to build oauth token client" ) ?;
88101
102+ let redirect_ports = env. config . cloud . redirect_ports ( ) ;
89103 let mut i = 0 ;
90104 let listener = loop {
91- if i >= env . config . cloud . redirect_ports . len ( ) {
105+ if i >= redirect_ports. len ( ) {
92106 return Err ( anyhow ! (
93107 "Failed to bind oauth callback server to localhost. Tried ports: [{:?}]" ,
94- env . config . cloud . redirect_ports
108+ redirect_ports
95109 ) ) ;
96110 }
97- if let Ok ( listener) = tokio:: net:: TcpListener :: bind ( SocketAddr :: from ( (
98- [ 127 , 0 , 0 , 1 ] ,
99- env. config . cloud . redirect_ports [ i] ,
100- ) ) )
101- . await
111+ if let Ok ( listener) =
112+ tokio:: net:: TcpListener :: bind ( SocketAddr :: from ( ( [ 127 , 0 , 0 , 1 ] , redirect_ports[ i] ) ) )
113+ . await
102114 {
103115 break listener;
104116 }
@@ -115,12 +127,12 @@ async fn auth_flow(env: &CliEnv, _opts: &Login) -> Result<String> {
115127
116128 let state = uuid:: Uuid :: now_v7 ( ) . simple ( ) . to_string ( ) ;
117129
118- let mut login_uri = env . config . cloud . login_base_url . join ( "/login" ) ?;
130+ let mut login_uri = login_base_url. join ( "/login" ) ?;
119131 login_uri
120132 . query_pairs_mut ( )
121133 . clear ( )
122134 . append_pair ( "response_type" , "code" )
123- . append_pair ( "client_id" , & env . config . cloud . client_id )
135+ . append_pair ( "client_id" , client_id)
124136 . append_pair ( "redirect_uri" , & redirect_uri)
125137 . append_pair ( "state" , & state)
126138 . append_pair ( "scope" , "openid" ) ;
@@ -160,8 +172,8 @@ async fn auth_flow(env: &CliEnv, _opts: &Login) -> Result<String> {
160172 )
161173 . with_state ( RedirectState {
162174 client,
163- login_base_url : env . config . cloud . login_base_url . clone ( ) ,
164- client_id : env . config . cloud . client_id . clone ( ) ,
175+ login_base_url : login_base_url. clone ( ) ,
176+ client_id : client_id. to_string ( ) ,
165177 redirect_uri,
166178 result_send,
167179 state,
@@ -276,3 +288,116 @@ fn write_access_token(doc: &mut DocumentMut, access_token: &str) -> Result<()> {
276288
277289 Ok ( ( ) )
278290}
291+
292+ async fn workos_auth_flow ( client_id : & str ) -> Result < String > {
293+ let client = reqwest:: Client :: builder ( )
294+ . user_agent ( format ! (
295+ "{}/{} {}-{}" ,
296+ env!( "CARGO_PKG_NAME" ) ,
297+ build_info:: RESTATE_CLI_VERSION ,
298+ std:: env:: consts:: OS ,
299+ std:: env:: consts:: ARCH ,
300+ ) )
301+ . https_only ( true )
302+ . connect_timeout ( CliContext :: get ( ) . connect_timeout ( ) )
303+ . build ( )
304+ . context ( "Failed to build HTTP client" ) ?;
305+
306+ let device_auth_response: DeviceAuthorizationResponse = client
307+ . post ( "https://api.workos.com/user_management/authorize/device" )
308+ . form ( & [ ( "client_id" , client_id) ] )
309+ . send ( )
310+ . await
311+ . context ( "Failed to request device authorization" ) ?
312+ . error_for_status ( )
313+ . context ( "Bad status code from device authorization endpoint" ) ?
314+ . json ( )
315+ . await
316+ . context ( "Failed to decode device authorization response" ) ?;
317+
318+ c_println ! (
319+ "Please visit {} and enter code: {}" ,
320+ device_auth_response. verification_uri,
321+ device_auth_response. user_code
322+ ) ;
323+
324+ if let Err ( _err) = open:: that ( device_auth_response. verification_uri_complete . clone ( ) ) {
325+ c_println ! ( "Failed to open browser automatically. Please open the above URL manually." )
326+ }
327+
328+ let progress = ProgressBar :: new_spinner ( ) ;
329+ progress. set_style ( indicatif:: ProgressStyle :: with_template ( "{spinner} {msg}" ) . unwrap ( ) ) ;
330+ progress. enable_steady_tick ( std:: time:: Duration :: from_millis ( 120 ) ) ;
331+ progress. set_message ( "Waiting for authentication..." ) ;
332+
333+ let mut interval = std:: time:: Duration :: from_secs ( device_auth_response. interval ) ;
334+ let expires_at =
335+ std:: time:: Instant :: now ( ) + std:: time:: Duration :: from_secs ( device_auth_response. expires_in ) ;
336+
337+ loop {
338+ if std:: time:: Instant :: now ( ) > expires_at {
339+ progress. finish_and_clear ( ) ;
340+ return Err ( anyhow ! ( "Device authorization expired. Please try again." ) ) ;
341+ }
342+
343+ tokio:: time:: sleep ( interval) . await ;
344+
345+ let token_result: Result < WorkOSAuthenticateResponse , _ > = client
346+ . post ( "https://api.workos.com/user_management/authenticate" )
347+ . form ( & [
348+ ( "grant_type" , "urn:ietf:params:oauth:grant-type:device_code" ) ,
349+ ( "device_code" , & device_auth_response. device_code ) ,
350+ ( "client_id" , client_id) ,
351+ ] )
352+ . send ( )
353+ . await
354+ . context ( "Failed to poll for authentication" ) ?
355+ . json ( )
356+ . await ;
357+
358+ match token_result {
359+ Ok ( response) if response. access_token . is_some ( ) => {
360+ progress. finish_and_clear ( ) ;
361+ return Ok ( response. access_token . unwrap ( ) ) ;
362+ }
363+ Ok ( response) if response. error . as_deref ( ) == Some ( "authorization_pending" ) => {
364+ continue ;
365+ }
366+ Ok ( response) if response. error . as_deref ( ) == Some ( "slow_down" ) => {
367+ interval += std:: time:: Duration :: from_secs ( 1 ) ;
368+ continue ;
369+ }
370+ Ok ( response) if response. error . is_some ( ) => {
371+ progress. finish_and_clear ( ) ;
372+ return Err ( anyhow ! (
373+ "Authentication failed: {}" ,
374+ response. error. unwrap_or_else( || "unknown error" . into( ) )
375+ ) ) ;
376+ }
377+ Ok ( _) => {
378+ progress. finish_and_clear ( ) ;
379+ return Err ( anyhow ! ( "Unexpected response from authentication endpoint" ) ) ;
380+ }
381+ Err ( err) => {
382+ progress. finish_and_clear ( ) ;
383+ return Err ( err. into ( ) ) ;
384+ }
385+ }
386+ }
387+ }
388+
389+ #[ derive( Deserialize ) ]
390+ struct DeviceAuthorizationResponse {
391+ device_code : String ,
392+ user_code : String ,
393+ verification_uri : String ,
394+ verification_uri_complete : String ,
395+ expires_in : u64 ,
396+ interval : u64 ,
397+ }
398+
399+ #[ derive( Deserialize ) ]
400+ struct WorkOSAuthenticateResponse {
401+ access_token : Option < String > ,
402+ error : Option < String > ,
403+ }
0 commit comments