Better upstream OAuth links pagination and filtering

This commit is contained in:
Quentin Gliech
2023-07-21 19:06:21 +02:00
parent ec1a87cfda
commit 01bc5802c3
13 changed files with 256 additions and 114 deletions

View File

@@ -20,7 +20,7 @@ use chrono::{DateTime, Utc};
use mas_storage::{
compat::{CompatSessionFilter, CompatSsoLoginFilter, CompatSsoLoginRepository},
oauth2::{OAuth2SessionFilter, OAuth2SessionRepository},
upstream_oauth2::UpstreamOAuthLinkRepository,
upstream_oauth2::{UpstreamOAuthLinkFilter, UpstreamOAuthLinkRepository},
user::{BrowserSessionFilter, BrowserSessionRepository, UserEmailFilter, UserEmailRepository},
Pagination, RepositoryAccess,
};
@@ -462,7 +462,8 @@ impl User {
before: Option<String>,
#[graphql(desc = "Returns the first *n* elements from the list.")] first: Option<i32>,
#[graphql(desc = "Returns the last *n* elements from the list.")] last: Option<i32>,
) -> Result<Connection<Cursor, UpstreamOAuth2Link>, async_graphql::Error> {
) -> Result<Connection<Cursor, UpstreamOAuth2Link, PreloadedTotalCount>, async_graphql::Error>
{
let state = ctx.state();
let mut repo = state.repository().await?;
@@ -484,14 +485,24 @@ impl User {
.transpose()?;
let pagination = Pagination::try_new(before_id, after_id, first, last)?;
let page = repo
.upstream_oauth_link()
.list_paginated(&self.0, pagination)
.await?;
let filter = UpstreamOAuthLinkFilter::new().for_user(&self.0);
let page = repo.upstream_oauth_link().list(filter, pagination).await?;
// Preload the total count if requested
let count = if ctx.look_ahead().field("totalCount").exists() {
Some(repo.upstream_oauth_link().count(filter).await?)
} else {
None
};
repo.cancel().await?;
let mut connection = Connection::new(page.has_previous_page, page.has_next_page);
let mut connection = Connection::with_additional_fields(
page.has_previous_page,
page.has_next_page,
PreloadedTotalCount(count),
);
connection.edges.extend(page.edges.into_iter().map(|s| {
Edge::new(
OpaqueCursor(NodeCursor(NodeType::UpstreamOAuth2Link, s.id)),

View File

@@ -22,7 +22,7 @@ use mas_storage::{
Clock, Page, Pagination,
};
use rand::RngCore;
use sea_query::{enum_def, Expr, IntoColumnRef, PostgresQueryBuilder, Query};
use sea_query::{enum_def, Expr, PostgresQueryBuilder, Query};
use sqlx::PgConnection;
use ulid::Ulid;
use url::Url;
@@ -397,7 +397,7 @@ impl<'c> CompatSessionRepository for PgCompatSessionRepository<'c> {
}
}))
.generate_pagination(
(CompatSessions::Table, CompatSessions::CompatSessionId).into_column_ref(),
(CompatSessions::Table, CompatSessions::CompatSessionId),
pagination,
)
.build(PostgresQueryBuilder);

View File

@@ -20,7 +20,7 @@ use mas_storage::{
Clock, Page, Pagination,
};
use rand::RngCore;
use sea_query::{enum_def, Expr, IntoColumnRef, PostgresQueryBuilder, Query};
use sea_query::{enum_def, Expr, PostgresQueryBuilder, Query};
use sqlx::PgConnection;
use ulid::Ulid;
use url::Url;
@@ -377,7 +377,7 @@ impl<'c> CompatSsoLoginRepository for PgCompatSsoLoginRepository<'c> {
}
}))
.generate_pagination(
(CompatSsoLogins::Table, CompatSsoLogins::CompatSsoLoginId).into_column_ref(),
(CompatSsoLogins::Table, CompatSsoLogins::CompatSsoLoginId),
pagination,
)
.build(PostgresQueryBuilder);

View File

