Remove the dependency mas-axum-utils <- mas-graphql

This is done by loading the browser session earlier
Also removes the GraphQL subscription logic
This commit is contained in:
Quentin Gliech
2022-12-15 16:33:15 +01:00
parent 212269d6c8
commit 26e1c34539
9 changed files with 60 additions and 189 deletions

53
Cargo.lock generated
View File

@@ -612,7 +612,6 @@ checksum = "08b108ad2665fa3f6e6a517c3d80ec3e77d224c47d605167aefaa5d7ef97fa48"
dependencies = [
"async-trait",
"axum-core",
"base64",
"bitflags",
"bytes 1.3.0",
"futures-util",
@@ -631,10 +630,8 @@ dependencies = [
"serde_json",
"serde_path_to_error",
"serde_urlencoded",
"sha-1",
"sync_wrapper",
"tokio",
"tokio-tungstenite",
"tower",
"tower-http",
"tower-layer",
@@ -2811,14 +2808,12 @@ dependencies = [
"anyhow",
"async-graphql",
"chrono",
"mas-axum-utils",
"mas-data-model",
"mas-storage",
"oauth2-types",
"serde",
"sqlx",
"thiserror",
"tokio",
"tracing",
"ulid",
"url",
@@ -4753,17 +4748,6 @@ dependencies = [
"unsafe-libyaml",
]
[[package]]
name = "sha-1"
version = "0.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f5058ada175748e33390e40e872bd0fe59a19f265d0158daa551c5a88a76009c"
dependencies = [
"cfg-if",
"cpufeatures",
"digest",
]
[[package]]
name = "sha1"
version = "0.10.5"
@@ -5327,18 +5311,6 @@ dependencies = [
"tokio-stream",
]
[[package]]
name = "tokio-tungstenite"
version = "0.17.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f714dd15bead90401d77e04243611caec13726c2408afd5b31901dfcdcb3b181"
dependencies = [
"futures-util",
"log",
"tokio",
"tungstenite",
]
[[package]]
name = "tokio-util"
version = "0.6.10"
@@ -5598,25 +5570,6 @@ version = "0.2.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "59547bce71d9c38b83d9c0e92b6066c4253371f15005def0c30d9657f50c7642"
[[package]]
name = "tungstenite"
version = "0.17.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e27992fd6a8c29ee7eef28fc78349aa244134e10ad447ce3b9f0ac0ed0fa4ce0"
dependencies = [
"base64",
"byteorder",
"bytes 1.3.0",
"http",
"httparse",
"log",
"rand 0.8.5",
"sha-1",
"thiserror",
"url",
"utf-8",
]
[[package]]
name = "typed-builder"
version = "0.9.1"
@@ -5798,12 +5751,6 @@ version = "2.1.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e8db7427f936968176eaa7cdf81b7f98b980b18495ec28f1b5791ac3bfe3eea9"
[[package]]
name = "utf-8"
version = "0.7.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9"
[[package]]
name = "uuid"
version = "1.2.2"

View File

@@ -12,19 +12,13 @@ chrono = "0.4.23"
serde = { version = "1.0.150", features = ["derive"] }
sqlx = { version = "0.6.2", features = ["runtime-tokio-rustls", "postgres"] }
thiserror = "1.0.37"
tokio = { version = "1.23.0", features = ["time"] }
tracing = "0.1.37"
ulid = "1.0.0"
url = "2.3.1"
oauth2-types = { path = "../oauth2-types" }
mas-axum-utils = { path = "../axum-utils" }
mas-data-model = { path = "../data-model" }
mas-storage = { path = "../storage" }
[features]
native-roots = ["mas-axum-utils/native-roots"]
webpki-roots = ["mas-axum-utils/webpki-roots"]
[[bin]]
name = "schema"

View File

@@ -20,13 +20,17 @@
clippy::future_not_send
)]
#![warn(clippy::pedantic)]
#![allow(clippy::module_name_repetitions, clippy::missing_errors_doc)]
#![allow(
clippy::module_name_repetitions,
clippy::missing_errors_doc,
clippy::unused_async
)]
use async_graphql::{
connection::{query, Connection, Edge, OpaqueCursor},
Context, Description, EmptyMutation, EmptySubscription, ID,
};
use mas_axum_utils::SessionInfo;
use model::CreationEvent;
use sqlx::PgPool;
use self::model::{
@@ -43,8 +47,7 @@ pub type SchemaBuilder = async_graphql::SchemaBuilder<RootQuery, EmptyMutation,
pub fn schema_builder() -> SchemaBuilder {
async_graphql::Schema::build(RootQuery::new(), EmptyMutation, EmptySubscription)
.register_output_type::<Node>()
// TODO: ordering of interface implementations is not stable
//.register_output_type::<CreationEvent>()
.register_output_type::<CreationEvent>()
}
/// The query root of the GraphQL interface.
@@ -67,21 +70,13 @@ impl RootQuery {
&self,
ctx: &Context<'_>,
) -> Result<Option<BrowserSession>, async_graphql::Error> {
let database = ctx.data::<PgPool>()?;
let session_info = ctx.data::<SessionInfo>()?;
let mut conn = database.acquire().await?;
let session = session_info.load_session(&mut conn).await?;
let session = ctx.data_opt::<mas_data_model::BrowserSession>().cloned();
Ok(session.map(BrowserSession::from))
}
/// Get the current logged in user
async fn current_user(&self, ctx: &Context<'_>) -> Result<Option<User>, async_graphql::Error> {
let database = ctx.data::<PgPool>()?;
let session_info = ctx.data::<SessionInfo>()?;
let mut conn = database.acquire().await?;
let session = session_info.load_session(&mut conn).await?;
let session = ctx.data_opt::<mas_data_model::BrowserSession>().cloned();
Ok(session.map(User::from))
}
@@ -103,10 +98,7 @@ impl RootQuery {
/// Fetch a user by its ID.
async fn user(&self, ctx: &Context<'_>, id: ID) -> Result<Option<User>, async_graphql::Error> {
let id = NodeType::User.extract_ulid(&id)?;
let database = ctx.data::<PgPool>()?;
let session_info = ctx.data::<SessionInfo>()?;
let mut conn = database.acquire().await?;
let session = session_info.load_session(&mut conn).await?;
let session = ctx.data_opt::<mas_data_model::BrowserSession>().cloned();
let Some(session) = session else { return Ok(None) };
let current_user = session.user;
@@ -125,10 +117,9 @@ impl RootQuery {
id: ID,
) -> Result<Option<BrowserSession>, async_graphql::Error> {
let id = NodeType::BrowserSession.extract_ulid(&id)?;
let session = ctx.data_opt::<mas_data_model::BrowserSession>().cloned();
let database = ctx.data::<PgPool>()?;
let session_info = ctx.data::<SessionInfo>()?;
let mut conn = database.acquire().await?;
let session = session_info.load_session(&mut conn).await?;
let Some(session) = session else { return Ok(None) };
let current_user = session.user;
@@ -153,10 +144,9 @@ impl RootQuery {
id: ID,
) -> Result<Option<UserEmail>, async_graphql::Error> {
let id = NodeType::UserEmail.extract_ulid(&id)?;
let session = ctx.data_opt::<mas_data_model::BrowserSession>().cloned();
let database = ctx.data::<PgPool>()?;
let session_info = ctx.data::<SessionInfo>()?;
let mut conn = database.acquire().await?;
let session = session_info.load_session(&mut conn).await?;
let Some(session) = session else { return Ok(None) };
let current_user = session.user;
@@ -174,10 +164,9 @@ impl RootQuery {
id: ID,
) -> Result<Option<UpstreamOAuth2Link>, async_graphql::Error> {
let id = NodeType::UpstreamOAuth2Link.extract_ulid(&id)?;
let session = ctx.data_opt::<mas_data_model::BrowserSession>().cloned();
let database = ctx.data::<PgPool>()?;
let session_info = ctx.data::<SessionInfo>()?;
let mut conn = database.acquire().await?;
let session = session_info.load_session(&mut conn).await?;
let Some(session) = session else { return Ok(None) };
let current_user = session.user;

View File

@@ -33,6 +33,7 @@ pub use self::{
users::{User, UserEmail},
};
/// An object with a creation date.
#[derive(Interface)]
#[graphql(field(
name = "created_at",

View File

@@ -21,7 +21,7 @@ anyhow = "1.0.66"
hyper = { version = "0.14.23", features = ["full"] }
tower = "0.4.13"
tower-http = { version = "0.3.5", features = ["cors"] }
axum = { version = "0.6.1", features = ["ws"] }
axum = "0.6.1"
axum-macros = "0.3.0"
axum-extra = { version = "0.4.2", features = ["cookie-private"] }

View File

@@ -12,28 +12,19 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::{borrow::Cow, str::FromStr};
use async_graphql::{
extensions::{ApolloTracing, Tracing},
http::{
playground_source, GraphQLPlaygroundConfig, MultipartOptions, WebSocketProtocols,
WsMessage, ALL_WEBSOCKET_PROTOCOLS,
},
Data,
http::{playground_source, GraphQLPlaygroundConfig, MultipartOptions},
};
use axum::{
extract::{
ws::{CloseFrame, Message},
BodyStream, RawQuery, State, WebSocketUpgrade,
},
response::{Html, IntoResponse, Response},
extract::{BodyStream, RawQuery, State},
response::{Html, IntoResponse},
Json, TypedHeader,
};
use axum_extra::extract::PrivateCookieJar;
use futures_util::{SinkExt, StreamExt, TryStreamExt};
use headers::{ContentType, Header, HeaderValue};
use hyper::header::{CACHE_CONTROL, SEC_WEBSOCKET_PROTOCOL};
use futures_util::{StreamExt, TryStreamExt};
use headers::{ContentType, HeaderValue};
use hyper::header::CACHE_CONTROL;
use mas_axum_utils::{FancyError, SessionInfoExt};
use mas_graphql::Schema;
use mas_keystore::Encrypter;
@@ -67,6 +58,7 @@ fn span_for_graphql_request(request: &async_graphql::Request) -> tracing::Span {
}
pub async fn post(
State(pool): State<PgPool>,
State(schema): State<Schema>,
cookie_jar: PrivateCookieJar<Encrypter>,
content_type: Option<TypedHeader<ContentType>>,
@@ -75,15 +67,19 @@ pub async fn post(
let content_type = content_type.map(|TypedHeader(h)| h.to_string());
let (session_info, _cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&pool).await?;
let request = async_graphql::http::receive_batch_body(
let mut request = async_graphql::http::receive_batch_body(
content_type,
body.map_err(|e| std::io::Error::new(std::io::ErrorKind::Other, e))
.into_async_read(),
MultipartOptions::default(),
)
.await? // XXX: this should probably return another error response?
.data(session_info);
.await?; // XXX: this should probably return another error response?
if let Some(session) = maybe_session {
request = request.data(session);
}
let response = match request {
async_graphql::BatchRequest::Single(request) => {
@@ -114,13 +110,19 @@ pub async fn post(
}
pub async fn get(
State(pool): State<PgPool>,
State(schema): State<Schema>,
cookie_jar: PrivateCookieJar<Encrypter>,
RawQuery(query): RawQuery,
) -> Result<impl IntoResponse, FancyError> {
let (session_info, _cookie_jar) = cookie_jar.session_info();
let request =
async_graphql::http::parse_query_string(&query.unwrap_or_default())?.data(session_info);
let maybe_session = session_info.load_session(&pool).await?;
let mut request = async_graphql::http::parse_query_string(&query.unwrap_or_default())?;
if let Some(session) = maybe_session {
request = request.data(session);
}
let span = span_for_graphql_request(&request);
let response = schema.execute(request).instrument(span).await;
@@ -136,78 +138,8 @@ pub async fn get(
Ok((headers, cache_control, Json(response)))
}
pub struct SecWebsocketProtocol(WebSocketProtocols);
impl Header for SecWebsocketProtocol {
fn name() -> &'static headers::HeaderName {
&SEC_WEBSOCKET_PROTOCOL
}
fn decode<'i, I>(values: &mut I) -> Result<Self, headers::Error>
where
Self: Sized,
I: Iterator<Item = &'i HeaderValue>,
{
values
.filter_map(|value| value.to_str().ok())
.flat_map(|value| value.split(','))
.find_map(|p| WebSocketProtocols::from_str(p.trim()).ok())
.map(Self)
.ok_or_else(headers::Error::invalid)
}
fn encode<E: Extend<HeaderValue>>(&self, values: &mut E) {
if let Ok(v) = HeaderValue::from_str(self.0.sec_websocket_protocol()) {
values.extend(std::iter::once(v));
}
}
}
pub async fn ws(
State(schema): State<Schema>,
cookie_jar: PrivateCookieJar<Encrypter>,
TypedHeader(SecWebsocketProtocol(protocol)): TypedHeader<SecWebsocketProtocol>,
websocket: WebSocketUpgrade,
) -> Response {
let (session_info, _cookie_jar) = cookie_jar.session_info();
websocket
.protocols(ALL_WEBSOCKET_PROTOCOLS)
.on_upgrade(move |ws| async move {
let (mut sink, stream) = ws.split();
let stream = stream
.take_while(|res| std::future::ready(res.is_ok()))
.map(Result::unwrap)
.filter_map(|msg| {
if let Message::Text(_) | Message::Binary(_) = msg {
std::future::ready(Some(msg.into_data()))
} else {
std::future::ready(None)
}
});
let mut data = Data::default();
data.insert(session_info);
let mut stream = async_graphql::http::WebSocket::new(schema.clone(), stream, protocol)
.connection_data(data)
.map(|msg| match msg {
WsMessage::Text(text) => Message::Text(text),
WsMessage::Close(code, status) => Message::Close(Some(CloseFrame {
code,
reason: Cow::from(status),
})),
});
while let Some(item) = stream.next().await {
let _res = sink.send(item).await;
}
})
}
pub async fn playground() -> impl IntoResponse {
Html(playground_source(
GraphQLPlaygroundConfig::new("/graphql")
.subscription_endpoint("/graphql/ws")
.with_setting("request.credentials", "include"),
GraphQLPlaygroundConfig::new("/graphql").with_setting("request.credentials", "include"),
))
}

View File

@@ -94,14 +94,13 @@ where
<B as HttpBody>::Error: std::error::Error + Send + Sync,
S: Clone + Send + Sync + 'static,
mas_graphql::Schema: FromRef<S>,
PgPool: FromRef<S>,
Encrypter: FromRef<S>,
{
let mut router = Router::new()
.route(
let mut router = Router::new().route(
"/graphql",
get(self::graphql::get).post(self::graphql::post),
)
.route("/graphql/ws", get(self::graphql::ws));
);
if playground {
router = router.route("/graphql/playground", get(self::graphql::playground));

View File

@@ -2,7 +2,7 @@
An authentication records when a user enter their credential in a browser
session.
"""
type Authentication implements Node {
type Authentication implements Node & CreationEvent {
"""
ID of the object.
"""
@@ -16,7 +16,7 @@ type Authentication implements Node {
"""
A browser session represents a logged in user in a browser.
"""
type BrowserSession implements Node {
type BrowserSession implements Node & CreationEvent {
"""
ID of the object.
"""
@@ -68,7 +68,7 @@ type BrowserSessionEdge {
A compat session represents a client session which used the legacy Matrix
login API.
"""
type CompatSession implements Node {
type CompatSession implements Node & CreationEvent {
"""
ID of the object.
"""
@@ -152,6 +152,16 @@ type CompatSsoLoginEdge {
node: CompatSsoLogin!
}
"""
An object with a creation date.
"""
interface CreationEvent {
"""
When the object was created.
"""
createdAt: DateTime!
}
"""
Implement the DateTime<Utc> scalar
@@ -332,7 +342,7 @@ type RootQuery {
node(id: ID!): Node
}
type UpstreamOAuth2Link implements Node {
type UpstreamOAuth2Link implements Node & CreationEvent {
"""
ID of the object.
"""
@@ -384,7 +394,7 @@ type UpstreamOAuth2LinkEdge {
node: UpstreamOAuth2Link!
}
type UpstreamOAuth2Provider implements Node {
type UpstreamOAuth2Provider implements Node & CreationEvent {
"""
ID of the object.
"""
@@ -503,7 +513,7 @@ type User implements Node {
"""
A user email address
"""
type UserEmail implements Node {
type UserEmail implements Node & CreationEvent {
"""
ID of the object.
"""

View File

@@ -8,9 +8,8 @@ CONFIG_SCHEMA="${BASE_DIR}/docs/config.schema.json"
GRAPHQL_SCHEMA="${BASE_DIR}/frontend/schema.graphql"
set -x
# XXX: we shouldn't have to specify this feature
cargo run -p mas-config > "${CONFIG_SCHEMA}"
cargo run -p mas-graphql --features webpki-roots > "${GRAPHQL_SCHEMA}"
cargo run -p mas-graphql > "${GRAPHQL_SCHEMA}"
cd "${BASE_DIR}/frontend"
npm run generate