diff --git a/crates/data-model/src/lib.rs b/crates/data-model/src/lib.rs index fd5c0e633..05b2466b9 100644 --- a/crates/data-model/src/lib.rs +++ b/crates/data-model/src/lib.rs @@ -56,8 +56,8 @@ pub use self::{ }, user_agent::{DeviceType, UserAgent}, users::{ - Authentication, AuthenticationMethod, BrowserSession, Password, User, UserEmail, - UserEmailAuthentication, UserEmailAuthenticationCode, UserRecoverySession, + Authentication, AuthenticationMethod, BrowserSession, MatrixUser, Password, User, + UserEmail, UserEmailAuthentication, UserEmailAuthenticationCode, UserRecoverySession, UserRecoveryTicket, UserRegistration, UserRegistrationPassword, UserRegistrationToken, }, utils::{BoxClock, BoxRng}, diff --git a/crates/data-model/src/users.rs b/crates/data-model/src/users.rs index 541eb26d2..78c483e12 100644 --- a/crates/data-model/src/users.rs +++ b/crates/data-model/src/users.rs @@ -12,6 +12,12 @@ use serde::Serialize; use ulid::Ulid; use url::Url; +#[derive(Debug, Clone, PartialEq, Eq, Serialize)] +pub struct MatrixUser { + pub mxid: String, + pub display_name: Option, +} + #[derive(Debug, Clone, PartialEq, Eq, Serialize)] pub struct User { pub id: Ulid, diff --git a/crates/handlers/src/oauth2/authorization/consent.rs b/crates/handlers/src/oauth2/authorization/consent.rs index 2587828b5..ab51bef1c 100644 --- a/crates/handlers/src/oauth2/authorization/consent.rs +++ b/crates/handlers/src/oauth2/authorization/consent.rs @@ -4,6 +4,8 @@ // SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial // Please see LICENSE files in the repository root for full details. +use std::{sync::Arc, time::Duration}; + use axum::{ extract::{Form, Path, State}, response::{Html, IntoResponse, Response}, @@ -15,8 +17,9 @@ use mas_axum_utils::{ cookies::CookieJar, csrf::{CsrfExt, ProtectedForm}, }; -use mas_data_model::{AuthorizationGrantStage, BoxClock, BoxRng}; +use mas_data_model::{AuthorizationGrantStage, BoxClock, BoxRng, MatrixUser}; use mas_keystore::Keystore; +use mas_matrix::HomeserverConnection; use mas_policy::Policy; use mas_router::{PostAuthAction, UrlBuilder}; use mas_storage::{ @@ -87,6 +90,7 @@ pub(crate) async fn get( PreferredLanguage(locale): PreferredLanguage, State(templates): State, State(url_builder): State, + State(homeserver): State>, mut policy: Policy, mut repo: BoxRepository, activity_tracker: BoundActivityTracker, @@ -138,6 +142,9 @@ pub(crate) async fn get( let session_counts = count_user_sessions_for_limiting(&mut repo, &session.user).await?; + // We can close the repository early, we don't need it at this point + repo.save().await?; + let res = policy .evaluate_authorization_grant(mas_policy::AuthorizationGrantInput { user: Some(&session.user), @@ -162,7 +169,37 @@ pub(crate) async fn get( return Ok((cookie_jar, Html(content)).into_response()); } - let ctx = ConsentContext::new(grant, client) + // Fetch informations about the user. This is purely cosmetic, so we let it + // fail and put a 1s timeout to it in case we fail to query it + // XXX: we're likely to need this in other places + let localpart = &session.user.username; + let display_name = match tokio::time::timeout( + Duration::from_secs(1), + homeserver.query_user(localpart), + ) + .await + { + Ok(Ok(user)) => user.displayname, + Ok(Err(err)) => { + tracing::warn!( + error = &*err as &dyn std::error::Error, + localpart, + "Failed to query user" + ); + None + } + Err(_) => { + tracing::warn!(localpart, "Timed out while querying user"); + None + } + }; + + let matrix_user = MatrixUser { + mxid: homeserver.mxid(localpart), + display_name, + }; + + let ctx = ConsentContext::new(grant, client, matrix_user) .with_session(session) .with_csrf(csrf_token.form_value()) .with_language(locale); diff --git a/crates/handlers/src/oauth2/device/consent.rs b/crates/handlers/src/oauth2/device/consent.rs index e1d32870f..3912d2dc1 100644 --- a/crates/handlers/src/oauth2/device/consent.rs +++ b/crates/handlers/src/oauth2/device/consent.rs @@ -4,6 +4,8 @@ // SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial // Please see LICENSE files in the repository root for full details. +use std::{sync::Arc, time::Duration}; + use anyhow::Context; use axum::{ Form, @@ -16,7 +18,8 @@ use mas_axum_utils::{ cookies::CookieJar, csrf::{CsrfExt, ProtectedForm}, }; -use mas_data_model::{BoxClock, BoxRng}; +use mas_data_model::{BoxClock, BoxRng, MatrixUser}; +use mas_matrix::HomeserverConnection; use mas_policy::Policy; use mas_router::UrlBuilder; use mas_storage::BoxRepository; @@ -49,6 +52,7 @@ pub(crate) async fn get( PreferredLanguage(locale): PreferredLanguage, State(templates): State, State(url_builder): State, + State(homeserver): State>, mut repo: BoxRepository, mut policy: Policy, activity_tracker: BoundActivityTracker, @@ -105,6 +109,9 @@ pub(crate) async fn get( let session_counts = count_user_sessions_for_limiting(&mut repo, &session.user).await?; + // We can close the repository early, we don't need it at this point + repo.save().await?; + // Evaluate the policy let res = policy .evaluate_authorization_grant(mas_policy::AuthorizationGrantInput { @@ -133,7 +140,37 @@ pub(crate) async fn get( return Ok((cookie_jar, Html(content)).into_response()); } - let ctx = DeviceConsentContext::new(grant, client) + // Fetch informations about the user. This is purely cosmetic, so we let it + // fail and put a 1s timeout to it in case we fail to query it + // XXX: we're likely to need this in other places + let localpart = &session.user.username; + let display_name = match tokio::time::timeout( + Duration::from_secs(1), + homeserver.query_user(localpart), + ) + .await + { + Ok(Ok(user)) => user.displayname, + Ok(Err(err)) => { + tracing::warn!( + error = &*err as &dyn std::error::Error, + localpart, + "Failed to query user" + ); + None + } + Err(_) => { + tracing::warn!(localpart, "Timed out while querying user"); + None + } + }; + + let matrix_user = MatrixUser { + mxid: homeserver.mxid(localpart), + display_name, + }; + + let ctx = DeviceConsentContext::new(grant, client, matrix_user) .with_session(session) .with_csrf(csrf_token.form_value()) .with_language(locale); @@ -153,6 +190,7 @@ pub(crate) async fn post( PreferredLanguage(locale): PreferredLanguage, State(templates): State, State(url_builder): State, + State(homeserver): State>, mut repo: BoxRepository, mut policy: Policy, activity_tracker: BoundActivityTracker, @@ -265,7 +303,37 @@ pub(crate) async fn post( repo.save().await?; - let ctx = DeviceConsentContext::new(grant, client) + // Fetch informations about the user. This is purely cosmetic, so we let it + // fail and put a 1s timeout to it in case we fail to query it + // XXX: we're likely to need this in other places + let localpart = &session.user.username; + let display_name = match tokio::time::timeout( + Duration::from_secs(1), + homeserver.query_user(localpart), + ) + .await + { + Ok(Ok(user)) => user.displayname, + Ok(Err(err)) => { + tracing::warn!( + error = &*err as &dyn std::error::Error, + localpart, + "Failed to query user" + ); + None + } + Err(_) => { + tracing::warn!(localpart, "Timed out while querying user"); + None + } + }; + + let matrix_user = MatrixUser { + mxid: homeserver.mxid(localpart), + display_name, + }; + + let ctx = DeviceConsentContext::new(grant, client, matrix_user) .with_session(session) .with_csrf(csrf_token.form_value()) .with_language(locale); diff --git a/crates/templates/src/context.rs b/crates/templates/src/context.rs index f836d7c4b..73a38972f 100644 --- a/crates/templates/src/context.rs +++ b/crates/templates/src/context.rs @@ -21,10 +21,11 @@ use chrono::{DateTime, Duration, Utc}; use http::{Method, Uri, Version}; use mas_data_model::{ AuthorizationGrant, BrowserSession, Client, CompatSsoLogin, CompatSsoLoginState, - DeviceCodeGrant, UpstreamOAuthLink, UpstreamOAuthProvider, UpstreamOAuthProviderClaimsImports, - UpstreamOAuthProviderDiscoveryMode, UpstreamOAuthProviderOnBackchannelLogout, - UpstreamOAuthProviderPkceMode, UpstreamOAuthProviderTokenAuthMethod, User, - UserEmailAuthentication, UserEmailAuthenticationCode, UserRecoverySession, UserRegistration, + DeviceCodeGrant, MatrixUser, UpstreamOAuthLink, UpstreamOAuthProvider, + UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderDiscoveryMode, + UpstreamOAuthProviderOnBackchannelLogout, UpstreamOAuthProviderPkceMode, + UpstreamOAuthProviderTokenAuthMethod, User, UserEmailAuthentication, + UserEmailAuthenticationCode, UserRecoverySession, UserRegistration, }; use mas_i18n::DataLocale; use mas_iana::jose::JsonWebSignatureAlg; @@ -732,6 +733,7 @@ pub struct ConsentContext { grant: AuthorizationGrant, client: Client, action: PostAuthAction, + matrix_user: MatrixUser, } impl TemplateContext for ConsentContext { @@ -755,6 +757,10 @@ impl TemplateContext for ConsentContext { grant, client, action, + matrix_user: MatrixUser { + mxid: "@alice:example.com".to_owned(), + display_name: Some("Alice".to_owned()), + }, } }) .collect(), @@ -765,12 +771,13 @@ impl TemplateContext for ConsentContext { impl ConsentContext { /// Constructs a context for the client consent page #[must_use] - pub fn new(grant: AuthorizationGrant, client: Client) -> Self { + pub fn new(grant: AuthorizationGrant, client: Client, matrix_user: MatrixUser) -> Self { let action = PostAuthAction::continue_grant(grant.id); Self { grant, client, action, + matrix_user, } } } @@ -1748,13 +1755,18 @@ impl TemplateContext for DeviceLinkContext { pub struct DeviceConsentContext { grant: DeviceCodeGrant, client: Client, + matrix_user: MatrixUser, } impl DeviceConsentContext { /// Constructs a new context with an existing linked user #[must_use] - pub fn new(grant: DeviceCodeGrant, client: Client) -> Self { - Self { grant, client } + pub fn new(grant: DeviceCodeGrant, client: Client, matrix_user: MatrixUser) -> Self { + Self { + grant, + client, + matrix_user, + } } } @@ -1782,7 +1794,14 @@ impl TemplateContext for DeviceConsentContext { ip_address: Some(IpAddr::V4(Ipv4Addr::LOCALHOST)), user_agent: Some("Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/93.0.0.0 Safari/537.36".to_owned()), }; - Self { grant, client } + Self { + grant, + client, + matrix_user: MatrixUser { + mxid: "@alice:example.com".to_owned(), + display_name: Some("Alice".to_owned()), + } + } }) .collect()) }