WIP error management in authorization request

This commit is contained in:
Quentin Gliech
2021-09-10 22:53:21 +02:00
parent 108a974880
commit c39e223032
7 changed files with 194 additions and 51 deletions

View File

@@ -4,3 +4,5 @@ members = [
"oauth2-types",
"matrix-authentication-service",
]
resolver = "2"

View File

@@ -289,7 +289,7 @@ pub struct OAuth2ClientConfig {
pub client_secret: Option<String>,
#[serde(default)]
pub redirect_uris: Option<Vec<Url>>,
pub redirect_uris: Vec<Url>,
}
#[derive(Debug, Error)]
@@ -301,19 +301,20 @@ impl OAuth2ClientConfig {
&'a self,
suggested_uri: &'a Option<Url>,
) -> Result<&'a Url, InvalidRedirectUriError> {
match (suggested_uri, &self.redirect_uris) {
(None, None) => Err(InvalidRedirectUriError),
(None, Some(redirect_uris)) => {
redirect_uris.iter().next().ok_or(InvalidRedirectUriError)
}
(Some(suggested_uri), None) => Ok(suggested_uri),
(Some(suggested_uri), Some(redirect_uris)) => {
if redirect_uris.contains(suggested_uri) {
Ok(suggested_uri)
} else {
Err(InvalidRedirectUriError)
}
}
suggested_uri.as_ref().map_or_else(
|| self.redirect_uris.get(0).ok_or(InvalidRedirectUriError),
|suggested_uri| self.check_redirect_uri(suggested_uri),
)
}
pub fn check_redirect_uri<'a>(
&self,
redirect_uri: &'a Url,
) -> Result<&'a Url, InvalidRedirectUriError> {
if self.redirect_uris.contains(redirect_uri) {
Ok(redirect_uri)
} else {
Err(InvalidRedirectUriError)
}
}
}
@@ -464,11 +465,11 @@ mod tests {
assert_eq!(config.clients[0].client_id, "hello");
assert_eq!(
config.clients[0].redirect_uris,
Some(vec!["https://exemple.fr/callback".parse().unwrap()])
vec!["https://exemple.fr/callback".parse().unwrap()]
);
assert_eq!(config.clients[1].client_id, "world");
assert_eq!(config.clients[1].redirect_uris, None);
assert_eq!(config.clients[1].redirect_uris, Vec::new());
Ok(())
});

View File

