diff --git a/Cargo.lock b/Cargo.lock index b82a50f0e..c508ab8bf 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -612,7 +612,6 @@ checksum = "08b108ad2665fa3f6e6a517c3d80ec3e77d224c47d605167aefaa5d7ef97fa48" dependencies = [ "async-trait", "axum-core", - "base64", "bitflags", "bytes 1.3.0", "futures-util", @@ -631,10 +630,8 @@ dependencies = [ "serde_json", "serde_path_to_error", "serde_urlencoded", - "sha-1", "sync_wrapper", "tokio", - "tokio-tungstenite", "tower", "tower-http", "tower-layer", @@ -2811,14 +2808,12 @@ dependencies = [ "anyhow", "async-graphql", "chrono", - "mas-axum-utils", "mas-data-model", "mas-storage", "oauth2-types", "serde", "sqlx", "thiserror", - "tokio", "tracing", "ulid", "url", @@ -4753,17 +4748,6 @@ dependencies = [ "unsafe-libyaml", ] -[[package]] -name = "sha-1" -version = "0.10.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f5058ada175748e33390e40e872bd0fe59a19f265d0158daa551c5a88a76009c" -dependencies = [ - "cfg-if", - "cpufeatures", - "digest", -] - [[package]] name = "sha1" version = "0.10.5" @@ -5327,18 +5311,6 @@ dependencies = [ "tokio-stream", ] -[[package]] -name = "tokio-tungstenite" -version = "0.17.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f714dd15bead90401d77e04243611caec13726c2408afd5b31901dfcdcb3b181" -dependencies = [ - "futures-util", - "log", - "tokio", - "tungstenite", -] - [[package]] name = "tokio-util" version = "0.6.10" @@ -5598,25 +5570,6 @@ version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "59547bce71d9c38b83d9c0e92b6066c4253371f15005def0c30d9657f50c7642" -[[package]] -name = "tungstenite" -version = "0.17.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e27992fd6a8c29ee7eef28fc78349aa244134e10ad447ce3b9f0ac0ed0fa4ce0" -dependencies = [ - "base64", - "byteorder", - "bytes 1.3.0", - "http", - "httparse", - "log", - "rand 0.8.5", - "sha-1", - "thiserror", - "url", - "utf-8", -] - [[package]] name = "typed-builder" version = "0.9.1" @@ -5798,12 +5751,6 @@ version = "2.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e8db7427f936968176eaa7cdf81b7f98b980b18495ec28f1b5791ac3bfe3eea9" -[[package]] -name = "utf-8" -version = "0.7.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" - [[package]] name = "uuid" version = "1.2.2" diff --git a/crates/graphql/Cargo.toml b/crates/graphql/Cargo.toml index a89364ba2..59b2a8fc1 100644 --- a/crates/graphql/Cargo.toml +++ b/crates/graphql/Cargo.toml @@ -12,19 +12,13 @@ chrono = "0.4.23" serde = { version = "1.0.150", features = ["derive"] } sqlx = { version = "0.6.2", features = ["runtime-tokio-rustls", "postgres"] } thiserror = "1.0.37" -tokio = { version = "1.23.0", features = ["time"] } tracing = "0.1.37" ulid = "1.0.0" url = "2.3.1" oauth2-types = { path = "../oauth2-types" } -mas-axum-utils = { path = "../axum-utils" } mas-data-model = { path = "../data-model" } mas-storage = { path = "../storage" } -[features] -native-roots = ["mas-axum-utils/native-roots"] -webpki-roots = ["mas-axum-utils/webpki-roots"] - [[bin]] name = "schema" diff --git a/crates/graphql/src/lib.rs b/crates/graphql/src/lib.rs index e04ae7632..1e691a96e 100644 --- a/crates/graphql/src/lib.rs +++ b/crates/graphql/src/lib.rs @@ -20,13 +20,17 @@ clippy::future_not_send )] #![warn(clippy::pedantic)] -#![allow(clippy::module_name_repetitions, clippy::missing_errors_doc)] +#![allow( + clippy::module_name_repetitions, + clippy::missing_errors_doc, + clippy::unused_async +)] use async_graphql::{ connection::{query, Connection, Edge, OpaqueCursor}, Context, Description, EmptyMutation, EmptySubscription, ID, }; -use mas_axum_utils::SessionInfo; +use model::CreationEvent; use sqlx::PgPool; use self::model::{ @@ -43,8 +47,7 @@ pub type SchemaBuilder = async_graphql::SchemaBuilder SchemaBuilder { async_graphql::Schema::build(RootQuery::new(), EmptyMutation, EmptySubscription) .register_output_type::() - // TODO: ordering of interface implementations is not stable - //.register_output_type::() + .register_output_type::() } /// The query root of the GraphQL interface. @@ -67,21 +70,13 @@ impl RootQuery { &self, ctx: &Context<'_>, ) -> Result, async_graphql::Error> { - let database = ctx.data::()?; - let session_info = ctx.data::()?; - let mut conn = database.acquire().await?; - let session = session_info.load_session(&mut conn).await?; - + let session = ctx.data_opt::().cloned(); Ok(session.map(BrowserSession::from)) } /// Get the current logged in user async fn current_user(&self, ctx: &Context<'_>) -> Result, async_graphql::Error> { - let database = ctx.data::()?; - let session_info = ctx.data::()?; - let mut conn = database.acquire().await?; - let session = session_info.load_session(&mut conn).await?; - + let session = ctx.data_opt::().cloned(); Ok(session.map(User::from)) } @@ -103,10 +98,7 @@ impl RootQuery { /// Fetch a user by its ID. async fn user(&self, ctx: &Context<'_>, id: ID) -> Result, async_graphql::Error> { let id = NodeType::User.extract_ulid(&id)?; - let database = ctx.data::()?; - let session_info = ctx.data::()?; - let mut conn = database.acquire().await?; - let session = session_info.load_session(&mut conn).await?; + let session = ctx.data_opt::().cloned(); let Some(session) = session else { return Ok(None) }; let current_user = session.user; @@ -125,10 +117,9 @@ impl RootQuery { id: ID, ) -> Result, async_graphql::Error> { let id = NodeType::BrowserSession.extract_ulid(&id)?; + let session = ctx.data_opt::().cloned(); let database = ctx.data::()?; - let session_info = ctx.data::()?; let mut conn = database.acquire().await?; - let session = session_info.load_session(&mut conn).await?; let Some(session) = session else { return Ok(None) }; let current_user = session.user; @@ -153,10 +144,9 @@ impl RootQuery { id: ID, ) -> Result, async_graphql::Error> { let id = NodeType::UserEmail.extract_ulid(&id)?; + let session = ctx.data_opt::().cloned(); let database = ctx.data::()?; - let session_info = ctx.data::()?; let mut conn = database.acquire().await?; - let session = session_info.load_session(&mut conn).await?; let Some(session) = session else { return Ok(None) }; let current_user = session.user; @@ -174,10 +164,9 @@ impl RootQuery { id: ID, ) -> Result, async_graphql::Error> { let id = NodeType::UpstreamOAuth2Link.extract_ulid(&id)?; + let session = ctx.data_opt::().cloned(); let database = ctx.data::()?; - let session_info = ctx.data::()?; let mut conn = database.acquire().await?; - let session = session_info.load_session(&mut conn).await?; let Some(session) = session else { return Ok(None) }; let current_user = session.user; diff --git a/crates/graphql/src/model/mod.rs b/crates/graphql/src/model/mod.rs index 7b923b3c4..f5d018e77 100644 --- a/crates/graphql/src/model/mod.rs +++ b/crates/graphql/src/model/mod.rs @@ -33,6 +33,7 @@ pub use self::{ users::{User, UserEmail}, }; +/// An object with a creation date. #[derive(Interface)] #[graphql(field( name = "created_at", diff --git a/crates/handlers/Cargo.toml b/crates/handlers/Cargo.toml index 4ba1ce7f5..74239e61c 100644 --- a/crates/handlers/Cargo.toml +++ b/crates/handlers/Cargo.toml @@ -21,7 +21,7 @@ anyhow = "1.0.66" hyper = { version = "0.14.23", features = ["full"] } tower = "0.4.13" tower-http = { version = "0.3.5", features = ["cors"] } -axum = { version = "0.6.1", features = ["ws"] } +axum = "0.6.1" axum-macros = "0.3.0" axum-extra = { version = "0.4.2", features = ["cookie-private"] } diff --git a/crates/handlers/src/graphql.rs b/crates/handlers/src/graphql.rs index 4adf1fe4f..ba6919940 100644 --- a/crates/handlers/src/graphql.rs +++ b/crates/handlers/src/graphql.rs @@ -12,28 +12,19 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::{borrow::Cow, str::FromStr}; - use async_graphql::{ extensions::{ApolloTracing, Tracing}, - http::{ - playground_source, GraphQLPlaygroundConfig, MultipartOptions, WebSocketProtocols, - WsMessage, ALL_WEBSOCKET_PROTOCOLS, - }, - Data, + http::{playground_source, GraphQLPlaygroundConfig, MultipartOptions}, }; use axum::{ - extract::{ - ws::{CloseFrame, Message}, - BodyStream, RawQuery, State, WebSocketUpgrade, - }, - response::{Html, IntoResponse, Response}, + extract::{BodyStream, RawQuery, State}, + response::{Html, IntoResponse}, Json, TypedHeader, }; use axum_extra::extract::PrivateCookieJar; -use futures_util::{SinkExt, StreamExt, TryStreamExt}; -use headers::{ContentType, Header, HeaderValue}; -use hyper::header::{CACHE_CONTROL, SEC_WEBSOCKET_PROTOCOL}; +use futures_util::{StreamExt, TryStreamExt}; +use headers::{ContentType, HeaderValue}; +use hyper::header::CACHE_CONTROL; use mas_axum_utils::{FancyError, SessionInfoExt}; use mas_graphql::Schema; use mas_keystore::Encrypter; @@ -67,6 +58,7 @@ fn span_for_graphql_request(request: &async_graphql::Request) -> tracing::Span { } pub async fn post( + State(pool): State, State(schema): State, cookie_jar: PrivateCookieJar, content_type: Option>, @@ -75,15 +67,19 @@ pub async fn post( let content_type = content_type.map(|TypedHeader(h)| h.to_string()); let (session_info, _cookie_jar) = cookie_jar.session_info(); + let maybe_session = session_info.load_session(&pool).await?; - let request = async_graphql::http::receive_batch_body( + let mut request = async_graphql::http::receive_batch_body( content_type, body.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)) .into_async_read(), MultipartOptions::default(), ) - .await? // XXX: this should probably return another error response? - .data(session_info); + .await?; // XXX: this should probably return another error response? + + if let Some(session) = maybe_session { + request = request.data(session); + } let response = match request { async_graphql::BatchRequest::Single(request) => { @@ -114,13 +110,19 @@ pub async fn post( } pub async fn get( + State(pool): State, State(schema): State, cookie_jar: PrivateCookieJar, RawQuery(query): RawQuery, ) -> Result { let (session_info, _cookie_jar) = cookie_jar.session_info(); - let request = - async_graphql::http::parse_query_string(&query.unwrap_or_default())?.data(session_info); + let maybe_session = session_info.load_session(&pool).await?; + + let mut request = async_graphql::http::parse_query_string(&query.unwrap_or_default())?; + + if let Some(session) = maybe_session { + request = request.data(session); + } let span = span_for_graphql_request(&request); let response = schema.execute(request).instrument(span).await; @@ -136,78 +138,8 @@ pub async fn get( Ok((headers, cache_control, Json(response))) } -pub struct SecWebsocketProtocol(WebSocketProtocols); - -impl Header for SecWebsocketProtocol { - fn name() -> &'static headers::HeaderName { - &SEC_WEBSOCKET_PROTOCOL - } - - fn decode<'i, I>(values: &mut I) -> Result - where - Self: Sized, - I: Iterator, - { - values - .filter_map(|value| value.to_str().ok()) - .flat_map(|value| value.split(',')) - .find_map(|p| WebSocketProtocols::from_str(p.trim()).ok()) - .map(Self) - .ok_or_else(headers::Error::invalid) - } - - fn encode>(&self, values: &mut E) { - if let Ok(v) = HeaderValue::from_str(self.0.sec_websocket_protocol()) { - values.extend(std::iter::once(v)); - } - } -} - -pub async fn ws( - State(schema): State, - cookie_jar: PrivateCookieJar, - TypedHeader(SecWebsocketProtocol(protocol)): TypedHeader, - websocket: WebSocketUpgrade, -) -> Response { - let (session_info, _cookie_jar) = cookie_jar.session_info(); - websocket - .protocols(ALL_WEBSOCKET_PROTOCOLS) - .on_upgrade(move |ws| async move { - let (mut sink, stream) = ws.split(); - let stream = stream - .take_while(|res| std::future::ready(res.is_ok())) - .map(Result::unwrap) - .filter_map(|msg| { - if let Message::Text(_) | Message::Binary(_) = msg { - std::future::ready(Some(msg.into_data())) - } else { - std::future::ready(None) - } - }); - - let mut data = Data::default(); - data.insert(session_info); - - let mut stream = async_graphql::http::WebSocket::new(schema.clone(), stream, protocol) - .connection_data(data) - .map(|msg| match msg { - WsMessage::Text(text) => Message::Text(text), - WsMessage::Close(code, status) => Message::Close(Some(CloseFrame { - code, - reason: Cow::from(status), - })), - }); - - while let Some(item) = stream.next().await { - let _res = sink.send(item).await; - } - }) -} - pub async fn playground() -> impl IntoResponse { Html(playground_source( - GraphQLPlaygroundConfig::new("/graphql") - .subscription_endpoint("/graphql/ws") - .with_setting("request.credentials", "include"), + GraphQLPlaygroundConfig::new("/graphql").with_setting("request.credentials", "include"), )) } diff --git a/crates/handlers/src/lib.rs b/crates/handlers/src/lib.rs index 927d20dfb..501f3176f 100644 --- a/crates/handlers/src/lib.rs +++ b/crates/handlers/src/lib.rs @@ -94,14 +94,13 @@ where ::Error: std::error::Error + Send + Sync, S: Clone + Send + Sync + 'static, mas_graphql::Schema: FromRef, + PgPool: FromRef, Encrypter: FromRef, { - let mut router = Router::new() - .route( - "/graphql", - get(self::graphql::get).post(self::graphql::post), - ) - .route("/graphql/ws", get(self::graphql::ws)); + let mut router = Router::new().route( + "/graphql", + get(self::graphql::get).post(self::graphql::post), + ); if playground { router = router.route("/graphql/playground", get(self::graphql::playground)); diff --git a/frontend/schema.graphql b/frontend/schema.graphql index d1efed688..8f98b93df 100644 --- a/frontend/schema.graphql +++ b/frontend/schema.graphql @@ -2,7 +2,7 @@ An authentication records when a user enter their credential in a browser session. """ -type Authentication implements Node { +type Authentication implements Node & CreationEvent { """ ID of the object. """ @@ -16,7 +16,7 @@ type Authentication implements Node { """ A browser session represents a logged in user in a browser. """ -type BrowserSession implements Node { +type BrowserSession implements Node & CreationEvent { """ ID of the object. """ @@ -68,7 +68,7 @@ type BrowserSessionEdge { A compat session represents a client session which used the legacy Matrix login API. """ -type CompatSession implements Node { +type CompatSession implements Node & CreationEvent { """ ID of the object. """ @@ -152,6 +152,16 @@ type CompatSsoLoginEdge { node: CompatSsoLogin! } +""" +An object with a creation date. +""" +interface CreationEvent { + """ + When the object was created. + """ + createdAt: DateTime! +} + """ Implement the DateTime scalar @@ -332,7 +342,7 @@ type RootQuery { node(id: ID!): Node } -type UpstreamOAuth2Link implements Node { +type UpstreamOAuth2Link implements Node & CreationEvent { """ ID of the object. """ @@ -384,7 +394,7 @@ type UpstreamOAuth2LinkEdge { node: UpstreamOAuth2Link! } -type UpstreamOAuth2Provider implements Node { +type UpstreamOAuth2Provider implements Node & CreationEvent { """ ID of the object. """ @@ -503,7 +513,7 @@ type User implements Node { """ A user email address """ -type UserEmail implements Node { +type UserEmail implements Node & CreationEvent { """ ID of the object. """ diff --git a/misc/update.sh b/misc/update.sh index b9ec5a483..52d7ac36b 100644 --- a/misc/update.sh +++ b/misc/update.sh @@ -8,9 +8,8 @@ CONFIG_SCHEMA="${BASE_DIR}/docs/config.schema.json" GRAPHQL_SCHEMA="${BASE_DIR}/frontend/schema.graphql" set -x -# XXX: we shouldn't have to specify this feature cargo run -p mas-config > "${CONFIG_SCHEMA}" -cargo run -p mas-graphql --features webpki-roots > "${GRAPHQL_SCHEMA}" +cargo run -p mas-graphql > "${GRAPHQL_SCHEMA}" cd "${BASE_DIR}/frontend" npm run generate