WIP error management in authorization request
This commit is contained in:
@@ -4,3 +4,5 @@ members = [
|
||||
"oauth2-types",
|
||||
"matrix-authentication-service",
|
||||
]
|
||||
|
||||
resolver = "2"
|
||||
|
||||
@@ -289,7 +289,7 @@ pub struct OAuth2ClientConfig {
|
||||
pub client_secret: Option<String>,
|
||||
|
||||
#[serde(default)]
|
||||
pub redirect_uris: Option<Vec<Url>>,
|
||||
pub redirect_uris: Vec<Url>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
@@ -301,19 +301,20 @@ impl OAuth2ClientConfig {
|
||||
&'a self,
|
||||
suggested_uri: &'a Option<Url>,
|
||||
) -> Result<&'a Url, InvalidRedirectUriError> {
|
||||
match (suggested_uri, &self.redirect_uris) {
|
||||
(None, None) => Err(InvalidRedirectUriError),
|
||||
(None, Some(redirect_uris)) => {
|
||||
redirect_uris.iter().next().ok_or(InvalidRedirectUriError)
|
||||
}
|
||||
(Some(suggested_uri), None) => Ok(suggested_uri),
|
||||
(Some(suggested_uri), Some(redirect_uris)) => {
|
||||
if redirect_uris.contains(suggested_uri) {
|
||||
Ok(suggested_uri)
|
||||
} else {
|
||||
Err(InvalidRedirectUriError)
|
||||
}
|
||||
}
|
||||
suggested_uri.as_ref().map_or_else(
|
||||
|| self.redirect_uris.get(0).ok_or(InvalidRedirectUriError),
|
||||
|suggested_uri| self.check_redirect_uri(suggested_uri),
|
||||
)
|
||||
}
|
||||
|
||||
pub fn check_redirect_uri<'a>(
|
||||
&self,
|
||||
redirect_uri: &'a Url,
|
||||
) -> Result<&'a Url, InvalidRedirectUriError> {
|
||||
if self.redirect_uris.contains(redirect_uri) {
|
||||
Ok(redirect_uri)
|
||||
} else {
|
||||
Err(InvalidRedirectUriError)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -464,11 +465,11 @@ mod tests {
|
||||
assert_eq!(config.clients[0].client_id, "hello");
|
||||
assert_eq!(
|
||||
config.clients[0].redirect_uris,
|
||||
Some(vec!["https://exemple.fr/callback".parse().unwrap()])
|
||||
vec!["https://exemple.fr/callback".parse().unwrap()]
|
||||
);
|
||||
|
||||
assert_eq!(config.clients[1].client_id, "world");
|
||||
assert_eq!(config.clients[1].redirect_uris, None);
|
||||
assert_eq!(config.clients[1].redirect_uris, Vec::new());
|
||||
|
||||
Ok(())
|
||||
});
|
||||
|
||||
@@ -18,7 +18,6 @@ use std::{
|
||||
};
|
||||
|
||||
use chrono::Duration;
|
||||
use data_encoding::BASE64URL_NOPAD;
|
||||
use hyper::{
|
||||
header::LOCATION,
|
||||
http::uri::{Parts, PathAndQuery, Uri},
|
||||
@@ -26,18 +25,21 @@ use hyper::{
|
||||
};
|
||||
use itertools::Itertools;
|
||||
use oauth2_types::{
|
||||
errors::{ErrorResponse, InvalidRequest, OAuth2Error},
|
||||
pkce,
|
||||
requests::{
|
||||
AccessTokenResponse, AuthorizationRequest, AuthorizationResponse, ResponseMode,
|
||||
ResponseType,
|
||||
},
|
||||
};
|
||||
use rand::thread_rng;
|
||||
use rand::{distributions::Alphanumeric, thread_rng, Rng};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::Value;
|
||||
use sqlx::{PgPool, Postgres, Transaction};
|
||||
use url::Url;
|
||||
use warp::{
|
||||
redirect::see_other,
|
||||
reject::InvalidQuery,
|
||||
reply::{html, with_header},
|
||||
Filter, Rejection, Reply,
|
||||
};
|
||||
@@ -63,6 +65,26 @@ use crate::{
|
||||
tokens,
|
||||
};
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct PartialParams {
|
||||
client_id: Option<String>,
|
||||
redirect_uri: Option<String>,
|
||||
/*
|
||||
response_type: Option<String>,
|
||||
response_mode: Option<String>,
|
||||
*/
|
||||
}
|
||||
|
||||
enum ReplyOrBackToClient {
|
||||
Reply(Box<dyn Reply>),
|
||||
BackToClient {
|
||||
params: Value,
|
||||
redirect_uri: Url,
|
||||
response_mode: ResponseMode,
|
||||
},
|
||||
Error(Box<dyn OAuth2Error>),
|
||||
}
|
||||
|
||||
fn back_to_client<T>(
|
||||
mut redirect_uri: Url,
|
||||
response_mode: ResponseMode,
|
||||
@@ -173,7 +195,6 @@ pub fn filter(
|
||||
.and(warp::query())
|
||||
.and(with_optional_session(pool, cookies_config))
|
||||
.and(with_transaction(pool))
|
||||
.and(with_templates(templates))
|
||||
.and_then(get);
|
||||
|
||||
let step = warp::path!("oauth2" / "authorize" / "step")
|
||||
@@ -181,10 +202,81 @@ pub fn filter(
|
||||
.and(warp::query().map(|s: StepRequest| s.id))
|
||||
.and(with_session(pool, cookies_config))
|
||||
.and(with_transaction(pool))
|
||||
.and(with_templates(templates))
|
||||
.and_then(step);
|
||||
|
||||
authorize.or(step)
|
||||
let clients = oauth2_config.clients.clone();
|
||||
authorize
|
||||
.or(step)
|
||||
.unify()
|
||||
.recover(recover)
|
||||
.unify()
|
||||
.and(warp::query())
|
||||
.and(warp::any().map(move || clients.clone()))
|
||||
.and(with_templates(templates))
|
||||
.and_then(actually_reply)
|
||||
}
|
||||
|
||||
async fn recover(rejection: Rejection) -> Result<ReplyOrBackToClient, Rejection> {
|
||||
if rejection.find::<InvalidQuery>().is_some() {
|
||||
Ok(ReplyOrBackToClient::Error(Box::new(InvalidRequest)))
|
||||
} else {
|
||||
Err(rejection)
|
||||
}
|
||||
}
|
||||
|
||||
async fn actually_reply(
|
||||
rep: ReplyOrBackToClient,
|
||||
q: PartialParams,
|
||||
clients: Vec<OAuth2ClientConfig>,
|
||||
templates: Templates,
|
||||
) -> Result<impl Reply, Rejection> {
|
||||
let (redirect_uri, response_mode, params) = match rep {
|
||||
ReplyOrBackToClient::Reply(r) => return Ok(r),
|
||||
ReplyOrBackToClient::BackToClient {
|
||||
redirect_uri,
|
||||
response_mode,
|
||||
params,
|
||||
} => (redirect_uri, response_mode, params),
|
||||
ReplyOrBackToClient::Error(error) => {
|
||||
let PartialParams {
|
||||
client_id,
|
||||
redirect_uri,
|
||||
..
|
||||
} = q;
|
||||
|
||||
// First, disover the client
|
||||
let client = client_id.and_then(|client_id| {
|
||||
clients
|
||||
.into_iter()
|
||||
.find(|client| client.client_id == client_id)
|
||||
});
|
||||
|
||||
let client = match client {
|
||||
Some(client) => client,
|
||||
None => return Ok(Box::new(html(templates.render_error(&error.into())?))),
|
||||
};
|
||||
|
||||
let redirect_uri: Result<Option<Url>, _> = redirect_uri.map(|r| r.parse()).transpose();
|
||||
let redirect_uri = match redirect_uri {
|
||||
Ok(r) => r,
|
||||
Err(_) => return Ok(Box::new(html(templates.render_error(&error.into())?))),
|
||||
};
|
||||
|
||||
let redirect_uri = client.resolve_redirect_uri(&redirect_uri);
|
||||
let redirect_uri = match redirect_uri {
|
||||
Ok(r) => r,
|
||||
Err(_) => return Ok(Box::new(html(templates.render_error(&error.into())?))),
|
||||
};
|
||||
|
||||
let reply: ErrorResponse = error.into();
|
||||
let reply = serde_json::to_value(&reply).wrap_error()?;
|
||||
// TODO: resolve response mode
|
||||
(redirect_uri.clone(), ResponseMode::Query, reply)
|
||||
}
|
||||
};
|
||||
|
||||
// TODO: we should include the state param in errors
|
||||
back_to_client(redirect_uri, response_mode, params, &templates).wrap_error()
|
||||
}
|
||||
|
||||
async fn get(
|
||||
@@ -192,8 +284,7 @@ async fn get(
|
||||
params: Params,
|
||||
maybe_session: Option<SessionInfo>,
|
||||
mut txn: Transaction<'_, Postgres>,
|
||||
templates: Templates,
|
||||
) -> Result<Box<dyn Reply>, Rejection> {
|
||||
) -> Result<ReplyOrBackToClient, Rejection> {
|
||||
// First, find out what client it is
|
||||
let client = clients
|
||||
.into_iter()
|
||||
@@ -201,11 +292,6 @@ async fn get(
|
||||
.ok_or_else(|| anyhow::anyhow!("could not find client"))
|
||||
.wrap_error()?;
|
||||
|
||||
// Then, figure out the redirect URI
|
||||
let redirect_uri = client
|
||||
.resolve_redirect_uri(¶ms.auth.redirect_uri)
|
||||
.wrap_error()?;
|
||||
|
||||
let maybe_session_id = maybe_session.as_ref().map(SessionInfo::key);
|
||||
|
||||
let scope: String = {
|
||||
@@ -213,6 +299,9 @@ async fn get(
|
||||
Itertools::intersperse(it, " ".to_string()).collect()
|
||||
};
|
||||
|
||||
let redirect_uri = client
|
||||
.resolve_redirect_uri(¶ms.auth.redirect_uri)
|
||||
.wrap_error()?;
|
||||
let response_type = ¶ms.auth.response_type;
|
||||
let response_mode =
|
||||
resolve_response_mode(response_type, params.auth.response_mode).wrap_error()?;
|
||||
@@ -234,9 +323,13 @@ async fn get(
|
||||
|
||||
// Generate the code at this stage, since we have the PKCE params ready
|
||||
if response_type.contains(&ResponseType::Code) {
|
||||
// 192bit random bytes encoded in base64, which gives a 32 character code
|
||||
let code: [u8; 24] = rand::random();
|
||||
let code = BASE64URL_NOPAD.encode(&code);
|
||||
// 32 random alphanumeric characters, about 190bit of entropy
|
||||
let code: String = thread_rng()
|
||||
.sample_iter(&Alphanumeric)
|
||||
.take(32)
|
||||
.map(char::from)
|
||||
.collect();
|
||||
|
||||
oauth2_session
|
||||
.add_code(&mut txn, &code, ¶ms.pkce)
|
||||
.await
|
||||
@@ -247,7 +340,7 @@ async fn get(
|
||||
let user_session = oauth2_session.fetch_session(&mut txn).await.wrap_error()?;
|
||||
|
||||
if let Some(user_session) = user_session {
|
||||
step(oauth2_session.id, user_session, txn, templates).await
|
||||
step(oauth2_session.id, user_session, txn).await
|
||||
} else {
|
||||
// If not, redirect the user to the login page
|
||||
txn.commit().await.wrap_error()?;
|
||||
@@ -258,7 +351,7 @@ async fn get(
|
||||
.to_string();
|
||||
|
||||
let destination = LoginRequest::new(Some(next)).build_uri().wrap_error()?;
|
||||
Ok(Box::new(see_other(destination)))
|
||||
Ok(ReplyOrBackToClient::Reply(Box::new(see_other(destination))))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -288,8 +381,7 @@ async fn step(
|
||||
oauth2_session_id: i64,
|
||||
user_session: SessionInfo,
|
||||
mut txn: Transaction<'_, Postgres>,
|
||||
templates: Templates,
|
||||
) -> Result<Box<dyn Reply>, Rejection> {
|
||||
) -> Result<ReplyOrBackToClient, Rejection> {
|
||||
let mut oauth2_session = get_session_by_id(&mut txn, oauth2_session_id)
|
||||
.await
|
||||
.wrap_error()?;
|
||||
@@ -350,11 +442,16 @@ async fn step(
|
||||
todo!("id tokens are not implemented yet");
|
||||
}
|
||||
|
||||
back_to_client(redirect_uri, response_mode, params, &templates).wrap_error()?
|
||||
let params = serde_json::to_value(¶ms).unwrap();
|
||||
ReplyOrBackToClient::BackToClient {
|
||||
redirect_uri,
|
||||
response_mode,
|
||||
params,
|
||||
}
|
||||
} else {
|
||||
// Ask for a reauth
|
||||
// TODO: have the OAuth2 session ID in there
|
||||
Box::new(see_other(Uri::from_static("/reauth")))
|
||||
ReplyOrBackToClient::Reply(Box::new(see_other(Uri::from_static("/reauth"))))
|
||||
};
|
||||
|
||||
txn.commit().await.wrap_error()?;
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::{collections::HashSet, convert::TryFrom, str::FromStr};
|
||||
use std::{collections::HashSet, convert::TryFrom, str::FromStr, string::ToString};
|
||||
|
||||
use anyhow::Context;
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
@@ -34,7 +34,7 @@ use super::{
|
||||
pub struct OAuth2Session {
|
||||
pub id: i64,
|
||||
user_session_id: Option<i64>,
|
||||
client_id: String,
|
||||
pub client_id: String,
|
||||
redirect_uri: String,
|
||||
scope: String,
|
||||
pub state: Option<String>,
|
||||
@@ -143,6 +143,7 @@ pub async fn start_session(
|
||||
// Checked convertion of duration to i32, maxing at i32::MAX
|
||||
let max_age = max_age.map(|d| i32::try_from(d.num_seconds()).unwrap_or(i32::MAX));
|
||||
let response_mode = response_mode.to_string();
|
||||
let redirect_uri = redirect_uri.to_string();
|
||||
let response_type: String = {
|
||||
let it = response_type.iter().map(ToString::to_string);
|
||||
Itertools::intersperse(it, " ".to_string()).collect()
|
||||
@@ -162,7 +163,7 @@ pub async fn start_session(
|
||||
"#,
|
||||
optional_session_id,
|
||||
client_id,
|
||||
redirect_uri.as_str(),
|
||||
redirect_uri,
|
||||
scope,
|
||||
state,
|
||||
nonce,
|
||||
|
||||
@@ -14,6 +14,7 @@
|
||||
|
||||
use std::{collections::HashSet, string::ToString, sync::Arc};
|
||||
|
||||
use oauth2_types::errors::OAuth2Error;
|
||||
use serde::Serialize;
|
||||
use tera::{Context, Error as TeraError, Tera};
|
||||
use thiserror::Error;
|
||||
@@ -146,6 +147,9 @@ register_templates! {
|
||||
|
||||
/// Render the form used by the form_post response mode
|
||||
pub fn render_form_post<T: Serialize>(FormPostContext<T>) { "form_post.html" }
|
||||
|
||||
/// Render the HTML error page
|
||||
pub fn render_error(ErrorContext) { "error.html" }
|
||||
}
|
||||
|
||||
/// Helper trait to construct context wrappers
|
||||
@@ -254,3 +258,42 @@ impl<T> FormPostContext<T> {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default, Serialize)]
|
||||
pub struct ErrorContext {
|
||||
code: Option<&'static str>,
|
||||
description: Option<String>,
|
||||
details: Option<String>,
|
||||
}
|
||||
|
||||
impl ErrorContext {
|
||||
pub fn new() -> Self {
|
||||
Self::default()
|
||||
}
|
||||
|
||||
pub fn with_code(mut self, code: &'static str) -> Self {
|
||||
self.code = Some(code);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_description(mut self, description: String) -> Self {
|
||||
self.description = Some(description);
|
||||
self
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn with_details(mut self, details: String) -> Self {
|
||||
self.details = Some(details);
|
||||
self
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Box<dyn OAuth2Error>> for ErrorContext {
|
||||
fn from(err: Box<dyn OAuth2Error>) -> Self {
|
||||
let mut ctx = ErrorContext::new().with_code(err.error());
|
||||
if let Some(desc) = err.description() {
|
||||
ctx = ctx.with_description(desc);
|
||||
}
|
||||
ctx
|
||||
}
|
||||
}
|
||||
|
||||
@@ -91,7 +91,6 @@ pub enum TokenFormatError {
|
||||
InvalidCrc { expected: String, got: String },
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn check(token: &str) -> Result<TokenType, TokenFormatError> {
|
||||
let split: Vec<&str> = token.split('_').collect();
|
||||
let [prefix, random_part, crc]: [&str; 3] = split
|
||||
|
||||
@@ -16,7 +16,7 @@ use http::status::StatusCode;
|
||||
use serde::ser::{Serialize, SerializeMap};
|
||||
use url::Url;
|
||||
|
||||
pub trait OAuth2Error: std::fmt::Debug {
|
||||
pub trait OAuth2Error: std::fmt::Debug + Send + Sync {
|
||||
/// A single ASCII error code.
|
||||
///
|
||||
/// Maps to the required "error" field.
|
||||
@@ -41,11 +41,11 @@ pub trait OAuth2Error: std::fmt::Debug {
|
||||
}
|
||||
|
||||
/// Wraps the error with an `ErrorResponse` to help serializing.
|
||||
fn into_response(self) -> ErrorResponse<Self>
|
||||
fn into_response(self) -> ErrorResponse
|
||||
where
|
||||
Self: Sized,
|
||||
Self: Sized + 'static,
|
||||
{
|
||||
ErrorResponse(self)
|
||||
ErrorResponse(Box::new(self))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -55,15 +55,15 @@ trait OAuth2ErrorCode: OAuth2Error {
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
pub struct ErrorResponse<T: OAuth2Error>(T);
|
||||
pub struct ErrorResponse(Box<dyn OAuth2Error>);
|
||||
|
||||
impl<T: OAuth2ErrorCode> OAuth2ErrorCode for ErrorResponse<T> {
|
||||
fn status(&self) -> StatusCode {
|
||||
self.0.status()
|
||||
impl From<Box<dyn OAuth2Error>> for ErrorResponse {
|
||||
fn from(b: Box<dyn OAuth2Error>) -> Self {
|
||||
Self(b)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: OAuth2Error> OAuth2Error for ErrorResponse<T> {
|
||||
impl OAuth2Error for ErrorResponse {
|
||||
fn error(&self) -> &'static str {
|
||||
self.0.error()
|
||||
}
|
||||
@@ -77,7 +77,7 @@ impl<T: OAuth2Error> OAuth2Error for ErrorResponse<T> {
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: OAuth2Error> Serialize for ErrorResponse<T> {
|
||||
impl Serialize for ErrorResponse {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
|
||||
Reference in New Issue
Block a user