Skip to content

Commit

Permalink
Add validity time option to the API (DIS-2749) (#299)
Browse files Browse the repository at this point in the history
Adds `valid_for_seconds` to the auth body request. We weren't actually
parsing the request body before, so we do that now.

Introduces a newtype to avoid mixing up the current time and expiration
times.

Closes #309
  • Loading branch information
paulgb authored Oct 12, 2024
1 parent 9e0856a commit c3ab7a1
Show file tree
Hide file tree
Showing 7 changed files with 165 additions and 40 deletions.
13 changes: 13 additions & 0 deletions crates/y-sweet-core/src/api_types.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,19 @@ pub struct AuthDocRequest {
pub user_id: Option<String>,
#[serde(default)]
pub metadata: HashMap<String, Value>,
#[serde(skip_serializing_if = "Option::is_none", rename = "validForSeconds")]
pub valid_for_seconds: Option<u64>,
}

impl Default for AuthDocRequest {
fn default() -> Self {
Self {
authorization: Authorization::Full,
user_id: None,
metadata: HashMap::new(),
valid_for_seconds: None,
}
}
}

#[derive(Serialize)]
Expand Down
57 changes: 41 additions & 16 deletions crates/y-sweet-core/src/auth.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,22 @@ use sha2::{Digest, Sha256};
use std::fmt::Display;
use thiserror::Error;

const EXPIRATION_MILLIS: u64 = 1000 * 60 * 60; // 60 minutes
pub const DEFAULT_EXPIRATION_SECONDS: u64 = 60 * 60; // 60 minutes

/// This newtype is introduced to distinguish between a u64 meant to represent the current time
/// (currently passed as a raw u64), and a u64 meant to represent an expiration time.
/// We introduce this to intentonally break callers to `gen_doc_token` that do not explicitly
/// update to pass an expiration time, so that calls that use the old signature to pass a current
/// time do not compile.
/// Unit is milliseconds since Jan 1, 1970.
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub struct ExpirationTimeEpochMillis(pub u64);

impl ExpirationTimeEpochMillis {
pub fn max() -> Self {
Self(u64::MAX)
}
}

/// This is a custom base64 encoder that is equivalent to BASE64URL_NOPAD for encoding,
/// but is tolerant when decoding of the “standard” alphabet and also of padding.
Expand Down Expand Up @@ -84,7 +99,7 @@ pub enum Permission {
#[derive(Serialize, Deserialize)]
pub struct Payload {
pub payload: Permission,
pub expiration_millis: Option<u64>,
pub expiration_millis: Option<ExpirationTimeEpochMillis>,
}

#[derive(Serialize, Deserialize)]
Expand Down Expand Up @@ -141,7 +156,10 @@ impl Payload {
}
}

pub fn new_with_expiration(payload: Permission, expiration_millis: u64) -> Self {
pub fn new_with_expiration(
payload: Permission,
expiration_millis: ExpirationTimeEpochMillis,
) -> Self {
Self {
payload,
expiration_millis: Some(expiration_millis),
Expand Down Expand Up @@ -256,7 +274,13 @@ impl Authenticator {

if expected_token != auth_req.token {
Err(AuthError::InvalidSignature)
} else if auth_req.payload.expiration_millis.unwrap_or(u64::MAX) < current_time {
} else if auth_req
.payload
.expiration_millis
.unwrap_or(ExpirationTimeEpochMillis::max())
.0
< current_time
{
Err(AuthError::Expired)
} else {
Ok(auth_req.payload)
Expand Down Expand Up @@ -289,12 +313,13 @@ impl Authenticator {
b64_encode(&self.private_key)
}

pub fn gen_doc_token(&self, doc_id: &str, current_time_epoch_millis: u64) -> String {
let expiration_time_epoch_millis = current_time_epoch_millis + EXPIRATION_MILLIS;
let payload = Payload::new_with_expiration(
Permission::Doc(doc_id.to_string()),
expiration_time_epoch_millis,
);
pub fn gen_doc_token(
&self,
doc_id: &str,
expiration_time: ExpirationTimeEpochMillis,
) -> String {
let payload =
Payload::new_with_expiration(Permission::Doc(doc_id.to_string()), expiration_time);
self.sign(payload)
}

Expand Down Expand Up @@ -361,10 +386,10 @@ mod tests {
#[test]
fn test_simple_auth() {
let authenticator = Authenticator::gen_key().unwrap();
let token = authenticator.gen_doc_token("doc123", 0);
let token = authenticator.gen_doc_token("doc123", ExpirationTimeEpochMillis(0));
assert_eq!(authenticator.verify_doc_token(&token, "doc123", 0), Ok(()));
assert_eq!(
authenticator.verify_doc_token(&token, "doc123", EXPIRATION_MILLIS + 1),
authenticator.verify_doc_token(&token, "doc123", DEFAULT_EXPIRATION_SECONDS + 1),
Err(AuthError::Expired)
);
assert_eq!(
Expand All @@ -378,7 +403,7 @@ mod tests {
let authenticator = Authenticator::gen_key()
.unwrap()
.with_key_id("myKeyId".try_into().unwrap());
let token = authenticator.gen_doc_token("doc123", 0);
let token = authenticator.gen_doc_token("doc123", ExpirationTimeEpochMillis(0));
assert!(
token.starts_with("myKeyId."),
"Token {} does not start with myKeyId.",
Expand Down Expand Up @@ -413,7 +438,7 @@ mod tests {
let authenticator = Authenticator::gen_key()
.unwrap()
.with_key_id("myKeyId".try_into().unwrap());
let token = authenticator.gen_doc_token("doc123", 0);
let token = authenticator.gen_doc_token("doc123", ExpirationTimeEpochMillis(0));
let token = token.replace("myKeyId.", "aDifferentKeyId.");
assert!(token.starts_with("aDifferentKeyId."));
assert_eq!(
Expand All @@ -427,7 +452,7 @@ mod tests {
let authenticator = Authenticator::gen_key()
.unwrap()
.with_key_id("myKeyId".try_into().unwrap());
let token = authenticator.gen_doc_token("doc123", 0);
let token = authenticator.gen_doc_token("doc123", ExpirationTimeEpochMillis(0));
let token = token.replace("myKeyId.", "");
assert_eq!(
authenticator.verify_doc_token(&token, "doc123", 0),
Expand All @@ -438,7 +463,7 @@ mod tests {
#[test]
fn test_unexpected_key_id() {
let authenticator = Authenticator::gen_key().unwrap();
let token = authenticator.gen_doc_token("doc123", 0);
let token = authenticator.gen_doc_token("doc123", ExpirationTimeEpochMillis(0));
let token = format!("unexpectedKeyId.{}", token);
assert_eq!(
authenticator.verify_doc_token(&token, "doc123", 0),
Expand Down
20 changes: 16 additions & 4 deletions crates/y-sweet-worker/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,10 @@ use std::collections::HashMap;
use worker::{event, Env};
use worker::{Date, Method, Request, Response, ResponseBody, Result, RouteContext, Router, Url};
use y_sweet_core::{
api_types::{validate_doc_name, ClientToken, DocCreationRequest, NewDocResponse},
auth::Authenticator,
api_types::{
validate_doc_name, AuthDocRequest, ClientToken, DocCreationRequest, NewDocResponse,
},
auth::{Authenticator, ExpirationTimeEpochMillis, DEFAULT_EXPIRATION_SECONDS},
doc_sync::DocWithSyncKv,
store::StoreError,
};
Expand Down Expand Up @@ -189,7 +191,7 @@ async fn auth_doc_handler(req: Request, ctx: RouteContext<ServerContext>) -> Res
}

async fn auth_doc(
req: Request,
mut req: Request,
mut ctx: RouteContext<ServerContext>,
) -> std::result::Result<ClientToken, Error> {
check_server_token(&req, ctx.data.auth()?)?;
Expand All @@ -205,10 +207,20 @@ async fn auth_doc(
return Err(Error::NoSuchDocument);
}

// Note: to preserve the existing behavior, we default to an empty request.
let body = req
.json::<AuthDocRequest>()
.await
.map_err(|_| Error::BadRequest)?;

let valid_time_seconds = body.valid_for_seconds.unwrap_or(DEFAULT_EXPIRATION_SECONDS);
let expiration_time =
ExpirationTimeEpochMillis(get_time_millis_since_epoch() + valid_time_seconds * 1000);

let token = ctx
.data
.auth()?
.map(|auth| auth.gen_doc_token(&doc_id, get_time_millis_since_epoch()));
.map(|auth| auth.gen_doc_token(&doc_id, expiration_time));

let url = if let Some(url_prefix) = &ctx.data.config.url_prefix {
let mut parsed = Url::parse(url_prefix).map_err(|_| Error::ConfigurationError {
Expand Down
23 changes: 13 additions & 10 deletions crates/y-sweet/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ use y_sweet_core::{
api_types::{
validate_doc_name, AuthDocRequest, ClientToken, DocCreationRequest, NewDocResponse,
},
auth::Authenticator,
auth::{Authenticator, ExpirationTimeEpochMillis, DEFAULT_EXPIRATION_SECONDS},
doc_connection::DocConnection,
doc_sync::DocWithSyncKv,
store::Store,
Expand Down Expand Up @@ -673,16 +673,22 @@ async fn auth_doc(
TypedHeader(host): TypedHeader<headers::Host>,
State(server_state): State<Arc<Server>>,
Path(doc_id): Path<String>,
Json(_body): Json<AuthDocRequest>,
body: Option<Json<AuthDocRequest>>,
) -> Result<Json<ClientToken>, AppError> {
server_state.check_auth(authorization)?;

let body = body.unwrap_or_default();

if !server_state.doc_exists(&doc_id).await {
Err((StatusCode::NOT_FOUND, anyhow!("Doc {} not found", doc_id)))?;
}

let valid_for_seconds = body.valid_for_seconds.unwrap_or(DEFAULT_EXPIRATION_SECONDS);
let expiration_time =
ExpirationTimeEpochMillis(current_time_epoch_millis() + valid_for_seconds * 1000);

let token = if let Some(auth) = &server_state.authenticator {
let token = auth.gen_doc_token(&doc_id, current_time_epoch_millis());
let token = auth.gen_doc_token(&doc_id, expiration_time);
Some(token)
} else {
None
Expand Down Expand Up @@ -740,11 +746,12 @@ mod test {
))),
State(Arc::new(server_state)),
Path(doc_id.clone()),
Json(AuthDocRequest {
Some(Json(AuthDocRequest {
authorization: Authorization::Full,
user_id: None,
metadata: HashMap::new(),
}),
valid_for_seconds: None,
})),
)
.await
.unwrap();
Expand Down Expand Up @@ -778,11 +785,7 @@ mod test {
))),
State(Arc::new(server_state)),
Path(doc_id.clone()),
Json(AuthDocRequest {
authorization: Authorization::Full,
user_id: None,
metadata: HashMap::new(),
}),
None,
)
.await
.unwrap();
Expand Down
33 changes: 24 additions & 9 deletions js-pkg/sdk/src/main.ts
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import { DocConnection } from './connection'
export { DocConnection } from './connection'
import { HttpClient } from './http'
import type { DocCreationResult, ClientToken, CheckStoreResult } from './types'
import type { DocCreationResult, ClientToken, CheckStoreResult, AuthDocRequest } from './types'
export type { DocCreationResult, ClientToken, CheckStoreResult } from './types'
export { type YSweetErrorPayload, YSweetError } from './error'
export { encodeClientToken, decodeClientToken } from './encoding'
Expand Down Expand Up @@ -70,14 +70,18 @@ export class DocumentManager {
* client.
*
* @param docId The ID of the document to get a token for.
* @param authDocRequest An optional {@link AuthDocRequest} providing options for the token request.
* @returns A {@link ClientToken} object containing the URL and token needed to connect to the document.
*/
public async getClientToken(docId: string | DocCreationResult): Promise<ClientToken> {
public async getClientToken(
docId: string | DocCreationResult,
authDocRequest?: AuthDocRequest,
): Promise<ClientToken> {
if (typeof docId !== 'string') {
docId = docId.docId
}

const result = await this.client.request(`doc/${docId}/auth`, 'POST', {})
const result = await this.client.request(`doc/${docId}/auth`, 'POST', authDocRequest ?? {})
if (!result.ok) {
throw new Error(`Failed to auth doc ${docId}: ${result.status} ${result.statusText}`)
}
Expand All @@ -91,11 +95,15 @@ export class DocumentManager {
* that one is created. If no docId is provided, a new document is created with a random ID.
*
* @param docId The ID of the document to get or create. If not provided, a new document with a random ID will be created.
* @param authDocRequest An optional {@link AuthDocRequest} providing options for the token request.
* @returns A {@link ClientToken} object containing the URL and token needed to connect to the document.
*/
public async getOrCreateDocAndToken(docId?: string): Promise<ClientToken> {
public async getOrCreateDocAndToken(
docId?: string,
authDocRequest?: AuthDocRequest,
): Promise<ClientToken> {
const result = await this.createDoc(docId)
return await this.getClientToken(result)
return await this.getClientToken(result, authDocRequest)
}

/**
Expand All @@ -120,8 +128,11 @@ export class DocumentManager {
return await connection.updateDoc(update)
}

public async getDocConnection(docId: string): Promise<DocConnection> {
const clientToken = await this.getClientToken(docId)
public async getDocConnection(
docId: string | DocCreationResult,
authDocRequest?: AuthDocRequest,
): Promise<DocConnection> {
const clientToken = await this.getClientToken(docId, authDocRequest)
return new DocConnection(clientToken)
}
}
Expand All @@ -132,29 +143,33 @@ export class DocumentManager {
*
* @param connectionString A connection string (starting with `ys://` or `yss://`) referring to a y-sweet server.
* @param docId The ID of the document to get or create. If not provided, a new document with a random ID will be created.
* @param authDocRequest An optional {@link AuthDocRequest} providing options for the token request.
* @returns A {@link ClientToken} object containing the URL and token needed to connect to the document.
*/
export async function getOrCreateDocAndToken(
connectionString: string,
docId?: string,
authDocRequest?: AuthDocRequest,
): Promise<ClientToken> {
const manager = new DocumentManager(connectionString)
return await manager.getOrCreateDocAndToken(docId)
return await manager.getOrCreateDocAndToken(docId, authDocRequest)
}

/**
* A convenience wrapper around {@link DocumentManager.getClientToken} for getting a client token for a document.
*
* @param connectionString A connection string (starting with `ys://` or `yss://`) referring to a y-sweet server.
* @param docId The ID of the document to get a token for.
* @param authDocRequest An optional {@link AuthDocRequest} providing options for the token request.
* @returns A {@link ClientToken} object containing the URL and token needed to connect to the document.
*/
export async function getClientToken(
connectionString: string,
docId: string | DocCreationResult,
authDocRequest?: AuthDocRequest,
): Promise<ClientToken> {
const manager = new DocumentManager(connectionString)
return await manager.getClientToken(docId)
return await manager.getClientToken(docId, authDocRequest)
}

/**
Expand Down
16 changes: 16 additions & 0 deletions js-pkg/sdk/src/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,19 @@ export type ClientToken = {
}

export type CheckStoreResult = { ok: true } | { ok: false; error: string }

export type Authorization = 'full' | 'read-only'

export type AuthDocRequest = {
/** The authorization level to use for the document. Defaults to 'full' (not currently enforced). */
authorization?: Authorization

/** A user ID to associate with the token. Not currently used. */
userId?: string

/** Metadata to associate with the user accessing the document. Not currently used. */
metadata?: Record<string, any>

/** The number of seconds the token should be valid for. */
validForSeconds?: number
}
Loading

0 comments on commit c3ab7a1

Please sign in to comment.