Merge the GraphQL requester and requester fingerprint into a single struct

This commit is contained in:
Quentin Gliech
2025-02-17 10:35:35 +01:00
parent b1b7bf5725
commit e74ddb832a
5 changed files with 68 additions and 46 deletions

View File

@@ -6,7 +6,7 @@
#![allow(clippy::module_name_repetitions)]
use std::sync::Arc;
use std::{net::IpAddr, ops::Deref, sync::Arc};
use async_graphql::{
extensions::Tracing,
@@ -240,7 +240,7 @@ async fn get_requester(
session_info: SessionInfo,
token: Option<&str>,
) -> Result<Requester, RouteError> {
let requester = if let Some(token) = token {
let entity = if let Some(token) = token {
// If we haven't enabled undocumented_oauth2_access on the listener, we bail out
if !undocumented_oauth2_access {
return Err(RouteError::InvalidToken);
@@ -285,7 +285,7 @@ async fn get_requester(
return Err(RouteError::MissingScope);
}
Requester::OAuth2Session(Box::new((session, user)))
RequestingEntity::OAuth2Session(Box::new((session, user)))
} else {
let maybe_session = session_info.load_session(&mut repo).await?;
@@ -295,8 +295,14 @@ async fn get_requester(
.await;
}
Requester::from(maybe_session)
RequestingEntity::from(maybe_session)
};
let requester = Requester {
entity,
ip_address: activity_tracker.ip(),
};
repo.cancel().await?;
Ok(requester)
}
@@ -312,7 +318,6 @@ pub async fn post(
cookie_jar: CookieJar,
content_type: Option<TypedHeader<ContentType>>,
authorization: Option<TypedHeader<Authorization<Bearer>>>,
requester_fingerprint: RequesterFingerprint,
body: Body,
) -> Result<impl IntoResponse, RouteError> {
let body = body.into_data_stream();
@@ -339,7 +344,6 @@ pub async fn post(
MultipartOptions::default(),
)
.await?
.data(requester_fingerprint)
.data(requester); // XXX: this should probably return another error response?
let span = span_for_graphql_request(&request);
@@ -366,7 +370,6 @@ pub async fn get(
activity_tracker: BoundActivityTracker,
cookie_jar: CookieJar,
authorization: Option<TypedHeader<Authorization<Bearer>>>,
requester_fingerprint: RequesterFingerprint,
RawQuery(query): RawQuery,
) -> Result<impl IntoResponse, FancyError> {
let token = authorization
@@ -383,9 +386,8 @@ pub async fn get(
)
.await?;
let request = async_graphql::http::parse_query_string(&query.unwrap_or_default())?
.data(requester)
.data(requester_fingerprint);
let request =
async_graphql::http::parse_query_string(&query.unwrap_or_default())?.data(requester);
let span = span_for_graphql_request(&request);
let response = schema.execute(request).instrument(span).await;
@@ -417,9 +419,32 @@ pub fn schema_builder() -> SchemaBuilder {
.register_output_type::<CreationEvent>()
}
pub struct Requester {
entity: RequestingEntity,
ip_address: Option<IpAddr>,
}
impl Requester {
pub fn fingerprint(&self) -> RequesterFingerprint {
if let Some(ip) = self.ip_address {
RequesterFingerprint::new(ip)
} else {
RequesterFingerprint::EMPTY
}
}
}
impl Deref for Requester {
type Target = RequestingEntity;
fn deref(&self) -> &Self::Target {
&self.entity
}
}
/// The identity of the requester.
#[derive(Debug, Clone, Default, PartialEq, Eq)]
pub enum Requester {
pub enum RequestingEntity {
/// The requester presented no authentication information.
#[default]
Anonymous,
@@ -480,7 +505,7 @@ impl OwnerId for UserId {
}
}
impl Requester {
impl RequestingEntity {
fn browser_session(&self) -> Option<&BrowserSession> {
match self {
Self::BrowserSession(session) => Some(session),
@@ -532,17 +557,21 @@ impl Requester {
Self::BrowserSession(_) | Self::Anonymous => false,
}
}
fn is_unauthenticated(&self) -> bool {
matches!(self, Self::Anonymous)
}
}
impl From<BrowserSession> for Requester {
impl From<BrowserSession> for RequestingEntity {
fn from(session: BrowserSession) -> Self {
Self::BrowserSession(Box::new(session))
}
}
impl<T> From<Option<T>> for Requester
impl<T> From<Option<T>> for RequestingEntity
where
T: Into<Requester>,
T: Into<RequestingEntity>,
{
fn from(session: Option<T>) -> Self {
session.map(Into::into).unwrap_or_default()

View File

@@ -21,7 +21,7 @@ use zeroize::Zeroizing;
use crate::graphql::{
model::{NodeType, User},
state::ContextExt,
Requester, UserId,
UserId,
};
#[derive(Default)]
@@ -728,7 +728,7 @@ impl UserMutations {
let state = ctx.state();
let requester = ctx.requester();
let clock = state.clock();
if !matches!(requester, Requester::Anonymous) {
if !requester.is_unauthenticated() {
return Err(async_graphql::Error::new(
"Account recovery is only for anonymous users.",
));
@@ -830,7 +830,7 @@ impl UserMutations {
input: ResendRecoveryEmailInput,
) -> Result<ResendRecoveryEmailPayload, async_graphql::Error> {
let state = ctx.state();
let requester_fingerprint = ctx.requester_fingerprint();
let requester = ctx.requester();
let clock = state.clock();
let mut rng = state.rng();
let limiter = state.limiter();
@@ -847,7 +847,7 @@ impl UserMutations {
.context("Could not load recovery session")?;
if let Err(e) =
limiter.check_account_recovery(requester_fingerprint, &recovery_session.email)
limiter.check_account_recovery(requester.fingerprint(), &recovery_session.email)
{
tracing::warn!(error = &e as &dyn std::error::Error);
return Ok(ResendRecoveryEmailPayload::RateLimited);

View File

@@ -398,7 +398,6 @@ impl UserEmailMutations {
let state = ctx.state();
let id = NodeType::User.extract_ulid(&input.user_id)?;
let requester = ctx.requester();
let requester_fingerprint = ctx.requester_fingerprint();
let clock = state.clock();
let mut rng = state.rng();
@@ -428,7 +427,7 @@ impl UserEmailMutations {
let res = policy
.evaluate_email(mas_policy::EmailInput {
email: &input.email,
requester: requester_fingerprint.into(),
requester: requester.fingerprint().into(),
})
.await?;
if !res.valid() {
@@ -561,7 +560,6 @@ impl UserEmailMutations {
let mut rng = state.rng();
let clock = state.clock();
let requester = ctx.requester();
let requester_fingerprint = ctx.requester_fingerprint();
let limiter = state.limiter();
// Only allow calling this if the requester is a browser session
@@ -591,7 +589,7 @@ impl UserEmailMutations {
}
if let Err(e) =
limiter.check_email_authentication_email(ctx.requester_fingerprint(), &input.email)
limiter.check_email_authentication_email(requester.fingerprint(), &input.email)
{
tracing::warn!(error = &e as &dyn std::error::Error);
return Ok(StartEmailAuthenticationPayload::RateLimited);
@@ -620,7 +618,7 @@ impl UserEmailMutations {
let res = policy
.evaluate_email(mas_policy::EmailInput {
email: &input.email,
requester: requester_fingerprint.into(),
requester: requester.fingerprint().into(),
})
.await?;
if !res.valid() {
@@ -660,9 +658,10 @@ impl UserEmailMutations {
let mut rng = state.rng();
let clock = state.clock();
let limiter = state.limiter();
let requester = ctx.requester();
let id = NodeType::UserEmailAuthentication.extract_ulid(&input.id)?;
let Some(browser_session) = ctx.requester().browser_session() else {
let Some(browser_session) = requester.browser_session() else {
return Err(async_graphql::Error::new("Unauthorized"));
};
@@ -692,8 +691,8 @@ impl UserEmailMutations {
return Ok(ResendEmailAuthenticationCodePayload::Completed);
}
if let Err(e) = limiter
.check_email_authentication_send_code(ctx.requester_fingerprint(), &authentication)
if let Err(e) =
limiter.check_email_authentication_send_code(requester.fingerprint(), &authentication)
{
tracing::warn!(error = &e as &dyn std::error::Error);
return Ok(ResendEmailAuthenticationCodePayload::RateLimited);

View File

@@ -9,7 +9,6 @@ use async_graphql::{Context, Object};
use crate::graphql::{
model::{Viewer, ViewerSession},
state::ContextExt,
Requester,
};
#[derive(Default)]
@@ -21,24 +20,25 @@ impl ViewerQuery {
async fn viewer(&self, ctx: &Context<'_>) -> Viewer {
let requester = ctx.requester();
match requester {
Requester::BrowserSession(session) => Viewer::user(session.user.clone()),
Requester::OAuth2Session(tuple) => match &tuple.1 {
Some(user) => Viewer::user(user.clone()),
None => Viewer::anonymous(),
},
Requester::Anonymous => Viewer::anonymous(),
if let Some(user) = requester.user() {
return Viewer::user(user.clone());
}
Viewer::anonymous()
}
/// Get the viewer's session
async fn viewer_session(&self, ctx: &Context<'_>) -> ViewerSession {
let requester = ctx.requester();
match requester {
Requester::BrowserSession(session) => ViewerSession::browser_session(*session.clone()),
Requester::OAuth2Session(tuple) => ViewerSession::oauth2_session(tuple.0.clone()),
Requester::Anonymous => ViewerSession::anonymous(),
if let Some(session) = requester.browser_session() {
return ViewerSession::browser_session(session.clone());
}
if let Some(session) = requester.oauth2_session() {
return ViewerSession::oauth2_session(session.clone());
}
ViewerSession::anonymous()
}
}

View File

@@ -10,7 +10,7 @@ use mas_policy::Policy;
use mas_router::UrlBuilder;
use mas_storage::{BoxClock, BoxRepository, BoxRng, RepositoryError};
use crate::{graphql::Requester, passwords::PasswordManager, Limiter, RequesterFingerprint};
use crate::{graphql::Requester, passwords::PasswordManager, Limiter};
#[async_trait::async_trait]
pub trait State {
@@ -31,8 +31,6 @@ pub trait ContextExt {
fn state(&self) -> &BoxState;
fn requester(&self) -> &Requester;
fn requester_fingerprint(&self) -> RequesterFingerprint;
}
impl ContextExt for async_graphql::Context<'_> {
@@ -43,8 +41,4 @@ impl ContextExt for async_graphql::Context<'_> {
fn requester(&self) -> &Requester {
self.data_unchecked()
}
fn requester_fingerprint(&self) -> RequesterFingerprint {
*self.data_unchecked()
}
}