@@ -93,3 +93,16 @@ pub enum UpstreamOAuthProviders {
CreatedAt,
ClaimsImports,
}
#[derive(sea_query::Iden)]
#[iden = "upstream_oauth_links"]
pub enum UpstreamOAuthLinks {
Table,
#[iden = "upstream_oauth_link_id"]
UpstreamOAuthLinkId,
#[iden = "upstream_oauth_provider_id"]
UpstreamOAuthProviderId,
UserId,
Subject,
CreatedAt,
}

View File

@@ -21,7 +21,7 @@ use mas_storage::{
};
use oauth2_types::scope::Scope;
use rand::RngCore;
use sea_query::{enum_def, Expr, IntoColumnRef, PostgresQueryBuilder, Query};
use sea_query::{enum_def, Expr, PostgresQueryBuilder, Query};
use sqlx::PgConnection;
use ulid::Ulid;
use uuid::Uuid;
@@ -288,7 +288,7 @@ impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> {
}
}))
.generate_pagination(
(OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId).into_column_ref(),
(OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId),
pagination,
)
.build(PostgresQueryBuilder);

View File

@@ -15,75 +15,29 @@
//! Utilities to manage paginated queries.
use mas_storage::{pagination::PaginationDirection, Pagination};
use sqlx::{Database, QueryBuilder};
use sea_query::IntoColumnRef;
use uuid::Uuid;
/// An extension trait to the `sqlx` [`QueryBuilder`], to help adding pagination
/// to a query
pub trait QueryBuilderExt {
type Iden;
/// Add cursor-based pagination to a query, as used in paginated GraphQL
/// connections
fn generate_pagination(&mut self, id_field: Self::Iden, pagination: Pagination) -> &mut Self;
}
impl<'a, DB> QueryBuilderExt for QueryBuilder<'a, DB>
where
DB: Database,
Uuid: sqlx::Type<DB> + sqlx::Encode<'a, DB>,
i64: sqlx::Type<DB> + sqlx::Encode<'a, DB>,
{
type Iden = &'static str;
fn generate_pagination(&mut self, id_field: &'static str, pagination: Pagination) -> &mut Self {
// ref: https://github.com/graphql/graphql-relay-js/issues/94#issuecomment-232410564
// 1. Start from the greedy query: SELECT * FROM table
// 2. If the after argument is provided, add `id > parsed_cursor` to the `WHERE`
// clause
if let Some(after) = pagination.after {
self.push(" AND ")
.push(id_field)
.push(" > ")
.push_bind(Uuid::from(after));
}
// 3. If the before argument is provided, add `id < parsed_cursor` to the
// `WHERE` clause
if let Some(before) = pagination.before {
self.push(" AND ")
.push(id_field)
.push(" < ")
.push_bind(Uuid::from(before));
}
match pagination.direction {
// 4. If the first argument is provided, add `ORDER BY id ASC LIMIT first+1` to the
// query
PaginationDirection::Forward => {
self.push(" ORDER BY ")
.push(id_field)
.push(" ASC LIMIT ")
.push_bind((pagination.count + 1) as i64);
}
// 5. If the first argument is provided, add `ORDER BY id DESC LIMIT last+1` to the
// query
PaginationDirection::Backward => {
self.push(" ORDER BY ")
.push(id_field)
.push(" DESC LIMIT ")
.push_bind((pagination.count + 1) as i64);
}
};
self
}
fn generate_pagination<C: IntoColumnRef>(
&mut self,
column: C,
pagination: Pagination,
) -> &mut Self;
}
impl QueryBuilderExt for sea_query::SelectStatement {
type Iden = sea_query::ColumnRef;
fn generate_pagination(&mut self, id_field: Self::Iden, pagination: Pagination) -> &mut Self {
fn generate_pagination<C: IntoColumnRef>(
&mut self,
column: C,
pagination: Pagination,
) -> &mut Self {
let id_field = column.into_column_ref();
// ref: https://github.com/graphql/graphql-relay-js/issues/94#issuecomment-232410564
// 1. Start from the greedy query: SELECT * FROM table

View File

@@ -15,13 +15,20 @@
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use mas_data_model::{UpstreamOAuthLink, UpstreamOAuthProvider, User};
use mas_storage::{upstream_oauth2::UpstreamOAuthLinkRepository, Clock, Page, Pagination};
use mas_storage::{
upstream_oauth2::{UpstreamOAuthLinkFilter, UpstreamOAuthLinkRepository},
Clock, Page, Pagination,
};
use rand::RngCore;
use sqlx::{PgConnection, QueryBuilder};
use sea_query::{enum_def, Expr, PostgresQueryBuilder, Query};
use sqlx::PgConnection;
use ulid::Ulid;
use uuid::Uuid;
use crate::{pagination::QueryBuilderExt, tracing::ExecuteExt, DatabaseError};
use crate::{
iden::UpstreamOAuthLinks, pagination::QueryBuilderExt, sea_query_sqlx::map_values,
tracing::ExecuteExt, DatabaseError,
};
/// An implementation of [`UpstreamOAuthLinkRepository`] for a PostgreSQL
/// connection
@@ -38,6 +45,7 @@ impl<'c> PgUpstreamOAuthLinkRepository<'c> {
}
#[derive(sqlx::FromRow)]
#[enum_def]
struct LinkLookup {
upstream_oauth_link_id: Uuid,
upstream_oauth_provider_id: Uuid,
@@ -221,44 +229,118 @@ impl<'c> UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'c> {
}
#[tracing::instrument(
name = "db.upstream_oauth_link.list_paginated",
name = "db.upstream_oauth_link.list",
skip_all,
fields(
db.statement,
%user.id,
%user.username,
),
err
err,
)]
async fn list_paginated(
async fn list(
&mut self,
user: &User,
filter: UpstreamOAuthLinkFilter<'_>,
pagination: Pagination,
) -> Result<Page<UpstreamOAuthLink>, Self::Error> {
let mut query = QueryBuilder::new(
r#"
SELECT
upstream_oauth_link_id,
upstream_oauth_provider_id,
user_id,
subject,
created_at
FROM upstream_oauth_links
"#,
);
) -> Result<Page<UpstreamOAuthLink>, DatabaseError> {
let (sql, values) = Query::select()
.expr_as(
Expr::col((
UpstreamOAuthLinks::Table,
UpstreamOAuthLinks::UpstreamOAuthLinkId,
)),
LinkLookupIden::UpstreamOauthLinkId,
)
.expr_as(
Expr::col((
UpstreamOAuthLinks::Table,
UpstreamOAuthLinks::UpstreamOAuthProviderId,
)),
LinkLookupIden::UpstreamOauthProviderId,
)
.expr_as(
Expr::col((UpstreamOAuthLinks::Table, UpstreamOAuthLinks::UserId)),
LinkLookupIden::UserId,
)
.expr_as(
Expr::col((UpstreamOAuthLinks::Table, UpstreamOAuthLinks::Subject)),
LinkLookupIden::Subject,
)
.expr_as(
Expr::col((UpstreamOAuthLinks::Table, UpstreamOAuthLinks::CreatedAt)),
LinkLookupIden::CreatedAt,
)
.from(UpstreamOAuthLinks::Table)
.and_where_option(filter.user().map(|user| {
Expr::col((UpstreamOAuthLinks::Table, UpstreamOAuthLinks::UserId))
.eq(Uuid::from(user.id))
}))
.and_where_option(filter.provider().map(|provider| {
Expr::col((
UpstreamOAuthLinks::Table,
UpstreamOAuthLinks::UpstreamOAuthProviderId,
))
.eq(Uuid::from(provider.id))
}))
.generate_pagination(
(
UpstreamOAuthLinks::Table,
UpstreamOAuthLinks::UpstreamOAuthLinkId,
),
pagination,
)
.build(PostgresQueryBuilder);
query
.push(" WHERE user_id = ")
.push_bind(Uuid::from(user.id))
.generate_pagination("upstream_oauth_link_id", pagination);
let arguments = map_values(values);
let edges: Vec<LinkLookup> = query
.build_query_as()
let edges: Vec<LinkLookup> = sqlx::query_as_with(&sql, arguments)
.traced()
.fetch_all(&mut *self.conn)
.await?;
let page = pagination.process(edges).map(UpstreamOAuthLink::from);
Ok(page)
}
#[tracing::instrument(
name = "db.upstream_oauth_link.count",
skip_all,
fields(
db.statement,
),
err,
)]
async fn count(&mut self, filter: UpstreamOAuthLinkFilter<'_>) -> Result<usize, Self::Error> {
let (sql, values) = Query::select()
.expr(
Expr::col((
UpstreamOAuthLinks::Table,
UpstreamOAuthLinks::UpstreamOAuthLinkId,
))
.count(),
)
.from(UpstreamOAuthLinks::Table)
.and_where_option(filter.user().map(|user| {
Expr::col((UpstreamOAuthLinks::Table, UpstreamOAuthLinks::UserId))
.eq(Uuid::from(user.id))
}))
.and_where_option(filter.provider().map(|provider| {
Expr::col((
UpstreamOAuthLinks::Table,
UpstreamOAuthLinks::UpstreamOAuthProviderId,
))
.eq(Uuid::from(provider.id))
}))
.build(PostgresQueryBuilder);
let arguments = map_values(values);
let count: i64 = sqlx::query_scalar_with(&sql, arguments)
.traced()
.fetch_one(&mut *self.conn)
.await?;
count
.try_into()
.map_err(DatabaseError::to_invalid_operation)
}
}

