From c30bb3ffa425efb8262ca603ef01be98e6494dce Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Fri, 6 Aug 2021 16:57:49 +0200 Subject: [PATCH] allow completing an oauth2 session after login --- .../20210731130515_oauth2_sessions.up.sql | 1 + matrix-authentication-service/sqlx-data.json | 300 +++++++++++++----- .../src/handlers/oauth2/authorization.rs | 138 ++++++-- .../src/handlers/views/login.rs | 52 ++- .../src/handlers/views/mod.rs | 1 + .../src/storage/oauth2.rs | 106 ++++++- 6 files changed, 469 insertions(+), 129 deletions(-) diff --git a/matrix-authentication-service/migrations/20210731130515_oauth2_sessions.up.sql b/matrix-authentication-service/migrations/20210731130515_oauth2_sessions.up.sql index a258b2a62..70e197b0c 100644 --- a/matrix-authentication-service/migrations/20210731130515_oauth2_sessions.up.sql +++ b/matrix-authentication-service/migrations/20210731130515_oauth2_sessions.up.sql @@ -16,6 +16,7 @@ CREATE TABLE oauth2_sessions ( "id" BIGSERIAL PRIMARY KEY, "user_session_id" BIGINT REFERENCES user_sessions (id) ON DELETE CASCADE, "client_id" TEXT NOT NULL, + "redirect_uri" TEXT NOT NULL, "scope" TEXT NOT NULL, "state" TEXT, "nonce" TEXT, diff --git a/matrix-authentication-service/sqlx-data.json b/matrix-authentication-service/sqlx-data.json index e993aa2e7..762923afe 100644 --- a/matrix-authentication-service/sqlx-data.json +++ b/matrix-authentication-service/sqlx-data.json @@ -73,6 +73,26 @@ ] } }, + "17729fd0354a84e04bfcd525db6575ed2ba75dd730bea3f2be964f4b347dd484": { + "query": "\n SELECT code\n FROM oauth2_codes\n WHERE oauth2_session_id = $1\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "code", + "type_info": "Text" + } + ], + "parameters": { + "Left": [ + "Int8" + ] + }, + "nullable": [ + false + ] + } + }, "35bedaa6fdf7ac91d54b458b4637f2182c2f82be3e2f80cd2db934ee279a7f2a": { "query": "\n SELECT id, username\n FROM users\n WHERE id = $1\n ", "describe": { @@ -99,93 +119,6 @@ ] } }, - "3f8aaca9f29ded15acf4f0056a789af643ab62816a2fb598dcc9af4fff2841f0": { - "query": "\n INSERT INTO oauth2_sessions \n (user_session_id, client_id, scope, state, nonce, max_age, response_type, response_mode)\n VALUES\n ($1, $2, $3, $4, $5, $6, $7, $8)\n RETURNING\n id, user_session_id, client_id, scope, state, nonce, max_age, \n response_type, response_mode, created_at, updated_at\n ", - "describe": { - "columns": [ - { - "ordinal": 0, - "name": "id", - "type_info": "Int8" - }, - { - "ordinal": 1, - "name": "user_session_id", - "type_info": "Int8" - }, - { - "ordinal": 2, - "name": "client_id", - "type_info": "Text" - }, - { - "ordinal": 3, - "name": "scope", - "type_info": "Text" - }, - { - "ordinal": 4, - "name": "state", - "type_info": "Text" - }, - { - "ordinal": 5, - "name": "nonce", - "type_info": "Text" - }, - { - "ordinal": 6, - "name": "max_age", - "type_info": "Int4" - }, - { - "ordinal": 7, - "name": "response_type", - "type_info": "Text" - }, - { - "ordinal": 8, - "name": "response_mode", - "type_info": "Text" - }, - { - "ordinal": 9, - "name": "created_at", - "type_info": "Timestamptz" - }, - { - "ordinal": 10, - "name": "updated_at", - "type_info": "Timestamptz" - } - ], - "parameters": { - "Left": [ - "Int8", - "Text", - "Text", - "Text", - "Text", - "Int4", - "Text", - "Text" - ] - }, - "nullable": [ - false, - true, - false, - false, - true, - true, - true, - false, - false, - false, - false - ] - } - }, "4f925a277d73df779360f81e0cf5d7983b50ebe744f461559dd561b7e36c20d4": { "query": "\n SELECT\n s.id,\n u.id as user_id,\n u.username,\n s.active,\n s.created_at,\n a.created_at as \"last_authd_at?\"\n FROM user_sessions s\n INNER JOIN users u \n ON s.user_id = u.id\n LEFT JOIN user_session_authentications a\n ON a.session_id = s.id\n WHERE s.id = $1 AND s.active\n ORDER BY a.created_at DESC\n LIMIT 1\n ", "describe": { @@ -306,6 +239,100 @@ ] } }, + "a051f542df7d3f80f5dc6dd6f04d49a462c64e4ce9146d90069d16ec9b61084b": { + "query": "\n INSERT INTO oauth2_sessions \n (user_session_id, client_id, redirect_uri, scope, state, nonce, max_age,\n response_type, response_mode)\n VALUES\n ($1, $2, $3, $4, $5, $6, $7, $8, $9)\n RETURNING\n id, user_session_id, client_id, redirect_uri, scope, state, nonce, max_age, \n response_type, response_mode, created_at, updated_at\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Int8" + }, + { + "ordinal": 1, + "name": "user_session_id", + "type_info": "Int8" + }, + { + "ordinal": 2, + "name": "client_id", + "type_info": "Text" + }, + { + "ordinal": 3, + "name": "redirect_uri", + "type_info": "Text" + }, + { + "ordinal": 4, + "name": "scope", + "type_info": "Text" + }, + { + "ordinal": 5, + "name": "state", + "type_info": "Text" + }, + { + "ordinal": 6, + "name": "nonce", + "type_info": "Text" + }, + { + "ordinal": 7, + "name": "max_age", + "type_info": "Int4" + }, + { + "ordinal": 8, + "name": "response_type", + "type_info": "Text" + }, + { + "ordinal": 9, + "name": "response_mode", + "type_info": "Text" + }, + { + "ordinal": 10, + "name": "created_at", + "type_info": "Timestamptz" + }, + { + "ordinal": 11, + "name": "updated_at", + "type_info": "Timestamptz" + } + ], + "parameters": { + "Left": [ + "Int8", + "Text", + "Text", + "Text", + "Text", + "Text", + "Int4", + "Text", + "Text" + ] + }, + "nullable": [ + false, + true, + false, + false, + false, + true, + true, + true, + false, + false, + false, + false + ] + } + }, "a09dfe1019110f2ec6eba0d35bafa467ab4b7980dd8b556826f03863f8edb0ab": { "query": "UPDATE user_sessions SET active = FALSE WHERE id = $1", "describe": { @@ -338,6 +365,19 @@ ] } }, + "a6eb935107d060dd01bf9824ceff87b9ff5492b58cefef002a49f444d3a3daa1": { + "query": "UPDATE oauth2_sessions SET user_session_id = $1 WHERE id = $2", + "describe": { + "columns": [], + "parameters": { + "Left": [ + "Int8", + "Int8" + ] + }, + "nullable": [] + } + }, "f9a09ff53b6f221649f4f050e3d5ade114f852ddf50a78610a6c0ef0689af681": { "query": "\n INSERT INTO users (username, hashed_password)\n VALUES ($1, $2)\n RETURNING id\n ", "describe": { @@ -358,5 +398,91 @@ false ] } + }, + "ff515ebb80ba4af1948472f5c7120a03e25b1ebe42151b8a2036bfbb042f17f6": { + "query": "\n SELECT\n id, user_session_id, client_id, redirect_uri, scope, state, nonce,\n max_age, response_type, response_mode, created_at, updated_at\n FROM oauth2_sessions\n WHERE id = $1\n ", + "describe": { + "columns": [ + { + "ordinal": 0, + "name": "id", + "type_info": "Int8" + }, + { + "ordinal": 1, + "name": "user_session_id", + "type_info": "Int8" + }, + { + "ordinal": 2, + "name": "client_id", + "type_info": "Text" + }, + { + "ordinal": 3, + "name": "redirect_uri", + "type_info": "Text" + }, + { + "ordinal": 4, + "name": "scope", + "type_info": "Text" + }, + { + "ordinal": 5, + "name": "state", + "type_info": "Text" + }, + { + "ordinal": 6, + "name": "nonce", + "type_info": "Text" + }, + { + "ordinal": 7, + "name": "max_age", + "type_info": "Int4" + }, + { + "ordinal": 8, + "name": "response_type", + "type_info": "Text" + }, + { + "ordinal": 9, + "name": "response_mode", + "type_info": "Text" + }, + { + "ordinal": 10, + "name": "created_at", + "type_info": "Timestamptz" + }, + { + "ordinal": 11, + "name": "updated_at", + "type_info": "Timestamptz" + } + ], + "parameters": { + "Left": [ + "Int8" + ] + }, + "nullable": [ + false, + true, + false, + false, + false, + true, + true, + true, + false, + false, + false, + false + ] + } } } \ No newline at end of file diff --git a/matrix-authentication-service/src/handlers/oauth2/authorization.rs b/matrix-authentication-service/src/handlers/oauth2/authorization.rs index e82916d31..2592a0be5 100644 --- a/matrix-authentication-service/src/handlers/oauth2/authorization.rs +++ b/matrix-authentication-service/src/handlers/oauth2/authorization.rs @@ -12,10 +12,17 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::{HashMap, HashSet}; +use std::{ + collections::{HashMap, HashSet}, + convert::TryFrom, +}; use data_encoding::BASE64URL_NOPAD; -use hyper::{header::LOCATION, StatusCode}; +use hyper::{ + header::LOCATION, + http::uri::{Parts, PathAndQuery, Uri}, + StatusCode, +}; use itertools::Itertools; use oauth2_types::{ pkce, @@ -28,6 +35,7 @@ use serde::{Deserialize, Serialize}; use sqlx::PgPool; use url::Url; use warp::{ + redirect::see_other, reply::{html, with_header}, Filter, Rejection, Reply, }; @@ -35,8 +43,15 @@ use warp::{ use crate::{ config::{CookiesConfig, OAuth2ClientConfig, OAuth2Config}, errors::WrapError, - filters::{session::with_optional_session, with_pool, with_templates}, - storage::{oauth2::start_session, SessionInfo}, + filters::{ + session::{with_optional_session, with_session}, + with_pool, with_templates, + }, + handlers::views::LoginRequest, + storage::{ + oauth2::{get_session_by_id, start_session}, + SessionInfo, + }, templates::{FormPostContext, Templates}, }; @@ -144,14 +159,24 @@ pub fn filter( cookies_config: &CookiesConfig, ) -> impl Filter + Clone + Send + Sync + 'static { let clients = oauth2_config.clients.clone(); - warp::get() + let authorize = warp::get() .and(warp::path!("oauth2" / "authorize")) .map(move || clients.clone()) .and(warp::query()) .and(with_optional_session(pool, cookies_config)) .and(with_pool(pool)) .and(with_templates(templates)) - .and_then(get) + .and_then(get); + + let step = warp::get() + .and(warp::path!("oauth2" / "authorize" / "step")) + .and(warp::query().map(|s: StepRequest| s.id)) + .and(with_session(pool, cookies_config)) + .and(with_pool(pool)) + .and(with_templates(templates)) + .and_then(step); + + authorize.or(step) } async fn get( @@ -190,6 +215,7 @@ async fn get( &mut txn, maybe_session_id, &client.client_id, + redirect_uri, &scope, params.auth.state.as_deref(), params.auth.nonce.as_deref(), @@ -200,24 +226,81 @@ async fn get( .await .wrap_error()?; - let code = if response_type.contains(&ResponseType::Code) { + // Generate the code at this stage, since we have the PKCE params ready + if response_type.contains(&ResponseType::Code) { // 192bit random bytes encoded in base64, which gives a 32 character code let code: [u8; 24] = rand::random(); let code = BASE64URL_NOPAD.encode(&code); - Some( - oauth2_session - .add_code(&mut txn, &code, ¶ms.pkce) - .await - .wrap_error()?, - ) - } else { - None + oauth2_session + .add_code(&mut txn, &code, ¶ms.pkce) + .await + .wrap_error()?; }; // Do we have a user in this session, with a last authentication time that // matches the requirement? let user_session = oauth2_session.fetch_session(&mut txn).await.wrap_error()?; + txn.commit().await.wrap_error()?; + if let Some(user_session) = user_session { + step(oauth2_session.id, user_session, pool, templates).await + } else { + let next = StepRequest::new(oauth2_session.id) + .build_uri() + .wrap_error()? + .to_string(); + + let destination = LoginRequest::new(Some(next)).build_uri().wrap_error()?; + Ok(Box::new(see_other(destination))) + } +} + +#[derive(Deserialize, Serialize)] +struct StepRequest { + id: i64, +} + +impl StepRequest { + fn new(id: i64) -> Self { + Self { id } + } + + fn build_uri(&self) -> anyhow::Result { + let qs = serde_urlencoded::to_string(self)?; + let path_and_query = PathAndQuery::try_from(format!("/oauth2/authorize/step?{}", qs))?; + let uri = Uri::from_parts({ + let mut parts = Parts::default(); + parts.path_and_query = Some(path_and_query); + parts + })?; + Ok(uri) + } +} + +async fn step( + oauth2_session_id: i64, + user_session: SessionInfo, + pool: PgPool, + templates: Templates, +) -> Result, Rejection> { + // Start a DB transaction + let mut txn = pool.begin().await.wrap_error()?; + + let mut oauth2_session = get_session_by_id(&mut txn, oauth2_session_id) + .await + .wrap_error()?; + + let user_session = oauth2_session + .match_or_set_session(&mut txn, user_session) + .await + .wrap_error()?; + + let response_mode = oauth2_session.response_mode().wrap_error()?; + let response_type = oauth2_session.response_type().wrap_error()?; + let redirect_uri = oauth2_session.redirect_uri().wrap_error()?; + + let reply = + // Check if the active session is valid if user_session.active && user_session.last_authd_at >= oauth2_session.max_auth_time() { // Yep! Let's complete the auth now let mut params = AuthorizationResponse { @@ -226,8 +309,8 @@ async fn get( }; // Did they request an auth code? - if let Some(ref code) = code { - params.code = Some(code.code.clone()); + if response_type.contains(&ResponseType::Code) { + params.code = Some(oauth2_session.fetch_code(&mut txn).await.wrap_error()?); } // Did they request an access token? @@ -243,20 +326,13 @@ async fn get( todo!("id tokens are not implemented yet"); } - txn.commit().await.wrap_error()?; - let reply = back_to_client(redirect_uri.clone(), response_mode, params, &templates) - .wrap_error()?; - return Ok(reply); - } - // TODO: show reauth form - } - - // TODO: show login form + back_to_client(redirect_uri, response_mode, params, &templates).wrap_error()? + } else { + // Ask for a reauth + // TODO: have the OAuth2 session ID in there + Box::new(see_other(Uri::from_static("/reauth"))) + }; txn.commit().await.wrap_error()?; - Ok(Box::new(warp::reply::json(&serde_json::json!({ - "session": oauth2_session, - "code": code, - "redirect_uri": redirect_uri, - })))) + Ok(reply) } diff --git a/matrix-authentication-service/src/handlers/views/login.rs b/matrix-authentication-service/src/handlers/views/login.rs index 758276a17..a95cdccd0 100644 --- a/matrix-authentication-service/src/handlers/views/login.rs +++ b/matrix-authentication-service/src/handlers/views/login.rs @@ -12,9 +12,12 @@ // See the License for the specific language governing permissions and // limitations under the License. -use serde::Deserialize; +use std::convert::TryFrom; + +use hyper::http::uri::{Parts, PathAndQuery, Uri}; +use serde::{Deserialize, Serialize}; use sqlx::PgPool; -use warp::{hyper::Uri, reply::html, wrap_fn, Filter, Rejection, Reply}; +use warp::{reply::html, wrap_fn, Filter, Rejection, Reply}; use crate::{ config::{CookiesConfig, CsrfConfig}, @@ -28,6 +31,28 @@ use crate::{ templates::{TemplateContext, Templates}, }; +#[derive(Serialize, Deserialize)] +pub struct LoginRequest { + next: Option, +} + +impl LoginRequest { + pub fn new(next: Option) -> Self { + Self { next } + } + + pub fn build_uri(&self) -> anyhow::Result { + let qs = serde_urlencoded::to_string(self)?; + let path_and_query = PathAndQuery::try_from(format!("/login?{}", qs))?; + let uri = Uri::from_parts({ + let mut parts = Parts::default(); + parts.path_and_query = Some(path_and_query); + parts + })?; + Ok(uri) + } +} + #[derive(Deserialize)] struct LoginForm { username: String, @@ -50,6 +75,7 @@ pub(super) fn filter( let post = warp::post() .and(with_pool(pool)) .and(protected_form(cookies_config)) + .and(warp::query()) .and_then(post) .untuple_one() .with(wrap_fn(save_session(cookies_config))); @@ -68,10 +94,28 @@ async fn get( Ok((csrf_token, html(content))) } -async fn post(db: PgPool, form: LoginForm) -> Result<(SessionInfo, impl Reply), Rejection> { +async fn post( + db: PgPool, + form: LoginForm, + query: LoginRequest, +) -> Result<(SessionInfo, impl Reply), Rejection> { let session_info = login(&db, &form.username, &form.password) .await .wrap_error()?; - Ok((session_info, warp::redirect(Uri::from_static("/")))) + let uri: Uri = Uri::from_parts({ + let mut parts = Parts::default(); + parts.path_and_query = Some( + query + .next + .map(warp::http::uri::PathAndQuery::try_from) + .transpose() + .wrap_error()? + .unwrap_or_else(|| PathAndQuery::from_static("/")), + ); + parts + }) + .wrap_error()?; + + Ok((session_info, warp::redirect(uri))) } diff --git a/matrix-authentication-service/src/handlers/views/mod.rs b/matrix-authentication-service/src/handlers/views/mod.rs index acf52e561..f772825b5 100644 --- a/matrix-authentication-service/src/handlers/views/mod.rs +++ b/matrix-authentication-service/src/handlers/views/mod.rs @@ -25,6 +25,7 @@ mod login; mod logout; mod reauth; +pub use self::login::LoginRequest; use self::{ index::filter as index, login::filter as login, logout::filter as logout, reauth::filter as reauth, diff --git a/matrix-authentication-service/src/storage/oauth2.rs b/matrix-authentication-service/src/storage/oauth2.rs index ab55c0ad5..b1730be82 100644 --- a/matrix-authentication-service/src/storage/oauth2.rs +++ b/matrix-authentication-service/src/storage/oauth2.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::{collections::HashSet, convert::TryFrom}; +use std::{collections::HashSet, convert::TryFrom, str::FromStr}; use anyhow::Context; use chrono::{DateTime, Duration, Utc}; @@ -23,14 +23,16 @@ use oauth2_types::{ }; use serde::Serialize; use sqlx::{Executor, FromRow, Postgres}; +use url::Url; use super::{user::lookup_session, SessionInfo}; #[derive(FromRow, Serialize)] pub struct OAuth2Session { - id: i64, + pub id: i64, user_session_id: Option, client_id: String, + redirect_uri: String, scope: String, pub state: Option, nonce: Option, @@ -52,9 +54,9 @@ impl OAuth2Session { add_code(executor, self.id, code, code_challenge).await } - pub async fn fetch_session<'e>( + pub async fn fetch_session( &self, - executor: impl Executor<'e, Database = Postgres>, + executor: impl Executor<'_, Database = Postgres>, ) -> anyhow::Result> { match self.user_session_id { Some(id) => { @@ -65,11 +67,61 @@ impl OAuth2Session { } } + pub async fn fetch_code( + &self, + executor: impl Executor<'_, Database = Postgres>, + ) -> anyhow::Result { + get_code_for_session(executor, self.id).await + } + + pub async fn match_or_set_session( + &mut self, + executor: impl Executor<'_, Database = Postgres>, + session: SessionInfo, + ) -> anyhow::Result { + match self.user_session_id { + Some(id) if id == session.key() => Ok(session), + Some(id) => Err(anyhow::anyhow!( + "session mismatch, expected {}, got {}", + id, + session.key() + )), + None => { + sqlx::query!( + "UPDATE oauth2_sessions SET user_session_id = $1 WHERE id = $2", + session.key(), + self.id, + ) + .execute(executor) + .await + .context("could not update oauth2 session")?; + Ok(session) + } + } + } + pub fn max_auth_time(&self) -> Option> { self.max_age .map(|d| Duration::seconds(i64::from(d))) .map(|d| self.created_at - d) } + + pub fn response_type(&self) -> anyhow::Result> { + self.response_type + .split(' ') + .map(|s| { + ResponseType::from_str(s).with_context(|| format!("invalid response type {}", s)) + }) + .collect() + } + + pub fn response_mode(&self) -> anyhow::Result { + self.response_mode.parse().context("invalid response mode") + } + + pub fn redirect_uri(&self) -> anyhow::Result { + self.redirect_uri.parse().context("invalid redirect uri") + } } #[allow(clippy::too_many_arguments)] @@ -77,6 +129,7 @@ pub async fn start_session( executor: impl Executor<'_, Database = Postgres>, optional_session_id: Option, client_id: &str, + redirect_uri: &Url, scope: &str, state: Option<&str>, nonce: Option<&str>, @@ -96,15 +149,17 @@ pub async fn start_session( OAuth2Session, r#" INSERT INTO oauth2_sessions - (user_session_id, client_id, scope, state, nonce, max_age, response_type, response_mode) + (user_session_id, client_id, redirect_uri, scope, state, nonce, max_age, + response_type, response_mode) VALUES - ($1, $2, $3, $4, $5, $6, $7, $8) + ($1, $2, $3, $4, $5, $6, $7, $8, $9) RETURNING - id, user_session_id, client_id, scope, state, nonce, max_age, + id, user_session_id, client_id, redirect_uri, scope, state, nonce, max_age, response_type, response_mode, created_at, updated_at "#, optional_session_id, client_id, + redirect_uri.as_str(), scope, state, nonce, @@ -117,6 +172,43 @@ pub async fn start_session( .context("could not insert oauth2 session") } +pub async fn get_session_by_id( + executor: impl Executor<'_, Database = Postgres>, + oauth2_session_id: i64, +) -> anyhow::Result { + sqlx::query_as!( + OAuth2Session, + r#" + SELECT + id, user_session_id, client_id, redirect_uri, scope, state, nonce, + max_age, response_type, response_mode, created_at, updated_at + FROM oauth2_sessions + WHERE id = $1 + "#, + oauth2_session_id + ) + .fetch_one(executor) + .await + .context("could not fetch oauth2 session") +} + +pub async fn get_code_for_session( + executor: impl Executor<'_, Database = Postgres>, + oauth2_session_id: i64, +) -> anyhow::Result { + sqlx::query_scalar!( + r#" + SELECT code + FROM oauth2_codes + WHERE oauth2_session_id = $1 + "#, + oauth2_session_id + ) + .fetch_one(executor) + .await + .context("could not fetch oauth2 code") +} + #[derive(FromRow, Serialize)] pub struct OAuth2Code { id: i64,