From e74ddb832ac7361b8c5cdbe60aa66070f5aba9a0 Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Mon, 17 Feb 2025 10:35:35 +0100 Subject: [PATCH] Merge the GraphQL requester and requester fingerprint into a single struct --- crates/handlers/src/graphql/mod.rs | 59 ++++++++++++++----- crates/handlers/src/graphql/mutations/user.rs | 8 +-- .../src/graphql/mutations/user_email.rs | 15 +++-- crates/handlers/src/graphql/query/viewer.rs | 24 ++++---- crates/handlers/src/graphql/state.rs | 8 +-- 5 files changed, 68 insertions(+), 46 deletions(-) diff --git a/crates/handlers/src/graphql/mod.rs b/crates/handlers/src/graphql/mod.rs index 3fbf30166..11c461bcc 100644 --- a/crates/handlers/src/graphql/mod.rs +++ b/crates/handlers/src/graphql/mod.rs @@ -6,7 +6,7 @@ #![allow(clippy::module_name_repetitions)] -use std::sync::Arc; +use std::{net::IpAddr, ops::Deref, sync::Arc}; use async_graphql::{ extensions::Tracing, @@ -240,7 +240,7 @@ async fn get_requester( session_info: SessionInfo, token: Option<&str>, ) -> Result { - let requester = if let Some(token) = token { + let entity = if let Some(token) = token { // If we haven't enabled undocumented_oauth2_access on the listener, we bail out if !undocumented_oauth2_access { return Err(RouteError::InvalidToken); @@ -285,7 +285,7 @@ async fn get_requester( return Err(RouteError::MissingScope); } - Requester::OAuth2Session(Box::new((session, user))) + RequestingEntity::OAuth2Session(Box::new((session, user))) } else { let maybe_session = session_info.load_session(&mut repo).await?; @@ -295,8 +295,14 @@ async fn get_requester( .await; } - Requester::from(maybe_session) + RequestingEntity::from(maybe_session) }; + + let requester = Requester { + entity, + ip_address: activity_tracker.ip(), + }; + repo.cancel().await?; Ok(requester) } @@ -312,7 +318,6 @@ pub async fn post( cookie_jar: CookieJar, content_type: Option>, authorization: Option>>, - requester_fingerprint: RequesterFingerprint, body: Body, ) -> Result { let body = body.into_data_stream(); @@ -339,7 +344,6 @@ pub async fn post( MultipartOptions::default(), ) .await? - .data(requester_fingerprint) .data(requester); // XXX: this should probably return another error response? let span = span_for_graphql_request(&request); @@ -366,7 +370,6 @@ pub async fn get( activity_tracker: BoundActivityTracker, cookie_jar: CookieJar, authorization: Option>>, - requester_fingerprint: RequesterFingerprint, RawQuery(query): RawQuery, ) -> Result { let token = authorization @@ -383,9 +386,8 @@ pub async fn get( ) .await?; - let request = async_graphql::http::parse_query_string(&query.unwrap_or_default())? - .data(requester) - .data(requester_fingerprint); + let request = + async_graphql::http::parse_query_string(&query.unwrap_or_default())?.data(requester); let span = span_for_graphql_request(&request); let response = schema.execute(request).instrument(span).await; @@ -417,9 +419,32 @@ pub fn schema_builder() -> SchemaBuilder { .register_output_type::() } +pub struct Requester { + entity: RequestingEntity, + ip_address: Option, +} + +impl Requester { + pub fn fingerprint(&self) -> RequesterFingerprint { + if let Some(ip) = self.ip_address { + RequesterFingerprint::new(ip) + } else { + RequesterFingerprint::EMPTY + } + } +} + +impl Deref for Requester { + type Target = RequestingEntity; + + fn deref(&self) -> &Self::Target { + &self.entity + } +} + /// The identity of the requester. #[derive(Debug, Clone, Default, PartialEq, Eq)] -pub enum Requester { +pub enum RequestingEntity { /// The requester presented no authentication information. #[default] Anonymous, @@ -480,7 +505,7 @@ impl OwnerId for UserId { } } -impl Requester { +impl RequestingEntity { fn browser_session(&self) -> Option<&BrowserSession> { match self { Self::BrowserSession(session) => Some(session), @@ -532,17 +557,21 @@ impl Requester { Self::BrowserSession(_) | Self::Anonymous => false, } } + + fn is_unauthenticated(&self) -> bool { + matches!(self, Self::Anonymous) + } } -impl From for Requester { +impl From for RequestingEntity { fn from(session: BrowserSession) -> Self { Self::BrowserSession(Box::new(session)) } } -impl From> for Requester +impl From> for RequestingEntity where - T: Into, + T: Into, { fn from(session: Option) -> Self { session.map(Into::into).unwrap_or_default() diff --git a/crates/handlers/src/graphql/mutations/user.rs b/crates/handlers/src/graphql/mutations/user.rs index 52c661b05..311a1dcb7 100644 --- a/crates/handlers/src/graphql/mutations/user.rs +++ b/crates/handlers/src/graphql/mutations/user.rs @@ -21,7 +21,7 @@ use zeroize::Zeroizing; use crate::graphql::{ model::{NodeType, User}, state::ContextExt, - Requester, UserId, + UserId, }; #[derive(Default)] @@ -728,7 +728,7 @@ impl UserMutations { let state = ctx.state(); let requester = ctx.requester(); let clock = state.clock(); - if !matches!(requester, Requester::Anonymous) { + if !requester.is_unauthenticated() { return Err(async_graphql::Error::new( "Account recovery is only for anonymous users.", )); @@ -830,7 +830,7 @@ impl UserMutations { input: ResendRecoveryEmailInput, ) -> Result { let state = ctx.state(); - let requester_fingerprint = ctx.requester_fingerprint(); + let requester = ctx.requester(); let clock = state.clock(); let mut rng = state.rng(); let limiter = state.limiter(); @@ -847,7 +847,7 @@ impl UserMutations { .context("Could not load recovery session")?; if let Err(e) = - limiter.check_account_recovery(requester_fingerprint, &recovery_session.email) + limiter.check_account_recovery(requester.fingerprint(), &recovery_session.email) { tracing::warn!(error = &e as &dyn std::error::Error); return Ok(ResendRecoveryEmailPayload::RateLimited); diff --git a/crates/handlers/src/graphql/mutations/user_email.rs b/crates/handlers/src/graphql/mutations/user_email.rs index 19ff39804..371864b10 100644 --- a/crates/handlers/src/graphql/mutations/user_email.rs +++ b/crates/handlers/src/graphql/mutations/user_email.rs @@ -398,7 +398,6 @@ impl UserEmailMutations { let state = ctx.state(); let id = NodeType::User.extract_ulid(&input.user_id)?; let requester = ctx.requester(); - let requester_fingerprint = ctx.requester_fingerprint(); let clock = state.clock(); let mut rng = state.rng(); @@ -428,7 +427,7 @@ impl UserEmailMutations { let res = policy .evaluate_email(mas_policy::EmailInput { email: &input.email, - requester: requester_fingerprint.into(), + requester: requester.fingerprint().into(), }) .await?; if !res.valid() { @@ -561,7 +560,6 @@ impl UserEmailMutations { let mut rng = state.rng(); let clock = state.clock(); let requester = ctx.requester(); - let requester_fingerprint = ctx.requester_fingerprint(); let limiter = state.limiter(); // Only allow calling this if the requester is a browser session @@ -591,7 +589,7 @@ impl UserEmailMutations { } if let Err(e) = - limiter.check_email_authentication_email(ctx.requester_fingerprint(), &input.email) + limiter.check_email_authentication_email(requester.fingerprint(), &input.email) { tracing::warn!(error = &e as &dyn std::error::Error); return Ok(StartEmailAuthenticationPayload::RateLimited); @@ -620,7 +618,7 @@ impl UserEmailMutations { let res = policy .evaluate_email(mas_policy::EmailInput { email: &input.email, - requester: requester_fingerprint.into(), + requester: requester.fingerprint().into(), }) .await?; if !res.valid() { @@ -660,9 +658,10 @@ impl UserEmailMutations { let mut rng = state.rng(); let clock = state.clock(); let limiter = state.limiter(); + let requester = ctx.requester(); let id = NodeType::UserEmailAuthentication.extract_ulid(&input.id)?; - let Some(browser_session) = ctx.requester().browser_session() else { + let Some(browser_session) = requester.browser_session() else { return Err(async_graphql::Error::new("Unauthorized")); }; @@ -692,8 +691,8 @@ impl UserEmailMutations { return Ok(ResendEmailAuthenticationCodePayload::Completed); } - if let Err(e) = limiter - .check_email_authentication_send_code(ctx.requester_fingerprint(), &authentication) + if let Err(e) = + limiter.check_email_authentication_send_code(requester.fingerprint(), &authentication) { tracing::warn!(error = &e as &dyn std::error::Error); return Ok(ResendEmailAuthenticationCodePayload::RateLimited); diff --git a/crates/handlers/src/graphql/query/viewer.rs b/crates/handlers/src/graphql/query/viewer.rs index 60f884357..6985dfd2e 100644 --- a/crates/handlers/src/graphql/query/viewer.rs +++ b/crates/handlers/src/graphql/query/viewer.rs @@ -9,7 +9,6 @@ use async_graphql::{Context, Object}; use crate::graphql::{ model::{Viewer, ViewerSession}, state::ContextExt, - Requester, }; #[derive(Default)] @@ -21,24 +20,25 @@ impl ViewerQuery { async fn viewer(&self, ctx: &Context<'_>) -> Viewer { let requester = ctx.requester(); - match requester { - Requester::BrowserSession(session) => Viewer::user(session.user.clone()), - Requester::OAuth2Session(tuple) => match &tuple.1 { - Some(user) => Viewer::user(user.clone()), - None => Viewer::anonymous(), - }, - Requester::Anonymous => Viewer::anonymous(), + if let Some(user) = requester.user() { + return Viewer::user(user.clone()); } + + Viewer::anonymous() } /// Get the viewer's session async fn viewer_session(&self, ctx: &Context<'_>) -> ViewerSession { let requester = ctx.requester(); - match requester { - Requester::BrowserSession(session) => ViewerSession::browser_session(*session.clone()), - Requester::OAuth2Session(tuple) => ViewerSession::oauth2_session(tuple.0.clone()), - Requester::Anonymous => ViewerSession::anonymous(), + if let Some(session) = requester.browser_session() { + return ViewerSession::browser_session(session.clone()); } + + if let Some(session) = requester.oauth2_session() { + return ViewerSession::oauth2_session(session.clone()); + } + + ViewerSession::anonymous() } } diff --git a/crates/handlers/src/graphql/state.rs b/crates/handlers/src/graphql/state.rs index 874f6f7aa..95752c4fd 100644 --- a/crates/handlers/src/graphql/state.rs +++ b/crates/handlers/src/graphql/state.rs @@ -10,7 +10,7 @@ use mas_policy::Policy; use mas_router::UrlBuilder; use mas_storage::{BoxClock, BoxRepository, BoxRng, RepositoryError}; -use crate::{graphql::Requester, passwords::PasswordManager, Limiter, RequesterFingerprint}; +use crate::{graphql::Requester, passwords::PasswordManager, Limiter}; #[async_trait::async_trait] pub trait State { @@ -31,8 +31,6 @@ pub trait ContextExt { fn state(&self) -> &BoxState; fn requester(&self) -> &Requester; - - fn requester_fingerprint(&self) -> RequesterFingerprint; } impl ContextExt for async_graphql::Context<'_> { @@ -43,8 +41,4 @@ impl ContextExt for async_graphql::Context<'_> { fn requester(&self) -> &Requester { self.data_unchecked() } - - fn requester_fingerprint(&self) -> RequesterFingerprint { - *self.data_unchecked() - } }