Record extra query parameters during upstream callback
And make them available in the templates. This is useful to get the user display name for Sign-in with Apple
This commit is contained in:
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -3311,6 +3311,7 @@ dependencies = [
|
||||
"regex",
|
||||
"ruma-common",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"thiserror",
|
||||
"ulid",
|
||||
"url",
|
||||
|
||||
@@ -15,6 +15,7 @@ workspace = true
|
||||
chrono.workspace = true
|
||||
thiserror.workspace = true
|
||||
serde.workspace = true
|
||||
serde_json.workspace = true
|
||||
url.workspace = true
|
||||
crc = "3.2.1"
|
||||
ulid.workspace = true
|
||||
|
||||
@@ -19,12 +19,14 @@ pub enum UpstreamOAuthAuthorizationSessionState {
|
||||
completed_at: DateTime<Utc>,
|
||||
link_id: Ulid,
|
||||
id_token: Option<String>,
|
||||
extra_callback_parameters: Option<serde_json::Value>,
|
||||
},
|
||||
Consumed {
|
||||
completed_at: DateTime<Utc>,
|
||||
consumed_at: DateTime<Utc>,
|
||||
link_id: Ulid,
|
||||
id_token: Option<String>,
|
||||
extra_callback_parameters: Option<serde_json::Value>,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -42,12 +44,14 @@ impl UpstreamOAuthAuthorizationSessionState {
|
||||
completed_at: DateTime<Utc>,
|
||||
link: &UpstreamOAuthLink,
|
||||
id_token: Option<String>,
|
||||
extra_callback_parameters: Option<serde_json::Value>,
|
||||
) -> Result<Self, InvalidTransitionError> {
|
||||
match self {
|
||||
Self::Pending => Ok(Self::Completed {
|
||||
completed_at,
|
||||
link_id: link.id,
|
||||
id_token,
|
||||
extra_callback_parameters,
|
||||
}),
|
||||
Self::Completed { .. } | Self::Consumed { .. } => Err(InvalidTransitionError),
|
||||
}
|
||||
@@ -67,11 +71,13 @@ impl UpstreamOAuthAuthorizationSessionState {
|
||||
completed_at,
|
||||
link_id,
|
||||
id_token,
|
||||
extra_callback_parameters,
|
||||
} => Ok(Self::Consumed {
|
||||
completed_at,
|
||||
link_id,
|
||||
consumed_at,
|
||||
id_token,
|
||||
extra_callback_parameters,
|
||||
}),
|
||||
Self::Pending | Self::Consumed { .. } => Err(InvalidTransitionError),
|
||||
}
|
||||
@@ -124,6 +130,27 @@ impl UpstreamOAuthAuthorizationSessionState {
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the extra query parameters that were sent to the upstream provider.
|
||||
///
|
||||
/// Returns `None` if the upstream OAuth 2.0 authorization session state is
|
||||
/// not [`Pending`].
|
||||
///
|
||||
/// [`Pending`]: UpstreamOAuthAuthorizationSessionState::Pending
|
||||
#[must_use]
|
||||
pub fn extra_callback_parameters(&self) -> Option<&serde_json::Value> {
|
||||
match self {
|
||||
Self::Pending => None,
|
||||
Self::Completed {
|
||||
extra_callback_parameters,
|
||||
..
|
||||
}
|
||||
| Self::Consumed {
|
||||
extra_callback_parameters,
|
||||
..
|
||||
} => extra_callback_parameters.as_ref(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the time at which the upstream OAuth 2.0 authorization session was
|
||||
/// consumed.
|
||||
///
|
||||
@@ -201,8 +228,11 @@ impl UpstreamOAuthAuthorizationSession {
|
||||
completed_at: DateTime<Utc>,
|
||||
link: &UpstreamOAuthLink,
|
||||
id_token: Option<String>,
|
||||
extra_callback_parameters: Option<serde_json::Value>,
|
||||
) -> Result<Self, InvalidTransitionError> {
|
||||
self.state = self.state.complete(completed_at, link, id_token)?;
|
||||
self.state =
|
||||
self.state
|
||||
.complete(completed_at, link, id_token, extra_callback_parameters)?;
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
|
||||
@@ -48,6 +48,9 @@ pub struct Params {
|
||||
enum CodeOrError {
|
||||
Code {
|
||||
code: String,
|
||||
|
||||
#[serde(flatten)]
|
||||
extra_callback_parameters: Option<serde_json::Value>,
|
||||
},
|
||||
Error {
|
||||
error: ClientErrorCode,
|
||||
@@ -201,7 +204,7 @@ pub(crate) async fn handler(
|
||||
}
|
||||
|
||||
// Let's extract the code from the params, and return if there was an error
|
||||
let code = match params.code_or_error {
|
||||
let (code, extra_callback_parameters) = match params.code_or_error {
|
||||
CodeOrError::Error {
|
||||
error,
|
||||
error_description,
|
||||
@@ -212,7 +215,10 @@ pub(crate) async fn handler(
|
||||
error_description,
|
||||
})
|
||||
}
|
||||
CodeOrError::Code { code } => code,
|
||||
CodeOrError::Code {
|
||||
code,
|
||||
extra_callback_parameters,
|
||||
} => (code, extra_callback_parameters),
|
||||
};
|
||||
|
||||
let mut lazy_metadata = LazyProviderInfos::new(&metadata_cache, &provider, &client);
|
||||
@@ -266,6 +272,10 @@ pub(crate) async fn handler(
|
||||
let env = {
|
||||
let mut env = environment();
|
||||
env.add_global("user", minijinja::Value::from_serialize(&id_token));
|
||||
env.add_global(
|
||||
"extra_callback_parameters",
|
||||
minijinja::Value::from_serialize(&extra_callback_parameters),
|
||||
);
|
||||
env
|
||||
};
|
||||
|
||||
@@ -299,7 +309,13 @@ pub(crate) async fn handler(
|
||||
|
||||
let session = repo
|
||||
.upstream_oauth_session()
|
||||
.complete_with_link(&clock, session, &link, response.id_token)
|
||||
.complete_with_link(
|
||||
&clock,
|
||||
session,
|
||||
&link,
|
||||
response.id_token,
|
||||
extra_callback_parameters,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let cookie_jar = sessions_cookie
|
||||
|
||||
@@ -340,6 +340,10 @@ pub(crate) async fn get(
|
||||
let env = {
|
||||
let mut e = environment();
|
||||
e.add_global("user", payload);
|
||||
e.add_global(
|
||||
"extra_callback_parameters",
|
||||
minijinja::Value::from_serialize(upstream_session.extra_callback_parameters()),
|
||||
);
|
||||
e
|
||||
};
|
||||
|
||||
@@ -582,6 +586,10 @@ pub(crate) async fn post(
|
||||
let env = {
|
||||
let mut e = environment();
|
||||
e.add_global("user", payload);
|
||||
e.add_global(
|
||||
"extra_callback_parameters",
|
||||
minijinja::Value::from_serialize(upstream_session.extra_callback_parameters()),
|
||||
);
|
||||
e
|
||||
};
|
||||
|
||||
@@ -945,7 +953,13 @@ mod tests {
|
||||
|
||||
let session = repo
|
||||
.upstream_oauth_session()
|
||||
.complete_with_link(&state.clock, session, &link, Some(id_token.into_string()))
|
||||
.complete_with_link(
|
||||
&state.clock,
|
||||
session,
|
||||
&link,
|
||||
Some(id_token.into_string()),
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
|
||||
@@ -68,6 +68,18 @@ fn string(value: &Value) -> String {
|
||||
value.to_string()
|
||||
}
|
||||
|
||||
fn from_json(value: &str) -> Result<Value, minijinja::Error> {
|
||||
let value: serde_json::Value = serde_json::from_str(value).map_err(|e| {
|
||||
minijinja::Error::new(
|
||||
minijinja::ErrorKind::InvalidOperation,
|
||||
"Failed to decode JSON",
|
||||
)
|
||||
.with_source(e)
|
||||
})?;
|
||||
|
||||
Ok(Value::from_serialize(value))
|
||||
}
|
||||
|
||||
pub fn environment() -> Environment<'static> {
|
||||
let mut env = Environment::new();
|
||||
|
||||
@@ -77,6 +89,7 @@ pub fn environment() -> Environment<'static> {
|
||||
env.add_filter("b64encode", b64encode);
|
||||
env.add_filter("tlvdecode", tlvdecode);
|
||||
env.add_filter("string", string);
|
||||
env.add_filter("from_json", from_json);
|
||||
|
||||
env.set_unknown_method_callback(minijinja_contrib::pycompat::unknown_method_callback);
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n UPDATE upstream_oauth_authorization_sessions\n SET upstream_oauth_link_id = $1,\n completed_at = $2,\n id_token = $3\n WHERE upstream_oauth_authorization_session_id = $4\n ",
|
||||
"query": "\n UPDATE upstream_oauth_authorization_sessions\n SET upstream_oauth_link_id = $1,\n completed_at = $2,\n id_token = $3,\n extra_callback_parameters = $4\n WHERE upstream_oauth_authorization_session_id = $5\n ",
|
||||
"describe": {
|
||||
"columns": [],
|
||||
"parameters": {
|
||||
@@ -8,10 +8,11 @@
|
||||
"Uuid",
|
||||
"Timestamptz",
|
||||
"Text",
|
||||
"Jsonb",
|
||||
"Uuid"
|
||||
]
|
||||
},
|
||||
"nullable": []
|
||||
},
|
||||
"hash": "b9875a270f7e753e48075ccae233df6e24a91775ceb877735508c1d5b2300d64"
|
||||
"hash": "5516235e0983fb64d18e82dbe3e34f966ed71a0ed59be0d48ec66fedf64e707d"
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n SELECT\n upstream_oauth_authorization_session_id,\n upstream_oauth_provider_id,\n upstream_oauth_link_id,\n state,\n code_challenge_verifier,\n nonce,\n id_token,\n created_at,\n completed_at,\n consumed_at\n FROM upstream_oauth_authorization_sessions\n WHERE upstream_oauth_authorization_session_id = $1\n ",
|
||||
"query": "\n SELECT\n upstream_oauth_authorization_session_id,\n upstream_oauth_provider_id,\n upstream_oauth_link_id,\n state,\n code_challenge_verifier,\n nonce,\n id_token,\n extra_callback_parameters,\n created_at,\n completed_at,\n consumed_at\n FROM upstream_oauth_authorization_sessions\n WHERE upstream_oauth_authorization_session_id = $1\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
@@ -40,16 +40,21 @@
|
||||
},
|
||||
{
|
||||
"ordinal": 7,
|
||||
"name": "extra_callback_parameters",
|
||||
"type_info": "Jsonb"
|
||||
},
|
||||
{
|
||||
"ordinal": 8,
|
||||
"name": "created_at",
|
||||
"type_info": "Timestamptz"
|
||||
},
|
||||
{
|
||||
"ordinal": 8,
|
||||
"ordinal": 9,
|
||||
"name": "completed_at",
|
||||
"type_info": "Timestamptz"
|
||||
},
|
||||
{
|
||||
"ordinal": 9,
|
||||
"ordinal": 10,
|
||||
"name": "consumed_at",
|
||||
"type_info": "Timestamptz"
|
||||
}
|
||||
@@ -67,10 +72,11 @@
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
true
|
||||
]
|
||||
},
|
||||
"hash": "67ab838035946ddc15b43dd2f79d10b233d07e863b3a5c776c5db97cff263c8c"
|
||||
"hash": "7d329e0c57f36b9ffe2aa7ddf4a21e293522c00009cca0222524b0c73f6eee30"
|
||||
}
|
||||
@@ -0,0 +1,9 @@
|
||||
-- Copyright 2024 New Vector Ltd.
|
||||
--
|
||||
-- SPDX-License-Identifier: AGPL-3.0-only
|
||||
-- Please see LICENSE in the repository root for full details.
|
||||
|
||||
-- Add a column to the upstream_oauth_authorization_sessions table to store
|
||||
-- extra query parameters
|
||||
ALTER TABLE "upstream_oauth_authorization_sessions"
|
||||
ADD COLUMN "extra_callback_parameters" JSONB;
|
||||
@@ -145,7 +145,7 @@ mod tests {
|
||||
|
||||
let session = repo
|
||||
.upstream_oauth_session()
|
||||
.complete_with_link(&clock, session, &link, None)
|
||||
.complete_with_link(&clock, session, &link, None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
// Reload the session
|
||||
|
||||
@@ -43,6 +43,7 @@ struct SessionLookup {
|
||||
created_at: DateTime<Utc>,
|
||||
completed_at: Option<DateTime<Utc>>,
|
||||
consumed_at: Option<DateTime<Utc>>,
|
||||
extra_callback_parameters: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
impl TryFrom<SessionLookup> for UpstreamOAuthAuthorizationSession {
|
||||
@@ -53,25 +54,32 @@ impl TryFrom<SessionLookup> for UpstreamOAuthAuthorizationSession {
|
||||
let state = match (
|
||||
value.upstream_oauth_link_id,
|
||||
value.id_token,
|
||||
value.extra_callback_parameters,
|
||||
value.completed_at,
|
||||
value.consumed_at,
|
||||
) {
|
||||
(None, None, None, None) => UpstreamOAuthAuthorizationSessionState::Pending,
|
||||
(Some(link_id), id_token, Some(completed_at), None) => {
|
||||
(None, None, None, None, None) => UpstreamOAuthAuthorizationSessionState::Pending,
|
||||
(Some(link_id), id_token, extra_callback_parameters, Some(completed_at), None) => {
|
||||
UpstreamOAuthAuthorizationSessionState::Completed {
|
||||
completed_at,
|
||||
link_id: link_id.into(),
|
||||
id_token,
|
||||
extra_callback_parameters,
|
||||
}
|
||||
}
|
||||
(Some(link_id), id_token, Some(completed_at), Some(consumed_at)) => {
|
||||
UpstreamOAuthAuthorizationSessionState::Consumed {
|
||||
completed_at,
|
||||
link_id: link_id.into(),
|
||||
id_token,
|
||||
consumed_at,
|
||||
}
|
||||
}
|
||||
(
|
||||
Some(link_id),
|
||||
id_token,
|
||||
extra_callback_parameters,
|
||||
Some(completed_at),
|
||||
Some(consumed_at),
|
||||
) => UpstreamOAuthAuthorizationSessionState::Consumed {
|
||||
completed_at,
|
||||
link_id: link_id.into(),
|
||||
id_token,
|
||||
extra_callback_parameters,
|
||||
consumed_at,
|
||||
},
|
||||
_ => {
|
||||
return Err(
|
||||
DatabaseInconsistencyError::on("upstream_oauth_authorization_sessions").row(id),
|
||||
@@ -119,6 +127,7 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c>
|
||||
code_challenge_verifier,
|
||||
nonce,
|
||||
id_token,
|
||||
extra_callback_parameters,
|
||||
created_at,
|
||||
completed_at,
|
||||
consumed_at
|
||||
@@ -216,6 +225,7 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c>
|
||||
upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
|
||||
upstream_oauth_link: &UpstreamOAuthLink,
|
||||
id_token: Option<String>,
|
||||
extra_callback_parameters: Option<serde_json::Value>,
|
||||
) -> Result<UpstreamOAuthAuthorizationSession, Self::Error> {
|
||||
let completed_at = clock.now();
|
||||
|
||||
@@ -224,12 +234,14 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c>
|
||||
UPDATE upstream_oauth_authorization_sessions
|
||||
SET upstream_oauth_link_id = $1,
|
||||
completed_at = $2,
|
||||
id_token = $3
|
||||
WHERE upstream_oauth_authorization_session_id = $4
|
||||
id_token = $3,
|
||||
extra_callback_parameters = $4
|
||||
WHERE upstream_oauth_authorization_session_id = $5
|
||||
"#,
|
||||
Uuid::from(upstream_oauth_link.id),
|
||||
completed_at,
|
||||
id_token,
|
||||
extra_callback_parameters,
|
||||
Uuid::from(upstream_oauth_authorization_session.id),
|
||||
)
|
||||
.traced()
|
||||
@@ -237,7 +249,12 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c>
|
||||
.await?;
|
||||
|
||||
let upstream_oauth_authorization_session = upstream_oauth_authorization_session
|
||||
.complete(completed_at, upstream_oauth_link, id_token)
|
||||
.complete(
|
||||
completed_at,
|
||||
upstream_oauth_link,
|
||||
id_token,
|
||||
extra_callback_parameters,
|
||||
)
|
||||
.map_err(DatabaseError::to_invalid_operation)?;
|
||||
|
||||
Ok(upstream_oauth_authorization_session)
|
||||
|
||||
@@ -74,6 +74,8 @@ pub trait UpstreamOAuthSessionRepository: Send + Sync {
|
||||
/// * `upstream_oauth_link`: the link to associate with the session
|
||||
/// * `id_token`: the ID token returned by the upstream OAuth provider, if
|
||||
/// present
|
||||
/// * `extra_callback_parameters`: the extra query parameters returned in
|
||||
/// the callback, if any
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
@@ -84,6 +86,7 @@ pub trait UpstreamOAuthSessionRepository: Send + Sync {
|
||||
upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
|
||||
upstream_oauth_link: &UpstreamOAuthLink,
|
||||
id_token: Option<String>,
|
||||
extra_callback_parameters: Option<serde_json::Value>,
|
||||
) -> Result<UpstreamOAuthAuthorizationSession, Self::Error>;
|
||||
|
||||
/// Mark a session as consumed
|
||||
@@ -127,6 +130,7 @@ repository_impl!(UpstreamOAuthSessionRepository:
|
||||
upstream_oauth_authorization_session: UpstreamOAuthAuthorizationSession,
|
||||
upstream_oauth_link: &UpstreamOAuthLink,
|
||||
id_token: Option<String>,
|
||||
extra_callback_parameters: Option<serde_json::Value>,
|
||||
) -> Result<UpstreamOAuthAuthorizationSession, Self::Error>;
|
||||
|
||||
async fn consume(
|
||||
|
||||
Reference in New Issue
Block a user