From 01bc5802c326eec792b30c5dd59325848ab710eb Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Fri, 21 Jul 2023 19:06:21 +0200 Subject: [PATCH] Better upstream OAuth links pagination and filtering --- crates/graphql/src/model/users.rs | 25 +++- crates/storage-pg/src/compat/session.rs | 4 +- crates/storage-pg/src/compat/sso_login.rs | 4 +- crates/storage-pg/src/iden.rs | 13 ++ crates/storage-pg/src/oauth2/session.rs | 4 +- crates/storage-pg/src/pagination.rs | 72 ++-------- crates/storage-pg/src/upstream_oauth2/link.rs | 136 ++++++++++++++---- crates/storage-pg/src/upstream_oauth2/mod.rs | 11 +- .../src/upstream_oauth2/provider.rs | 5 +- crates/storage-pg/src/user/email.rs | 7 +- crates/storage-pg/src/user/session.rs | 4 +- crates/storage/src/upstream_oauth2/link.rs | 83 ++++++++++- crates/storage/src/upstream_oauth2/mod.rs | 2 +- 13 files changed, 256 insertions(+), 114 deletions(-) diff --git a/crates/graphql/src/model/users.rs b/crates/graphql/src/model/users.rs index 45b2c922c..7fbfdd8d0 100644 --- a/crates/graphql/src/model/users.rs +++ b/crates/graphql/src/model/users.rs @@ -20,7 +20,7 @@ use chrono::{DateTime, Utc}; use mas_storage::{ compat::{CompatSessionFilter, CompatSsoLoginFilter, CompatSsoLoginRepository}, oauth2::{OAuth2SessionFilter, OAuth2SessionRepository}, - upstream_oauth2::UpstreamOAuthLinkRepository, + upstream_oauth2::{UpstreamOAuthLinkFilter, UpstreamOAuthLinkRepository}, user::{BrowserSessionFilter, BrowserSessionRepository, UserEmailFilter, UserEmailRepository}, Pagination, RepositoryAccess, }; @@ -462,7 +462,8 @@ impl User { before: Option, #[graphql(desc = "Returns the first *n* elements from the list.")] first: Option, #[graphql(desc = "Returns the last *n* elements from the list.")] last: Option, - ) -> Result, async_graphql::Error> { + ) -> Result, async_graphql::Error> + { let state = ctx.state(); let mut repo = state.repository().await?; @@ -484,14 +485,24 @@ impl User { .transpose()?; let pagination = Pagination::try_new(before_id, after_id, first, last)?; - let page = repo - .upstream_oauth_link() - .list_paginated(&self.0, pagination) - .await?; + let filter = UpstreamOAuthLinkFilter::new().for_user(&self.0); + + let page = repo.upstream_oauth_link().list(filter, pagination).await?; + + // Preload the total count if requested + let count = if ctx.look_ahead().field("totalCount").exists() { + Some(repo.upstream_oauth_link().count(filter).await?) + } else { + None + }; repo.cancel().await?; - let mut connection = Connection::new(page.has_previous_page, page.has_next_page); + let mut connection = Connection::with_additional_fields( + page.has_previous_page, + page.has_next_page, + PreloadedTotalCount(count), + ); connection.edges.extend(page.edges.into_iter().map(|s| { Edge::new( OpaqueCursor(NodeCursor(NodeType::UpstreamOAuth2Link, s.id)), diff --git a/crates/storage-pg/src/compat/session.rs b/crates/storage-pg/src/compat/session.rs index 91b8d33a5..de3e673c4 100644 --- a/crates/storage-pg/src/compat/session.rs +++ b/crates/storage-pg/src/compat/session.rs @@ -22,7 +22,7 @@ use mas_storage::{ Clock, Page, Pagination, }; use rand::RngCore; -use sea_query::{enum_def, Expr, IntoColumnRef, PostgresQueryBuilder, Query}; +use sea_query::{enum_def, Expr, PostgresQueryBuilder, Query}; use sqlx::PgConnection; use ulid::Ulid; use url::Url; @@ -397,7 +397,7 @@ impl<'c> CompatSessionRepository for PgCompatSessionRepository<'c> { } })) .generate_pagination( - (CompatSessions::Table, CompatSessions::CompatSessionId).into_column_ref(), + (CompatSessions::Table, CompatSessions::CompatSessionId), pagination, ) .build(PostgresQueryBuilder); diff --git a/crates/storage-pg/src/compat/sso_login.rs b/crates/storage-pg/src/compat/sso_login.rs index b201ed225..11ab2fa44 100644 --- a/crates/storage-pg/src/compat/sso_login.rs +++ b/crates/storage-pg/src/compat/sso_login.rs @@ -20,7 +20,7 @@ use mas_storage::{ Clock, Page, Pagination, }; use rand::RngCore; -use sea_query::{enum_def, Expr, IntoColumnRef, PostgresQueryBuilder, Query}; +use sea_query::{enum_def, Expr, PostgresQueryBuilder, Query}; use sqlx::PgConnection; use ulid::Ulid; use url::Url; @@ -377,7 +377,7 @@ impl<'c> CompatSsoLoginRepository for PgCompatSsoLoginRepository<'c> { } })) .generate_pagination( - (CompatSsoLogins::Table, CompatSsoLogins::CompatSsoLoginId).into_column_ref(), + (CompatSsoLogins::Table, CompatSsoLogins::CompatSsoLoginId), pagination, ) .build(PostgresQueryBuilder); diff --git a/crates/storage-pg/src/iden.rs b/crates/storage-pg/src/iden.rs index 42856d2b8..29978ab8a 100644 --- a/crates/storage-pg/src/iden.rs +++ b/crates/storage-pg/src/iden.rs @@ -93,3 +93,16 @@ pub enum UpstreamOAuthProviders { CreatedAt, ClaimsImports, } + +#[derive(sea_query::Iden)] +#[iden = "upstream_oauth_links"] +pub enum UpstreamOAuthLinks { + Table, + #[iden = "upstream_oauth_link_id"] + UpstreamOAuthLinkId, + #[iden = "upstream_oauth_provider_id"] + UpstreamOAuthProviderId, + UserId, + Subject, + CreatedAt, +} diff --git a/crates/storage-pg/src/oauth2/session.rs b/crates/storage-pg/src/oauth2/session.rs index 825008550..e1f59835b 100644 --- a/crates/storage-pg/src/oauth2/session.rs +++ b/crates/storage-pg/src/oauth2/session.rs @@ -21,7 +21,7 @@ use mas_storage::{ }; use oauth2_types::scope::Scope; use rand::RngCore; -use sea_query::{enum_def, Expr, IntoColumnRef, PostgresQueryBuilder, Query}; +use sea_query::{enum_def, Expr, PostgresQueryBuilder, Query}; use sqlx::PgConnection; use ulid::Ulid; use uuid::Uuid; @@ -288,7 +288,7 @@ impl<'c> OAuth2SessionRepository for PgOAuth2SessionRepository<'c> { } })) .generate_pagination( - (OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId).into_column_ref(), + (OAuth2Sessions::Table, OAuth2Sessions::OAuth2SessionId), pagination, ) .build(PostgresQueryBuilder); diff --git a/crates/storage-pg/src/pagination.rs b/crates/storage-pg/src/pagination.rs index 06c8aff62..1c077af59 100644 --- a/crates/storage-pg/src/pagination.rs +++ b/crates/storage-pg/src/pagination.rs @@ -15,75 +15,29 @@ //! Utilities to manage paginated queries. use mas_storage::{pagination::PaginationDirection, Pagination}; -use sqlx::{Database, QueryBuilder}; +use sea_query::IntoColumnRef; use uuid::Uuid; /// An extension trait to the `sqlx` [`QueryBuilder`], to help adding pagination /// to a query pub trait QueryBuilderExt { - type Iden; - /// Add cursor-based pagination to a query, as used in paginated GraphQL /// connections - fn generate_pagination(&mut self, id_field: Self::Iden, pagination: Pagination) -> &mut Self; -} - -impl<'a, DB> QueryBuilderExt for QueryBuilder<'a, DB> -where - DB: Database, - Uuid: sqlx::Type + sqlx::Encode<'a, DB>, - i64: sqlx::Type + sqlx::Encode<'a, DB>, -{ - type Iden = &'static str; - - fn generate_pagination(&mut self, id_field: &'static str, pagination: Pagination) -> &mut Self { - // ref: https://github.com/graphql/graphql-relay-js/issues/94#issuecomment-232410564 - // 1. Start from the greedy query: SELECT * FROM table - - // 2. If the after argument is provided, add `id > parsed_cursor` to the `WHERE` - // clause - if let Some(after) = pagination.after { - self.push(" AND ") - .push(id_field) - .push(" > ") - .push_bind(Uuid::from(after)); - } - - // 3. If the before argument is provided, add `id < parsed_cursor` to the - // `WHERE` clause - if let Some(before) = pagination.before { - self.push(" AND ") - .push(id_field) - .push(" < ") - .push_bind(Uuid::from(before)); - } - - match pagination.direction { - // 4. If the first argument is provided, add `ORDER BY id ASC LIMIT first+1` to the - // query - PaginationDirection::Forward => { - self.push(" ORDER BY ") - .push(id_field) - .push(" ASC LIMIT ") - .push_bind((pagination.count + 1) as i64); - } - // 5. If the first argument is provided, add `ORDER BY id DESC LIMIT last+1` to the - // query - PaginationDirection::Backward => { - self.push(" ORDER BY ") - .push(id_field) - .push(" DESC LIMIT ") - .push_bind((pagination.count + 1) as i64); - } - }; - - self - } + fn generate_pagination( + &mut self, + column: C, + pagination: Pagination, + ) -> &mut Self; } impl QueryBuilderExt for sea_query::SelectStatement { - type Iden = sea_query::ColumnRef; - fn generate_pagination(&mut self, id_field: Self::Iden, pagination: Pagination) -> &mut Self { + fn generate_pagination( + &mut self, + column: C, + pagination: Pagination, + ) -> &mut Self { + let id_field = column.into_column_ref(); + // ref: https://github.com/graphql/graphql-relay-js/issues/94#issuecomment-232410564 // 1. Start from the greedy query: SELECT * FROM table diff --git a/crates/storage-pg/src/upstream_oauth2/link.rs b/crates/storage-pg/src/upstream_oauth2/link.rs index b1ba31213..a0f6a2f8e 100644 --- a/crates/storage-pg/src/upstream_oauth2/link.rs +++ b/crates/storage-pg/src/upstream_oauth2/link.rs @@ -15,13 +15,20 @@ use async_trait::async_trait; use chrono::{DateTime, Utc}; use mas_data_model::{UpstreamOAuthLink, UpstreamOAuthProvider, User}; -use mas_storage::{upstream_oauth2::UpstreamOAuthLinkRepository, Clock, Page, Pagination}; +use mas_storage::{ + upstream_oauth2::{UpstreamOAuthLinkFilter, UpstreamOAuthLinkRepository}, + Clock, Page, Pagination, +}; use rand::RngCore; -use sqlx::{PgConnection, QueryBuilder}; +use sea_query::{enum_def, Expr, PostgresQueryBuilder, Query}; +use sqlx::PgConnection; use ulid::Ulid; use uuid::Uuid; -use crate::{pagination::QueryBuilderExt, tracing::ExecuteExt, DatabaseError}; +use crate::{ + iden::UpstreamOAuthLinks, pagination::QueryBuilderExt, sea_query_sqlx::map_values, + tracing::ExecuteExt, DatabaseError, +}; /// An implementation of [`UpstreamOAuthLinkRepository`] for a PostgreSQL /// connection @@ -38,6 +45,7 @@ impl<'c> PgUpstreamOAuthLinkRepository<'c> { } #[derive(sqlx::FromRow)] +#[enum_def] struct LinkLookup { upstream_oauth_link_id: Uuid, upstream_oauth_provider_id: Uuid, @@ -221,44 +229,118 @@ impl<'c> UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'c> { } #[tracing::instrument( - name = "db.upstream_oauth_link.list_paginated", + name = "db.upstream_oauth_link.list", skip_all, fields( db.statement, - %user.id, - %user.username, ), - err + err, )] - async fn list_paginated( + async fn list( &mut self, - user: &User, + filter: UpstreamOAuthLinkFilter<'_>, pagination: Pagination, - ) -> Result, Self::Error> { - let mut query = QueryBuilder::new( - r#" - SELECT - upstream_oauth_link_id, - upstream_oauth_provider_id, - user_id, - subject, - created_at - FROM upstream_oauth_links - "#, - ); + ) -> Result, DatabaseError> { + let (sql, values) = Query::select() + .expr_as( + Expr::col(( + UpstreamOAuthLinks::Table, + UpstreamOAuthLinks::UpstreamOAuthLinkId, + )), + LinkLookupIden::UpstreamOauthLinkId, + ) + .expr_as( + Expr::col(( + UpstreamOAuthLinks::Table, + UpstreamOAuthLinks::UpstreamOAuthProviderId, + )), + LinkLookupIden::UpstreamOauthProviderId, + ) + .expr_as( + Expr::col((UpstreamOAuthLinks::Table, UpstreamOAuthLinks::UserId)), + LinkLookupIden::UserId, + ) + .expr_as( + Expr::col((UpstreamOAuthLinks::Table, UpstreamOAuthLinks::Subject)), + LinkLookupIden::Subject, + ) + .expr_as( + Expr::col((UpstreamOAuthLinks::Table, UpstreamOAuthLinks::CreatedAt)), + LinkLookupIden::CreatedAt, + ) + .from(UpstreamOAuthLinks::Table) + .and_where_option(filter.user().map(|user| { + Expr::col((UpstreamOAuthLinks::Table, UpstreamOAuthLinks::UserId)) + .eq(Uuid::from(user.id)) + })) + .and_where_option(filter.provider().map(|provider| { + Expr::col(( + UpstreamOAuthLinks::Table, + UpstreamOAuthLinks::UpstreamOAuthProviderId, + )) + .eq(Uuid::from(provider.id)) + })) + .generate_pagination( + ( + UpstreamOAuthLinks::Table, + UpstreamOAuthLinks::UpstreamOAuthLinkId, + ), + pagination, + ) + .build(PostgresQueryBuilder); - query - .push(" WHERE user_id = ") - .push_bind(Uuid::from(user.id)) - .generate_pagination("upstream_oauth_link_id", pagination); + let arguments = map_values(values); - let edges: Vec = query - .build_query_as() + let edges: Vec = sqlx::query_as_with(&sql, arguments) .traced() .fetch_all(&mut *self.conn) .await?; let page = pagination.process(edges).map(UpstreamOAuthLink::from); + Ok(page) } + + #[tracing::instrument( + name = "db.upstream_oauth_link.count", + skip_all, + fields( + db.statement, + ), + err, + )] + async fn count(&mut self, filter: UpstreamOAuthLinkFilter<'_>) -> Result { + let (sql, values) = Query::select() + .expr( + Expr::col(( + UpstreamOAuthLinks::Table, + UpstreamOAuthLinks::UpstreamOAuthLinkId, + )) + .count(), + ) + .from(UpstreamOAuthLinks::Table) + .and_where_option(filter.user().map(|user| { + Expr::col((UpstreamOAuthLinks::Table, UpstreamOAuthLinks::UserId)) + .eq(Uuid::from(user.id)) + })) + .and_where_option(filter.provider().map(|provider| { + Expr::col(( + UpstreamOAuthLinks::Table, + UpstreamOAuthLinks::UpstreamOAuthProviderId, + )) + .eq(Uuid::from(provider.id)) + })) + .build(PostgresQueryBuilder); + + let arguments = map_values(values); + + let count: i64 = sqlx::query_scalar_with(&sql, arguments) + .traced() + .fetch_one(&mut *self.conn) + .await?; + + count + .try_into() + .map_err(DatabaseError::to_invalid_operation) + } } diff --git a/crates/storage-pg/src/upstream_oauth2/mod.rs b/crates/storage-pg/src/upstream_oauth2/mod.rs index 80b084d0b..02455684b 100644 --- a/crates/storage-pg/src/upstream_oauth2/mod.rs +++ b/crates/storage-pg/src/upstream_oauth2/mod.rs @@ -31,7 +31,7 @@ mod tests { use mas_storage::{ clock::MockClock, upstream_oauth2::{ - UpstreamOAuthLinkRepository, UpstreamOAuthProviderFilter, + UpstreamOAuthLinkFilter, UpstreamOAuthLinkRepository, UpstreamOAuthProviderFilter, UpstreamOAuthProviderRepository, UpstreamOAuthSessionRepository, }, user::UserRepository, @@ -177,9 +177,14 @@ mod tests { .await .unwrap(); + // XXX: we should also try other combinations of the filter + let filter = UpstreamOAuthLinkFilter::new() + .for_user(&user) + .for_provider(&provider); + let links = repo .upstream_oauth_link() - .list_paginated(&user, Pagination::first(10)) + .list(filter, Pagination::first(10)) .await .unwrap(); assert!(!links.has_previous_page); @@ -188,6 +193,8 @@ mod tests { assert_eq!(links.edges[0].id, link.id); assert_eq!(links.edges[0].user_id, Some(user.id)); + assert_eq!(repo.upstream_oauth_link().count(filter).await.unwrap(), 1); + // Try deleting the provider repo.upstream_oauth_provider() .delete(provider) diff --git a/crates/storage-pg/src/upstream_oauth2/provider.rs b/crates/storage-pg/src/upstream_oauth2/provider.rs index 0954cadfa..2bc713d4b 100644 --- a/crates/storage-pg/src/upstream_oauth2/provider.rs +++ b/crates/storage-pg/src/upstream_oauth2/provider.rs @@ -22,7 +22,7 @@ use mas_storage::{ }; use oauth2_types::scope::Scope; use rand::RngCore; -use sea_query::{enum_def, Expr, IntoColumnRef, PostgresQueryBuilder, Query}; +use sea_query::{enum_def, Expr, PostgresQueryBuilder, Query}; use sqlx::{types::Json, PgConnection}; use tracing::{info_span, Instrument}; use ulid::Ulid; @@ -439,8 +439,7 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<' ( UpstreamOAuthProviders::Table, UpstreamOAuthProviders::UpstreamOAuthProviderId, - ) - .into_column_ref(), + ), pagination, ) .build(PostgresQueryBuilder); diff --git a/crates/storage-pg/src/user/email.rs b/crates/storage-pg/src/user/email.rs index a9a363147..83106aaf5 100644 --- a/crates/storage-pg/src/user/email.rs +++ b/crates/storage-pg/src/user/email.rs @@ -20,7 +20,7 @@ use mas_storage::{ Clock, Page, Pagination, }; use rand::RngCore; -use sea_query::{enum_def, Expr, IntoColumnRef, PostgresQueryBuilder, Query}; +use sea_query::{enum_def, Expr, PostgresQueryBuilder, Query}; use sqlx::PgConnection; use tracing::{info_span, Instrument}; use ulid::Ulid; @@ -275,10 +275,7 @@ impl<'c> UserEmailRepository for PgUserEmailRepository<'c> { Expr::col((UserEmails::Table, UserEmails::ConfirmedAt)).is_null() } })) - .generate_pagination( - (UserEmails::Table, UserEmails::UserEmailId).into_column_ref(), - pagination, - ) + .generate_pagination((UserEmails::Table, UserEmails::UserEmailId), pagination) .build(PostgresQueryBuilder); let arguments = map_values(values); diff --git a/crates/storage-pg/src/user/session.rs b/crates/storage-pg/src/user/session.rs index a71ace316..b6458a933 100644 --- a/crates/storage-pg/src/user/session.rs +++ b/crates/storage-pg/src/user/session.rs @@ -17,7 +17,7 @@ use chrono::{DateTime, Utc}; use mas_data_model::{Authentication, BrowserSession, Password, UpstreamOAuthLink, User}; use mas_storage::{user::BrowserSessionRepository, Clock, Page, Pagination}; use rand::RngCore; -use sea_query::{Expr, IntoColumnRef, PostgresQueryBuilder}; +use sea_query::{Expr, PostgresQueryBuilder}; use sqlx::PgConnection; use ulid::Ulid; use uuid::Uuid; @@ -251,7 +251,7 @@ impl<'c> BrowserSessionRepository for PgBrowserSessionRepository<'c> { } })) .generate_pagination( - (UserSessions::Table, UserSessions::UserSessionId).into_column_ref(), + (UserSessions::Table, UserSessions::UserSessionId), pagination, ) .build(PostgresQueryBuilder); diff --git a/crates/storage/src/upstream_oauth2/link.rs b/crates/storage/src/upstream_oauth2/link.rs index 9b8a4f1cd..21c0f41e3 100644 --- a/crates/storage/src/upstream_oauth2/link.rs +++ b/crates/storage/src/upstream_oauth2/link.rs @@ -19,6 +19,52 @@ use ulid::Ulid; use crate::{pagination::Page, repository_impl, Clock, Pagination}; +/// Filter parameters for listing upstream OAuth links +#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)] +pub struct UpstreamOAuthLinkFilter<'a> { + // XXX: we might also want to filter for links without a user linked to them + user: Option<&'a User>, + provider: Option<&'a UpstreamOAuthProvider>, +} + +impl<'a> UpstreamOAuthLinkFilter<'a> { + /// Create a new [`UpstreamOAuthLinkFilter`] with default values + #[must_use] + pub fn new() -> Self { + Self::default() + } + + /// Set the user who owns the upstream OAuth links + #[must_use] + pub fn for_user(mut self, user: &'a User) -> Self { + self.user = Some(user); + self + } + + /// Get the user filter + /// + /// Returns [`None`] if no filter was set + #[must_use] + pub fn user(&self) -> Option<&User> { + self.user + } + + /// Set the upstream OAuth provider for which to list links + #[must_use] + pub fn for_provider(mut self, provider: &'a UpstreamOAuthProvider) -> Self { + self.provider = Some(provider); + self + } + + /// Get the upstream OAuth provider filter + /// + /// Returns [`None`] if no filter was set + #[must_use] + pub fn provider(&self) -> Option<&UpstreamOAuthProvider> { + self.provider + } +} + /// An [`UpstreamOAuthLinkRepository`] helps interacting with /// [`UpstreamOAuthLink`] with the storage backend #[async_trait] @@ -109,11 +155,42 @@ pub trait UpstreamOAuthLinkRepository: Send + Sync { /// # Errors /// /// Returns [`Self::Error`] if the underlying repository fails + #[deprecated(note = "Use `list` instead")] async fn list_paginated( &mut self, user: &User, pagination: Pagination, + ) -> Result, Self::Error> { + self.list(UpstreamOAuthLinkFilter::new().for_user(user), pagination) + .await + } + + /// List [`UpstreamOAuthLink`] with the given filter and pagination + /// + /// # Parameters + /// + /// * `filter`: The filter to apply + /// * `pagination`: The pagination parameters + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn list( + &mut self, + filter: UpstreamOAuthLinkFilter<'_>, + pagination: Pagination, ) -> Result, Self::Error>; + + /// Count the number of [`UpstreamOAuthLink`] with the given filter + /// + /// # Parameters + /// + /// * `filter`: The filter to apply + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn count(&mut self, filter: UpstreamOAuthLinkFilter<'_>) -> Result; } repository_impl!(UpstreamOAuthLinkRepository: @@ -139,9 +216,11 @@ repository_impl!(UpstreamOAuthLinkRepository: user: &User, ) -> Result<(), Self::Error>; - async fn list_paginated( + async fn list( &mut self, - user: &User, + filter: UpstreamOAuthLinkFilter<'_>, pagination: Pagination, ) -> Result, Self::Error>; + + async fn count(&mut self, filter: UpstreamOAuthLinkFilter<'_>) -> Result; ); diff --git a/crates/storage/src/upstream_oauth2/mod.rs b/crates/storage/src/upstream_oauth2/mod.rs index 6cf65945b..0d9cf10bf 100644 --- a/crates/storage/src/upstream_oauth2/mod.rs +++ b/crates/storage/src/upstream_oauth2/mod.rs @@ -20,7 +20,7 @@ mod provider; mod session; pub use self::{ - link::UpstreamOAuthLinkRepository, + link::{UpstreamOAuthLinkFilter, UpstreamOAuthLinkRepository}, provider::{UpstreamOAuthProviderFilter, UpstreamOAuthProviderRepository}, session::UpstreamOAuthSessionRepository, };