diff --git a/Cargo.toml b/Cargo.toml index ce5431d1d..4a4909c76 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,3 +4,5 @@ members = [ "oauth2-types", "matrix-authentication-service", ] + +resolver = "2" diff --git a/matrix-authentication-service/src/config/oauth2.rs b/matrix-authentication-service/src/config/oauth2.rs index d83e35930..015411b50 100644 --- a/matrix-authentication-service/src/config/oauth2.rs +++ b/matrix-authentication-service/src/config/oauth2.rs @@ -289,7 +289,7 @@ pub struct OAuth2ClientConfig { pub client_secret: Option, #[serde(default)] - pub redirect_uris: Option>, + pub redirect_uris: Vec, } #[derive(Debug, Error)] @@ -301,19 +301,20 @@ impl OAuth2ClientConfig { &'a self, suggested_uri: &'a Option, ) -> Result<&'a Url, InvalidRedirectUriError> { - match (suggested_uri, &self.redirect_uris) { - (None, None) => Err(InvalidRedirectUriError), - (None, Some(redirect_uris)) => { - redirect_uris.iter().next().ok_or(InvalidRedirectUriError) - } - (Some(suggested_uri), None) => Ok(suggested_uri), - (Some(suggested_uri), Some(redirect_uris)) => { - if redirect_uris.contains(suggested_uri) { - Ok(suggested_uri) - } else { - Err(InvalidRedirectUriError) - } - } + suggested_uri.as_ref().map_or_else( + || self.redirect_uris.get(0).ok_or(InvalidRedirectUriError), + |suggested_uri| self.check_redirect_uri(suggested_uri), + ) + } + + pub fn check_redirect_uri<'a>( + &self, + redirect_uri: &'a Url, + ) -> Result<&'a Url, InvalidRedirectUriError> { + if self.redirect_uris.contains(redirect_uri) { + Ok(redirect_uri) + } else { + Err(InvalidRedirectUriError) } } } @@ -464,11 +465,11 @@ mod tests { assert_eq!(config.clients[0].client_id, "hello"); assert_eq!( config.clients[0].redirect_uris, - Some(vec!["https://exemple.fr/callback".parse().unwrap()]) + vec!["https://exemple.fr/callback".parse().unwrap()] ); assert_eq!(config.clients[1].client_id, "world"); - assert_eq!(config.clients[1].redirect_uris, None); + assert_eq!(config.clients[1].redirect_uris, Vec::new()); Ok(()) }); diff --git a/matrix-authentication-service/src/handlers/oauth2/authorization.rs b/matrix-authentication-service/src/handlers/oauth2/authorization.rs index deea237ea..c4c2c1c4c 100644 --- a/matrix-authentication-service/src/handlers/oauth2/authorization.rs +++ b/matrix-authentication-service/src/handlers/oauth2/authorization.rs @@ -18,7 +18,6 @@ use std::{ }; use chrono::Duration; -use data_encoding::BASE64URL_NOPAD; use hyper::{ header::LOCATION, http::uri::{Parts, PathAndQuery, Uri}, @@ -26,18 +25,21 @@ use hyper::{ }; use itertools::Itertools; use oauth2_types::{ + errors::{ErrorResponse, InvalidRequest, OAuth2Error}, pkce, requests::{ AccessTokenResponse, AuthorizationRequest, AuthorizationResponse, ResponseMode, ResponseType, }, }; -use rand::thread_rng; +use rand::{distributions::Alphanumeric, thread_rng, Rng}; use serde::{Deserialize, Serialize}; +use serde_json::Value; use sqlx::{PgPool, Postgres, Transaction}; use url::Url; use warp::{ redirect::see_other, + reject::InvalidQuery, reply::{html, with_header}, Filter, Rejection, Reply, }; @@ -63,6 +65,26 @@ use crate::{ tokens, }; +#[derive(Deserialize)] +struct PartialParams { + client_id: Option, + redirect_uri: Option, + /* + response_type: Option, + response_mode: Option, + */ +} + +enum ReplyOrBackToClient { + Reply(Box), + BackToClient { + params: Value, + redirect_uri: Url, + response_mode: ResponseMode, + }, + Error(Box), +} + fn back_to_client( mut redirect_uri: Url, response_mode: ResponseMode, @@ -173,7 +195,6 @@ pub fn filter( .and(warp::query()) .and(with_optional_session(pool, cookies_config)) .and(with_transaction(pool)) - .and(with_templates(templates)) .and_then(get); let step = warp::path!("oauth2" / "authorize" / "step") @@ -181,10 +202,81 @@ pub fn filter( .and(warp::query().map(|s: StepRequest| s.id)) .and(with_session(pool, cookies_config)) .and(with_transaction(pool)) - .and(with_templates(templates)) .and_then(step); - authorize.or(step) + let clients = oauth2_config.clients.clone(); + authorize + .or(step) + .unify() + .recover(recover) + .unify() + .and(warp::query()) + .and(warp::any().map(move || clients.clone())) + .and(with_templates(templates)) + .and_then(actually_reply) +} + +async fn recover(rejection: Rejection) -> Result { + if rejection.find::().is_some() { + Ok(ReplyOrBackToClient::Error(Box::new(InvalidRequest))) + } else { + Err(rejection) + } +} + +async fn actually_reply( + rep: ReplyOrBackToClient, + q: PartialParams, + clients: Vec, + templates: Templates, +) -> Result { + let (redirect_uri, response_mode, params) = match rep { + ReplyOrBackToClient::Reply(r) => return Ok(r), + ReplyOrBackToClient::BackToClient { + redirect_uri, + response_mode, + params, + } => (redirect_uri, response_mode, params), + ReplyOrBackToClient::Error(error) => { + let PartialParams { + client_id, + redirect_uri, + .. + } = q; + + // First, disover the client + let client = client_id.and_then(|client_id| { + clients + .into_iter() + .find(|client| client.client_id == client_id) + }); + + let client = match client { + Some(client) => client, + None => return Ok(Box::new(html(templates.render_error(&error.into())?))), + }; + + let redirect_uri: Result, _> = redirect_uri.map(|r| r.parse()).transpose(); + let redirect_uri = match redirect_uri { + Ok(r) => r, + Err(_) => return Ok(Box::new(html(templates.render_error(&error.into())?))), + }; + + let redirect_uri = client.resolve_redirect_uri(&redirect_uri); + let redirect_uri = match redirect_uri { + Ok(r) => r, + Err(_) => return Ok(Box::new(html(templates.render_error(&error.into())?))), + }; + + let reply: ErrorResponse = error.into(); + let reply = serde_json::to_value(&reply).wrap_error()?; + // TODO: resolve response mode + (redirect_uri.clone(), ResponseMode::Query, reply) + } + }; + + // TODO: we should include the state param in errors + back_to_client(redirect_uri, response_mode, params, &templates).wrap_error() } async fn get( @@ -192,8 +284,7 @@ async fn get( params: Params, maybe_session: Option, mut txn: Transaction<'_, Postgres>, - templates: Templates, -) -> Result, Rejection> { +) -> Result { // First, find out what client it is let client = clients .into_iter() @@ -201,11 +292,6 @@ async fn get( .ok_or_else(|| anyhow::anyhow!("could not find client")) .wrap_error()?; - // Then, figure out the redirect URI - let redirect_uri = client - .resolve_redirect_uri(¶ms.auth.redirect_uri) - .wrap_error()?; - let maybe_session_id = maybe_session.as_ref().map(SessionInfo::key); let scope: String = { @@ -213,6 +299,9 @@ async fn get( Itertools::intersperse(it, " ".to_string()).collect() }; + let redirect_uri = client + .resolve_redirect_uri(¶ms.auth.redirect_uri) + .wrap_error()?; let response_type = ¶ms.auth.response_type; let response_mode = resolve_response_mode(response_type, params.auth.response_mode).wrap_error()?; @@ -234,9 +323,13 @@ async fn get( // Generate the code at this stage, since we have the PKCE params ready if response_type.contains(&ResponseType::Code) { - // 192bit random bytes encoded in base64, which gives a 32 character code - let code: [u8; 24] = rand::random(); - let code = BASE64URL_NOPAD.encode(&code); + // 32 random alphanumeric characters, about 190bit of entropy + let code: String = thread_rng() + .sample_iter(&Alphanumeric) + .take(32) + .map(char::from) + .collect(); + oauth2_session .add_code(&mut txn, &code, ¶ms.pkce) .await @@ -247,7 +340,7 @@ async fn get( let user_session = oauth2_session.fetch_session(&mut txn).await.wrap_error()?; if let Some(user_session) = user_session { - step(oauth2_session.id, user_session, txn, templates).await + step(oauth2_session.id, user_session, txn).await } else { // If not, redirect the user to the login page txn.commit().await.wrap_error()?; @@ -258,7 +351,7 @@ async fn get( .to_string(); let destination = LoginRequest::new(Some(next)).build_uri().wrap_error()?; - Ok(Box::new(see_other(destination))) + Ok(ReplyOrBackToClient::Reply(Box::new(see_other(destination)))) } } @@ -288,8 +381,7 @@ async fn step( oauth2_session_id: i64, user_session: SessionInfo, mut txn: Transaction<'_, Postgres>, - templates: Templates, -) -> Result, Rejection> { +) -> Result { let mut oauth2_session = get_session_by_id(&mut txn, oauth2_session_id) .await .wrap_error()?; @@ -350,11 +442,16 @@ async fn step( todo!("id tokens are not implemented yet"); } - back_to_client(redirect_uri, response_mode, params, &templates).wrap_error()? + let params = serde_json::to_value(¶ms).unwrap(); + ReplyOrBackToClient::BackToClient { + redirect_uri, + response_mode, + params, + } } else { // Ask for a reauth // TODO: have the OAuth2 session ID in there - Box::new(see_other(Uri::from_static("/reauth"))) + ReplyOrBackToClient::Reply(Box::new(see_other(Uri::from_static("/reauth")))) }; txn.commit().await.wrap_error()?; diff --git a/matrix-authentication-service/src/storage/oauth2/session.rs b/matrix-authentication-service/src/storage/oauth2/session.rs index d620e7f38..18b30bc0d 100644 --- a/matrix-authentication-service/src/storage/oauth2/session.rs +++ b/matrix-authentication-service/src/storage/oauth2/session.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::{collections::HashSet, convert::TryFrom, str::FromStr}; +use std::{collections::HashSet, convert::TryFrom, str::FromStr, string::ToString}; use anyhow::Context; use chrono::{DateTime, Duration, Utc}; @@ -34,7 +34,7 @@ use super::{ pub struct OAuth2Session { pub id: i64, user_session_id: Option, - client_id: String, + pub client_id: String, redirect_uri: String, scope: String, pub state: Option, @@ -143,6 +143,7 @@ pub async fn start_session( // Checked convertion of duration to i32, maxing at i32::MAX let max_age = max_age.map(|d| i32::try_from(d.num_seconds()).unwrap_or(i32::MAX)); let response_mode = response_mode.to_string(); + let redirect_uri = redirect_uri.to_string(); let response_type: String = { let it = response_type.iter().map(ToString::to_string); Itertools::intersperse(it, " ".to_string()).collect() @@ -162,7 +163,7 @@ pub async fn start_session( "#, optional_session_id, client_id, - redirect_uri.as_str(), + redirect_uri, scope, state, nonce, diff --git a/matrix-authentication-service/src/templates.rs b/matrix-authentication-service/src/templates.rs index 399e9392f..754ad9961 100644 --- a/matrix-authentication-service/src/templates.rs +++ b/matrix-authentication-service/src/templates.rs @@ -14,6 +14,7 @@ use std::{collections::HashSet, string::ToString, sync::Arc}; +use oauth2_types::errors::OAuth2Error; use serde::Serialize; use tera::{Context, Error as TeraError, Tera}; use thiserror::Error; @@ -146,6 +147,9 @@ register_templates! { /// Render the form used by the form_post response mode pub fn render_form_post(FormPostContext) { "form_post.html" } + + /// Render the HTML error page + pub fn render_error(ErrorContext) { "error.html" } } /// Helper trait to construct context wrappers @@ -254,3 +258,42 @@ impl FormPostContext { } } } + +#[derive(Default, Serialize)] +pub struct ErrorContext { + code: Option<&'static str>, + description: Option, + details: Option, +} + +impl ErrorContext { + pub fn new() -> Self { + Self::default() + } + + pub fn with_code(mut self, code: &'static str) -> Self { + self.code = Some(code); + self + } + + pub fn with_description(mut self, description: String) -> Self { + self.description = Some(description); + self + } + + #[allow(dead_code)] + pub fn with_details(mut self, details: String) -> Self { + self.details = Some(details); + self + } +} + +impl From> for ErrorContext { + fn from(err: Box) -> Self { + let mut ctx = ErrorContext::new().with_code(err.error()); + if let Some(desc) = err.description() { + ctx = ctx.with_description(desc); + } + ctx + } +} diff --git a/matrix-authentication-service/src/tokens.rs b/matrix-authentication-service/src/tokens.rs index 5c189212f..7e88c864c 100644 --- a/matrix-authentication-service/src/tokens.rs +++ b/matrix-authentication-service/src/tokens.rs @@ -91,7 +91,6 @@ pub enum TokenFormatError { InvalidCrc { expected: String, got: String }, } -#[allow(dead_code)] pub fn check(token: &str) -> Result { let split: Vec<&str> = token.split('_').collect(); let [prefix, random_part, crc]: [&str; 3] = split diff --git a/oauth2-types/src/errors.rs b/oauth2-types/src/errors.rs index 92d588510..a7d6cf2f4 100644 --- a/oauth2-types/src/errors.rs +++ b/oauth2-types/src/errors.rs @@ -16,7 +16,7 @@ use http::status::StatusCode; use serde::ser::{Serialize, SerializeMap}; use url::Url; -pub trait OAuth2Error: std::fmt::Debug { +pub trait OAuth2Error: std::fmt::Debug + Send + Sync { /// A single ASCII error code. /// /// Maps to the required "error" field. @@ -41,11 +41,11 @@ pub trait OAuth2Error: std::fmt::Debug { } /// Wraps the error with an `ErrorResponse` to help serializing. - fn into_response(self) -> ErrorResponse + fn into_response(self) -> ErrorResponse where - Self: Sized, + Self: Sized + 'static, { - ErrorResponse(self) + ErrorResponse(Box::new(self)) } } @@ -55,15 +55,15 @@ trait OAuth2ErrorCode: OAuth2Error { } #[derive(Debug)] -pub struct ErrorResponse(T); +pub struct ErrorResponse(Box); -impl OAuth2ErrorCode for ErrorResponse { - fn status(&self) -> StatusCode { - self.0.status() +impl From> for ErrorResponse { + fn from(b: Box) -> Self { + Self(b) } } -impl OAuth2Error for ErrorResponse { +impl OAuth2Error for ErrorResponse { fn error(&self) -> &'static str { self.0.error() } @@ -77,7 +77,7 @@ impl OAuth2Error for ErrorResponse { } } -impl Serialize for ErrorResponse { +impl Serialize for ErrorResponse { fn serialize(&self, serializer: S) -> Result where S: serde::Serializer,