Skip to content

Commit ea36f66

Browse files
committed
[cli] Add WorkOS identity provider support for Restate Cloud login
1 parent c457828 commit ea36f66

File tree

4 files changed

+332
-29
lines changed

4 files changed

+332
-29
lines changed

cli/src/clients/cloud/client.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ impl CloudClient {
114114

115115
Ok(Self {
116116
inner: raw_client,
117-
base_url: env.config.cloud.api_base_url.clone(),
117+
base_url: env.config.cloud.api_base_url(),
118118
access_token,
119119
request_timeout: CliContext::get().request_timeout(),
120120
})
Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
// Copyright (c) 2023 - 2025 Restate Software, Inc., Restate GmbH.
2+
// All rights reserved.
3+
//
4+
// Use of this software is governed by the Business Source License
5+
// included in the LICENSE file.
6+
//
7+
// As of the Change Date specified in that file, in accordance with
8+
// the Business Source License, use of this software will be governed
9+
// by the Apache License, Version 2.0.
10+
11+
use std::time::Duration;
12+
13+
use anyhow::{Context, Result};
14+
use serde::Deserialize;
15+
use tracing::debug;
16+
use url::Url;
17+
18+
#[derive(Debug, Deserialize)]
19+
pub struct DiscoveredAuthConfig {
20+
pub provider: String,
21+
pub client_id: String,
22+
pub login_base_url: Option<Url>,
23+
}
24+
25+
pub async fn fetch_auth_config(discovery_url: &Url) -> Result<DiscoveredAuthConfig> {
26+
debug!(%discovery_url, "Fetching authentication config");
27+
let client = reqwest::Client::builder()
28+
.timeout(Duration::from_secs(5))
29+
.build()
30+
.context("Failed to build HTTP client for config discovery")?;
31+
32+
let response = client
33+
.get(discovery_url.clone())
34+
.send()
35+
.await
36+
.with_context(|| {
37+
format!("Failed to discover authentication configuration at {discovery_url}")
38+
})?;
39+
40+
if !response.status().is_success() {
41+
anyhow::bail!(
42+
"Auth config discovery returned non-success status {} from {discovery_url}",
43+
response.status()
44+
);
45+
}
46+
47+
response
48+
.json::<DiscoveredAuthConfig>()
49+
.await
50+
.with_context(|| format!("Failed to parse auth config from {discovery_url}"))
51+
}

cli/src/commands/cloud/login.rs

Lines changed: 139 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -26,11 +26,13 @@ use crate::{
2626
clients::cloud::{CloudClient, CloudClientInterface},
2727
};
2828

29+
use super::IdentityProvider;
30+
2931
#[derive(Run, Parser, Collect, Clone)]
3032
#[cling(run = "run_login")]
3133
pub struct Login {}
3234

33-
pub async fn run_login(State(env): State<CliEnv>, opts: &Login) -> Result<()> {
35+
pub async fn run_login(State(env): State<CliEnv>, _opts: &Login) -> Result<()> {
3436
let config_data = if env.config_file.is_file() {
3537
std::fs::read_to_string(env.config_file.as_path())?
3638
} else {
@@ -41,7 +43,8 @@ pub async fn run_login(State(env): State<CliEnv>, opts: &Login) -> Result<()> {
4143
.parse::<DocumentMut>()
4244
.context("Failed to parse config file as TOML")?;
4345

44-
let access_token = auth_flow(&env, opts).await?;
46+
let identity_provider = env.config.cloud.resolve_identity_provider().await?;
47+
let access_token = auth_flow(&env, &identity_provider).await?;
4548

4649
write_access_token(&mut doc, &access_token)?;
4750

@@ -72,7 +75,17 @@ pub async fn run_login(State(env): State<CliEnv>, opts: &Login) -> Result<()> {
7275
Ok(())
7376
}
7477

75-
async fn auth_flow(env: &CliEnv, _opts: &Login) -> Result<String> {
78+
async fn auth_flow(env: &CliEnv, identity_provider: &IdentityProvider) -> Result<String> {
79+
match identity_provider {
80+
IdentityProvider::Cognito {
81+
client_id,
82+
login_base_url,
83+
} => cognito_auth_flow(env, client_id, login_base_url).await,
84+
IdentityProvider::WorkOS { client_id } => workos_auth_flow(client_id).await,
85+
}
86+
}
87+
88+
async fn cognito_auth_flow(env: &CliEnv, client_id: &str, login_base_url: &Url) -> Result<String> {
7689
let client = reqwest::Client::builder()
7790
.user_agent(format!(
7891
"{}/{} {}-{}",
@@ -86,19 +99,18 @@ async fn auth_flow(env: &CliEnv, _opts: &Login) -> Result<String> {
8699
.build()
87100
.context("Failed to build oauth token client")?;
88101

102+
let redirect_ports = env.config.cloud.redirect_ports();
89103
let mut i = 0;
90104
let listener = loop {
91-
if i >= env.config.cloud.redirect_ports.len() {
105+
if i >= redirect_ports.len() {
92106
return Err(anyhow!(
93107
"Failed to bind oauth callback server to localhost. Tried ports: [{:?}]",
94-
env.config.cloud.redirect_ports
108+
redirect_ports
95109
));
96110
}
97-
if let Ok(listener) = tokio::net::TcpListener::bind(SocketAddr::from((
98-
[127, 0, 0, 1],
99-
env.config.cloud.redirect_ports[i],
100-
)))
101-
.await
111+
if let Ok(listener) =
112+
tokio::net::TcpListener::bind(SocketAddr::from(([127, 0, 0, 1], redirect_ports[i])))
113+
.await
102114
{
103115
break listener;
104116
}
@@ -115,12 +127,12 @@ async fn auth_flow(env: &CliEnv, _opts: &Login) -> Result<String> {
115127

116128
let state = uuid::Uuid::now_v7().simple().to_string();
117129

118-
let mut login_uri = env.config.cloud.login_base_url.join("/login")?;
130+
let mut login_uri = login_base_url.join("/login")?;
119131
login_uri
120132
.query_pairs_mut()
121133
.clear()
122134
.append_pair("response_type", "code")
123-
.append_pair("client_id", &env.config.cloud.client_id)
135+
.append_pair("client_id", client_id)
124136
.append_pair("redirect_uri", &redirect_uri)
125137
.append_pair("state", &state)
126138
.append_pair("scope", "openid");
@@ -160,8 +172,8 @@ async fn auth_flow(env: &CliEnv, _opts: &Login) -> Result<String> {
160172
)
161173
.with_state(RedirectState {
162174
client,
163-
login_base_url: env.config.cloud.login_base_url.clone(),
164-
client_id: env.config.cloud.client_id.clone(),
175+
login_base_url: login_base_url.clone(),
176+
client_id: client_id.to_string(),
165177
redirect_uri,
166178
result_send,
167179
state,
@@ -276,3 +288,116 @@ fn write_access_token(doc: &mut DocumentMut, access_token: &str) -> Result<()> {
276288

277289
Ok(())
278290
}
291+
292+
async fn workos_auth_flow(client_id: &str) -> Result<String> {
293+
let client = reqwest::Client::builder()
294+
.user_agent(format!(
295+
"{}/{} {}-{}",
296+
env!("CARGO_PKG_NAME"),
297+
build_info::RESTATE_CLI_VERSION,
298+
std::env::consts::OS,
299+
std::env::consts::ARCH,
300+
))
301+
.https_only(true)
302+
.connect_timeout(CliContext::get().connect_timeout())
303+
.build()
304+
.context("Failed to build HTTP client")?;
305+
306+
let device_auth_response: DeviceAuthorizationResponse = client
307+
.post("https://api.workos.com/user_management/authorize/device")
308+
.form(&[("client_id", client_id)])
309+
.send()
310+
.await
311+
.context("Failed to request device authorization")?
312+
.error_for_status()
313+
.context("Bad status code from device authorization endpoint")?
314+
.json()
315+
.await
316+
.context("Failed to decode device authorization response")?;
317+
318+
c_println!(
319+
"Please visit {} and enter code: {}",
320+
device_auth_response.verification_uri,
321+
device_auth_response.user_code
322+
);
323+
324+
if let Err(_err) = open::that(device_auth_response.verification_uri_complete.clone()) {
325+
c_println!("Failed to open browser automatically. Please open the above URL manually.")
326+
}
327+
328+
let progress = ProgressBar::new_spinner();
329+
progress.set_style(indicatif::ProgressStyle::with_template("{spinner} {msg}").unwrap());
330+
progress.enable_steady_tick(std::time::Duration::from_millis(120));
331+
progress.set_message("Waiting for authentication...");
332+
333+
let mut interval = std::time::Duration::from_secs(device_auth_response.interval);
334+
let expires_at =
335+
std::time::Instant::now() + std::time::Duration::from_secs(device_auth_response.expires_in);
336+
337+
loop {
338+
if std::time::Instant::now() > expires_at {
339+
progress.finish_and_clear();
340+
return Err(anyhow!("Device authorization expired. Please try again."));
341+
}
342+
343+
tokio::time::sleep(interval).await;
344+
345+
let token_result: Result<WorkOSAuthenticateResponse, _> = client
346+
.post("https://api.workos.com/user_management/authenticate")
347+
.form(&[
348+
("grant_type", "urn:ietf:params:oauth:grant-type:device_code"),
349+
("device_code", &device_auth_response.device_code),
350+
("client_id", client_id),
351+
])
352+
.send()
353+
.await
354+
.context("Failed to poll for authentication")?
355+
.json()
356+
.await;
357+
358+
match token_result {
359+
Ok(response) if response.access_token.is_some() => {
360+
progress.finish_and_clear();
361+
return Ok(response.access_token.unwrap());
362+
}
363+
Ok(response) if response.error.as_deref() == Some("authorization_pending") => {
364+
continue;
365+
}
366+
Ok(response) if response.error.as_deref() == Some("slow_down") => {
367+
interval += std::time::Duration::from_secs(1);
368+
continue;
369+
}
370+
Ok(response) if response.error.is_some() => {
371+
progress.finish_and_clear();
372+
return Err(anyhow!(
373+
"Authentication failed: {}",
374+
response.error.unwrap_or_else(|| "unknown error".into())
375+
));
376+
}
377+
Ok(_) => {
378+
progress.finish_and_clear();
379+
return Err(anyhow!("Unexpected response from authentication endpoint"));
380+
}
381+
Err(err) => {
382+
progress.finish_and_clear();
383+
return Err(err.into());
384+
}
385+
}
386+
}
387+
}
388+
389+
#[derive(Deserialize)]
390+
struct DeviceAuthorizationResponse {
391+
device_code: String,
392+
user_code: String,
393+
verification_uri: String,
394+
verification_uri_complete: String,
395+
expires_in: u64,
396+
interval: u64,
397+
}
398+
399+
#[derive(Deserialize)]
400+
struct WorkOSAuthenticateResponse {
401+
access_token: Option<String>,
402+
error: Option<String>,
403+
}

0 commit comments

Comments
 (0)