Merge the GraphQL requester and requester fingerprint into a single struct
This commit is contained in:
@@ -6,7 +6,7 @@
|
||||
|
||||
#![allow(clippy::module_name_repetitions)]
|
||||
|
||||
use std::sync::Arc;
|
||||
use std::{net::IpAddr, ops::Deref, sync::Arc};
|
||||
|
||||
use async_graphql::{
|
||||
extensions::Tracing,
|
||||
@@ -240,7 +240,7 @@ async fn get_requester(
|
||||
session_info: SessionInfo,
|
||||
token: Option<&str>,
|
||||
) -> Result<Requester, RouteError> {
|
||||
let requester = if let Some(token) = token {
|
||||
let entity = if let Some(token) = token {
|
||||
// If we haven't enabled undocumented_oauth2_access on the listener, we bail out
|
||||
if !undocumented_oauth2_access {
|
||||
return Err(RouteError::InvalidToken);
|
||||
@@ -285,7 +285,7 @@ async fn get_requester(
|
||||
return Err(RouteError::MissingScope);
|
||||
}
|
||||
|
||||
Requester::OAuth2Session(Box::new((session, user)))
|
||||
RequestingEntity::OAuth2Session(Box::new((session, user)))
|
||||
} else {
|
||||
let maybe_session = session_info.load_session(&mut repo).await?;
|
||||
|
||||
@@ -295,8 +295,14 @@ async fn get_requester(
|
||||
.await;
|
||||
}
|
||||
|
||||
Requester::from(maybe_session)
|
||||
RequestingEntity::from(maybe_session)
|
||||
};
|
||||
|
||||
let requester = Requester {
|
||||
entity,
|
||||
ip_address: activity_tracker.ip(),
|
||||
};
|
||||
|
||||
repo.cancel().await?;
|
||||
Ok(requester)
|
||||
}
|
||||
@@ -312,7 +318,6 @@ pub async fn post(
|
||||
cookie_jar: CookieJar,
|
||||
content_type: Option<TypedHeader<ContentType>>,
|
||||
authorization: Option<TypedHeader<Authorization<Bearer>>>,
|
||||
requester_fingerprint: RequesterFingerprint,
|
||||
body: Body,
|
||||
) -> Result<impl IntoResponse, RouteError> {
|
||||
let body = body.into_data_stream();
|
||||
@@ -339,7 +344,6 @@ pub async fn post(
|
||||
MultipartOptions::default(),
|
||||
)
|
||||
.await?
|
||||
.data(requester_fingerprint)
|
||||
.data(requester); // XXX: this should probably return another error response?
|
||||
|
||||
let span = span_for_graphql_request(&request);
|
||||
@@ -366,7 +370,6 @@ pub async fn get(
|
||||
activity_tracker: BoundActivityTracker,
|
||||
cookie_jar: CookieJar,
|
||||
authorization: Option<TypedHeader<Authorization<Bearer>>>,
|
||||
requester_fingerprint: RequesterFingerprint,
|
||||
RawQuery(query): RawQuery,
|
||||
) -> Result<impl IntoResponse, FancyError> {
|
||||
let token = authorization
|
||||
@@ -383,9 +386,8 @@ pub async fn get(
|
||||
)
|
||||
.await?;
|
||||
|
||||
let request = async_graphql::http::parse_query_string(&query.unwrap_or_default())?
|
||||
.data(requester)
|
||||
.data(requester_fingerprint);
|
||||
let request =
|
||||
async_graphql::http::parse_query_string(&query.unwrap_or_default())?.data(requester);
|
||||
|
||||
let span = span_for_graphql_request(&request);
|
||||
let response = schema.execute(request).instrument(span).await;
|
||||
@@ -417,9 +419,32 @@ pub fn schema_builder() -> SchemaBuilder {
|
||||
.register_output_type::<CreationEvent>()
|
||||
}
|
||||
|
||||
pub struct Requester {
|
||||
entity: RequestingEntity,
|
||||
ip_address: Option<IpAddr>,
|
||||
}
|
||||
|
||||
impl Requester {
|
||||
pub fn fingerprint(&self) -> RequesterFingerprint {
|
||||
if let Some(ip) = self.ip_address {
|
||||
RequesterFingerprint::new(ip)
|
||||
} else {
|
||||
RequesterFingerprint::EMPTY
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Deref for Requester {
|
||||
type Target = RequestingEntity;
|
||||
|
||||
fn deref(&self) -> &Self::Target {
|
||||
&self.entity
|
||||
}
|
||||
}
|
||||
|
||||
/// The identity of the requester.
|
||||
#[derive(Debug, Clone, Default, PartialEq, Eq)]
|
||||
pub enum Requester {
|
||||
pub enum RequestingEntity {
|
||||
/// The requester presented no authentication information.
|
||||
#[default]
|
||||
Anonymous,
|
||||
@@ -480,7 +505,7 @@ impl OwnerId for UserId {
|
||||
}
|
||||
}
|
||||
|
||||
impl Requester {
|
||||
impl RequestingEntity {
|
||||
fn browser_session(&self) -> Option<&BrowserSession> {
|
||||
match self {
|
||||
Self::BrowserSession(session) => Some(session),
|
||||
@@ -532,17 +557,21 @@ impl Requester {
|
||||
Self::BrowserSession(_) | Self::Anonymous => false,
|
||||
}
|
||||
}
|
||||
|
||||
fn is_unauthenticated(&self) -> bool {
|
||||
matches!(self, Self::Anonymous)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<BrowserSession> for Requester {
|
||||
impl From<BrowserSession> for RequestingEntity {
|
||||
fn from(session: BrowserSession) -> Self {
|
||||
Self::BrowserSession(Box::new(session))
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<Option<T>> for Requester
|
||||
impl<T> From<Option<T>> for RequestingEntity
|
||||
where
|
||||
T: Into<Requester>,
|
||||
T: Into<RequestingEntity>,
|
||||
{
|
||||
fn from(session: Option<T>) -> Self {
|
||||
session.map(Into::into).unwrap_or_default()
|
||||
|
||||
@@ -21,7 +21,7 @@ use zeroize::Zeroizing;
|
||||
use crate::graphql::{
|
||||
model::{NodeType, User},
|
||||
state::ContextExt,
|
||||
Requester, UserId,
|
||||
UserId,
|
||||
};
|
||||
|
||||
#[derive(Default)]
|
||||
@@ -728,7 +728,7 @@ impl UserMutations {
|
||||
let state = ctx.state();
|
||||
let requester = ctx.requester();
|
||||
let clock = state.clock();
|
||||
if !matches!(requester, Requester::Anonymous) {
|
||||
if !requester.is_unauthenticated() {
|
||||
return Err(async_graphql::Error::new(
|
||||
"Account recovery is only for anonymous users.",
|
||||
));
|
||||
@@ -830,7 +830,7 @@ impl UserMutations {
|
||||
input: ResendRecoveryEmailInput,
|
||||
) -> Result<ResendRecoveryEmailPayload, async_graphql::Error> {
|
||||
let state = ctx.state();
|
||||
let requester_fingerprint = ctx.requester_fingerprint();
|
||||
let requester = ctx.requester();
|
||||
let clock = state.clock();
|
||||
let mut rng = state.rng();
|
||||
let limiter = state.limiter();
|
||||
@@ -847,7 +847,7 @@ impl UserMutations {
|
||||
.context("Could not load recovery session")?;
|
||||
|
||||
if let Err(e) =
|
||||
limiter.check_account_recovery(requester_fingerprint, &recovery_session.email)
|
||||
limiter.check_account_recovery(requester.fingerprint(), &recovery_session.email)
|
||||
{
|
||||
tracing::warn!(error = &e as &dyn std::error::Error);
|
||||
return Ok(ResendRecoveryEmailPayload::RateLimited);
|
||||
|
||||
@@ -398,7 +398,6 @@ impl UserEmailMutations {
|
||||
let state = ctx.state();
|
||||
let id = NodeType::User.extract_ulid(&input.user_id)?;
|
||||
let requester = ctx.requester();
|
||||
let requester_fingerprint = ctx.requester_fingerprint();
|
||||
let clock = state.clock();
|
||||
let mut rng = state.rng();
|
||||
|
||||
@@ -428,7 +427,7 @@ impl UserEmailMutations {
|
||||
let res = policy
|
||||
.evaluate_email(mas_policy::EmailInput {
|
||||
email: &input.email,
|
||||
requester: requester_fingerprint.into(),
|
||||
requester: requester.fingerprint().into(),
|
||||
})
|
||||
.await?;
|
||||
if !res.valid() {
|
||||
@@ -561,7 +560,6 @@ impl UserEmailMutations {
|
||||
let mut rng = state.rng();
|
||||
let clock = state.clock();
|
||||
let requester = ctx.requester();
|
||||
let requester_fingerprint = ctx.requester_fingerprint();
|
||||
let limiter = state.limiter();
|
||||
|
||||
// Only allow calling this if the requester is a browser session
|
||||
@@ -591,7 +589,7 @@ impl UserEmailMutations {
|
||||
}
|
||||
|
||||
if let Err(e) =
|
||||
limiter.check_email_authentication_email(ctx.requester_fingerprint(), &input.email)
|
||||
limiter.check_email_authentication_email(requester.fingerprint(), &input.email)
|
||||
{
|
||||
tracing::warn!(error = &e as &dyn std::error::Error);
|
||||
return Ok(StartEmailAuthenticationPayload::RateLimited);
|
||||
@@ -620,7 +618,7 @@ impl UserEmailMutations {
|
||||
let res = policy
|
||||
.evaluate_email(mas_policy::EmailInput {
|
||||
email: &input.email,
|
||||
requester: requester_fingerprint.into(),
|
||||
requester: requester.fingerprint().into(),
|
||||
})
|
||||
.await?;
|
||||
if !res.valid() {
|
||||
@@ -660,9 +658,10 @@ impl UserEmailMutations {
|
||||
let mut rng = state.rng();
|
||||
let clock = state.clock();
|
||||
let limiter = state.limiter();
|
||||
let requester = ctx.requester();
|
||||
|
||||
let id = NodeType::UserEmailAuthentication.extract_ulid(&input.id)?;
|
||||
let Some(browser_session) = ctx.requester().browser_session() else {
|
||||
let Some(browser_session) = requester.browser_session() else {
|
||||
return Err(async_graphql::Error::new("Unauthorized"));
|
||||
};
|
||||
|
||||
@@ -692,8 +691,8 @@ impl UserEmailMutations {
|
||||
return Ok(ResendEmailAuthenticationCodePayload::Completed);
|
||||
}
|
||||
|
||||
if let Err(e) = limiter
|
||||
.check_email_authentication_send_code(ctx.requester_fingerprint(), &authentication)
|
||||
if let Err(e) =
|
||||
limiter.check_email_authentication_send_code(requester.fingerprint(), &authentication)
|
||||
{
|
||||
tracing::warn!(error = &e as &dyn std::error::Error);
|
||||
return Ok(ResendEmailAuthenticationCodePayload::RateLimited);
|
||||
|
||||
@@ -9,7 +9,6 @@ use async_graphql::{Context, Object};
|
||||
use crate::graphql::{
|
||||
model::{Viewer, ViewerSession},
|
||||
state::ContextExt,
|
||||
Requester,
|
||||
};
|
||||
|
||||
#[derive(Default)]
|
||||
@@ -21,24 +20,25 @@ impl ViewerQuery {
|
||||
async fn viewer(&self, ctx: &Context<'_>) -> Viewer {
|
||||
let requester = ctx.requester();
|
||||
|
||||
match requester {
|
||||
Requester::BrowserSession(session) => Viewer::user(session.user.clone()),
|
||||
Requester::OAuth2Session(tuple) => match &tuple.1 {
|
||||
Some(user) => Viewer::user(user.clone()),
|
||||
None => Viewer::anonymous(),
|
||||
},
|
||||
Requester::Anonymous => Viewer::anonymous(),
|
||||
if let Some(user) = requester.user() {
|
||||
return Viewer::user(user.clone());
|
||||
}
|
||||
|
||||
Viewer::anonymous()
|
||||
}
|
||||
|
||||
/// Get the viewer's session
|
||||
async fn viewer_session(&self, ctx: &Context<'_>) -> ViewerSession {
|
||||
let requester = ctx.requester();
|
||||
|
||||
match requester {
|
||||
Requester::BrowserSession(session) => ViewerSession::browser_session(*session.clone()),
|
||||
Requester::OAuth2Session(tuple) => ViewerSession::oauth2_session(tuple.0.clone()),
|
||||
Requester::Anonymous => ViewerSession::anonymous(),
|
||||
if let Some(session) = requester.browser_session() {
|
||||
return ViewerSession::browser_session(session.clone());
|
||||
}
|
||||
|
||||
if let Some(session) = requester.oauth2_session() {
|
||||
return ViewerSession::oauth2_session(session.clone());
|
||||
}
|
||||
|
||||
ViewerSession::anonymous()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ use mas_policy::Policy;
|
||||
use mas_router::UrlBuilder;
|
||||
use mas_storage::{BoxClock, BoxRepository, BoxRng, RepositoryError};
|
||||
|
||||
use crate::{graphql::Requester, passwords::PasswordManager, Limiter, RequesterFingerprint};
|
||||
use crate::{graphql::Requester, passwords::PasswordManager, Limiter};
|
||||
|
||||
#[async_trait::async_trait]
|
||||
pub trait State {
|
||||
@@ -31,8 +31,6 @@ pub trait ContextExt {
|
||||
fn state(&self) -> &BoxState;
|
||||
|
||||
fn requester(&self) -> &Requester;
|
||||
|
||||
fn requester_fingerprint(&self) -> RequesterFingerprint;
|
||||
}
|
||||
|
||||
impl ContextExt for async_graphql::Context<'_> {
|
||||
@@ -43,8 +41,4 @@ impl ContextExt for async_graphql::Context<'_> {
|
||||
fn requester(&self) -> &Requester {
|
||||
self.data_unchecked()
|
||||
}
|
||||
|
||||
fn requester_fingerprint(&self) -> RequesterFingerprint {
|
||||
*self.data_unchecked()
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user