Remove the dependency mas-axum-utils <- mas-graphql

This is done by loading the browser session earlier
Also removes the GraphQL subscription logic
This commit is contained in:
Quentin Gliech
2022-12-15 16:33:15 +01:00
parent 212269d6c8
commit 26e1c34539
9 changed files with 60 additions and 189 deletions

53
Cargo.lock generated
View File

@@ -612,7 +612,6 @@ checksum = "08b108ad2665fa3f6e6a517c3d80ec3e77d224c47d605167aefaa5d7ef97fa48"
dependencies = [ dependencies = [
"async-trait", "async-trait",
"axum-core", "axum-core",
"base64",
"bitflags", "bitflags",
"bytes 1.3.0", "bytes 1.3.0",
"futures-util", "futures-util",
@@ -631,10 +630,8 @@ dependencies = [
"serde_json", "serde_json",
"serde_path_to_error", "serde_path_to_error",
"serde_urlencoded", "serde_urlencoded",
"sha-1",
"sync_wrapper", "sync_wrapper",
"tokio", "tokio",
"tokio-tungstenite",
"tower", "tower",
"tower-http", "tower-http",
"tower-layer", "tower-layer",
@@ -2811,14 +2808,12 @@ dependencies = [
"anyhow", "anyhow",
"async-graphql", "async-graphql",
"chrono", "chrono",
"mas-axum-utils",
"mas-data-model", "mas-data-model",
"mas-storage", "mas-storage",
"oauth2-types", "oauth2-types",
"serde", "serde",
"sqlx", "sqlx",
"thiserror", "thiserror",
"tokio",
"tracing", "tracing",
"ulid", "ulid",
"url", "url",
@@ -4753,17 +4748,6 @@ dependencies = [
"unsafe-libyaml", "unsafe-libyaml",
] ]
[[package]]
name = "sha-1"
version = "0.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f5058ada175748e33390e40e872bd0fe59a19f265d0158daa551c5a88a76009c"
dependencies = [
"cfg-if",
"cpufeatures",
"digest",
]
[[package]] [[package]]
name = "sha1" name = "sha1"
version = "0.10.5" version = "0.10.5"
@@ -5327,18 +5311,6 @@ dependencies = [
"tokio-stream", "tokio-stream",
] ]
[[package]]
name = "tokio-tungstenite"
version = "0.17.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f714dd15bead90401d77e04243611caec13726c2408afd5b31901dfcdcb3b181"
dependencies = [
"futures-util",
"log",
"tokio",
"tungstenite",
]
[[package]] [[package]]
name = "tokio-util" name = "tokio-util"
version = "0.6.10" version = "0.6.10"
@@ -5598,25 +5570,6 @@ version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "59547bce71d9c38b83d9c0e92b6066c4253371f15005def0c30d9657f50c7642" checksum = "59547bce71d9c38b83d9c0e92b6066c4253371f15005def0c30d9657f50c7642"
[[package]]
name = "tungstenite"
version = "0.17.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e27992fd6a8c29ee7eef28fc78349aa244134e10ad447ce3b9f0ac0ed0fa4ce0"
dependencies = [
"base64",
"byteorder",
"bytes 1.3.0",
"http",
"httparse",
"log",
"rand 0.8.5",
"sha-1",
"thiserror",
"url",
"utf-8",
]
[[package]] [[package]]
name = "typed-builder" name = "typed-builder"
version = "0.9.1" version = "0.9.1"
@@ -5798,12 +5751,6 @@ version = "2.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e8db7427f936968176eaa7cdf81b7f98b980b18495ec28f1b5791ac3bfe3eea9" checksum = "e8db7427f936968176eaa7cdf81b7f98b980b18495ec28f1b5791ac3bfe3eea9"
[[package]]
name = "utf-8"
version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9"
[[package]] [[package]]
name = "uuid" name = "uuid"
version = "1.2.2" version = "1.2.2"

View File

@@ -12,19 +12,13 @@ chrono = "0.4.23"
serde = { version = "1.0.150", features = ["derive"] } serde = { version = "1.0.150", features = ["derive"] }
sqlx = { version = "0.6.2", features = ["runtime-tokio-rustls", "postgres"] } sqlx = { version = "0.6.2", features = ["runtime-tokio-rustls", "postgres"] }
thiserror = "1.0.37" thiserror = "1.0.37"
tokio = { version = "1.23.0", features = ["time"] }
tracing = "0.1.37" tracing = "0.1.37"
ulid = "1.0.0" ulid = "1.0.0"
url = "2.3.1" url = "2.3.1"
oauth2-types = { path = "../oauth2-types" } oauth2-types = { path = "../oauth2-types" }
mas-axum-utils = { path = "../axum-utils" }
mas-data-model = { path = "../data-model" } mas-data-model = { path = "../data-model" }
mas-storage = { path = "../storage" } mas-storage = { path = "../storage" }
[features]
native-roots = ["mas-axum-utils/native-roots"]
webpki-roots = ["mas-axum-utils/webpki-roots"]
[[bin]] [[bin]]
name = "schema" name = "schema"

View File

@@ -20,13 +20,17 @@
clippy::future_not_send clippy::future_not_send
)] )]
#![warn(clippy::pedantic)] #![warn(clippy::pedantic)]
#![allow(clippy::module_name_repetitions, clippy::missing_errors_doc)] #![allow(
clippy::module_name_repetitions,
clippy::missing_errors_doc,
clippy::unused_async
)]
use async_graphql::{ use async_graphql::{
connection::{query, Connection, Edge, OpaqueCursor}, connection::{query, Connection, Edge, OpaqueCursor},
Context, Description, EmptyMutation, EmptySubscription, ID, Context, Description, EmptyMutation, EmptySubscription, ID,
}; };
use mas_axum_utils::SessionInfo; use model::CreationEvent;
use sqlx::PgPool; use sqlx::PgPool;
use self::model::{ use self::model::{
@@ -43,8 +47,7 @@ pub type SchemaBuilder = async_graphql::SchemaBuilder<RootQuery, EmptyMutation,
pub fn schema_builder() -> SchemaBuilder { pub fn schema_builder() -> SchemaBuilder {
async_graphql::Schema::build(RootQuery::new(), EmptyMutation, EmptySubscription) async_graphql::Schema::build(RootQuery::new(), EmptyMutation, EmptySubscription)
.register_output_type::<Node>() .register_output_type::<Node>()
// TODO: ordering of interface implementations is not stable .register_output_type::<CreationEvent>()
//.register_output_type::<CreationEvent>()
} }
/// The query root of the GraphQL interface. /// The query root of the GraphQL interface.
@@ -67,21 +70,13 @@ impl RootQuery {
&self, &self,
ctx: &Context<'_>, ctx: &Context<'_>,
) -> Result<Option<BrowserSession>, async_graphql::Error> { ) -> Result<Option<BrowserSession>, async_graphql::Error> {
let database = ctx.data::<PgPool>()?; let session = ctx.data_opt::<mas_data_model::BrowserSession>().cloned();
let session_info = ctx.data::<SessionInfo>()?;
let mut conn = database.acquire().await?;
let session = session_info.load_session(&mut conn).await?;
Ok(session.map(BrowserSession::from)) Ok(session.map(BrowserSession::from))
} }
/// Get the current logged in user /// Get the current logged in user
async fn current_user(&self, ctx: &Context<'_>) -> Result<Option<User>, async_graphql::Error> { async fn current_user(&self, ctx: &Context<'_>) -> Result<Option<User>, async_graphql::Error> {
let database = ctx.data::<PgPool>()?; let session = ctx.data_opt::<mas_data_model::BrowserSession>().cloned();
let session_info = ctx.data::<SessionInfo>()?;
let mut conn = database.acquire().await?;
let session = session_info.load_session(&mut conn).await?;
Ok(session.map(User::from)) Ok(session.map(User::from))
} }
@@ -103,10 +98,7 @@ impl RootQuery {
/// Fetch a user by its ID. /// Fetch a user by its ID.
async fn user(&self, ctx: &Context<'_>, id: ID) -> Result<Option<User>, async_graphql::Error> { async fn user(&self, ctx: &Context<'_>, id: ID) -> Result<Option<User>, async_graphql::Error> {
let id = NodeType::User.extract_ulid(&id)?; let id = NodeType::User.extract_ulid(&id)?;
let database = ctx.data::<PgPool>()?; let session = ctx.data_opt::<mas_data_model::BrowserSession>().cloned();
let session_info = ctx.data::<SessionInfo>()?;
let mut conn = database.acquire().await?;
let session = session_info.load_session(&mut conn).await?;
let Some(session) = session else { return Ok(None) }; let Some(session) = session else { return Ok(None) };
let current_user = session.user; let current_user = session.user;
@@ -125,10 +117,9 @@ impl RootQuery {
id: ID, id: ID,
) -> Result<Option<BrowserSession>, async_graphql::Error> { ) -> Result<Option<BrowserSession>, async_graphql::Error> {
let id = NodeType::BrowserSession.extract_ulid(&id)?; let id = NodeType::BrowserSession.extract_ulid(&id)?;
let session = ctx.data_opt::<mas_data_model::BrowserSession>().cloned();
let database = ctx.data::<PgPool>()?; let database = ctx.data::<PgPool>()?;
let session_info = ctx.data::<SessionInfo>()?;
let mut conn = database.acquire().await?; let mut conn = database.acquire().await?;
let session = session_info.load_session(&mut conn).await?;
let Some(session) = session else { return Ok(None) }; let Some(session) = session else { return Ok(None) };
let current_user = session.user; let current_user = session.user;
@@ -153,10 +144,9 @@ impl RootQuery {
id: ID, id: ID,
) -> Result<Option<UserEmail>, async_graphql::Error> { ) -> Result<Option<UserEmail>, async_graphql::Error> {
let id = NodeType::UserEmail.extract_ulid(&id)?; let id = NodeType::UserEmail.extract_ulid(&id)?;
let session = ctx.data_opt::<mas_data_model::BrowserSession>().cloned();
let database = ctx.data::<PgPool>()?; let database = ctx.data::<PgPool>()?;
let session_info = ctx.data::<SessionInfo>()?;
let mut conn = database.acquire().await?; let mut conn = database.acquire().await?;
let session = session_info.load_session(&mut conn).await?;
let Some(session) = session else { return Ok(None) }; let Some(session) = session else { return Ok(None) };
let current_user = session.user; let current_user = session.user;
@@ -174,10 +164,9 @@ impl RootQuery {
id: ID, id: ID,
) -> Result<Option<UpstreamOAuth2Link>, async_graphql::Error> { ) -> Result<Option<UpstreamOAuth2Link>, async_graphql::Error> {
let id = NodeType::UpstreamOAuth2Link.extract_ulid(&id)?; let id = NodeType::UpstreamOAuth2Link.extract_ulid(&id)?;
let session = ctx.data_opt::<mas_data_model::BrowserSession>().cloned();
let database = ctx.data::<PgPool>()?; let database = ctx.data::<PgPool>()?;
let session_info = ctx.data::<SessionInfo>()?;
let mut conn = database.acquire().await?; let mut conn = database.acquire().await?;
let session = session_info.load_session(&mut conn).await?;
let Some(session) = session else { return Ok(None) }; let Some(session) = session else { return Ok(None) };
let current_user = session.user; let current_user = session.user;

View File

@@ -33,6 +33,7 @@ pub use self::{
users::{User, UserEmail}, users::{User, UserEmail},
}; };
/// An object with a creation date.
#[derive(Interface)] #[derive(Interface)]
#[graphql(field( #[graphql(field(
name = "created_at", name = "created_at",

View File

@@ -21,7 +21,7 @@ anyhow = "1.0.66"
hyper = { version = "0.14.23", features = ["full"] } hyper = { version = "0.14.23", features = ["full"] }
tower = "0.4.13" tower = "0.4.13"
tower-http = { version = "0.3.5", features = ["cors"] } tower-http = { version = "0.3.5", features = ["cors"] }
axum = { version = "0.6.1", features = ["ws"] } axum = "0.6.1"
axum-macros = "0.3.0" axum-macros = "0.3.0"
axum-extra = { version = "0.4.2", features = ["cookie-private"] } axum-extra = { version = "0.4.2", features = ["cookie-private"] }

View File

@@ -12,28 +12,19 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
use std::{borrow::Cow, str::FromStr};
use async_graphql::{ use async_graphql::{
extensions::{ApolloTracing, Tracing}, extensions::{ApolloTracing, Tracing},
http::{ http::{playground_source, GraphQLPlaygroundConfig, MultipartOptions},
playground_source, GraphQLPlaygroundConfig, MultipartOptions, WebSocketProtocols,
WsMessage, ALL_WEBSOCKET_PROTOCOLS,
},
Data,
}; };
use axum::{ use axum::{
extract::{ extract::{BodyStream, RawQuery, State},
ws::{CloseFrame, Message}, response::{Html, IntoResponse},
BodyStream, RawQuery, State, WebSocketUpgrade,
},
response::{Html, IntoResponse, Response},
Json, TypedHeader, Json, TypedHeader,
}; };
use axum_extra::extract::PrivateCookieJar; use axum_extra::extract::PrivateCookieJar;
use futures_util::{SinkExt, StreamExt, TryStreamExt}; use futures_util::{StreamExt, TryStreamExt};
use headers::{ContentType, Header, HeaderValue}; use headers::{ContentType, HeaderValue};
use hyper::header::{CACHE_CONTROL, SEC_WEBSOCKET_PROTOCOL}; use hyper::header::CACHE_CONTROL;
use mas_axum_utils::{FancyError, SessionInfoExt}; use mas_axum_utils::{FancyError, SessionInfoExt};
use mas_graphql::Schema; use mas_graphql::Schema;
use mas_keystore::Encrypter; use mas_keystore::Encrypter;
@@ -67,6 +58,7 @@ fn span_for_graphql_request(request: &async_graphql::Request) -> tracing::Span {
} }
pub async fn post( pub async fn post(
State(pool): State<PgPool>,
State(schema): State<Schema>, State(schema): State<Schema>,
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
content_type: Option<TypedHeader<ContentType>>, content_type: Option<TypedHeader<ContentType>>,
@@ -75,15 +67,19 @@ pub async fn post(
let content_type = content_type.map(|TypedHeader(h)| h.to_string()); let content_type = content_type.map(|TypedHeader(h)| h.to_string());
let (session_info, _cookie_jar) = cookie_jar.session_info(); let (session_info, _cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&pool).await?;
let request = async_graphql::http::receive_batch_body( let mut request = async_graphql::http::receive_batch_body(
content_type, content_type,
body.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e)) body.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
.into_async_read(), .into_async_read(),
MultipartOptions::default(), MultipartOptions::default(),
) )
.await? // XXX: this should probably return another error response? .await?; // XXX: this should probably return another error response?
.data(session_info);
if let Some(session) = maybe_session {
request = request.data(session);
}
let response = match request { let response = match request {
async_graphql::BatchRequest::Single(request) => { async_graphql::BatchRequest::Single(request) => {
@@ -114,13 +110,19 @@ pub async fn post(
} }
pub async fn get( pub async fn get(
State(pool): State<PgPool>,
State(schema): State<Schema>, State(schema): State<Schema>,
cookie_jar: PrivateCookieJar<Encrypter>, cookie_jar: PrivateCookieJar<Encrypter>,
RawQuery(query): RawQuery, RawQuery(query): RawQuery,
) -> Result<impl IntoResponse, FancyError> { ) -> Result<impl IntoResponse, FancyError> {
let (session_info, _cookie_jar) = cookie_jar.session_info(); let (session_info, _cookie_jar) = cookie_jar.session_info();
let request = let maybe_session = session_info.load_session(&pool).await?;
async_graphql::http::parse_query_string(&query.unwrap_or_default())?.data(session_info);
let mut request = async_graphql::http::parse_query_string(&query.unwrap_or_default())?;
if let Some(session) = maybe_session {
request = request.data(session);
}
let span = span_for_graphql_request(&request); let span = span_for_graphql_request(&request);
let response = schema.execute(request).instrument(span).await; let response = schema.execute(request).instrument(span).await;
@@ -136,78 +138,8 @@ pub async fn get(
Ok((headers, cache_control, Json(response))) Ok((headers, cache_control, Json(response)))
} }
pub struct SecWebsocketProtocol(WebSocketProtocols);
impl Header for SecWebsocketProtocol {
fn name() -> &'static headers::HeaderName {
&SEC_WEBSOCKET_PROTOCOL
}
fn decode<'i, I>(values: &mut I) -> Result<Self, headers::Error>
where
Self: Sized,
I: Iterator<Item = &'i HeaderValue>,
{
values
.filter_map(|value| value.to_str().ok())
.flat_map(|value| value.split(','))
.find_map(|p| WebSocketProtocols::from_str(p.trim()).ok())
.map(Self)
.ok_or_else(headers::Error::invalid)
}
fn encode<E: Extend<HeaderValue>>(&self, values: &mut E) {
if let Ok(v) = HeaderValue::from_str(self.0.sec_websocket_protocol()) {
values.extend(std::iter::once(v));
}
}
}
pub async fn ws(
State(schema): State<Schema>,
cookie_jar: PrivateCookieJar<Encrypter>,
TypedHeader(SecWebsocketProtocol(protocol)): TypedHeader<SecWebsocketProtocol>,
websocket: WebSocketUpgrade,
) -> Response {
let (session_info, _cookie_jar) = cookie_jar.session_info();
websocket
.protocols(ALL_WEBSOCKET_PROTOCOLS)
.on_upgrade(move |ws| async move {
let (mut sink, stream) = ws.split();
let stream = stream
.take_while(|res| std::future::ready(res.is_ok()))
.map(Result::unwrap)
.filter_map(|msg| {
if let Message::Text(_) | Message::Binary(_) = msg {
std::future::ready(Some(msg.into_data()))
} else {
std::future::ready(None)
}
});
let mut data = Data::default();
data.insert(session_info);
let mut stream = async_graphql::http::WebSocket::new(schema.clone(), stream, protocol)
.connection_data(data)
.map(|msg| match msg {
WsMessage::Text(text) => Message::Text(text),
WsMessage::Close(code, status) => Message::Close(Some(CloseFrame {
code,
reason: Cow::from(status),
})),
});
while let Some(item) = stream.next().await {
let _res = sink.send(item).await;
}
})
}
pub async fn playground() -> impl IntoResponse { pub async fn playground() -> impl IntoResponse {
Html(playground_source( Html(playground_source(
GraphQLPlaygroundConfig::new("/graphql") GraphQLPlaygroundConfig::new("/graphql").with_setting("request.credentials", "include"),
.subscription_endpoint("/graphql/ws")
.with_setting("request.credentials", "include"),
)) ))
} }

View File

@@ -94,14 +94,13 @@ where
<B as HttpBody>::Error: std::error::Error + Send + Sync, <B as HttpBody>::Error: std::error::Error + Send + Sync,
S: Clone + Send + Sync + 'static, S: Clone + Send + Sync + 'static,
mas_graphql::Schema: FromRef<S>, mas_graphql::Schema: FromRef<S>,
PgPool: FromRef<S>,
Encrypter: FromRef<S>, Encrypter: FromRef<S>,
{ {
let mut router = Router::new() let mut router = Router::new().route(
.route( "/graphql",
"/graphql", get(self::graphql::get).post(self::graphql::post),
get(self::graphql::get).post(self::graphql::post), );
)
.route("/graphql/ws", get(self::graphql::ws));
if playground { if playground {
router = router.route("/graphql/playground", get(self::graphql::playground)); router = router.route("/graphql/playground", get(self::graphql::playground));

View File

@@ -2,7 +2,7 @@
An authentication records when a user enter their credential in a browser An authentication records when a user enter their credential in a browser
session. session.
""" """
type Authentication implements Node { type Authentication implements Node & CreationEvent {
""" """
ID of the object. ID of the object.
""" """
@@ -16,7 +16,7 @@ type Authentication implements Node {
""" """
A browser session represents a logged in user in a browser. A browser session represents a logged in user in a browser.
""" """
type BrowserSession implements Node { type BrowserSession implements Node & CreationEvent {
""" """
ID of the object. ID of the object.
""" """
@@ -68,7 +68,7 @@ type BrowserSessionEdge {
A compat session represents a client session which used the legacy Matrix A compat session represents a client session which used the legacy Matrix
login API. login API.
""" """
type CompatSession implements Node { type CompatSession implements Node & CreationEvent {
""" """
ID of the object. ID of the object.
""" """
@@ -152,6 +152,16 @@ type CompatSsoLoginEdge {
node: CompatSsoLogin! node: CompatSsoLogin!
} }
"""
An object with a creation date.
"""
interface CreationEvent {
"""
When the object was created.
"""
createdAt: DateTime!
}
""" """
Implement the DateTime<Utc> scalar Implement the DateTime<Utc> scalar
@@ -332,7 +342,7 @@ type RootQuery {
node(id: ID!): Node node(id: ID!): Node
} }
type UpstreamOAuth2Link implements Node { type UpstreamOAuth2Link implements Node & CreationEvent {
""" """
ID of the object. ID of the object.
""" """
@@ -384,7 +394,7 @@ type UpstreamOAuth2LinkEdge {
node: UpstreamOAuth2Link! node: UpstreamOAuth2Link!
} }
type UpstreamOAuth2Provider implements Node { type UpstreamOAuth2Provider implements Node & CreationEvent {
""" """
ID of the object. ID of the object.
""" """
@@ -503,7 +513,7 @@ type User implements Node {
""" """
A user email address A user email address
""" """
type UserEmail implements Node { type UserEmail implements Node & CreationEvent {
""" """
ID of the object. ID of the object.
""" """

View File

@@ -8,9 +8,8 @@ CONFIG_SCHEMA="${BASE_DIR}/docs/config.schema.json"
GRAPHQL_SCHEMA="${BASE_DIR}/frontend/schema.graphql" GRAPHQL_SCHEMA="${BASE_DIR}/frontend/schema.graphql"
set -x set -x
# XXX: we shouldn't have to specify this feature
cargo run -p mas-config > "${CONFIG_SCHEMA}" cargo run -p mas-config > "${CONFIG_SCHEMA}"
cargo run -p mas-graphql --features webpki-roots > "${GRAPHQL_SCHEMA}" cargo run -p mas-graphql > "${GRAPHQL_SCHEMA}"
cd "${BASE_DIR}/frontend" cd "${BASE_DIR}/frontend"
npm run generate npm run generate