From 1e464b40d912ba992fd91986ed3dcc8d784f0de1 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Tue, 5 Apr 2022 12:08:56 +0200 Subject: [PATCH] Axum migration: /oauth2/token --- crates/handlers/src/lib.rs | 1 + crates/handlers/src/oauth2/authorization.rs | 2 +- crates/handlers/src/oauth2/mod.rs | 2 +- crates/handlers/src/oauth2/token.rs | 315 +++++++++----------- crates/oauth2-types/src/errors.rs | 50 +++- 5 files changed, 188 insertions(+), 182 deletions(-) diff --git a/crates/handlers/src/lib.rs b/crates/handlers/src/lib.rs index b42acf588..23f44a727 100644 --- a/crates/handlers/src/lib.rs +++ b/crates/handlers/src/lib.rs @@ -94,6 +94,7 @@ where "/oauth2/introspect", post(self::oauth2::introspection::post), ) + .route("/oauth2/token", post(self::oauth2::token::post)) .fallback(mas_static_files::Assets) .layer(Extension(pool.clone())) .layer(Extension(templates.clone())) diff --git a/crates/handlers/src/oauth2/authorization.rs b/crates/handlers/src/oauth2/authorization.rs index ced27600d..4de344f23 100644 --- a/crates/handlers/src/oauth2/authorization.rs +++ b/crates/handlers/src/oauth2/authorization.rs @@ -1,4 +1,4 @@ -// Copyright 2021 The Matrix.org Foundation C.I.C. +// Copyright 2021, 2022 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. diff --git a/crates/handlers/src/oauth2/mod.rs b/crates/handlers/src/oauth2/mod.rs index a225b50bf..e32553a3a 100644 --- a/crates/handlers/src/oauth2/mod.rs +++ b/crates/handlers/src/oauth2/mod.rs @@ -16,7 +16,7 @@ pub mod discovery; pub mod introspection; pub mod keys; -// pub mod token; +pub mod token; pub mod userinfo; use hyper::{ diff --git a/crates/handlers/src/oauth2/token.rs b/crates/handlers/src/oauth2/token.rs index 1a1a294fb..ff0eb47f9 100644 --- a/crates/handlers/src/oauth2/token.rs +++ b/crates/handlers/src/oauth2/token.rs @@ -1,4 +1,4 @@ -// Copyright 2021 The Matrix.org Foundation C.I.C. +// Copyright 2021, 2022 The Matrix.org Foundation C.I.C. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -12,21 +12,30 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::{collections::HashMap, convert::Infallible, sync::Arc}; +use std::{collections::HashMap, sync::Arc}; use anyhow::Context; +use axum::{extract::Extension, response::IntoResponse, Json}; use chrono::{DateTime, Duration, Utc}; use data_encoding::BASE64URL_NOPAD; -use headers::{CacheControl, Pragma}; +use headers::{CacheControl, HeaderMap, HeaderMapExt, Pragma}; use hyper::StatusCode; -use mas_config::{Encrypter, HttpConfig}; +use mas_axum_utils::{ + client_authorization::{ClientAuthorization, CredentialsVerificationError}, + UrlBuilder, +}; +use mas_config::Encrypter; use mas_data_model::{AuthorizationGrantStage, Client, TokenType}; -use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod}; -use mas_jose::{claims, DecodedJsonWebToken, SigningKeystore, StaticKeystore}; +use mas_iana::jose::JsonWebSignatureAlg; +use mas_jose::{ + claims::{self, ClaimError}, + DecodedJsonWebToken, SigningKeystore, StaticKeystore, +}; use mas_storage::{ oauth2::{ access_token::{add_access_token, revoke_access_token}, authorization_grant::{exchange_grant, lookup_grant_by_code}, + client::ClientFetchError, end_oauth_session, refresh_token::{ add_refresh_token, lookup_active_refresh_token, replace_refresh_token, @@ -35,15 +44,7 @@ use mas_storage::{ }, DatabaseInconsistencyError, PostgresqlBackend, }; -use mas_warp_utils::{ - errors::WrapError, - filters::{self, client::client_authentication, database::connection, url_builder::UrlBuilder}, - reply::with_typed_header, -}; use oauth2_types::{ - errors::{ - InvalidGrant, InvalidRequest, OAuth2Error, OAuth2ErrorCode, ServerError, UnauthorizedClient, - }, requests::{ AccessTokenRequest, AccessTokenResponse, AuthorizationCodeGrant, RefreshTokenGrant, }, @@ -53,15 +54,9 @@ use rand::thread_rng; use serde::Serialize; use serde_with::{serde_as, skip_serializing_none}; use sha2::{Digest, Sha256}; -use sqlx::{pool::PoolConnection, Acquire, PgPool, Postgres}; +use sqlx::{PgPool, Postgres, Transaction}; use tracing::debug; use url::Url; -use warp::{ - filters::BoxedFilter, - reject::Reject, - reply::{json, with_status}, - Filter, Rejection, Reply, -}; #[serde_as] #[skip_serializing_none] @@ -80,96 +75,107 @@ struct CustomClaims { c_hash: String, } -#[derive(Debug)] -struct Error { - json: serde_json::Value, - status: StatusCode, +pub(crate) enum RouteError { + Internal(Box), + Anyhow(anyhow::Error), + BadRequest, + ClientNotFound, + ClientNotAllowed, + ClientCredentialsVerification(CredentialsVerificationError), + InvalidGrant, + UnauthorizedClient, } -impl Reject for Error {} - -fn error(e: E) -> Result -where - E: OAuth2ErrorCode + 'static, -{ - let status = e.status(); - let json = serde_json::to_value(e.into_response()).wrap_error()?; - Err(Error { json, status }.into()) -} - -pub fn filter( - pool: &PgPool, - encrypter: &Encrypter, - key_store: &Arc, - http_config: &HttpConfig, -) -> BoxedFilter<(Box,)> { - let key_store = key_store.clone(); - let builder = UrlBuilder::from(http_config); - let audience = builder.oauth_token_endpoint().to_string(); - - let issuer = builder.oidc_issuer(); - - warp::path!("oauth2" / "token") - .and(filters::trace::name("POST /oauth2/token")) - .and( - warp::post() - .and(client_authentication(pool, encrypter, audience)) - .and(warp::any().map(move || key_store.clone())) - .and(warp::any().map(move || issuer.clone())) - .and(connection(pool)) - .and_then(token) - .recover(recover) - .unify(), - ) - .boxed() -} - -async fn recover(rejection: Rejection) -> Result, Infallible> { - fn reply(err: E) -> Box { - let status = err.status(); - Box::new(with_status(warp::reply::json(&err.into_response()), status)) - } - - if let Some(Error { json, status }) = rejection.find::() { - return Ok(Box::new(with_status(warp::reply::json(json), *status))); - } - - if let Some(e) = rejection.find::() { +impl From for RouteError { + fn from(e: ClientFetchError) -> Self { if e.not_found() { - return Ok(reply(InvalidGrant)); + Self::ClientNotFound + } else { + Self::Internal(Box::new(e)) } - }; - - Ok(reply(ServerError)) + } } -async fn token( - _auth: OAuthClientAuthenticationMethod, - client: Client, - req: AccessTokenRequest, - key_store: Arc, - issuer: Url, - mut conn: PoolConnection, -) -> Result, Rejection> { - let reply = match req { +impl From for RouteError { + fn from(e: RefreshTokenLookupError) -> Self { + if e.not_found() { + Self::InvalidGrant + } else { + Self::Internal(Box::new(e)) + } + } +} + +impl IntoResponse for RouteError { + fn into_response(self) -> axum::response::Response { + // TODO + StatusCode::INTERNAL_SERVER_ERROR.into_response() + } +} + +impl From for RouteError { + fn from(e: sqlx::Error) -> Self { + Self::Internal(Box::new(e)) + } +} + +impl From for RouteError { + fn from(e: ClaimError) -> Self { + Self::Internal(Box::new(e)) + } +} + +impl From for RouteError { + fn from(e: anyhow::Error) -> Self { + Self::Anyhow(e) + } +} + +impl From for RouteError { + fn from(e: CredentialsVerificationError) -> Self { + Self::ClientCredentialsVerification(e) + } +} + +pub(crate) async fn post( + client_authorization: ClientAuthorization, + Extension(key_store): Extension>, + Extension(url_builder): Extension, + Extension(pool): Extension, + Extension(encrypter): Extension, +) -> Result { + let mut txn = pool.begin().await?; + + let client = client_authorization.credentials.fetch(&mut txn).await?; + + let method = client + .token_endpoint_auth_method + .ok_or(RouteError::ClientNotAllowed)?; + + client_authorization + .credentials + .verify(&encrypter, method, &client) + .await?; + + let form = client_authorization.form.ok_or(RouteError::BadRequest)?; + + let reply = match form { AccessTokenRequest::AuthorizationCode(grant) => { - let reply = - authorization_code_grant(&grant, &client, &key_store, issuer, &mut conn).await?; - json(&reply) + authorization_code_grant(&grant, &client, &key_store, &url_builder, txn).await? } AccessTokenRequest::RefreshToken(grant) => { - let reply = refresh_token_grant(&grant, &client, &mut conn).await?; - json(&reply) + refresh_token_grant(&grant, &client, txn).await? } _ => { - let reply = InvalidGrant.into_response(); - json(&reply) + return Err(RouteError::InvalidGrant); } }; - let reply = with_typed_header(CacheControl::new().with_no_store(), reply); - let reply = with_typed_header(Pragma::no_cache(), reply); - Ok(Box::new(reply)) + let mut headers = HeaderMap::new(); + headers.typed_insert(CacheControl::new().with_no_store()); + headers.typed_insert(Pragma::no_cache()); + + Ok((StatusCode::OK, headers, Json(reply))) } fn hash(mut hasher: H, token: &str) -> anyhow::Result { @@ -187,16 +193,12 @@ async fn authorization_code_grant( grant: &AuthorizationCodeGrant, client: &Client, key_store: &StaticKeystore, - issuer: Url, - conn: &mut PoolConnection, -) -> Result { + url_builder: &UrlBuilder, + mut txn: Transaction<'_, Postgres>, +) -> Result { // TODO: there is a bunch of unnecessary cloning here - let mut txn = conn.begin().await.wrap_error()?; - // TODO: handle "not found" cases - let authz_grant = lookup_grant_by_code(&mut txn, &grant.code) - .await - .wrap_error()?; + let authz_grant = lookup_grant_by_code(&mut txn, &grant.code).await?; // TODO: that's not a timestamp from the DB. Let's assume they are in sync let now = Utc::now(); @@ -204,7 +206,7 @@ async fn authorization_code_grant( let session = match authz_grant.stage { AuthorizationGrantStage::Cancelled { cancelled_at } => { debug!(%cancelled_at, "Authorization grant was cancelled"); - return error(InvalidGrant); + return Err(RouteError::InvalidGrant); } AuthorizationGrantStage::Exchanged { exchanged_at, @@ -216,15 +218,15 @@ async fn authorization_code_grant( // Ending the session if the token was already exchanged more than 20s ago if now - exchanged_at > Duration::seconds(20) { debug!("Ending potentially compromised session"); - end_oauth_session(&mut txn, session).await.wrap_error()?; - txn.commit().await.wrap_error()?; + end_oauth_session(&mut txn, session).await?; + txn.commit().await?; } - return error(InvalidGrant); + return Err(RouteError::InvalidGrant); } AuthorizationGrantStage::Pending => { debug!("Authorization grant has not been fulfilled yet"); - return error(InvalidGrant); + return Err(RouteError::InvalidGrant); } AuthorizationGrantStage::Fulfilled { ref session, @@ -232,7 +234,7 @@ async fn authorization_code_grant( } => { if now - fulfilled_at > Duration::minutes(10) { debug!("Code exchange took more than 10 minutes"); - return error(InvalidGrant); + return Err(RouteError::InvalidGrant); } session @@ -243,21 +245,20 @@ async fn authorization_code_grant( let code = authz_grant .code .as_ref() - .ok_or(DatabaseInconsistencyError) - .wrap_error()?; + .ok_or_else(|| anyhow::anyhow!(DatabaseInconsistencyError))?; if client.client_id != session.client.client_id { - return error(UnauthorizedClient); + return Err(RouteError::UnauthorizedClient); } match (code.pkce.as_ref(), grant.code_verifier.as_ref()) { (None, None) => {} // We have a challenge but no verifier (or vice-versa)? Bad request. - (Some(_), None) | (None, Some(_)) => return error(InvalidRequest), + (Some(_), None) | (None, Some(_)) => return Err(RouteError::BadRequest), // If we have both, we need to check the code validity (Some(pkce), Some(verifier)) => { if !pkce.verify(verifier) { - return error(InvalidRequest); + return Err(RouteError::BadRequest); } } }; @@ -273,58 +274,33 @@ async fn authorization_code_grant( ) }; - let access_token = add_access_token(&mut txn, session, &access_token_str, ttl) - .await - .wrap_error()?; + let access_token = add_access_token(&mut txn, session, &access_token_str, ttl).await?; - let _refresh_token = add_refresh_token(&mut txn, session, access_token, &refresh_token_str) - .await - .wrap_error()?; + let _refresh_token = + add_refresh_token(&mut txn, session, access_token, &refresh_token_str).await?; let id_token = if session.scope.contains(&scope::OPENID) { let mut claims = HashMap::new(); let now = Utc::now(); - claims::ISS - .insert(&mut claims, issuer.to_string()) - .wrap_error()?; - claims::SUB - .insert(&mut claims, &browser_session.user.sub) - .wrap_error()?; - claims::AUD - .insert(&mut claims, client.client_id.clone()) - .wrap_error()?; - claims::IAT.insert(&mut claims, now).wrap_error()?; - claims::EXP - .insert(&mut claims, now + Duration::hours(1)) - .wrap_error()?; + claims::ISS.insert(&mut claims, url_builder.oidc_issuer().to_string())?; + claims::SUB.insert(&mut claims, &browser_session.user.sub)?; + claims::AUD.insert(&mut claims, client.client_id.clone())?; + claims::IAT.insert(&mut claims, now)?; + claims::EXP.insert(&mut claims, now + Duration::hours(1))?; if let Some(ref nonce) = authz_grant.nonce { - claims::NONCE - .insert(&mut claims, nonce.clone()) - .wrap_error()?; + claims::NONCE.insert(&mut claims, nonce.clone())?; } if let Some(ref last_authentication) = browser_session.last_authentication { - claims::AUTH_TIME - .insert(&mut claims, last_authentication.created_at) - .wrap_error()?; + claims::AUTH_TIME.insert(&mut claims, last_authentication.created_at)?; } - claims::AT_HASH - .insert( - &mut claims, - hash(Sha256::new(), &access_token_str).wrap_error()?, - ) - .wrap_error()?; - claims::C_HASH - .insert(&mut claims, hash(Sha256::new(), &grant.code).wrap_error()?) - .wrap_error()?; + claims::AT_HASH.insert(&mut claims, hash(Sha256::new(), &access_token_str)?)?; + claims::C_HASH.insert(&mut claims, hash(Sha256::new(), &grant.code)?)?; - let header = key_store - .prepare_header(JsonWebSignatureAlg::Rs256) - .await - .wrap_error()?; + let header = key_store.prepare_header(JsonWebSignatureAlg::Rs256).await?; let id_token = DecodedJsonWebToken::new(header, claims); - let id_token = id_token.sign(key_store).await.wrap_error()?; + let id_token = id_token.sign(key_store).await?; Some(id_token.serialize()) } else { @@ -340,9 +316,9 @@ async fn authorization_code_grant( params = params.with_id_token(id_token); } - exchange_grant(&mut txn, authz_grant).await.wrap_error()?; + exchange_grant(&mut txn, authz_grant).await?; - txn.commit().await.wrap_error()?; + txn.commit().await?; Ok(params) } @@ -350,15 +326,14 @@ async fn authorization_code_grant( async fn refresh_token_grant( grant: &RefreshTokenGrant, client: &Client, - conn: &mut PoolConnection, -) -> Result { - let mut txn = conn.begin().await.wrap_error()?; + mut txn: Transaction<'_, Postgres>, +) -> Result { let (refresh_token, session) = lookup_active_refresh_token(&mut txn, &grant.refresh_token).await?; if client.client_id != session.client.client_id { // As per https://datatracker.ietf.org/doc/html/rfc6749#section-5.2 - return error(InvalidGrant); + return Err(RouteError::InvalidGrant); } let ttl = Duration::minutes(5); @@ -370,23 +345,15 @@ async fn refresh_token_grant( ) }; - let new_access_token = add_access_token(&mut txn, &session, &access_token_str, ttl) - .await - .wrap_error()?; + let new_access_token = add_access_token(&mut txn, &session, &access_token_str, ttl).await?; let new_refresh_token = - add_refresh_token(&mut txn, &session, new_access_token, &refresh_token_str) - .await - .wrap_error()?; + add_refresh_token(&mut txn, &session, new_access_token, &refresh_token_str).await?; - replace_refresh_token(&mut txn, &refresh_token, &new_refresh_token) - .await - .wrap_error()?; + replace_refresh_token(&mut txn, &refresh_token, &new_refresh_token).await?; if let Some(access_token) = refresh_token.access_token { - revoke_access_token(&mut txn, &access_token) - .await - .wrap_error()?; + revoke_access_token(&mut txn, &access_token).await?; } let params = AccessTokenResponse::new(access_token_str) @@ -394,7 +361,7 @@ async fn refresh_token_grant( .with_refresh_token(refresh_token_str) .with_scope(session.scope); - txn.commit().await.wrap_error()?; + txn.commit().await?; Ok(params) } diff --git a/crates/oauth2-types/src/errors.rs b/crates/oauth2-types/src/errors.rs index 4ba4692e0..1d213fbbf 100644 --- a/crates/oauth2-types/src/errors.rs +++ b/crates/oauth2-types/src/errors.rs @@ -16,6 +16,11 @@ use http::status::StatusCode; use serde::ser::{Serialize, SerializeMap}; use url::Url; +pub struct ClientError { + pub error: &'static str, + pub error_description: &'static str, +} + pub trait OAuth2Error: std::fmt::Debug + Send + Sync { /// A single ASCII error code. /// @@ -148,6 +153,15 @@ macro_rules! oauth2_error_error { }; } +macro_rules! oauth2_error_const { + ($const:ident, $err:literal, $description:expr) => { + pub const $const: ClientError = ClientError { + error: $err, + error_description: $description, + }; + }; +} + macro_rules! oauth2_error_description { ($description:expr) => { fn description(&self) -> Option { @@ -157,32 +171,36 @@ macro_rules! oauth2_error_description { } macro_rules! oauth2_error { - ($name:ident, $err:literal => $description:expr) => { + ($name:ident, $const:ident, $err:literal => $description:expr) => { + oauth2_error_const!($const, $err, $description); oauth2_error_def!($name); impl $crate::errors::OAuth2Error for $name { oauth2_error_error!($err); oauth2_error_description!(indoc::indoc! {$description}); } }; - ($name:ident, $err:literal) => { + ($name:ident, $const:ident, $err:literal) => { oauth2_error_def!($name); impl $crate::errors::OAuth2Error for $name { oauth2_error_error!($err); } }; - ($name:ident, code: $code:ident, $err:literal => $description:expr) => { - oauth2_error!($name, $err => $description); + ($name:ident, $const:ident, code: $code:ident, $err:literal => $description:expr) => { + oauth2_error!($name, $const, $err => $description); oauth2_error_status!($name, $code); }; - ($name:ident, code: $code:ident, $err:literal) => { - oauth2_error!($name, $err); + ($name:ident, $const:ident, code: $code:ident, $err:literal) => { + oauth2_error!($name, $const, $err); oauth2_error_status!($name, $code); }; } pub mod rfc6749 { + use super::ClientError; + oauth2_error! { InvalidRequest, + INVALID_REQUEST, code: BAD_REQUEST, "invalid_request" => "The request is missing a required parameter, includes an invalid parameter value, \ @@ -191,6 +209,7 @@ pub mod rfc6749 { oauth2_error! { InvalidClient, + INVALID_CLIENT, code: BAD_REQUEST, "invalid_client" => "Client authentication failed." @@ -198,12 +217,14 @@ pub mod rfc6749 { oauth2_error! { InvalidGrant, + INVALID_GRANT, code: BAD_REQUEST, "invalid_grant" } oauth2_error! { UnauthorizedClient, + UNAUTHORIZED_CLIENT, code: BAD_REQUEST, "unauthorized_client" => "The client is not authorized to request an access token using this method." @@ -211,6 +232,7 @@ pub mod rfc6749 { oauth2_error! { UnsupportedGrantType, + UNSUPPORTED_GRANT_TYPE, code: BAD_REQUEST, "unsupported_grant_type" => "The authorization grant type is not supported by the authorization server." @@ -218,18 +240,21 @@ pub mod rfc6749 { oauth2_error! { AccessDenied, + ACCESS_DENIED, "access_denied" => "The resource owner or authorization server denied the request." } oauth2_error! { UnsupportedResponseType, + UNSUPPORTED_RESPONSE_TYPE, "unsupported_response_type" => "The authorization server does not support obtaining an access token using this method." } oauth2_error! { InvalidScope, + INVALID_SCOPE, code: BAD_REQUEST, "invalid_scope" => "The requested scope is invalid, unknown, or malformed." @@ -237,6 +262,7 @@ pub mod rfc6749 { oauth2_error! { ServerError, + SERVER_ERROR, code: INTERNAL_SERVER_ERROR, "server_error" => "The authorization server encountered an unexpected \ @@ -245,6 +271,7 @@ pub mod rfc6749 { oauth2_error! { TemporarilyUnavailable, + TEMPORARILY_UNAVAILABLE, "temporarily_unavailable" => "The authorization server is currently unable to handle \ the request due to a temporary overloading or maintenance \ @@ -253,54 +280,65 @@ pub mod rfc6749 { } pub mod oidc_core { + use super::ClientError; + oauth2_error! { InteractionRequired, + INTERACTION_REQUIRED, "interaction_required" => "The Authorization Server requires End-User interaction of some form to proceed." } oauth2_error! { LoginRequired, + LOGIN_REQUIRED, "login_required" => "The Authorization Server requires End-User authentication." } oauth2_error! { AccountSelectionRequired, + ACCOUNT_SELECTION_REQUIRED, "account_selection_required" } oauth2_error! { ConsentRequired, + CONSENT_REQUIRED, "consent_required" } oauth2_error! { InvalidRequestUri, + INVALID_REQUEST_URI, "invalid_request_uri" => "The request_uri in the Authorization Request returns an error or contains invalid data. " } oauth2_error! { InvalidRequestObject, + INVALID_REQUEST_OBJECT, "invalid_request_object" => "The request parameter contains an invalid Request Object." } oauth2_error! { RequestNotSupported, + REQUEST_NOT_SUPPORTED, "request_not_supported" => "The provider does not support use of the request parameter." } oauth2_error! { RequestUriNotSupported, + REQUEST_URI_NOT_SUPPORTED, "request_uri_not_supported" => "The provider does not support use of the request_uri parameter." } oauth2_error! { RegistrationNotSupported, + REGISTRATION_NOT_SUPPORTED, "registration_not_supported" => "The provider does not support use of the registration parameter." }