diff --git a/crates/graphql/src/model/mod.rs b/crates/graphql/src/model/mod.rs index 0d1bfd161..9e01aaf45 100644 --- a/crates/graphql/src/model/mod.rs +++ b/crates/graphql/src/model/mod.rs @@ -12,7 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -use async_graphql::Interface; +use async_graphql::{Interface, Object}; use chrono::{DateTime, Utc}; mod browser_sessions; @@ -52,3 +52,14 @@ pub enum CreationEvent { UpstreamOAuth2Link(Box), OAuth2Session(Box), } + +pub struct PreloadedTotalCount(pub Option); + +#[Object] +impl PreloadedTotalCount { + /// Identifies the total count of items in the connection. + async fn total_count(&self) -> Result { + self.0 + .ok_or_else(|| async_graphql::Error::new("total count not preloaded")) + } +} diff --git a/crates/graphql/src/model/users.rs b/crates/graphql/src/model/users.rs index 796844205..45b2c922c 100644 --- a/crates/graphql/src/model/users.rs +++ b/crates/graphql/src/model/users.rs @@ -26,19 +26,14 @@ use mas_storage::{ }; use super::{ - compat_sessions::CompatSsoLogin, BrowserSession, Cursor, NodeCursor, NodeType, OAuth2Session, - UpstreamOAuth2Link, -}; -use crate::{ - model::{ - browser_sessions::BrowserSessionState, - compat_sessions::{CompatSessionState, CompatSessionType}, - matrix::MatrixUser, - oauth::OAuth2SessionState, - CompatSession, - }, - state::ContextExt, + browser_sessions::BrowserSessionState, + compat_sessions::{CompatSessionState, CompatSessionType, CompatSsoLogin}, + matrix::MatrixUser, + oauth::OAuth2SessionState, + BrowserSession, CompatSession, Cursor, NodeCursor, NodeType, OAuth2Session, + PreloadedTotalCount, UpstreamOAuth2Link, }; +use crate::state::ContextExt; #[derive(Description)] /// A user is an individual's account. @@ -511,17 +506,6 @@ impl User { } } -pub struct PreloadedTotalCount(Option); - -#[Object] -impl PreloadedTotalCount { - /// Identifies the total count of items in the connection. - async fn total_count(&self) -> Result { - self.0 - .ok_or_else(|| async_graphql::Error::new("total count not preloaded")) - } -} - /// A user email address #[derive(Description)] pub struct UserEmail(pub mas_data_model::UserEmail); diff --git a/crates/graphql/src/query/upstream_oauth.rs b/crates/graphql/src/query/upstream_oauth.rs index ed9c5dd0c..3ecc347d9 100644 --- a/crates/graphql/src/query/upstream_oauth.rs +++ b/crates/graphql/src/query/upstream_oauth.rs @@ -16,10 +16,13 @@ use async_graphql::{ connection::{query, Connection, Edge, OpaqueCursor}, Context, Object, ID, }; -use mas_storage::Pagination; +use mas_storage::{upstream_oauth2::UpstreamOAuthProviderFilter, Pagination, RepositoryAccess}; use crate::{ - model::{Cursor, NodeCursor, NodeType, UpstreamOAuth2Link, UpstreamOAuth2Provider}, + model::{ + Cursor, NodeCursor, NodeType, PreloadedTotalCount, UpstreamOAuth2Link, + UpstreamOAuth2Provider, + }, state::ContextExt, }; @@ -78,7 +81,8 @@ impl UpstreamOAuthQuery { before: Option, #[graphql(desc = "Returns the first *n* elements from the list.")] first: Option, #[graphql(desc = "Returns the last *n* elements from the list.")] last: Option, - ) -> Result, async_graphql::Error> { + ) -> Result, async_graphql::Error> + { let state = ctx.state(); let mut repo = state.repository().await?; @@ -100,14 +104,27 @@ impl UpstreamOAuthQuery { .transpose()?; let pagination = Pagination::try_new(before_id, after_id, first, last)?; + let filter = UpstreamOAuthProviderFilter::new(); + let page = repo .upstream_oauth_provider() - .list_paginated(pagination) + .list(filter, pagination) .await?; + // Preload the total count if requested + let count = if ctx.look_ahead().field("totalCount").exists() { + Some(repo.upstream_oauth_provider().count(filter).await?) + } else { + None + }; + repo.cancel().await?; - let mut connection = Connection::new(page.has_previous_page, page.has_next_page); + let mut connection = Connection::with_additional_fields( + page.has_previous_page, + page.has_next_page, + PreloadedTotalCount(count), + ); connection.edges.extend(page.edges.into_iter().map(|p| { Edge::new( OpaqueCursor(NodeCursor(NodeType::UpstreamOAuth2Provider, p.id)), diff --git a/crates/storage-pg/.sqlx/query-d12a513b81b3ef658eae1f0a719933323f28c6ee260b52cafe337dd3d19e865c.json b/crates/storage-pg/.sqlx/query-d12a513b81b3ef658eae1f0a719933323f28c6ee260b52cafe337dd3d19e865c.json deleted file mode 100644 index 5a447ffe6..000000000 --- a/crates/storage-pg/.sqlx/query-d12a513b81b3ef658eae1f0a719933323f28c6ee260b52cafe337dd3d19e865c.json +++ /dev/null @@ -1,22 +0,0 @@ -{ - "db_name": "PostgreSQL", - "query": "\n SELECT COUNT(*)\n FROM user_emails\n WHERE user_id = $1\n ", - "describe": { - "columns": [ - { - "ordinal": 0, - "name": "count", - "type_info": "Int8" - } - ], - "parameters": { - "Left": [ - "Uuid" - ] - }, - "nullable": [ - null - ] - }, - "hash": "d12a513b81b3ef658eae1f0a719933323f28c6ee260b52cafe337dd3d19e865c" -} diff --git a/crates/storage-pg/src/iden.rs b/crates/storage-pg/src/iden.rs index 06b1ab877..42856d2b8 100644 --- a/crates/storage-pg/src/iden.rs +++ b/crates/storage-pg/src/iden.rs @@ -77,3 +77,19 @@ pub enum OAuth2Sessions { CreatedAt, FinishedAt, } + +#[derive(sea_query::Iden)] +#[iden = "upstream_oauth_providers"] +pub enum UpstreamOAuthProviders { + Table, + #[iden = "upstream_oauth_provider_id"] + UpstreamOAuthProviderId, + Issuer, + Scope, + ClientId, + EncryptedClientSecret, + TokenEndpointSigningAlg, + TokenEndpointAuthMethod, + CreatedAt, + ClaimsImports, +} diff --git a/crates/storage-pg/src/upstream_oauth2/mod.rs b/crates/storage-pg/src/upstream_oauth2/mod.rs index 48ecab7c9..80b084d0b 100644 --- a/crates/storage-pg/src/upstream_oauth2/mod.rs +++ b/crates/storage-pg/src/upstream_oauth2/mod.rs @@ -31,8 +31,8 @@ mod tests { use mas_storage::{ clock::MockClock, upstream_oauth2::{ - UpstreamOAuthLinkRepository, UpstreamOAuthProviderRepository, - UpstreamOAuthSessionRepository, + UpstreamOAuthLinkRepository, UpstreamOAuthProviderFilter, + UpstreamOAuthProviderRepository, UpstreamOAuthSessionRepository, }, user::UserRepository, Pagination, RepositoryAccess, @@ -208,6 +208,14 @@ mod tests { let clock = MockClock::default(); let mut repo = PgRepository::from_pool(&pool).await.unwrap(); + let filter = UpstreamOAuthProviderFilter::new(); + + // Count the number of providers before we start + assert_eq!( + repo.upstream_oauth_provider().count(filter).await.unwrap(), + 0 + ); + let mut ids = Vec::with_capacity(20); // Create 20 providers for idx in 0..20 { @@ -231,10 +239,16 @@ mod tests { clock.advance(Duration::seconds(10)); } + // Now we have 20 providers + assert_eq!( + repo.upstream_oauth_provider().count(filter).await.unwrap(), + 20 + ); + // Lookup the first 10 items let page = repo .upstream_oauth_provider() - .list_paginated(Pagination::first(10)) + .list(filter, Pagination::first(10)) .await .unwrap(); @@ -246,7 +260,7 @@ mod tests { // Lookup the next 10 items let page = repo .upstream_oauth_provider() - .list_paginated(Pagination::first(10).after(ids[9])) + .list(filter, Pagination::first(10).after(ids[9])) .await .unwrap(); @@ -258,7 +272,7 @@ mod tests { // Lookup the last 10 items let page = repo .upstream_oauth_provider() - .list_paginated(Pagination::last(10)) + .list(filter, Pagination::last(10)) .await .unwrap(); @@ -270,7 +284,7 @@ mod tests { // Lookup the previous 10 items let page = repo .upstream_oauth_provider() - .list_paginated(Pagination::last(10).before(ids[10])) + .list(filter, Pagination::last(10).before(ids[10])) .await .unwrap(); @@ -282,7 +296,7 @@ mod tests { // Lookup 10 items between two IDs let page = repo .upstream_oauth_provider() - .list_paginated(Pagination::first(10).after(ids[5]).before(ids[8])) + .list(filter, Pagination::first(10).after(ids[5]).before(ids[8])) .await .unwrap(); diff --git a/crates/storage-pg/src/upstream_oauth2/provider.rs b/crates/storage-pg/src/upstream_oauth2/provider.rs index 508b969b8..0954cadfa 100644 --- a/crates/storage-pg/src/upstream_oauth2/provider.rs +++ b/crates/storage-pg/src/upstream_oauth2/provider.rs @@ -16,16 +16,21 @@ use async_trait::async_trait; use chrono::{DateTime, Utc}; use mas_data_model::{UpstreamOAuthProvider, UpstreamOAuthProviderClaimsImports}; use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod}; -use mas_storage::{upstream_oauth2::UpstreamOAuthProviderRepository, Clock, Page, Pagination}; +use mas_storage::{ + upstream_oauth2::{UpstreamOAuthProviderFilter, UpstreamOAuthProviderRepository}, + Clock, Page, Pagination, +}; use oauth2_types::scope::Scope; use rand::RngCore; -use sqlx::{types::Json, PgConnection, QueryBuilder}; +use sea_query::{enum_def, Expr, IntoColumnRef, PostgresQueryBuilder, Query}; +use sqlx::{types::Json, PgConnection}; use tracing::{info_span, Instrument}; use ulid::Ulid; use uuid::Uuid; use crate::{ - pagination::QueryBuilderExt, tracing::ExecuteExt, DatabaseError, DatabaseInconsistencyError, + iden::UpstreamOAuthProviders, pagination::QueryBuilderExt, sea_query_sqlx::map_values, + tracing::ExecuteExt, DatabaseError, DatabaseInconsistencyError, }; /// An implementation of [`UpstreamOAuthProviderRepository`] for a PostgreSQL @@ -43,6 +48,7 @@ impl<'c> PgUpstreamOAuthProviderRepository<'c> { } #[derive(sqlx::FromRow)] +#[enum_def] struct ProviderLookup { upstream_oauth_provider_id: Uuid, issuer: String, @@ -209,6 +215,72 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<' }) } + #[tracing::instrument( + name = "db.upstream_oauth_provider.delete_by_id", + skip_all, + fields( + db.statement, + upstream_oauth_provider.id = %id, + ), + err, + )] + async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error> { + // Delete the authorization sessions first, as they have a foreign key + // constraint on the links and the providers. + { + let span = info_span!( + "db.oauth2_client.delete_by_id.authorization_sessions", + upstream_oauth_provider.id = %id, + db.statement = tracing::field::Empty, + ); + sqlx::query!( + r#" + DELETE FROM upstream_oauth_authorization_sessions + WHERE upstream_oauth_provider_id = $1 + "#, + Uuid::from(id), + ) + .record(&span) + .execute(&mut *self.conn) + .instrument(span) + .await?; + } + + // Delete the links next, as they have a foreign key constraint on the + // providers. + { + let span = info_span!( + "db.oauth2_client.delete_by_id.links", + upstream_oauth_provider.id = %id, + db.statement = tracing::field::Empty, + ); + sqlx::query!( + r#" + DELETE FROM upstream_oauth_links + WHERE upstream_oauth_provider_id = $1 + "#, + Uuid::from(id), + ) + .record(&span) + .execute(&mut *self.conn) + .instrument(span) + .await?; + } + + let res = sqlx::query!( + r#" + DELETE FROM upstream_oauth_providers + WHERE upstream_oauth_provider_id = $1 + "#, + Uuid::from(id), + ) + .traced() + .execute(&mut *self.conn) + .await?; + + DatabaseError::ensure_affected_rows(&res, 1) + } + #[tracing::instrument( name = "db.upstream_oauth_provider.add", skip_all, @@ -288,110 +360,139 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<' } #[tracing::instrument( - name = "db.upstream_oauth_provider.delete_by_id", - skip_all, - fields( - db.statement, - upstream_oauth_provider.id = %id, - ), - err, - )] - async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error> { - // Delete the authorization sessions first, as they have a foreign key - // constraint on the links and the providers. - { - let span = info_span!( - "db.oauth2_client.delete_by_id.authorization_sessions", - upstream_oauth_provider.id = %id, - db.statement = tracing::field::Empty, - ); - sqlx::query!( - r#" - DELETE FROM upstream_oauth_authorization_sessions - WHERE upstream_oauth_provider_id = $1 - "#, - Uuid::from(id), - ) - .record(&span) - .execute(&mut *self.conn) - .instrument(span) - .await?; - } - - // Delete the links next, as they have a foreign key constraint on the - // providers. - { - let span = info_span!( - "db.oauth2_client.delete_by_id.links", - upstream_oauth_provider.id = %id, - db.statement = tracing::field::Empty, - ); - sqlx::query!( - r#" - DELETE FROM upstream_oauth_links - WHERE upstream_oauth_provider_id = $1 - "#, - Uuid::from(id), - ) - .record(&span) - .execute(&mut *self.conn) - .instrument(span) - .await?; - } - - let res = sqlx::query!( - r#" - DELETE FROM upstream_oauth_providers - WHERE upstream_oauth_provider_id = $1 - "#, - Uuid::from(id), - ) - .traced() - .execute(&mut *self.conn) - .await?; - - DatabaseError::ensure_affected_rows(&res, 1) - } - - #[tracing::instrument( - name = "db.upstream_oauth_provider.list_paginated", + name = "db.upstream_oauth_provider.list", skip_all, fields( db.statement, ), err, )] - async fn list_paginated( + async fn list( &mut self, + _filter: UpstreamOAuthProviderFilter<'_>, pagination: Pagination, ) -> Result, Self::Error> { - let mut query = QueryBuilder::new( - r#" - SELECT - upstream_oauth_provider_id, - issuer, - scope, - client_id, - encrypted_client_secret, - token_endpoint_signing_alg, - token_endpoint_auth_method, - created_at, - claims_imports - FROM upstream_oauth_providers - WHERE 1 = 1 - "#, - ); + // XXX: the filter is currently ignored, as it does not have any fields + let (sql, values) = Query::select() + .expr_as( + Expr::col(( + UpstreamOAuthProviders::Table, + UpstreamOAuthProviders::UpstreamOAuthProviderId, + )), + ProviderLookupIden::UpstreamOauthProviderId, + ) + .expr_as( + Expr::col(( + UpstreamOAuthProviders::Table, + UpstreamOAuthProviders::Issuer, + )), + ProviderLookupIden::Issuer, + ) + .expr_as( + Expr::col((UpstreamOAuthProviders::Table, UpstreamOAuthProviders::Scope)), + ProviderLookupIden::Scope, + ) + .expr_as( + Expr::col(( + UpstreamOAuthProviders::Table, + UpstreamOAuthProviders::ClientId, + )), + ProviderLookupIden::ClientId, + ) + .expr_as( + Expr::col(( + UpstreamOAuthProviders::Table, + UpstreamOAuthProviders::EncryptedClientSecret, + )), + ProviderLookupIden::EncryptedClientSecret, + ) + .expr_as( + Expr::col(( + UpstreamOAuthProviders::Table, + UpstreamOAuthProviders::TokenEndpointSigningAlg, + )), + ProviderLookupIden::TokenEndpointSigningAlg, + ) + .expr_as( + Expr::col(( + UpstreamOAuthProviders::Table, + UpstreamOAuthProviders::TokenEndpointAuthMethod, + )), + ProviderLookupIden::TokenEndpointAuthMethod, + ) + .expr_as( + Expr::col(( + UpstreamOAuthProviders::Table, + UpstreamOAuthProviders::CreatedAt, + )), + ProviderLookupIden::CreatedAt, + ) + .expr_as( + Expr::col(( + UpstreamOAuthProviders::Table, + UpstreamOAuthProviders::ClaimsImports, + )), + ProviderLookupIden::ClaimsImports, + ) + .from(UpstreamOAuthProviders::Table) + .generate_pagination( + ( + UpstreamOAuthProviders::Table, + UpstreamOAuthProviders::UpstreamOAuthProviderId, + ) + .into_column_ref(), + pagination, + ) + .build(PostgresQueryBuilder); - query.generate_pagination("upstream_oauth_provider_id", pagination); + let arguments = map_values(values); - let edges: Vec = query - .build_query_as() + let edges: Vec = sqlx::query_as_with(&sql, arguments) .traced() .fetch_all(&mut *self.conn) .await?; - let page = pagination.process(edges).try_map(TryInto::try_into)?; - Ok(page) + let page = pagination + .process(edges) + .try_map(UpstreamOAuthProvider::try_from)?; + + return Ok(page); + } + + #[tracing::instrument( + name = "db.upstream_oauth_provider.count", + skip_all, + fields( + db.statement, + ), + err, + )] + async fn count( + &mut self, + _filter: UpstreamOAuthProviderFilter<'_>, + ) -> Result { + // XXX: the filter is currently ignored, as it does not have any fields + let (sql, values) = Query::select() + .expr( + Expr::col(( + UpstreamOAuthProviders::Table, + UpstreamOAuthProviders::UpstreamOAuthProviderId, + )) + .count(), + ) + .from(UpstreamOAuthProviders::Table) + .build(PostgresQueryBuilder); + + let arguments = map_values(values); + + let count: i64 = sqlx::query_scalar_with(&sql, arguments) + .traced() + .fetch_one(&mut *self.conn) + .await?; + + count + .try_into() + .map_err(DatabaseError::to_invalid_operation) } #[tracing::instrument( diff --git a/crates/storage/src/upstream_oauth2/mod.rs b/crates/storage/src/upstream_oauth2/mod.rs index 252217527..6cf65945b 100644 --- a/crates/storage/src/upstream_oauth2/mod.rs +++ b/crates/storage/src/upstream_oauth2/mod.rs @@ -20,6 +20,7 @@ mod provider; mod session; pub use self::{ - link::UpstreamOAuthLinkRepository, provider::UpstreamOAuthProviderRepository, + link::UpstreamOAuthLinkRepository, + provider::{UpstreamOAuthProviderFilter, UpstreamOAuthProviderRepository}, session::UpstreamOAuthSessionRepository, }; diff --git a/crates/storage/src/upstream_oauth2/provider.rs b/crates/storage/src/upstream_oauth2/provider.rs index 9624a40a2..ff4f5850b 100644 --- a/crates/storage/src/upstream_oauth2/provider.rs +++ b/crates/storage/src/upstream_oauth2/provider.rs @@ -12,6 +12,8 @@ // See the License for the specific language governing permissions and // limitations under the License. +use std::marker::PhantomData; + use async_trait::async_trait; use mas_data_model::{UpstreamOAuthProvider, UpstreamOAuthProviderClaimsImports}; use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod}; @@ -21,6 +23,20 @@ use ulid::Ulid; use crate::{pagination::Page, repository_impl, Clock, Pagination}; +/// Filter parameters for listing upstream OAuth 2.0 providers +#[derive(Clone, Copy, Debug, PartialEq, Eq, Default)] +pub struct UpstreamOAuthProviderFilter<'a> { + _lifetime: PhantomData<&'a ()>, +} + +impl<'a> UpstreamOAuthProviderFilter<'a> { + /// Create a new [`UpstreamOAuthProviderFilter`] with default values + #[must_use] + pub fn new() -> Self { + Self::default() + } +} + /// An [`UpstreamOAuthProviderRepository`] helps interacting with /// [`UpstreamOAuthProvider`] saved in the storage backend #[async_trait] @@ -137,20 +153,36 @@ pub trait UpstreamOAuthProviderRepository: Send + Sync { claims_imports: UpstreamOAuthProviderClaimsImports, ) -> Result; - /// Get a paginated list of upstream OAuth providers + /// List [`UpstreamOAuthProvider`] with the given filter and pagination /// /// # Parameters /// + /// * `filter`: The filter to apply /// * `pagination`: The pagination parameters /// /// # Errors /// /// Returns [`Self::Error`] if the underlying repository fails - async fn list_paginated( + async fn list( &mut self, + filter: UpstreamOAuthProviderFilter<'_>, pagination: Pagination, ) -> Result, Self::Error>; + /// Count the number of [`UpstreamOAuthProvider`] with the given filter + /// + /// # Parameters + /// + /// * `filter`: The filter to apply + /// + /// # Errors + /// + /// Returns [`Self::Error`] if the underlying repository fails + async fn count( + &mut self, + filter: UpstreamOAuthProviderFilter<'_>, + ) -> Result; + /// Get all upstream OAuth providers /// /// # Errors @@ -192,10 +224,16 @@ repository_impl!(UpstreamOAuthProviderRepository: async fn delete_by_id(&mut self, id: Ulid) -> Result<(), Self::Error>; - async fn list_paginated( + async fn list( &mut self, + filter: UpstreamOAuthProviderFilter<'_>, pagination: Pagination ) -> Result, Self::Error>; + async fn count( + &mut self, + filter: UpstreamOAuthProviderFilter<'_> + ) -> Result; + async fn all(&mut self) -> Result, Self::Error>; ); diff --git a/frontend/schema.graphql b/frontend/schema.graphql index 1a8a0dc82..6b16fe6ba 100644 --- a/frontend/schema.graphql +++ b/frontend/schema.graphql @@ -880,6 +880,10 @@ type UpstreamOAuth2ProviderConnection { A list of nodes. """ nodes: [UpstreamOAuth2Provider!]! + """ + Identifies the total count of items in the connection. + """ + totalCount: Int! } """ diff --git a/frontend/src/gql/graphql.ts b/frontend/src/gql/graphql.ts index 1c5fd8bad..ffc3d4eb4 100644 --- a/frontend/src/gql/graphql.ts +++ b/frontend/src/gql/graphql.ts @@ -668,6 +668,8 @@ export type UpstreamOAuth2ProviderConnection = { nodes: Array; /** Information to aid in pagination. */ pageInfo: PageInfo; + /** Identifies the total count of items in the connection. */ + totalCount: Scalars["Int"]["output"]; }; /** An edge in a connection. */ diff --git a/frontend/src/gql/schema.ts b/frontend/src/gql/schema.ts index 8f2b80a51..678c5780a 100644 --- a/frontend/src/gql/schema.ts +++ b/frontend/src/gql/schema.ts @@ -1938,6 +1938,17 @@ export default { }, args: [], }, + { + name: "totalCount", + type: { + kind: "NON_NULL", + ofType: { + kind: "SCALAR", + name: "Any", + }, + }, + args: [], + }, ], interfaces: [], },