Record extra query parameters during upstream callback

And make them available in the templates.
This is useful to get the user display name for Sign-in with Apple
This commit is contained in:
Quentin Gliech
2024-11-18 13:28:12 +01:00
parent c6d8ab7b7d
commit 05e2572258
12 changed files with 137 additions and 25 deletions

1
Cargo.lock generated
View File

@@ -3311,6 +3311,7 @@ dependencies = [
"regex",
"ruma-common",
"serde",
"serde_json",
"thiserror",
"ulid",
"url",

View File

@@ -15,6 +15,7 @@ workspace = true
chrono.workspace = true
thiserror.workspace = true
serde.workspace = true
serde_json.workspace = true
url.workspace = true
crc = "3.2.1"
ulid.workspace = true

View File

@@ -19,12 +19,14 @@ pub enum UpstreamOAuthAuthorizationSessionState {
completed_at: DateTime<Utc>,
link_id: Ulid,
id_token: Option<String>,
extra_callback_parameters: Option<serde_json::Value>,
},
Consumed {
completed_at: DateTime<Utc>,
consumed_at: DateTime<Utc>,
link_id: Ulid,
id_token: Option<String>,
extra_callback_parameters: Option<serde_json::Value>,
},
}
@@ -42,12 +44,14 @@ impl UpstreamOAuthAuthorizationSessionState {
completed_at: DateTime<Utc>,
link: &UpstreamOAuthLink,
id_token: Option<String>,
extra_callback_parameters: Option<serde_json::Value>,
) -> Result<Self, InvalidTransitionError> {
match self {
Self::Pending => Ok(Self::Completed {
completed_at,
link_id: link.id,
id_token,
extra_callback_parameters,
}),
Self::Completed { .. } | Self::Consumed { .. } => Err(InvalidTransitionError),
}
@@ -67,11 +71,13 @@ impl UpstreamOAuthAuthorizationSessionState {
completed_at,
link_id,
id_token,
extra_callback_parameters,
} => Ok(Self::Consumed {
completed_at,
link_id,
consumed_at,
id_token,
extra_callback_parameters,
}),
Self::Pending | Self::Consumed { .. } => Err(InvalidTransitionError),
}
@@ -124,6 +130,27 @@ impl UpstreamOAuthAuthorizationSessionState {
}
}
/// Get the extra query parameters that were sent to the upstream provider.
///
/// Returns `None` if the upstream OAuth 2.0 authorization session state is
/// not [`Pending`].
///
/// [`Pending`]: UpstreamOAuthAuthorizationSessionState::Pending
#[must_use]
pub fn extra_callback_parameters(&self) -> Option<&serde_json::Value> {
match self {
Self::Pending => None,
Self::Completed {
extra_callback_parameters,
..
}
| Self::Consumed {
extra_callback_parameters,
..
} => extra_callback_parameters.as_ref(),
}
}
/// Get the time at which the upstream OAuth 2.0 authorization session was
/// consumed.
///
@@ -201,8 +228,11 @@ impl UpstreamOAuthAuthorizationSession {
completed_at: DateTime<Utc>,
link: &UpstreamOAuthLink,
id_token: Option<String>,
extra_callback_parameters: Option<serde_json::Value>,
) -> Result<Self, InvalidTransitionError> {
self.state = self.state.complete(completed_at, link, id_token)?;
self.state =
self.state
.complete(completed_at, link, id_token, extra_callback_parameters)?;
Ok(self)
}

View File

