diff --git a/Cargo.lock b/Cargo.lock index e29c28f3d..154c833db 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3311,6 +3311,7 @@ dependencies = [ "regex", "ruma-common", "serde", + "serde_json", "thiserror", "ulid", "url", diff --git a/crates/data-model/Cargo.toml b/crates/data-model/Cargo.toml index 16cab034f..2c648b1ff 100644 --- a/crates/data-model/Cargo.toml +++ b/crates/data-model/Cargo.toml @@ -15,6 +15,7 @@ workspace = true chrono.workspace = true thiserror.workspace = true serde.workspace = true +serde_json.workspace = true url.workspace = true crc = "3.2.1" ulid.workspace = true diff --git a/crates/data-model/src/upstream_oauth2/session.rs b/crates/data-model/src/upstream_oauth2/session.rs index eb59000a0..38b622987 100644 --- a/crates/data-model/src/upstream_oauth2/session.rs +++ b/crates/data-model/src/upstream_oauth2/session.rs @@ -19,12 +19,14 @@ pub enum UpstreamOAuthAuthorizationSessionState { completed_at: DateTime, link_id: Ulid, id_token: Option, + extra_callback_parameters: Option, }, Consumed { completed_at: DateTime, consumed_at: DateTime, link_id: Ulid, id_token: Option, + extra_callback_parameters: Option, }, } @@ -42,12 +44,14 @@ impl UpstreamOAuthAuthorizationSessionState { completed_at: DateTime, link: &UpstreamOAuthLink, id_token: Option, + extra_callback_parameters: Option, ) -> Result { match self { Self::Pending => Ok(Self::Completed { completed_at, link_id: link.id, id_token, + extra_callback_parameters, }), Self::Completed { .. } | Self::Consumed { .. } => Err(InvalidTransitionError), } @@ -67,11 +71,13 @@ impl UpstreamOAuthAuthorizationSessionState { completed_at, link_id, id_token, + extra_callback_parameters, } => Ok(Self::Consumed { completed_at, link_id, consumed_at, id_token, + extra_callback_parameters, }), Self::Pending | Self::Consumed { .. } => Err(InvalidTransitionError), } @@ -124,6 +130,27 @@ impl UpstreamOAuthAuthorizationSessionState { } } + /// Get the extra query parameters that were sent to the upstream provider. + /// + /// Returns `None` if the upstream OAuth 2.0 authorization session state is + /// not [`Pending`]. + /// + /// [`Pending`]: UpstreamOAuthAuthorizationSessionState::Pending + #[must_use] + pub fn extra_callback_parameters(&self) -> Option<&serde_json::Value> { + match self { + Self::Pending => None, + Self::Completed { + extra_callback_parameters, + .. + } + | Self::Consumed { + extra_callback_parameters, + .. + } => extra_callback_parameters.as_ref(), + } + } + /// Get the time at which the upstream OAuth 2.0 authorization session was /// consumed. /// @@ -201,8 +228,11 @@ impl UpstreamOAuthAuthorizationSession { completed_at: DateTime, link: &UpstreamOAuthLink, id_token: Option, + extra_callback_parameters: Option, ) -> Result { - self.state = self.state.complete(completed_at, link, id_token)?; + self.state = + self.state + .complete(completed_at, link, id_token, extra_callback_parameters)?; Ok(self) } diff --git a/crates/handlers/src/upstream_oauth2/callback.rs b/crates/handlers/src/upstream_oauth2/callback.rs index d0cc7b848..734194142 100644 --- a/crates/handlers/src/upstream_oauth2/callback.rs +++ b/crates/handlers/src/upstream_oauth2/callback.rs @@ -48,6 +48,9 @@ pub struct Params { enum CodeOrError { Code { code: String, + + #[serde(flatten)] + extra_callback_parameters: Option, }, Error { error: ClientErrorCode, @@ -201,7 +204,7 @@ pub(crate) async fn handler( } // Let's extract the code from the params, and return if there was an error - let code = match params.code_or_error { + let (code, extra_callback_parameters) = match params.code_or_error { CodeOrError::Error { error, error_description, @@ -212,7 +215,10 @@ pub(crate) async fn handler( error_description, }) } - CodeOrError::Code { code } => code, + CodeOrError::Code { + code, + extra_callback_parameters, + } => (code, extra_callback_parameters), }; let mut lazy_metadata = LazyProviderInfos::new(&metadata_cache, &provider, &client); @@ -266,6 +272,10 @@ pub(crate) async fn handler( let env = { let mut env = environment(); env.add_global("user", minijinja::Value::from_serialize(&id_token)); + env.add_global( + "extra_callback_parameters", + minijinja::Value::from_serialize(&extra_callback_parameters), + ); env }; @@ -299,7 +309,13 @@ pub(crate) async fn handler( let session = repo .upstream_oauth_session() - .complete_with_link(&clock, session, &link, response.id_token) + .complete_with_link( + &clock, + session, + &link, + response.id_token, + extra_callback_parameters, + ) .await?; let cookie_jar = sessions_cookie diff --git a/crates/handlers/src/upstream_oauth2/link.rs b/crates/handlers/src/upstream_oauth2/link.rs index 77dc43af1..854d65444 100644 --- a/crates/handlers/src/upstream_oauth2/link.rs +++ b/crates/handlers/src/upstream_oauth2/link.rs @@ -340,6 +340,10 @@ pub(crate) async fn get( let env = { let mut e = environment(); e.add_global("user", payload); + e.add_global( + "extra_callback_parameters", + minijinja::Value::from_serialize(upstream_session.extra_callback_parameters()), + ); e }; @@ -582,6 +586,10 @@ pub(crate) async fn post( let env = { let mut e = environment(); e.add_global("user", payload); + e.add_global( + "extra_callback_parameters", + minijinja::Value::from_serialize(upstream_session.extra_callback_parameters()), + ); e }; @@ -945,7 +953,13 @@ mod tests { let session = repo .upstream_oauth_session() - .complete_with_link(&state.clock, session, &link, Some(id_token.into_string())) + .complete_with_link( + &state.clock, + session, + &link, + Some(id_token.into_string()), + None, + ) .await .unwrap(); diff --git a/crates/handlers/src/upstream_oauth2/template.rs b/crates/handlers/src/upstream_oauth2/template.rs index 740953f89..c23e175c9 100644 --- a/crates/handlers/src/upstream_oauth2/template.rs +++ b/crates/handlers/src/upstream_oauth2/template.rs @@ -68,6 +68,18 @@ fn string(value: &Value) -> String { value.to_string() } +fn from_json(value: &str) -> Result { + let value: serde_json::Value = serde_json::from_str(value).map_err(|e| { + minijinja::Error::new( + minijinja::ErrorKind::InvalidOperation, + "Failed to decode JSON", + ) + .with_source(e) + })?; + + Ok(Value::from_serialize(value)) +} + pub fn environment() -> Environment<'static> { let mut env = Environment::new(); @@ -77,6 +89,7 @@ pub fn environment() -> Environment<'static> { env.add_filter("b64encode", b64encode); env.add_filter("tlvdecode", tlvdecode); env.add_filter("string", string); + env.add_filter("from_json", from_json); env.set_unknown_method_callback(minijinja_contrib::pycompat::unknown_method_callback); diff --git a/crates/storage-pg/.sqlx/query-b9875a270f7e753e48075ccae233df6e24a91775ceb877735508c1d5b2300d64.json b/crates/storage-pg/.sqlx/query-5516235e0983fb64d18e82dbe3e34f966ed71a0ed59be0d48ec66fedf64e707d.json similarity index 62% rename from crates/storage-pg/.sqlx/query-b9875a270f7e753e48075ccae233df6e24a91775ceb877735508c1d5b2300d64.json rename to crates/storage-pg/.sqlx/query-5516235e0983fb64d18e82dbe3e34f966ed71a0ed59be0d48ec66fedf64e707d.json index 3a4483604..96ad3513a 100644 --- a/crates/storage-pg/.sqlx/query-b9875a270f7e753e48075ccae233df6e24a91775ceb877735508c1d5b2300d64.json +++ b/crates/storage-pg/.sqlx/query-5516235e0983fb64d18e82dbe3e34f966ed71a0ed59be0d48ec66fedf64e707d.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "\n UPDATE upstream_oauth_authorization_sessions\n SET upstream_oauth_link_id = $1,\n completed_at = $2,\n id_token = $3\n WHERE upstream_oauth_authorization_session_id = $4\n ", + "query": "\n UPDATE upstream_oauth_authorization_sessions\n SET upstream_oauth_link_id = $1,\n completed_at = $2,\n id_token = $3,\n extra_callback_parameters = $4\n WHERE upstream_oauth_authorization_session_id = $5\n ", "describe": { "columns": [], "parameters": { @@ -8,10 +8,11 @@ "Uuid", "Timestamptz", "Text", + "Jsonb", "Uuid" ] }, "nullable": [] }, - "hash": "b9875a270f7e753e48075ccae233df6e24a91775ceb877735508c1d5b2300d64" + "hash": "5516235e0983fb64d18e82dbe3e34f966ed71a0ed59be0d48ec66fedf64e707d" } diff --git a/crates/storage-pg/.sqlx/query-67ab838035946ddc15b43dd2f79d10b233d07e863b3a5c776c5db97cff263c8c.json b/crates/storage-pg/.sqlx/query-7d329e0c57f36b9ffe2aa7ddf4a21e293522c00009cca0222524b0c73f6eee30.json similarity index 76% rename from crates/storage-pg/.sqlx/query-67ab838035946ddc15b43dd2f79d10b233d07e863b3a5c776c5db97cff263c8c.json rename to crates/storage-pg/.sqlx/query-7d329e0c57f36b9ffe2aa7ddf4a21e293522c00009cca0222524b0c73f6eee30.json index 0b378d382..57fc34e9b 100644 --- a/crates/storage-pg/.sqlx/query-67ab838035946ddc15b43dd2f79d10b233d07e863b3a5c776c5db97cff263c8c.json +++ b/crates/storage-pg/.sqlx/query-7d329e0c57f36b9ffe2aa7ddf4a21e293522c00009cca0222524b0c73f6eee30.json @@ -1,6 +1,6 @@ { "db_name": "PostgreSQL", - "query": "\n SELECT\n upstream_oauth_authorization_session_id,\n upstream_oauth_provider_id,\n upstream_oauth_link_id,\n state,\n code_challenge_verifier,\n nonce,\n id_token,\n created_at,\n completed_at,\n consumed_at\n FROM upstream_oauth_authorization_sessions\n WHERE upstream_oauth_authorization_session_id = $1\n ", + "query": "\n SELECT\n upstream_oauth_authorization_session_id,\n upstream_oauth_provider_id,\n upstream_oauth_link_id,\n state,\n code_challenge_verifier,\n nonce,\n id_token,\n extra_callback_parameters,\n created_at,\n completed_at,\n consumed_at\n FROM upstream_oauth_authorization_sessions\n WHERE upstream_oauth_authorization_session_id = $1\n ", "describe": { "columns": [ { @@ -40,16 +40,21 @@ }, { "ordinal": 7, + "name": "extra_callback_parameters", + "type_info": "Jsonb" + }, + { + "ordinal": 8, "name": "created_at", "type_info": "Timestamptz" }, { - "ordinal": 8, + "ordinal": 9, "name": "completed_at", "type_info": "Timestamptz" }, { - "ordinal": 9, + "ordinal": 10, "name": "consumed_at", "type_info": "Timestamptz" } @@ -67,10 +72,11 @@ true, false, true, + true, false, true, true ] }, - "hash": "67ab838035946ddc15b43dd2f79d10b233d07e863b3a5c776c5db97cff263c8c" + "hash": "7d329e0c57f36b9ffe2aa7ddf4a21e293522c00009cca0222524b0c73f6eee30" } diff --git a/crates/storage-pg/migrations/20241118115314_upstream_oauth2_extra_query_params.sql b/crates/storage-pg/migrations/20241118115314_upstream_oauth2_extra_query_params.sql new file mode 100644 index 000000000..0e900e0af --- /dev/null +++ b/crates/storage-pg/migrations/20241118115314_upstream_oauth2_extra_query_params.sql @@ -0,0 +1,9 @@ +-- Copyright 2024 New Vector Ltd. +-- +-- SPDX-License-Identifier: AGPL-3.0-only +-- Please see LICENSE in the repository root for full details. + +-- Add a column to the upstream_oauth_authorization_sessions table to store +-- extra query parameters +ALTER TABLE "upstream_oauth_authorization_sessions" + ADD COLUMN "extra_callback_parameters" JSONB; diff --git a/crates/storage-pg/src/upstream_oauth2/mod.rs b/crates/storage-pg/src/upstream_oauth2/mod.rs index 8c929636a..a544c9cd3 100644 --- a/crates/storage-pg/src/upstream_oauth2/mod.rs +++ b/crates/storage-pg/src/upstream_oauth2/mod.rs @@ -145,7 +145,7 @@ mod tests { let session = repo .upstream_oauth_session() - .complete_with_link(&clock, session, &link, None) + .complete_with_link(&clock, session, &link, None, None) .await .unwrap(); // Reload the session diff --git a/crates/storage-pg/src/upstream_oauth2/session.rs b/crates/storage-pg/src/upstream_oauth2/session.rs index fb27da5f8..e3f28b5f4 100644 --- a/crates/storage-pg/src/upstream_oauth2/session.rs +++ b/crates/storage-pg/src/upstream_oauth2/session.rs @@ -43,6 +43,7 @@ struct SessionLookup { created_at: DateTime, completed_at: Option>, consumed_at: Option>, + extra_callback_parameters: Option, } impl TryFrom for UpstreamOAuthAuthorizationSession { @@ -53,25 +54,32 @@ impl TryFrom for UpstreamOAuthAuthorizationSession { let state = match ( value.upstream_oauth_link_id, value.id_token, + value.extra_callback_parameters, value.completed_at, value.consumed_at, ) { - (None, None, None, None) => UpstreamOAuthAuthorizationSessionState::Pending, - (Some(link_id), id_token, Some(completed_at), None) => { + (None, None, None, None, None) => UpstreamOAuthAuthorizationSessionState::Pending, + (Some(link_id), id_token, extra_callback_parameters, Some(completed_at), None) => { UpstreamOAuthAuthorizationSessionState::Completed { completed_at, link_id: link_id.into(), id_token, + extra_callback_parameters, } } - (Some(link_id), id_token, Some(completed_at), Some(consumed_at)) => { - UpstreamOAuthAuthorizationSessionState::Consumed { - completed_at, - link_id: link_id.into(), - id_token, - consumed_at, - } - } + ( + Some(link_id), + id_token, + extra_callback_parameters, + Some(completed_at), + Some(consumed_at), + ) => UpstreamOAuthAuthorizationSessionState::Consumed { + completed_at, + link_id: link_id.into(), + id_token, + extra_callback_parameters, + consumed_at, + }, _ => { return Err( DatabaseInconsistencyError::on("upstream_oauth_authorization_sessions").row(id), @@ -119,6 +127,7 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c> code_challenge_verifier, nonce, id_token, + extra_callback_parameters, created_at, completed_at, consumed_at @@ -216,6 +225,7 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c> upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession, upstream_oauth_link: &UpstreamOAuthLink, id_token: Option, + extra_callback_parameters: Option, ) -> Result { let completed_at = clock.now(); @@ -224,12 +234,14 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c> UPDATE upstream_oauth_authorization_sessions SET upstream_oauth_link_id = $1, completed_at = $2, - id_token = $3 - WHERE upstream_oauth_authorization_session_id = $4 + id_token = $3, + extra_callback_parameters = $4 + WHERE upstream_oauth_authorization_session_id = $5 "#, Uuid::from(upstream_oauth_link.id), completed_at, id_token, + extra_callback_parameters, Uuid::from(upstream_oauth_authorization_session.id), ) .traced() @@ -237,7 +249,12 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c> .await?; let upstream_oauth_authorization_session = upstream_oauth_authorization_session - .complete(completed_at, upstream_oauth_link, id_token) + .complete( + completed_at, + upstream_oauth_link, + id_token, + extra_callback_parameters, + ) .map_err(DatabaseError::to_invalid_operation)?; Ok(upstream_oauth_authorization_session) diff --git a/crates/storage/src/upstream_oauth2/session.rs b/crates/storage/src/upstream_oauth2/session.rs index 80da6135a..a9a438a3a 100644 --- a/crates/storage/src/upstream_oauth2/session.rs +++ b/crates/storage/src/upstream_oauth2/session.rs @@ -74,6 +74,8 @@ pub trait UpstreamOAuthSessionRepository: Send + Sync { /// * `upstream_oauth_link`: the link to associate with the session /// * `id_token`: the ID token returned by the upstream OAuth provider, if /// present + /// * `extra_callback_parameters`: the extra query parameters returned in + /// the callback, if any /// /// # Errors /// @@ -84,6 +86,7 @@ pub trait UpstreamOAuthSessionRepository: Send + Sync { upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession, upstream_oauth_link: &UpstreamOAuthLink, id_token: Option, + extra_callback_parameters: Option, ) -> Result; /// Mark a session as consumed @@ -127,6 +130,7 @@ repository_impl!(UpstreamOAuthSessionRepository: upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession, upstream_oauth_link: &UpstreamOAuthLink, id_token: Option, + extra_callback_parameters: Option, ) -> Result; async fn consume(