View File

@@ -31,7 +31,7 @@ mod tests {
use mas_storage::{
clock::MockClock,
upstream_oauth2::{
UpstreamOAuthLinkRepository, UpstreamOAuthProviderFilter,
UpstreamOAuthLinkFilter, UpstreamOAuthLinkRepository, UpstreamOAuthProviderFilter,
UpstreamOAuthProviderRepository, UpstreamOAuthSessionRepository,
},
user::UserRepository,
@@ -177,9 +177,14 @@ mod tests {
.await
.unwrap();
// XXX: we should also try other combinations of the filter
let filter = UpstreamOAuthLinkFilter::new()
.for_user(&user)
.for_provider(&provider);
let links = repo
.upstream_oauth_link()
.list_paginated(&user, Pagination::first(10))
.list(filter, Pagination::first(10))
.await
.unwrap();
assert!(!links.has_previous_page);
@@ -188,6 +193,8 @@ mod tests {
assert_eq!(links.edges[0].id, link.id);
assert_eq!(links.edges[0].user_id, Some(user.id));
assert_eq!(repo.upstream_oauth_link().count(filter).await.unwrap(), 1);
// Try deleting the provider
repo.upstream_oauth_provider()
.delete(provider)

View File

@@ -22,7 +22,7 @@ use mas_storage::{
};
use oauth2_types::scope::Scope;
use rand::RngCore;
use sea_query::{enum_def, Expr, IntoColumnRef, PostgresQueryBuilder, Query};
use sea_query::{enum_def, Expr, PostgresQueryBuilder, Query};
use sqlx::{types::Json, PgConnection};
use tracing::{info_span, Instrument};
use ulid::Ulid;
@@ -439,8 +439,7 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
(
UpstreamOAuthProviders::Table,
UpstreamOAuthProviders::UpstreamOAuthProviderId,
)
.into_column_ref(),
),
pagination,
)
.build(PostgresQueryBuilder);

View File

@@ -20,7 +20,7 @@ use mas_storage::{
Clock, Page, Pagination,
};
use rand::RngCore;
use sea_query::{enum_def, Expr, IntoColumnRef, PostgresQueryBuilder, Query};
use sea_query::{enum_def, Expr, PostgresQueryBuilder, Query};
use sqlx::PgConnection;
use tracing::{info_span, Instrument};
use ulid::Ulid;
@@ -275,10 +275,7 @@ impl<'c> UserEmailRepository for PgUserEmailRepository<'c> {
Expr::col((UserEmails::Table, UserEmails::ConfirmedAt)).is_null()
}
}))
.generate_pagination(
(UserEmails::Table, UserEmails::UserEmailId).into_column_ref(),
pagination,
)
.generate_pagination((UserEmails::Table, UserEmails::UserEmailId), pagination)
.build(PostgresQueryBuilder);
let arguments = map_values(values);