@@ -48,6 +48,9 @@ pub struct Params {
enum CodeOrError {
Code {
code: String,
#[serde(flatten)]
extra_callback_parameters: Option<serde_json::Value>,
},
Error {
error: ClientErrorCode,
@@ -201,7 +204,7 @@ pub(crate) async fn handler(
}
// Let's extract the code from the params, and return if there was an error
let code = match params.code_or_error {
let (code, extra_callback_parameters) = match params.code_or_error {
CodeOrError::Error {
error,
error_description,
@@ -212,7 +215,10 @@ pub(crate) async fn handler(
error_description,
})
}
CodeOrError::Code { code } => code,
CodeOrError::Code {
code,
extra_callback_parameters,
} => (code, extra_callback_parameters),
};
let mut lazy_metadata = LazyProviderInfos::new(&metadata_cache, &provider, &client);
@@ -266,6 +272,10 @@ pub(crate) async fn handler(
let env = {
let mut env = environment();
env.add_global("user", minijinja::Value::from_serialize(&id_token));
env.add_global(
"extra_callback_parameters",
minijinja::Value::from_serialize(&extra_callback_parameters),
);
env
};
@@ -299,7 +309,13 @@ pub(crate) async fn handler(
let session = repo
.upstream_oauth_session()
.complete_with_link(&clock, session, &link, response.id_token)
.complete_with_link(
&clock,
session,
&link,
response.id_token,
extra_callback_parameters,
)
.await?;
let cookie_jar = sessions_cookie

View File

@@ -340,6 +340,10 @@ pub(crate) async fn get(
let env = {
let mut e = environment();
e.add_global("user", payload);
e.add_global(
"extra_callback_parameters",
minijinja::Value::from_serialize(upstream_session.extra_callback_parameters()),
);
e
};
@@ -582,6 +586,10 @@ pub(crate) async fn post(
let env = {
let mut e = environment();
e.add_global("user", payload);
e.add_global(
"extra_callback_parameters",
minijinja::Value::from_serialize(upstream_session.extra_callback_parameters()),
);
e
};
@@ -945,7 +953,13 @@ mod tests {
let session = repo
.upstream_oauth_session()
.complete_with_link(&state.clock, session, &link, Some(id_token.into_string()))
.complete_with_link(
&state.clock,
session,
&link,
Some(id_token.into_string()),
None,
)
.await
.unwrap();

View File

@@ -68,6 +68,18 @@ fn string(value: &Value) -> String {
value.to_string()
}
fn from_json(value: &str) -> Result<Value, minijinja::Error> {
let value: serde_json::Value = serde_json::from_str(value).map_err(|e| {
minijinja::Error::new(
minijinja::ErrorKind::InvalidOperation,
"Failed to decode JSON",
)
.with_source(e)
})?;
Ok(Value::from_serialize(value))
}
pub fn environment() -> Environment<'static> {
let mut env = Environment::new();
@@ -77,6 +89,7 @@ pub fn environment() -> Environment<'static> {
env.add_filter("b64encode", b64encode);
env.add_filter("tlvdecode", tlvdecode);
env.add_filter("string", string);
env.add_filter("from_json", from_json);
env.set_unknown_method_callback(minijinja_contrib::pycompat::unknown_method_callback);

View File

@@ -1,6 +1,6 @@
{
"db_name": "PostgreSQL",
"query": "\n UPDATE upstream_oauth_authorization_sessions\n SET upstream_oauth_link_id = $1,\n completed_at = $2,\n id_token = $3\n WHERE upstream_oauth_authorization_session_id = $4\n ",
"query": "\n UPDATE upstream_oauth_authorization_sessions\n SET upstream_oauth_link_id = $1,\n completed_at = $2,\n id_token = $3,\n extra_callback_parameters = $4\n WHERE upstream_oauth_authorization_session_id = $5\n ",
"describe": {
"columns": [],
"parameters": {
@@ -8,10 +8,11 @@
"Uuid",
"Timestamptz",
"Text",
"Jsonb",
"Uuid"
]
},
"nullable": []
},
"hash": "b9875a270f7e753e48075ccae233df6e24a91775ceb877735508c1d5b2300d64"
"hash": "5516235e0983fb64d18e82dbe3e34f966ed71a0ed59be0d48ec66fedf64e707d"
}

View File

@@ -1,6 +1,6 @@
{
"db_name": "PostgreSQL",
"query": "\n SELECT\n upstream_oauth_authorization_session_id,\n upstream_oauth_provider_id,\n upstream_oauth_link_id,\n state,\n code_challenge_verifier,\n nonce,\n id_token,\n created_at,\n completed_at,\n consumed_at\n FROM upstream_oauth_authorization_sessions\n WHERE upstream_oauth_authorization_session_id = $1\n ",
"query": "\n SELECT\n upstream_oauth_authorization_session_id,\n upstream_oauth_provider_id,\n upstream_oauth_link_id,\n state,\n code_challenge_verifier,\n nonce,\n id_token,\n extra_callback_parameters,\n created_at,\n completed_at,\n consumed_at\n FROM upstream_oauth_authorization_sessions\n WHERE upstream_oauth_authorization_session_id = $1\n ",
"describe": {
"columns": [
{
@@ -40,16 +40,21 @@
},
{
"ordinal": 7,
"name": "extra_callback_parameters",
"type_info": "Jsonb"
},
{
"ordinal": 8,
"name": "created_at",
"type_info": "Timestamptz"
},
{
"ordinal": 8,
"ordinal": 9,
"name": "completed_at",
"type_info": "Timestamptz"
},
{
"ordinal": 9,
"ordinal": 10,
"name": "consumed_at",
"type_info": "Timestamptz"
}
@@ -67,10 +72,11 @@
true,
false,
true,
true,
false,
true,
true
]
},
"hash": "67ab838035946ddc15b43dd2f79d10b233d07e863b3a5c776c5db97cff263c8c"
"hash": "7d329e0c57f36b9ffe2aa7ddf4a21e293522c00009cca0222524b0c73f6eee30"
}

View File

@@ -0,0 +1,9 @@
-- Copyright 2024 New Vector Ltd.
--
-- SPDX-License-Identifier: AGPL-3.0-only
-- Please see LICENSE in the repository root for full details.
-- Add a column to the upstream_oauth_authorization_sessions table to store
-- extra query parameters
ALTER TABLE "upstream_oauth_authorization_sessions"
ADD COLUMN "extra_callback_parameters" JSONB;

View File

@@ -145,7 +145,7 @@ mod tests {
let session = repo
.upstream_oauth_session()
.complete_with_link(&clock, session, &link, None)
.complete_with_link(&clock, session, &link, None, None)
.await
.unwrap();
// Reload the session

View File

@@ -43,6 +43,7 @@ struct SessionLookup {
created_at: DateTime<Utc>,
completed_at: Option<DateTime<Utc>>,
consumed_at: Option<DateTime<Utc>>,
extra_callback_parameters: Option<serde_json::Value>,
}
impl TryFrom<SessionLookup> for UpstreamOAuthAuthorizationSession {
@@ -53,25 +54,32 @@ impl TryFrom<SessionLookup> for UpstreamOAuthAuthorizationSession {
let state = match (
value.upstream_oauth_link_id,
value.id_token,
value.extra_callback_parameters,
value.completed_at,
value.consumed_at,
) {
(None, None, None, None) => UpstreamOAuthAuthorizationSessionState::Pending,
(Some(link_id), id_token, Some(completed_at), None) => {
(None, None, None, None, None) => UpstreamOAuthAuthorizationSessionState::Pending,
(Some(link_id), id_token, extra_callback_parameters, Some(completed_at), None) => {
UpstreamOAuthAuthorizationSessionState::Completed {
completed_at,
link_id: link_id.into(),
id_token,
extra_callback_parameters,
}
}
(Some(link_id), id_token, Some(completed_at), Some(consumed_at)) => {
UpstreamOAuthAuthorizationSessionState::Consumed {
completed_at,
link_id: link_id.into(),
id_token,
consumed_at,
}
}
(
Some(link_id),
id_token,
extra_callback_parameters,
Some(completed_at),
Some(consumed_at),
) => UpstreamOAuthAuthorizationSessionState::Consumed {
completed_at,
link_id: link_id.into(),
id_token,
extra_callback_parameters,
consumed_at,
},
_ => {
return Err(
DatabaseInconsistencyError::on("upstream_oauth_authorization_sessions").row(id),
@@ -119,6 +127,7 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c>
code_challenge_verifier,
nonce,
id_token,
extra_callback_parameters,
created_at,
completed_at,
consumed_at
@@ -216,6 +225,7 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c>
upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
upstream_oauth_link: &UpstreamOAuthLink,
id_token: Option<String>,
extra_callback_parameters: Option<serde_json::Value>,
) -> Result<UpstreamOAuthAuthorizationSession, Self::Error> {
let completed_at = clock.now();
@@ -224,12 +234,14 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c>
UPDATE upstream_oauth_authorization_sessions
SET upstream_oauth_link_id = $1,
completed_at = $2,
id_token = $3
WHERE upstream_oauth_authorization_session_id = $4
id_token = $3,
extra_callback_parameters = $4
WHERE upstream_oauth_authorization_session_id = $5
"#,
Uuid::from(upstream_oauth_link.id),
completed_at,
id_token,
extra_callback_parameters,
Uuid::from(upstream_oauth_authorization_session.id),
)
.traced()
@@ -237,7 +249,12 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c>
.await?;
let upstream_oauth_authorization_session = upstream_oauth_authorization_session
.complete(completed_at, upstream_oauth_link, id_token)
.complete(
completed_at,
upstream_oauth_link,
id_token,
extra_callback_parameters,
)
.map_err(DatabaseError::to_invalid_operation)?;
Ok(upstream_oauth_authorization_session)

View File

@@ -74,6 +74,8 @@ pub trait UpstreamOAuthSessionRepository: Send + Sync {
/// * `upstream_oauth_link`: the link to associate with the session
/// * `id_token`: the ID token returned by the upstream OAuth provider, if
/// present
/// * `extra_callback_parameters`: the extra query parameters returned in
/// the callback, if any
///
/// # Errors
///
@@ -84,6 +86,7 @@ pub trait UpstreamOAuthSessionRepository: Send + Sync {
upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
upstream_oauth_link: &UpstreamOAuthLink,
id_token: Option<String>,
extra_callback_parameters: Option<serde_json::Value>,
) -> Result<UpstreamOAuthAuthorizationSession, Self::Error>;
/// Mark a session as consumed
@@ -127,6 +130,7 @@ repository_impl!(UpstreamOAuthSessionRepository:
upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
upstream_oauth_link: &UpstreamOAuthLink,
id_token: Option<String>,
extra_callback_parameters: Option<serde_json::Value>,
) -> Result<UpstreamOAuthAuthorizationSession, Self::Error>;
async fn consume(