diff --git a/crates/handlers/src/compat/login_sso_complete.rs b/crates/handlers/src/compat/login_sso_complete.rs index 84c8f8525..df059cd36 100644 --- a/crates/handlers/src/compat/login_sso_complete.rs +++ b/crates/handlers/src/compat/login_sso_complete.rs @@ -83,7 +83,7 @@ pub async fn get( Some(CompatLoginSsoAction::Register) => { url_builder.redirect(&mas_router::Register::and_continue_compat_sso_login(id)) } - Some(CompatLoginSsoAction::Login) | None => { + Some(CompatLoginSsoAction::Login | CompatLoginSsoAction::Unknown) | None => { url_builder.redirect(&mas_router::Login::and_continue_compat_sso_login(id)) } }; @@ -224,7 +224,7 @@ pub async fn post( Some(CompatLoginSsoAction::Register) => { url_builder.redirect(&mas_router::Register::and_continue_compat_sso_login(id)) } - Some(CompatLoginSsoAction::Login) | None => { + Some(CompatLoginSsoAction::Login | CompatLoginSsoAction::Unknown) | None => { url_builder.redirect(&mas_router::Login::and_continue_compat_sso_login(id)) } }; diff --git a/crates/handlers/src/compat/login_sso_redirect.rs b/crates/handlers/src/compat/login_sso_redirect.rs index 1ad47c55e..8edb868fd 100644 --- a/crates/handlers/src/compat/login_sso_redirect.rs +++ b/crates/handlers/src/compat/login_sso_redirect.rs @@ -13,7 +13,6 @@ use mas_router::{CompatLoginSsoAction, CompatLoginSsoComplete, UrlBuilder}; use mas_storage::{BoxRepository, compat::CompatSsoLoginRepository}; use rand::distributions::{Alphanumeric, DistString}; use serde::Deserialize; -use serde_with::serde; use thiserror::Error; use url::Url; @@ -23,12 +22,21 @@ use crate::impl_from_error_for_route; pub struct Params { #[serde(rename = "redirectUrl")] redirect_url: Option, + action: Option, #[serde(rename = "org.matrix.msc3824.action")] unstable_action: Option, } +impl Params { + fn action(&self) -> Option { + self.action + .filter(CompatLoginSsoAction::is_known) + .or(self.unstable_action.filter(CompatLoginSsoAction::is_known)) + } +} + #[derive(Debug, Error)] pub enum RouteError { #[error(transparent)] @@ -62,6 +70,8 @@ pub async fn get( State(url_builder): State, Query(params): Query, ) -> Result { + let action = params.action(); + // Check the redirectUrl parameter let redirect_url = params.redirect_url.ok_or(RouteError::MissingRedirectUrl)?; let redirect_url = Url::parse(&redirect_url).map_err(|_| RouteError::InvalidRedirectUrl)?; @@ -84,10 +94,7 @@ pub async fn get( repo.save().await?; - Ok(url_builder.absolute_redirect(&CompatLoginSsoComplete::new( - login.id, - params.action.or(params.unstable_action), - ))) + Ok(url_builder.absolute_redirect(&CompatLoginSsoComplete::new(login.id, action))) } #[cfg(test)] @@ -121,4 +128,29 @@ mod tests { assert!(location.contains("org.matrix.msc3824.action=register")); assert!(location.contains("action=register")); } + + #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] + async fn test_unknown_action(pool: PgPool) { + let state: TestState = TestState::from_pool(pool).await.unwrap(); + + let request = Request::get( + "/_matrix/client/v3/login/sso/redirect?\ + redirectUrl=http://example.com/\ + &org.matrix.msc3824.action=undefinedaction", + ) + .empty(); + + let response = state.request(request).await; + + response.assert_status(StatusCode::SEE_OTHER); + + let location = response + .headers() + .get("Location") + .unwrap() + .to_str() + .unwrap(); + assert!(!location.contains("org.matrix.msc3824.action")); + assert!(!location.contains("action")); + } } diff --git a/crates/router/src/endpoints.rs b/crates/router/src/endpoints.rs index 37c200aac..6aa18f13d 100644 --- a/crates/router/src/endpoints.rs +++ b/crates/router/src/endpoints.rs @@ -628,6 +628,16 @@ impl SimpleRoute for CompatLoginSsoRedirectIdp { pub enum CompatLoginSsoAction { Login, Register, + #[serde(other)] + Unknown, +} + +impl CompatLoginSsoAction { + /// Returns true if the action is a known action. + #[must_use] + pub fn is_known(&self) -> bool { + !matches!(self, Self::Unknown) + } } #[derive(Debug, Serialize, Deserialize, Clone, Copy)]