View File

@@ -17,7 +17,7 @@ use chrono::{DateTime, Utc};
use mas_data_model::{Authentication, BrowserSession, Password, UpstreamOAuthLink, User};
use mas_storage::{user::BrowserSessionRepository, Clock, Page, Pagination};
use rand::RngCore;
use sea_query::{Expr, IntoColumnRef, PostgresQueryBuilder};
use sea_query::{Expr, PostgresQueryBuilder};
use sqlx::PgConnection;
use ulid::Ulid;
use uuid::Uuid;
@@ -251,7 +251,7 @@ impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> {
}
}))
.generate_pagination(
(UserSessions::Table, UserSessions::UserSessionId).into_column_ref(),
(UserSessions::Table, UserSessions::UserSessionId),
pagination,
)
.build(PostgresQueryBuilder);

View File

@@ -19,6 +19,52 @@ use ulid::Ulid;
use crate::{pagination::Page, repository_impl, Clock, Pagination};
/// Filter parameters for listing upstream OAuth links
#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)]
pub struct UpstreamOAuthLinkFilter<'a> {
// XXX: we might also want to filter for links without a user linked to them
user: Option<&'a User>,
provider: Option<&'a UpstreamOAuthProvider>,
}
impl<'a> UpstreamOAuthLinkFilter<'a> {
/// Create a new [`UpstreamOAuthLinkFilter`] with default values
#[must_use]
pub fn new() -> Self {
Self::default()
}
/// Set the user who owns the upstream OAuth links
#[must_use]
pub fn for_user(mut self, user: &'a User) -> Self {
self.user = Some(user);
self
}
/// Get the user filter
///
/// Returns [`None`] if no filter was set
#[must_use]
pub fn user(&self) -> Option<&User> {
self.user
}
/// Set the upstream OAuth provider for which to list links
#[must_use]
pub fn for_provider(mut self, provider: &'a UpstreamOAuthProvider) -> Self {
self.provider = Some(provider);
self
}
/// Get the upstream OAuth provider filter
///
/// Returns [`None`] if no filter was set
#[must_use]
pub fn provider(&self) -> Option<&UpstreamOAuthProvider> {
self.provider
}
}
/// An [`UpstreamOAuthLinkRepository`] helps interacting with
/// [`UpstreamOAuthLink`] with the storage backend
#[async_trait]
@@ -109,11 +155,42 @@ pub trait UpstreamOAuthLinkRepository: Send + Sync {
/// # Errors
///
/// Returns [`Self::Error`] if the underlying repository fails
#[deprecated(note = "Use `list` instead")]
async fn list_paginated(
&mut self,
user: &User,
pagination: Pagination,
) -> Result<Page<UpstreamOAuthLink>, Self::Error> {
self.list(UpstreamOAuthLinkFilter::new().for_user(user), pagination)
.await
}
/// List [`UpstreamOAuthLink`] with the given filter and pagination
///
/// # Parameters
///
/// * `filter`: The filter to apply
/// * `pagination`: The pagination parameters
///
/// # Errors
///
/// Returns [`Self::Error`] if the underlying repository fails
async fn list(
&mut self,
filter: UpstreamOAuthLinkFilter<'_>,
pagination: Pagination,
) -> Result<Page<UpstreamOAuthLink>, Self::Error>;
/// Count the number of [`UpstreamOAuthLink`] with the given filter
///
/// # Parameters
///
/// * `filter`: The filter to apply
///
/// # Errors
///
/// Returns [`Self::Error`] if the underlying repository fails
async fn count(&mut self, filter: UpstreamOAuthLinkFilter<'_>) -> Result<usize, Self::Error>;
}
repository_impl!(UpstreamOAuthLinkRepository:
@@ -139,9 +216,11 @@ repository_impl!(UpstreamOAuthLinkRepository:
user: &User,
) -> Result<(), Self::Error>;
async fn list_paginated(
async fn list(
&mut self,
user: &User,
filter: UpstreamOAuthLinkFilter<'_>,
pagination: Pagination,
) -> Result<Page<UpstreamOAuthLink>, Self::Error>;
async fn count(&mut self, filter: UpstreamOAuthLinkFilter<'_>) -> Result<usize, Self::Error>;
);

View File

@@ -20,7 +20,7 @@ mod provider;
mod session;
pub use self::{
link::UpstreamOAuthLinkRepository,
link::{UpstreamOAuthLinkFilter, UpstreamOAuthLinkRepository},
provider::{UpstreamOAuthProviderFilter, UpstreamOAuthProviderRepository},
session::UpstreamOAuthSessionRepository,
};