@@ -18,7 +18,6 @@ use std::{
};
use chrono::Duration;
use data_encoding::BASE64URL_NOPAD;
use hyper::{
header::LOCATION,
http::uri::{Parts, PathAndQuery, Uri},
@@ -26,18 +25,21 @@ use hyper::{
};
use itertools::Itertools;
use oauth2_types::{
errors::{ErrorResponse, InvalidRequest, OAuth2Error},
pkce,
requests::{
AccessTokenResponse, AuthorizationRequest, AuthorizationResponse, ResponseMode,
ResponseType,
},
};
use rand::thread_rng;
use rand::{distributions::Alphanumeric, thread_rng, Rng};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use sqlx::{PgPool, Postgres, Transaction};
use url::Url;
use warp::{
redirect::see_other,
reject::InvalidQuery,
reply::{html, with_header},
Filter, Rejection, Reply,
};
@@ -63,6 +65,26 @@ use crate::{
tokens,
};
#[derive(Deserialize)]
struct PartialParams {
client_id: Option<String>,
redirect_uri: Option<String>,
/*
response_type: Option<String>,
response_mode: Option<String>,
*/
}
enum ReplyOrBackToClient {
Reply(Box<dyn Reply>),
BackToClient {
params: Value,
redirect_uri: Url,
response_mode: ResponseMode,
},
Error(Box<dyn OAuth2Error>),
}
fn back_to_client<T>(
mut redirect_uri: Url,
response_mode: ResponseMode,
@@ -173,7 +195,6 @@ pub fn filter(
.and(warp::query())
.and(with_optional_session(pool, cookies_config))
.and(with_transaction(pool))
.and(with_templates(templates))
.and_then(get);
let step = warp::path!("oauth2" / "authorize" / "step")
@@ -181,10 +202,81 @@ pub fn filter(
.and(warp::query().map(|s: StepRequest| s.id))
.and(with_session(pool, cookies_config))
.and(with_transaction(pool))
.and(with_templates(templates))
.and_then(step);
authorize.or(step)
let clients = oauth2_config.clients.clone();
authorize
.or(step)
.unify()
.recover(recover)
.unify()
.and(warp::query())
.and(warp::any().map(move || clients.clone()))
.and(with_templates(templates))
.and_then(actually_reply)
}
async fn recover(rejection: Rejection) -> Result<ReplyOrBackToClient, Rejection> {
if rejection.find::<InvalidQuery>().is_some() {
Ok(ReplyOrBackToClient::Error(Box::new(InvalidRequest)))
} else {
Err(rejection)
}
}
async fn actually_reply(
rep: ReplyOrBackToClient,
q: PartialParams,
clients: Vec<OAuth2ClientConfig>,
templates: Templates,
) -> Result<impl Reply, Rejection> {
let (redirect_uri, response_mode, params) = match rep {
ReplyOrBackToClient::Reply(r) => return Ok(r),
ReplyOrBackToClient::BackToClient {
redirect_uri,
response_mode,
params,
} => (redirect_uri, response_mode, params),
ReplyOrBackToClient::Error(error) => {
let PartialParams {
client_id,
redirect_uri,
..
} = q;
// First, disover the client
let client = client_id.and_then(|client_id| {
clients
.into_iter()
.find(|client| client.client_id == client_id)
});
let client = match client {
Some(client) => client,
None => return Ok(Box::new(html(templates.render_error(&error.into())?))),
};
let redirect_uri: Result<Option<Url>, _> = redirect_uri.map(|r| r.parse()).transpose();
let redirect_uri = match redirect_uri {
Ok(r) => r,
Err(_) => return Ok(Box::new(html(templates.render_error(&error.into())?))),
};
let redirect_uri = client.resolve_redirect_uri(&redirect_uri);
let redirect_uri = match redirect_uri {
Ok(r) => r,
Err(_) => return Ok(Box::new(html(templates.render_error(&error.into())?))),
};
let reply: ErrorResponse = error.into();
let reply = serde_json::to_value(&reply).wrap_error()?;
// TODO: resolve response mode
(redirect_uri.clone(), ResponseMode::Query, reply)
}
};
// TODO: we should include the state param in errors
back_to_client(redirect_uri, response_mode, params, &templates).wrap_error()
}
async fn get(
@@ -192,8 +284,7 @@ async fn get(
params: Params,
maybe_session: Option<SessionInfo>,
mut txn: Transaction<'_, Postgres>,
templates: Templates,
) -> Result<Box<dyn Reply>, Rejection> {
) -> Result<ReplyOrBackToClient, Rejection> {
// First, find out what client it is
let client = clients
.into_iter()
@@ -201,11 +292,6 @@ async fn get(
.ok_or_else(|| anyhow::anyhow!("could not find client"))
.wrap_error()?;
// Then, figure out the redirect URI
let redirect_uri = client
.resolve_redirect_uri(&params.auth.redirect_uri)
.wrap_error()?;
let maybe_session_id = maybe_session.as_ref().map(SessionInfo::key);
let scope: String = {
@@ -213,6 +299,9 @@ async fn get(
Itertools::intersperse(it, " ".to_string()).collect()
};
let redirect_uri = client
.resolve_redirect_uri(&params.auth.redirect_uri)
.wrap_error()?;
let response_type = &params.auth.response_type;
let response_mode =
resolve_response_mode(response_type, params.auth.response_mode).wrap_error()?;
@@ -234,9 +323,13 @@ async fn get(
// Generate the code at this stage, since we have the PKCE params ready
if response_type.contains(&ResponseType::Code) {
// 192bit random bytes encoded in base64, which gives a 32 character code
let code: [u8; 24] = rand::random();
let code = BASE64URL_NOPAD.encode(&code);
// 32 random alphanumeric characters, about 190bit of entropy
let code: String = thread_rng()
.sample_iter(&Alphanumeric)
.take(32)
.map(char::from)
.collect();
oauth2_session
.add_code(&mut txn, &code, &params.pkce)
.await
@@ -247,7 +340,7 @@ async fn get(
let user_session = oauth2_session.fetch_session(&mut txn).await.wrap_error()?;
if let Some(user_session) = user_session {
step(oauth2_session.id, user_session, txn, templates).await
step(oauth2_session.id, user_session, txn).await
} else {
// If not, redirect the user to the login page
txn.commit().await.wrap_error()?;
@@ -258,7 +351,7 @@ async fn get(
.to_string();
let destination = LoginRequest::new(Some(next)).build_uri().wrap_error()?;
Ok(Box::new(see_other(destination)))
Ok(ReplyOrBackToClient::Reply(Box::new(see_other(destination))))
}
}
@@ -288,8 +381,7 @@ async fn step(
oauth2_session_id: i64,
user_session: SessionInfo,
mut txn: Transaction<'_, Postgres>,
templates: Templates,
) -> Result<Box<dyn Reply>, Rejection> {
) -> Result<ReplyOrBackToClient, Rejection> {
let mut oauth2_session = get_session_by_id(&mut txn, oauth2_session_id)
.await
.wrap_error()?;
@@ -350,11 +442,16 @@ async fn step(
todo!("id tokens are not implemented yet");
}
back_to_client(redirect_uri, response_mode, params, &templates).wrap_error()?
let params = serde_json::to_value(&params).unwrap();
ReplyOrBackToClient::BackToClient {
redirect_uri,
response_mode,
params,
}
} else {
// Ask for a reauth
// TODO: have the OAuth2 session ID in there
Box::new(see_other(Uri::from_static("/reauth")))
ReplyOrBackToClient::Reply(Box::new(see_other(Uri::from_static("/reauth"))))
};
txn.commit().await.wrap_error()?;

View File

@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{collections::HashSet, convert::TryFrom, str::FromStr};
use std::{collections::HashSet, convert::TryFrom, str::FromStr, string::ToString};
use anyhow::Context;
use chrono::{DateTime, Duration, Utc};
@@ -34,7 +34,7 @@ use super::{
pub struct OAuth2Session {
pub id: i64,
user_session_id: Option<i64>,
client_id: String,
pub client_id: String,
redirect_uri: String,
scope: String,
pub state: Option<String>,
@@ -143,6 +143,7 @@ pub async fn start_session(
// Checked convertion of duration to i32, maxing at i32::MAX
let max_age = max_age.map(|d| i32::try_from(d.num_seconds()).unwrap_or(i32::MAX));
let response_mode = response_mode.to_string();
let redirect_uri = redirect_uri.to_string();
let response_type: String = {
let it = response_type.iter().map(ToString::to_string);
Itertools::intersperse(it, " ".to_string()).collect()
@@ -162,7 +163,7 @@ pub async fn start_session(
"#,
optional_session_id,
client_id,
redirect_uri.as_str(),
redirect_uri,
scope,
state,
nonce,

View File

@@ -14,6 +14,7 @@
use std::{collections::HashSet, string::ToString, sync::Arc};
use oauth2_types::errors::OAuth2Error;
use serde::Serialize;
use tera::{Context, Error as TeraError, Tera};
use thiserror::Error;
@@ -146,6 +147,9 @@ register_templates! {
/// Render the form used by the form_post response mode
pub fn render_form_post<T: Serialize>(FormPostContext<T>) { "form_post.html" }
/// Render the HTML error page
pub fn render_error(ErrorContext) { "error.html" }
}
/// Helper trait to construct context wrappers
@@ -254,3 +258,42 @@ impl<T> FormPostContext<T> {
}
}
}
#[derive(Default, Serialize)]
pub struct ErrorContext {
code: Option<&'static str>,
description: Option<String>,
details: Option<String>,
}
impl ErrorContext {
pub fn new() -> Self {
Self::default()
}
pub fn with_code(mut self, code: &'static str) -> Self {
self.code = Some(code);
self
}
pub fn with_description(mut self, description: String) -> Self {
self.description = Some(description);
self
}
#[allow(dead_code)]
pub fn with_details(mut self, details: String) -> Self {
self.details = Some(details);
self
}
}
impl From<Box<dyn OAuth2Error>> for ErrorContext {
fn from(err: Box<dyn OAuth2Error>) -> Self {
let mut ctx = ErrorContext::new().with_code(err.error());
if let Some(desc) = err.description() {
ctx = ctx.with_description(desc);
}
ctx
}
}

View File

@@ -91,7 +91,6 @@ pub enum TokenFormatError {
InvalidCrc { expected: String, got: String },
}
#[allow(dead_code)]
pub fn check(token: &str) -> Result<TokenType, TokenFormatError> {
let split: Vec<&str> = token.split('_').collect();
let [prefix, random_part, crc]: [&str; 3] = split

View File

@@ -16,7 +16,7 @@ use http::status::StatusCode;
use serde::ser::{Serialize, SerializeMap};
use url::Url;
pub trait OAuth2Error: std::fmt::Debug {
pub trait OAuth2Error: std::fmt::Debug + Send + Sync {
/// A single ASCII error code.
///
/// Maps to the required "error" field.
@@ -41,11 +41,11 @@ pub trait OAuth2Error: std::fmt::Debug {
}
/// Wraps the error with an `ErrorResponse` to help serializing.
fn into_response(self) -> ErrorResponse<Self>
fn into_response(self) -> ErrorResponse
where
Self: Sized,
Self: Sized + 'static,
{
ErrorResponse(self)
ErrorResponse(Box::new(self))
}
}
@@ -55,15 +55,15 @@ trait OAuth2ErrorCode: OAuth2Error {
}
#[derive(Debug)]
pub struct ErrorResponse<T: OAuth2Error>(T);
pub struct ErrorResponse(Box<dyn OAuth2Error>);
impl<T: OAuth2ErrorCode> OAuth2ErrorCode for ErrorResponse<T> {
fn status(&self) -> StatusCode {
self.0.status()
impl From<Box<dyn OAuth2Error>> for ErrorResponse {
fn from(b: Box<dyn OAuth2Error>) -> Self {
Self(b)
}
}
impl<T: OAuth2Error> OAuth2Error for ErrorResponse<T> {
impl OAuth2Error for ErrorResponse {
fn error(&self) -> &'static str {
self.0.error()
}
@@ -77,7 +77,7 @@ impl<T: OAuth2Error> OAuth2Error for ErrorResponse<T> {
}
}
impl<T: OAuth2Error> Serialize for ErrorResponse<T> {
impl Serialize for ErrorResponse {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,