Add fetch_userinfo to upstream SSO provider (#3363)

This commit is contained in:
Mathieu Velten
2024-11-26 15:01:03 +00:00
committed by GitHub
parent 49f237796c
commit f832666a86
27 changed files with 414 additions and 142 deletions

View File

@@ -284,10 +284,12 @@ pub async fn config_sync(
encrypted_client_secret,
claims_imports: map_claims_imports(&provider.claims_imports),
token_endpoint_override: provider.token_endpoint,
userinfo_endpoint_override: provider.userinfo_endpoint,
authorization_endpoint_override: provider.authorization_endpoint,
jwks_uri_override: provider.jwks_uri,
discovery_mode,
pkce_mode,
fetch_userinfo: provider.fetch_userinfo,
response_mode,
additional_authorization_parameters: provider
.additional_authorization_parameters

View File

@@ -465,12 +465,26 @@ pub struct Provider {
#[serde(default, skip_serializing_if = "PkceMethod::is_default")]
pub pkce_method: PkceMethod,
/// Whether to fetch the user profile from the userinfo endpoint,
/// or to rely on the data returned in the `id_token` from the
/// `token_endpoint`.
///
/// Defaults to `false`.
#[serde(default)]
pub fetch_userinfo: bool,
/// The URL to use for the provider's authorization endpoint
///
/// Defaults to the `authorization_endpoint` provided through discovery
#[serde(skip_serializing_if = "Option::is_none")]
pub authorization_endpoint: Option<Url>,
/// The URL to use for the provider's userinfo endpoint
///
/// Defaults to the `userinfo_endpoint` provided through discovery
#[serde(skip_serializing_if = "Option::is_none")]
pub userinfo_endpoint: Option<Url>,
/// The URL to use for the provider's token endpoint
///
/// Defaults to the `token_endpoint` provided through discovery

View File

@@ -228,6 +228,8 @@ pub struct UpstreamOAuthProvider {
pub authorization_endpoint_override: Option<Url>,
pub scope: Scope,
pub token_endpoint_override: Option<Url>,
pub userinfo_endpoint_override: Option<Url>,
pub fetch_userinfo: bool,
pub client_id: String,
pub encrypted_client_secret: Option<String>,
pub token_endpoint_signing_alg: Option<JsonWebSignatureAlg>,

View File

@@ -20,6 +20,7 @@ pub enum UpstreamOAuthAuthorizationSessionState {
link_id: Ulid,
id_token: Option<String>,
extra_callback_parameters: Option<serde_json::Value>,
userinfo: Option<serde_json::Value>,
},
Consumed {
completed_at: DateTime<Utc>,
@@ -27,6 +28,7 @@ pub enum UpstreamOAuthAuthorizationSessionState {
link_id: Ulid,
id_token: Option<String>,
extra_callback_parameters: Option<serde_json::Value>,
userinfo: Option<serde_json::Value>,
},
}
@@ -45,6 +47,7 @@ impl UpstreamOAuthAuthorizationSessionState {
link: &UpstreamOAuthLink,
id_token: Option<String>,
extra_callback_parameters: Option<serde_json::Value>,
userinfo: Option<serde_json::Value>,
) -> Result<Self, InvalidTransitionError> {
match self {
Self::Pending => Ok(Self::Completed {
@@ -52,6 +55,7 @@ impl UpstreamOAuthAuthorizationSessionState {
link_id: link.id,
id_token,
extra_callback_parameters,
userinfo,
}),
Self::Completed { .. } | Self::Consumed { .. } => Err(InvalidTransitionError),
}
@@ -72,12 +76,14 @@ impl UpstreamOAuthAuthorizationSessionState {
link_id,
id_token,
extra_callback_parameters,
userinfo,
} => Ok(Self::Consumed {
completed_at,
link_id,
consumed_at,
id_token,
extra_callback_parameters,
userinfo,
}),
Self::Pending | Self::Consumed { .. } => Err(InvalidTransitionError),
}
@@ -151,6 +157,14 @@ impl UpstreamOAuthAuthorizationSessionState {
}
}
#[must_use]
pub fn userinfo(&self) -> Option<&serde_json::Value> {
match self {
Self::Pending => None,
Self::Completed { userinfo, .. } | Self::Consumed { userinfo, .. } => userinfo.as_ref(),
}
}
/// Get the time at which the upstream OAuth 2.0 authorization session was
/// consumed.
///
@@ -229,10 +243,15 @@ impl UpstreamOAuthAuthorizationSession {
link: &UpstreamOAuthLink,
id_token: Option<String>,
extra_callback_parameters: Option<serde_json::Value>,
userinfo: Option<serde_json::Value>,
) -> Result<Self, InvalidTransitionError> {
self.state =
self.state
.complete(completed_at, link, id_token, extra_callback_parameters)?;
self.state = self.state.complete(
completed_at,
link,
id_token,
extra_callback_parameters,
userinfo,
)?;
Ok(self)
}

View File

@@ -108,6 +108,18 @@ impl<'a> LazyProviderInfos<'a> {
Ok(self.load().await?.token_endpoint())
}
/// Get the userinfo endpoint for the provider.
///
/// Uses [`UpstreamOAuthProvider.userinfo_endpoint_override`] if set,
/// otherwise uses the one from discovery.
pub async fn userinfo_endpoint(&mut self) -> Result<&Url, DiscoveryError> {
if let Some(userinfo_endpoint) = &self.provider.userinfo_endpoint_override {
return Ok(userinfo_endpoint);
}
Ok(self.load().await?.userinfo_endpoint())
}
/// Get the PKCE methods supported by the provider.
///
/// If the mode is set to auto, it will use the ones from discovery,
@@ -387,9 +399,11 @@ mod tests {
brand_name: None,
discovery_mode: UpstreamOAuthProviderDiscoveryMode::Insecure,
pkce_mode: UpstreamOAuthProviderPkceMode::Auto,
fetch_userinfo: false,
jwks_uri_override: None,
authorization_endpoint_override: None,
scope: Scope::from_iter([OPENID]),
userinfo_endpoint_override: None,
token_endpoint_override: None,
client_id: "client_id".to_owned(),
encrypted_client_secret: None,

View File

@@ -29,6 +29,7 @@ use mas_storage::{
use mas_templates::{FormPostContext, Templates};
use oauth2_types::errors::ClientErrorCode;
use serde::{Deserialize, Serialize};
use serde_json::json;
use thiserror::Error;
use ulid::Ulid;
@@ -117,7 +118,7 @@ pub(crate) enum RouteError {
},
#[error(transparent)]
Internal(Box<dyn std::error::Error>),
Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
}
impl_from_error_for_route!(mas_templates::TemplateError);
@@ -125,6 +126,7 @@ impl_from_error_for_route!(mas_storage::RepositoryError);
impl_from_error_for_route!(mas_oidc_client::error::DiscoveryError);
impl_from_error_for_route!(mas_oidc_client::error::JwksError);
impl_from_error_for_route!(mas_oidc_client::error::TokenAuthorizationCodeError);
impl_from_error_for_route!(mas_oidc_client::error::UserInfoError);
impl_from_error_for_route!(super::ProviderCredentialsError);
impl_from_error_for_route!(super::cookie::UpstreamSessionNotFound);
@@ -274,7 +276,7 @@ pub(crate) async fn handler(
redirect_uri,
};
let id_token_verification_data = JwtVerificationData {
let verification_data = JwtVerificationData {
issuer: &provider.issuer,
jwks: &jwks,
// TODO: make that configurable
@@ -282,25 +284,48 @@ pub(crate) async fn handler(
client_id: &provider.client_id,
};
let (response, id_token) =
let (response, id_token_map) =
mas_oidc_client::requests::authorization_code::access_token_with_authorization_code(
&client,
client_credentials,
lazy_metadata.token_endpoint().await?,
code,
validation_data,
Some(id_token_verification_data),
Some(verification_data),
clock.now(),
&mut rng,
)
.await?;
let (_header, id_token) = id_token.ok_or(RouteError::MissingIDToken)?.into_parts();
let (_header, id_token) = id_token_map
.clone()
.ok_or(RouteError::MissingIDToken)?
.into_parts();
let mut context = AttributeMappingContext::new().with_id_token_claims(id_token);
if let Some(extra_callback_parameters) = extra_callback_parameters.clone() {
context = context.with_extra_callback_parameters(extra_callback_parameters);
}
let userinfo = if provider.fetch_userinfo {
Some(json!(
mas_oidc_client::requests::userinfo::fetch_userinfo(
&client,
lazy_metadata.userinfo_endpoint().await?,
response.access_token.as_str(),
Some(verification_data),
&id_token_map.ok_or(RouteError::MissingIDToken)?,
)
.await?
))
} else {
None
};
if let Some(userinfo) = userinfo.clone() {
context = context.with_userinfo_claims(userinfo);
}
let context = context.build();
let env = environment();
@@ -341,6 +366,7 @@ pub(crate) async fn handler(
&link,
response.id_token,
extra_callback_parameters,
userinfo,
)
.await?;

View File

@@ -344,6 +344,9 @@ pub(crate) async fn get(
if let Some(extra_callback_parameters) = upstream_session.extra_callback_parameters() {
context = context.with_extra_callback_parameters(extra_callback_parameters.clone());
}
if let Some(userinfo) = upstream_session.userinfo() {
context = context.with_userinfo_claims(userinfo.clone());
}
let context = context.build();
let ctx = if provider.claims_imports.displayname.ignore() {
@@ -582,6 +585,9 @@ pub(crate) async fn post(
if let Some(extra_callback_parameters) = upstream_session.extra_callback_parameters() {
context = context.with_extra_callback_parameters(extra_callback_parameters.clone());
}
if let Some(userinfo) = upstream_session.userinfo() {
context = context.with_userinfo_claims(userinfo.clone());
}
let context = context.build();
// Is the email verified according to the upstream provider?
@@ -921,6 +927,8 @@ mod tests {
claims_imports,
authorization_endpoint_override: None,
token_endpoint_override: None,
userinfo_endpoint_override: None,
fetch_userinfo: false,
jwks_uri_override: None,
discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
@@ -958,6 +966,7 @@ mod tests {
&link,
Some(id_token.into_string()),
None,
None,
)
.await
.unwrap();

View File

@@ -23,6 +23,7 @@ use minijinja::{
pub(crate) struct AttributeMappingContext {
id_token_claims: Option<HashMap<String, serde_json::Value>>,
extra_callback_parameters: Option<serde_json::Value>,
userinfo_claims: Option<serde_json::Value>,
}
impl AttributeMappingContext {
@@ -46,6 +47,11 @@ impl AttributeMappingContext {
self
}
pub fn with_userinfo_claims(mut self, userinfo_claims: serde_json::Value) -> Self {
self.userinfo_claims = Some(userinfo_claims);
self
}
pub fn build(self) -> Value {
Value::from_object(self)
}
@@ -54,7 +60,25 @@ impl AttributeMappingContext {
impl Object for AttributeMappingContext {
fn get_value(self: &Arc<Self>, name: &Value) -> Option<Value> {
match name.as_str()? {
"user" | "id_token_claims" => self.id_token_claims.as_ref().map(Value::from_serialize),
"user" => {
if self.id_token_claims.is_none() && self.userinfo_claims.is_none() {
return None;
}
let mut merged_user: HashMap<String, serde_json::Value> = HashMap::new();
if let serde_json::Value::Object(userinfo) = self
.userinfo_claims
.clone()
.unwrap_or(serde_json::Value::Null)
{
merged_user.extend(userinfo);
}
if let Some(id_token) = self.id_token_claims.clone() {
merged_user.extend(id_token);
}
Some(Value::from_serialize(merged_user))
}
"id_token_claims" => self.id_token_claims.as_ref().map(Value::from_serialize),
"userinfo_claims" => self.userinfo_claims.as_ref().map(Value::from_serialize),
"extra_callback_parameters" => self
.extra_callback_parameters
.as_ref()
@@ -64,17 +88,20 @@ impl Object for AttributeMappingContext {
}
fn enumerate(self: &Arc<Self>) -> Enumerator {
match (
self.id_token_claims.is_some(),
self.extra_callback_parameters.is_some(),
) {
(true, true) => {
Enumerator::Str(&["user", "id_token_claims", "extra_callback_parameters"])
}
(true, false) => Enumerator::Str(&["user", "id_token_claims"]),
(false, true) => Enumerator::Str(&["extra_callback_parameters"]),
(false, false) => Enumerator::Str(&["user"]),
let mut attrs = Vec::new();
if self.id_token_claims.is_some() || self.userinfo_claims.is_none() {
attrs.push(minijinja::Value::from("user"));
}
if self.id_token_claims.is_some() {
attrs.push(minijinja::Value::from("id_token_claims"));
}
if self.userinfo_claims.is_some() {
attrs.push(minijinja::Value::from("userinfo_claims"));
}
if self.extra_callback_parameters.is_some() {
attrs.push(minijinja::Value::from("extra_callback_parameters"));
}
Enumerator::Values(attrs)
}
}

View File

@@ -403,11 +403,13 @@ mod test {
scope: [OPENID].into_iter().collect(),
token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
token_endpoint_signing_alg: None,
fetch_userinfo: false,
client_id: "client".to_owned(),
encrypted_client_secret: None,
claims_imports: UpstreamOAuthProviderClaimsImports::default(),
authorization_endpoint_override: None,
token_endpoint_override: None,
userinfo_endpoint_override: None,
jwks_uri_override: None,
discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
@@ -439,11 +441,13 @@ mod test {
scope: [OPENID].into_iter().collect(),
token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
token_endpoint_signing_alg: None,
fetch_userinfo: false,
client_id: "client".to_owned(),
encrypted_client_secret: None,
claims_imports: UpstreamOAuthProviderClaimsImports::default(),
authorization_endpoint_override: None,
token_endpoint_override: None,
userinfo_endpoint_override: None,
jwks_uri_override: None,
discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,

View File

@@ -950,6 +950,15 @@ impl VerifiedProviderMetadata {
}
}
/// URL of the authorization server's userinfo endpoint.
#[must_use]
pub fn userinfo_endpoint(&self) -> &Url {
match &self.userinfo_endpoint {
Some(u) => u,
None => unreachable!(),
}
}
/// URL of the authorization server's token endpoint.
#[must_use]
pub fn token_endpoint(&self) -> &Url {

View File

@@ -1,6 +1,6 @@
{
"db_name": "PostgreSQL",
"query": "\n SELECT\n upstream_oauth_provider_id,\n issuer,\n human_name,\n brand_name,\n scope,\n client_id,\n encrypted_client_secret,\n token_endpoint_signing_alg,\n token_endpoint_auth_method,\n created_at,\n disabled_at,\n claims_imports as \"claims_imports: Json<UpstreamOAuthProviderClaimsImports>\",\n jwks_uri_override,\n authorization_endpoint_override,\n token_endpoint_override,\n discovery_mode,\n pkce_mode,\n response_mode,\n additional_parameters as \"additional_parameters: Json<Vec<(String, String)>>\"\n FROM upstream_oauth_providers\n WHERE disabled_at IS NULL\n ",
"query": "\n SELECT\n upstream_oauth_provider_id,\n issuer,\n human_name,\n brand_name,\n scope,\n client_id,\n encrypted_client_secret,\n token_endpoint_signing_alg,\n token_endpoint_auth_method,\n fetch_userinfo,\n created_at,\n disabled_at,\n claims_imports as \"claims_imports: Json<UpstreamOAuthProviderClaimsImports>\",\n jwks_uri_override,\n authorization_endpoint_override,\n token_endpoint_override,\n userinfo_endpoint_override,\n discovery_mode,\n pkce_mode,\n response_mode,\n additional_parameters as \"additional_parameters: Json<Vec<(String, String)>>\"\n FROM upstream_oauth_providers\n WHERE disabled_at IS NULL\n ",
"describe": {
"columns": [
{
@@ -50,51 +50,61 @@
},
{
"ordinal": 9,
"name": "fetch_userinfo",
"type_info": "Bool"
},
{
"ordinal": 10,
"name": "created_at",
"type_info": "Timestamptz"
},
{
"ordinal": 10,
"ordinal": 11,
"name": "disabled_at",
"type_info": "Timestamptz"
},
{
"ordinal": 11,
"ordinal": 12,
"name": "claims_imports: Json<UpstreamOAuthProviderClaimsImports>",
"type_info": "Jsonb"
},
{
"ordinal": 12,
"ordinal": 13,
"name": "jwks_uri_override",
"type_info": "Text"
},
{
"ordinal": 13,
"ordinal": 14,
"name": "authorization_endpoint_override",
"type_info": "Text"
},
{
"ordinal": 14,
"ordinal": 15,
"name": "token_endpoint_override",
"type_info": "Text"
},
{
"ordinal": 15,
"name": "discovery_mode",
"type_info": "Text"
},
{
"ordinal": 16,
"name": "pkce_mode",
"name": "userinfo_endpoint_override",
"type_info": "Text"
},
{
"ordinal": 17,
"name": "response_mode",
"name": "discovery_mode",
"type_info": "Text"
},
{
"ordinal": 18,
"name": "pkce_mode",
"type_info": "Text"
},
{
"ordinal": 19,
"name": "response_mode",
"type_info": "Text"
},
{
"ordinal": 20,
"name": "additional_parameters: Json<Vec<(String, String)>>",
"type_info": "Jsonb"
}
@@ -113,16 +123,18 @@
true,
false,
false,
false,
true,
false,
true,
true,
true,
true,
false,
false,
false,
true
]
},
"hash": "6b133c3c6bfc3c80a21f6f72d0a6468f748ed59e88d8c904bb0a4bbfee43a67f"
"hash": "39657c8064532745c8a8a944b73f650b468a4677eddf671c69c329d361edf00e"
}

View File

@@ -1,6 +1,6 @@
{
"db_name": "PostgreSQL",
"query": "\n UPDATE upstream_oauth_authorization_sessions\n SET upstream_oauth_link_id = $1,\n completed_at = $2,\n id_token = $3,\n extra_callback_parameters = $4\n WHERE upstream_oauth_authorization_session_id = $5\n ",
"query": "\n UPDATE upstream_oauth_authorization_sessions\n SET upstream_oauth_link_id = $1,\n completed_at = $2,\n id_token = $3,\n extra_callback_parameters = $4,\n userinfo = $5\n WHERE upstream_oauth_authorization_session_id = $6\n ",
"describe": {
"columns": [],
"parameters": {
@@ -9,10 +9,11 @@
"Timestamptz",
"Text",
"Jsonb",
"Jsonb",
"Uuid"
]
},
"nullable": []
},
"hash": "5516235e0983fb64d18e82dbe3e34f966ed71a0ed59be0d48ec66fedf64e707d"
"hash": "5f5245ace61b896f92be78ab4fef701b37c9e3c2f4a332f418b9fb2625a0fe3f"
}

View File

@@ -1,6 +1,6 @@
{
"db_name": "PostgreSQL",
"query": "\n SELECT\n upstream_oauth_provider_id,\n issuer,\n human_name,\n brand_name,\n scope,\n client_id,\n encrypted_client_secret,\n token_endpoint_signing_alg,\n token_endpoint_auth_method,\n created_at,\n disabled_at,\n claims_imports as \"claims_imports: Json<UpstreamOAuthProviderClaimsImports>\",\n jwks_uri_override,\n authorization_endpoint_override,\n token_endpoint_override,\n discovery_mode,\n pkce_mode,\n response_mode,\n additional_parameters as \"additional_parameters: Json<Vec<(String, String)>>\"\n FROM upstream_oauth_providers\n WHERE upstream_oauth_provider_id = $1\n ",
"query": "\n SELECT\n upstream_oauth_provider_id,\n issuer,\n human_name,\n brand_name,\n scope,\n client_id,\n encrypted_client_secret,\n token_endpoint_signing_alg,\n token_endpoint_auth_method,\n fetch_userinfo,\n created_at,\n disabled_at,\n claims_imports as \"claims_imports: Json<UpstreamOAuthProviderClaimsImports>\",\n jwks_uri_override,\n authorization_endpoint_override,\n token_endpoint_override,\n userinfo_endpoint_override,\n discovery_mode,\n pkce_mode,\n response_mode,\n additional_parameters as \"additional_parameters: Json<Vec<(String, String)>>\"\n FROM upstream_oauth_providers\n WHERE upstream_oauth_provider_id = $1\n ",
"describe": {
"columns": [
{
@@ -50,51 +50,61 @@
},
{
"ordinal": 9,
"name": "fetch_userinfo",
"type_info": "Bool"
},
{
"ordinal": 10,
"name": "created_at",
"type_info": "Timestamptz"
},
{
"ordinal": 10,
"ordinal": 11,
"name": "disabled_at",
"type_info": "Timestamptz"
},
{
"ordinal": 11,
"ordinal": 12,
"name": "claims_imports: Json<UpstreamOAuthProviderClaimsImports>",
"type_info": "Jsonb"
},
{
"ordinal": 12,
"ordinal": 13,
"name": "jwks_uri_override",
"type_info": "Text"
},
{
"ordinal": 13,
"ordinal": 14,
"name": "authorization_endpoint_override",
"type_info": "Text"
},
{
"ordinal": 14,
"ordinal": 15,
"name": "token_endpoint_override",
"type_info": "Text"
},
{
"ordinal": 15,
"name": "discovery_mode",
"type_info": "Text"
},
{
"ordinal": 16,
"name": "pkce_mode",
"name": "userinfo_endpoint_override",
"type_info": "Text"
},
{
"ordinal": 17,
"name": "response_mode",
"name": "discovery_mode",
"type_info": "Text"
},
{
"ordinal": 18,
"name": "pkce_mode",
"type_info": "Text"
},
{
"ordinal": 19,
"name": "response_mode",
"type_info": "Text"
},
{
"ordinal": 20,
"name": "additional_parameters: Json<Vec<(String, String)>>",
"type_info": "Jsonb"
}
@@ -115,16 +125,18 @@
true,
false,
false,
false,
true,
false,
true,
true,
true,
true,
false,
false,
false,
true
]
},
"hash": "73f4e5a724a432f1328c6112185cdc9c7a1ae1de45a2a8c02e7a2b8020b41160"
"hash": "887bd597132831c5caab2356f2d935c00a32274161ec5265da91d1c75ad0bb2b"
}

View File

@@ -0,0 +1,32 @@
{
"db_name": "PostgreSQL",
"query": "\n INSERT INTO upstream_oauth_providers (\n upstream_oauth_provider_id,\n issuer,\n human_name,\n brand_name,\n scope,\n token_endpoint_auth_method,\n fetch_userinfo,\n token_endpoint_signing_alg,\n client_id,\n encrypted_client_secret,\n claims_imports,\n authorization_endpoint_override,\n token_endpoint_override,\n userinfo_endpoint_override,\n jwks_uri_override,\n discovery_mode,\n pkce_mode,\n response_mode,\n created_at\n ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10,\n $11, $12, $13, $14, $15, $16, $17, $18, $19)\n ",
"describe": {
"columns": [],
"parameters": {
"Left": [
"Uuid",
"Text",
"Text",
"Text",
"Text",
"Text",
"Bool",
"Text",
"Text",
"Text",
"Jsonb",
"Text",
"Text",
"Text",
"Text",
"Text",
"Text",
"Text",
"Timestamptz"
]
},
"nullable": []
},
"hash": "8e1c0760c0b652cf62e47779f9d0aef89463cc60eeae2088d0fedf0aeb75718b"
}

View File

@@ -0,0 +1,41 @@
{
"db_name": "PostgreSQL",
"query": "\n INSERT INTO upstream_oauth_providers (\n upstream_oauth_provider_id,\n issuer,\n human_name,\n brand_name,\n scope,\n token_endpoint_auth_method,\n fetch_userinfo,\n token_endpoint_signing_alg,\n client_id,\n encrypted_client_secret,\n claims_imports,\n authorization_endpoint_override,\n token_endpoint_override,\n userinfo_endpoint_override,\n jwks_uri_override,\n discovery_mode,\n pkce_mode,\n response_mode,\n additional_parameters,\n created_at\n ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11,\n $12, $13, $14, $15, $16, $17, $18, $19, $20)\n ON CONFLICT (upstream_oauth_provider_id)\n DO UPDATE\n SET\n issuer = EXCLUDED.issuer,\n human_name = EXCLUDED.human_name,\n brand_name = EXCLUDED.brand_name,\n scope = EXCLUDED.scope,\n token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method,\n fetch_userinfo = EXCLUDED.fetch_userinfo,\n token_endpoint_signing_alg = EXCLUDED.token_endpoint_signing_alg,\n disabled_at = NULL,\n client_id = EXCLUDED.client_id,\n encrypted_client_secret = EXCLUDED.encrypted_client_secret,\n claims_imports = EXCLUDED.claims_imports,\n authorization_endpoint_override = EXCLUDED.authorization_endpoint_override,\n token_endpoint_override = EXCLUDED.token_endpoint_override,\n userinfo_endpoint_override = EXCLUDED.userinfo_endpoint_override,\n jwks_uri_override = EXCLUDED.jwks_uri_override,\n discovery_mode = EXCLUDED.discovery_mode,\n pkce_mode = EXCLUDED.pkce_mode,\n response_mode = EXCLUDED.response_mode,\n additional_parameters = EXCLUDED.additional_parameters\n RETURNING created_at\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "created_at",
"type_info": "Timestamptz"
}
],
"parameters": {
"Left": [
"Uuid",
"Text",
"Text",
"Text",
"Text",
"Text",
"Bool",
"Text",
"Text",
"Text",
"Jsonb",
"Text",
"Text",
"Text",
"Text",
"Text",
"Text",
"Text",
"Jsonb",
"Timestamptz"
]
},
"nullable": [
false
]
},
"hash": "bf7747552fe6f5489dec3c91fe1cb13a737644b94871c28334a29c88977dd84c"
}

View File

@@ -1,39 +0,0 @@
{
"db_name": "PostgreSQL",
"query": "\n INSERT INTO upstream_oauth_providers (\n upstream_oauth_provider_id,\n issuer,\n human_name,\n brand_name,\n scope,\n token_endpoint_auth_method,\n token_endpoint_signing_alg,\n client_id,\n encrypted_client_secret,\n claims_imports,\n authorization_endpoint_override,\n token_endpoint_override,\n jwks_uri_override,\n discovery_mode,\n pkce_mode,\n response_mode,\n additional_parameters,\n created_at\n ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9,\n $10, $11, $12, $13, $14, $15, $16, $17, $18)\n ON CONFLICT (upstream_oauth_provider_id)\n DO UPDATE\n SET\n issuer = EXCLUDED.issuer,\n human_name = EXCLUDED.human_name,\n brand_name = EXCLUDED.brand_name,\n scope = EXCLUDED.scope,\n token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method,\n token_endpoint_signing_alg = EXCLUDED.token_endpoint_signing_alg,\n disabled_at = NULL,\n client_id = EXCLUDED.client_id,\n encrypted_client_secret = EXCLUDED.encrypted_client_secret,\n claims_imports = EXCLUDED.claims_imports,\n authorization_endpoint_override = EXCLUDED.authorization_endpoint_override,\n token_endpoint_override = EXCLUDED.token_endpoint_override,\n jwks_uri_override = EXCLUDED.jwks_uri_override,\n discovery_mode = EXCLUDED.discovery_mode,\n pkce_mode = EXCLUDED.pkce_mode,\n response_mode = EXCLUDED.response_mode,\n additional_parameters = EXCLUDED.additional_parameters\n RETURNING created_at\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "created_at",
"type_info": "Timestamptz"
}
],
"parameters": {
"Left": [
"Uuid",
"Text",
"Text",
"Text",
"Text",
"Text",
"Text",
"Text",
"Text",
"Jsonb",
"Text",
"Text",
"Text",
"Text",
"Text",
"Text",
"Jsonb",
"Timestamptz"
]
},
"nullable": [
false
]
},
"hash": "e36ed76d0176edf8c4a029f017b8f368a529b2d32a54c52f6a28b9e615716f4d"
}

View File

@@ -1,6 +1,6 @@
{
"db_name": "PostgreSQL",
"query": "\n SELECT\n upstream_oauth_authorization_session_id,\n upstream_oauth_provider_id,\n upstream_oauth_link_id,\n state,\n code_challenge_verifier,\n nonce,\n id_token,\n extra_callback_parameters,\n created_at,\n completed_at,\n consumed_at\n FROM upstream_oauth_authorization_sessions\n WHERE upstream_oauth_authorization_session_id = $1\n ",
"query": "\n SELECT\n upstream_oauth_authorization_session_id,\n upstream_oauth_provider_id,\n upstream_oauth_link_id,\n state,\n code_challenge_verifier,\n nonce,\n id_token,\n extra_callback_parameters,\n userinfo,\n created_at,\n completed_at,\n consumed_at\n FROM upstream_oauth_authorization_sessions\n WHERE upstream_oauth_authorization_session_id = $1\n ",
"describe": {
"columns": [
{
@@ -45,16 +45,21 @@
},
{
"ordinal": 8,
"name": "userinfo",
"type_info": "Jsonb"
},
{
"ordinal": 9,
"name": "created_at",
"type_info": "Timestamptz"
},
{
"ordinal": 9,
"ordinal": 10,
"name": "completed_at",
"type_info": "Timestamptz"
},
{
"ordinal": 10,
"ordinal": 11,
"name": "consumed_at",
"type_info": "Timestamptz"
}
@@ -73,10 +78,11 @@
false,
true,
true,
true,
false,
true,
true
]
},
"hash": "7d329e0c57f36b9ffe2aa7ddf4a21e293522c00009cca0222524b0c73f6eee30"
"hash": "ea30b3809fd7c1d4e9983909c0219f343953a89f2a43f6b8c4ab4fbea7645ccc"
}

View File

@@ -1,30 +0,0 @@
{
"db_name": "PostgreSQL",
"query": "\n INSERT INTO upstream_oauth_providers (\n upstream_oauth_provider_id,\n issuer,\n human_name,\n brand_name,\n scope,\n token_endpoint_auth_method,\n token_endpoint_signing_alg,\n client_id,\n encrypted_client_secret,\n claims_imports,\n authorization_endpoint_override,\n token_endpoint_override,\n jwks_uri_override,\n discovery_mode,\n pkce_mode,\n response_mode,\n created_at\n ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9,\n $10, $11, $12, $13, $14, $15, $16, $17)\n ",
"describe": {
"columns": [],
"parameters": {
"Left": [
"Uuid",
"Text",
"Text",
"Text",
"Text",
"Text",
"Text",
"Text",
"Text",
"Jsonb",
"Text",
"Text",
"Text",
"Text",
"Text",
"Text",
"Timestamptz"
]
},
"nullable": []
},
"hash": "ebb1a78003293376a52de830f89f6f526ad1c5c823328463a6525d3c3d0d95c6"
}

View File

@@ -1,6 +1,6 @@
{
"db_name": "PostgreSQL",
"query": "\n INSERT INTO upstream_oauth_authorization_sessions (\n upstream_oauth_authorization_session_id,\n upstream_oauth_provider_id,\n state,\n code_challenge_verifier,\n nonce,\n created_at,\n completed_at,\n consumed_at,\n id_token\n ) VALUES ($1, $2, $3, $4, $5, $6, NULL, NULL, NULL)\n ",
"query": "\n INSERT INTO upstream_oauth_authorization_sessions (\n upstream_oauth_authorization_session_id,\n upstream_oauth_provider_id,\n state,\n code_challenge_verifier,\n nonce,\n created_at,\n completed_at,\n consumed_at,\n id_token,\n userinfo\n ) VALUES ($1, $2, $3, $4, $5, $6, NULL, NULL, NULL, NULL)\n ",
"describe": {
"columns": [],
"parameters": {
@@ -15,5 +15,5 @@
},
"nullable": []
},
"hash": "64e6ea47c2e877c1ebe4338d64d9ad8a6c1c777d1daea024b8ca2e7f0dd75b0f"
"hash": "f5c2ec9b7038d7ed36091e670f9bf34f8aa9ea8ed50929731845e32dc3176e39"
}

View File

@@ -0,0 +1,13 @@
-- Copyright 2024 New Vector Ltd.
--
-- SPDX-License-Identifier: AGPL-3.0-only
-- Please see LICENSE in the repository root for full details.
-- Add columms to upstream_oauth_providers and upstream_oauth_authorization_sessions
-- table to handle userinfo endpoint.
ALTER TABLE "upstream_oauth_providers"
ADD COLUMN "fetch_userinfo" BOOLEAN NOT NULL DEFAULT FALSE,
ADD COLUMN "userinfo_endpoint_override" TEXT;
ALTER TABLE "upstream_oauth_authorization_sessions"
ADD COLUMN "userinfo" JSONB;

View File

@@ -98,6 +98,7 @@ pub enum UpstreamOAuthProviders {
EncryptedClientSecret,
TokenEndpointSigningAlg,
TokenEndpointAuthMethod,
FetchUserinfo,
CreatedAt,
DisabledAt,
ClaimsImports,
@@ -108,6 +109,7 @@ pub enum UpstreamOAuthProviders {
JwksUriOverride,
TokenEndpointOverride,
AuthorizationEndpointOverride,
UserinfoEndpointOverride,
}
#[derive(sea_query::Iden)]

View File

@@ -60,12 +60,14 @@ mod tests {
brand_name: None,
scope: Scope::from_iter([OPENID]),
token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
fetch_userinfo: false,
token_endpoint_signing_alg: None,
client_id: "client-id".to_owned(),
encrypted_client_secret: None,
claims_imports: UpstreamOAuthProviderClaimsImports::default(),
token_endpoint_override: None,
authorization_endpoint_override: None,
userinfo_endpoint_override: None,
jwks_uri_override: None,
discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
@@ -145,7 +147,7 @@ mod tests {
let session = repo
.upstream_oauth_session()
.complete_with_link(&clock, session, &link, None, None)
.complete_with_link(&clock, session, &link, None, None, None)
.await
.unwrap();
// Reload the session
@@ -302,12 +304,14 @@ mod tests {
brand_name: None,
scope: scope.clone(),
token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
fetch_userinfo: false,
token_endpoint_signing_alg: None,
client_id,
encrypted_client_secret: None,
claims_imports: UpstreamOAuthProviderClaimsImports::default(),
token_endpoint_override: None,
authorization_endpoint_override: None,
userinfo_endpoint_override: None,
jwks_uri_override: None,
discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,

View File

@@ -56,12 +56,14 @@ struct ProviderLookup {
encrypted_client_secret: Option<String>,
token_endpoint_signing_alg: Option<String>,
token_endpoint_auth_method: String,
fetch_userinfo: bool,
created_at: DateTime<Utc>,
disabled_at: Option<DateTime<Utc>>,
claims_imports: Json<UpstreamOAuthProviderClaimsImports>,
jwks_uri_override: Option<String>,
authorization_endpoint_override: Option<String>,
token_endpoint_override: Option<String>,
userinfo_endpoint_override: Option<String>,
discovery_mode: String,
pkce_mode: String,
response_mode: String,
@@ -70,6 +72,8 @@ struct ProviderLookup {
impl TryFrom<ProviderLookup> for UpstreamOAuthProvider {
type Error = DatabaseInconsistencyError;
#[allow(clippy::too_many_lines)]
fn try_from(value: ProviderLookup) -> Result<Self, Self::Error> {
let id = value.upstream_oauth_provider_id.into();
let scope = value.scope.parse().map_err(|e| {
@@ -117,6 +121,17 @@ impl TryFrom<ProviderLookup> for UpstreamOAuthProvider {
.source(e)
})?;
let userinfo_endpoint_override = value
.userinfo_endpoint_override
.map(|x| x.parse())
.transpose()
.map_err(|e| {
DatabaseInconsistencyError::on("upstream_oauth_providers")
.column("userinfo_endpoint_override")
.row(id)
.source(e)
})?;
let jwks_uri_override = value
.jwks_uri_override
.map(|x| x.parse())
@@ -163,12 +178,14 @@ impl TryFrom<ProviderLookup> for UpstreamOAuthProvider {
client_id: value.client_id,
encrypted_client_secret: value.encrypted_client_secret,
token_endpoint_auth_method,
fetch_userinfo: value.fetch_userinfo,
token_endpoint_signing_alg,
created_at: value.created_at,
disabled_at: value.disabled_at,
claims_imports: value.claims_imports.0,
authorization_endpoint_override,
token_endpoint_override,
userinfo_endpoint_override,
jwks_uri_override,
discovery_mode,
pkce_mode,
@@ -218,12 +235,14 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
encrypted_client_secret,
token_endpoint_signing_alg,
token_endpoint_auth_method,
fetch_userinfo,
created_at,
disabled_at,
claims_imports as "claims_imports: Json<UpstreamOAuthProviderClaimsImports>",
jwks_uri_override,
authorization_endpoint_override,
token_endpoint_override,
userinfo_endpoint_override,
discovery_mode,
pkce_mode,
response_mode,
@@ -275,19 +294,21 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
brand_name,
scope,
token_endpoint_auth_method,
fetch_userinfo,
token_endpoint_signing_alg,
client_id,
encrypted_client_secret,
claims_imports,
authorization_endpoint_override,
token_endpoint_override,
userinfo_endpoint_override,
jwks_uri_override,
discovery_mode,
pkce_mode,
response_mode,
created_at
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9,
$10, $11, $12, $13, $14, $15, $16, $17)
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10,
$11, $12, $13, $14, $15, $16, $17, $18, $19)
"#,
Uuid::from(id),
&params.issuer,
@@ -295,6 +316,7 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
params.brand_name.as_deref(),
params.scope.to_string(),
params.token_endpoint_auth_method.to_string(),
params.fetch_userinfo,
params
.token_endpoint_signing_alg
.as_ref()
@@ -310,6 +332,10 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
.token_endpoint_override
.as_ref()
.map(ToString::to_string),
params
.userinfo_endpoint_override
.as_ref()
.map(ToString::to_string),
params.jwks_uri_override.as_ref().map(ToString::to_string),
params.discovery_mode.as_str(),
params.pkce_mode.as_str(),
@@ -330,11 +356,13 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
encrypted_client_secret: params.encrypted_client_secret,
token_endpoint_signing_alg: params.token_endpoint_signing_alg,
token_endpoint_auth_method: params.token_endpoint_auth_method,
fetch_userinfo: params.fetch_userinfo,
created_at,
disabled_at: None,
claims_imports: params.claims_imports,
authorization_endpoint_override: params.authorization_endpoint_override,
token_endpoint_override: params.token_endpoint_override,
userinfo_endpoint_override: params.userinfo_endpoint_override,
jwks_uri_override: params.jwks_uri_override,
discovery_mode: params.discovery_mode,
pkce_mode: params.pkce_mode,
@@ -437,20 +465,22 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
brand_name,
scope,
token_endpoint_auth_method,
fetch_userinfo,
token_endpoint_signing_alg,
client_id,
encrypted_client_secret,
claims_imports,
authorization_endpoint_override,
token_endpoint_override,
userinfo_endpoint_override,
jwks_uri_override,
discovery_mode,
pkce_mode,
response_mode,
additional_parameters,
created_at
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9,
$10, $11, $12, $13, $14, $15, $16, $17, $18)
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11,
$12, $13, $14, $15, $16, $17, $18, $19, $20)
ON CONFLICT (upstream_oauth_provider_id)
DO UPDATE
SET
@@ -459,6 +489,7 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
brand_name = EXCLUDED.brand_name,
scope = EXCLUDED.scope,
token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method,
fetch_userinfo = EXCLUDED.fetch_userinfo,
token_endpoint_signing_alg = EXCLUDED.token_endpoint_signing_alg,
disabled_at = NULL,
client_id = EXCLUDED.client_id,
@@ -466,6 +497,7 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
claims_imports = EXCLUDED.claims_imports,
authorization_endpoint_override = EXCLUDED.authorization_endpoint_override,
token_endpoint_override = EXCLUDED.token_endpoint_override,
userinfo_endpoint_override = EXCLUDED.userinfo_endpoint_override,
jwks_uri_override = EXCLUDED.jwks_uri_override,
discovery_mode = EXCLUDED.discovery_mode,
pkce_mode = EXCLUDED.pkce_mode,
@@ -479,6 +511,7 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
params.brand_name.as_deref(),
params.scope.to_string(),
params.token_endpoint_auth_method.to_string(),
params.fetch_userinfo,
params
.token_endpoint_signing_alg
.as_ref()
@@ -494,6 +527,10 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
.token_endpoint_override
.as_ref()
.map(ToString::to_string),
params
.userinfo_endpoint_override
.as_ref()
.map(ToString::to_string),
params.jwks_uri_override.as_ref().map(ToString::to_string),
params.discovery_mode.as_str(),
params.pkce_mode.as_str(),
@@ -515,11 +552,13 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
encrypted_client_secret: params.encrypted_client_secret,
token_endpoint_signing_alg: params.token_endpoint_signing_alg,
token_endpoint_auth_method: params.token_endpoint_auth_method,
fetch_userinfo: params.fetch_userinfo,
created_at,
disabled_at: None,
claims_imports: params.claims_imports,
authorization_endpoint_override: params.authorization_endpoint_override,
token_endpoint_override: params.token_endpoint_override,
userinfo_endpoint_override: params.userinfo_endpoint_override,
jwks_uri_override: params.jwks_uri_override,
discovery_mode: params.discovery_mode,
pkce_mode: params.pkce_mode,
@@ -644,6 +683,13 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
)),
ProviderLookupIden::CreatedAt,
)
.expr_as(
Expr::col((
UpstreamOAuthProviders::Table,
UpstreamOAuthProviders::FetchUserinfo,
)),
ProviderLookupIden::FetchUserinfo,
)
.expr_as(
Expr::col((
UpstreamOAuthProviders::Table,
@@ -679,6 +725,13 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
)),
ProviderLookupIden::AuthorizationEndpointOverride,
)
.expr_as(
Expr::col((
UpstreamOAuthProviders::Table,
UpstreamOAuthProviders::UserinfoEndpointOverride,
)),
ProviderLookupIden::UserinfoEndpointOverride,
)
.expr_as(
Expr::col((
UpstreamOAuthProviders::Table,
@@ -786,12 +839,14 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
encrypted_client_secret,
token_endpoint_signing_alg,
token_endpoint_auth_method,
fetch_userinfo,
created_at,
disabled_at,
claims_imports as "claims_imports: Json<UpstreamOAuthProviderClaimsImports>",
jwks_uri_override,
authorization_endpoint_override,
token_endpoint_override,
userinfo_endpoint_override,
discovery_mode,
pkce_mode,
response_mode,

View File

@@ -40,6 +40,7 @@ struct SessionLookup {
code_challenge_verifier: Option<String>,
nonce: String,
id_token: Option<String>,
userinfo: Option<serde_json::Value>,
created_at: DateTime<Utc>,
completed_at: Option<DateTime<Utc>>,
consumed_at: Option<DateTime<Utc>>,
@@ -55,22 +56,30 @@ impl TryFrom<SessionLookup> for UpstreamOAuthAuthorizationSession {
value.upstream_oauth_link_id,
value.id_token,
value.extra_callback_parameters,
value.userinfo,
value.completed_at,
value.consumed_at,
) {
(None, None, None, None, None) => UpstreamOAuthAuthorizationSessionState::Pending,
(Some(link_id), id_token, extra_callback_parameters, Some(completed_at), None) => {
UpstreamOAuthAuthorizationSessionState::Completed {
completed_at,
link_id: link_id.into(),
id_token,
extra_callback_parameters,
}
}
(None, None, None, None, None, None) => UpstreamOAuthAuthorizationSessionState::Pending,
(
Some(link_id),
id_token,
extra_callback_parameters,
userinfo,
Some(completed_at),
None,
) => UpstreamOAuthAuthorizationSessionState::Completed {
completed_at,
link_id: link_id.into(),
id_token,
extra_callback_parameters,
userinfo,
},
(
Some(link_id),
id_token,
extra_callback_parameters,
userinfo,
Some(completed_at),
Some(consumed_at),
) => UpstreamOAuthAuthorizationSessionState::Consumed {
@@ -78,6 +87,7 @@ impl TryFrom<SessionLookup> for UpstreamOAuthAuthorizationSession {
link_id: link_id.into(),
id_token,
extra_callback_parameters,
userinfo,
consumed_at,
},
_ => {
@@ -128,6 +138,7 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c>
nonce,
id_token,
extra_callback_parameters,
userinfo,
created_at,
completed_at,
consumed_at
@@ -184,8 +195,9 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c>
created_at,
completed_at,
consumed_at,
id_token
) VALUES ($1, $2, $3, $4, $5, $6, NULL, NULL, NULL)
id_token,
userinfo
) VALUES ($1, $2, $3, $4, $5, $6, NULL, NULL, NULL, NULL)
"#,
Uuid::from(id),
Uuid::from(upstream_oauth_provider.id),
@@ -226,6 +238,7 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c>
upstream_oauth_link: &UpstreamOAuthLink,
id_token: Option<String>,
extra_callback_parameters: Option<serde_json::Value>,
userinfo: Option<serde_json::Value>,
) -> Result<UpstreamOAuthAuthorizationSession, Self::Error> {
let completed_at = clock.now();
@@ -235,13 +248,15 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c>
SET upstream_oauth_link_id = $1,
completed_at = $2,
id_token = $3,
extra_callback_parameters = $4
WHERE upstream_oauth_authorization_session_id = $5
extra_callback_parameters = $4,
userinfo = $5
WHERE upstream_oauth_authorization_session_id = $6
"#,
Uuid::from(upstream_oauth_link.id),
completed_at,
id_token,
extra_callback_parameters,
userinfo,
Uuid::from(upstream_oauth_authorization_session.id),
)
.traced()
@@ -254,6 +269,7 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c>
upstream_oauth_link,
id_token,
extra_callback_parameters,
userinfo,
)
.map_err(DatabaseError::to_invalid_operation)?;

View File

@@ -42,6 +42,11 @@ pub struct UpstreamOAuthProviderParams {
/// `private_key_jwt` authentication methods are used
pub token_endpoint_signing_alg: Option<JsonWebSignatureAlg>,
/// Whether to fetch the user profile from the userinfo endpoint,
/// or to rely on the data returned in the `id_token` from the
/// `token_endpoint`.
pub fetch_userinfo: bool,
/// The client ID to use when authenticating to the upstream
pub client_id: String,
@@ -59,6 +64,10 @@ pub struct UpstreamOAuthProviderParams {
/// discovered
pub token_endpoint_override: Option<Url>,
/// The URL to use as the userinfo endpoint. If `None`, the URL will be
/// discovered
pub userinfo_endpoint_override: Option<Url>,
/// The URL to use when fetching JWKS. If `None`, the URL will be discovered
pub jwks_uri_override: Option<Url>,

View File

@@ -87,6 +87,7 @@ pub trait UpstreamOAuthSessionRepository: Send + Sync {
upstream_oauth_link: &UpstreamOAuthLink,
id_token: Option<String>,
extra_callback_parameters: Option<serde_json::Value>,
userinfo: Option<serde_json::Value>,
) -> Result<UpstreamOAuthAuthorizationSession, Self::Error>;
/// Mark a session as consumed
@@ -131,6 +132,7 @@ repository_impl!(UpstreamOAuthSessionRepository:
upstream_oauth_link: &UpstreamOAuthLink,
id_token: Option<String>,
extra_callback_parameters: Option<serde_json::Value>,
userinfo: Option<serde_json::Value>,
) -> Result<UpstreamOAuthAuthorizationSession, Self::Error>;
async fn consume(

View File

@@ -1895,11 +1895,21 @@
}
]
},
"fetch_userinfo": {
"description": "Whether to fetch the user profile from the userinfo endpoint, or to rely on the data returned in the `id_token` from the `token_endpoint`.\n\nDefaults to `false`.",
"default": false,
"type": "boolean"
},
"authorization_endpoint": {
"description": "The URL to use for the provider's authorization endpoint\n\nDefaults to the `authorization_endpoint` provided through discovery",
"type": "string",
"format": "uri"
},
"userinfo_endpoint": {
"description": "The URL to use for the provider's userinfo endpoint\n\nDefaults to the `userinfo_endpoint` provided through discovery",
"type": "string",
"format": "uri"
},
"token_endpoint": {
"description": "The URL to use for the provider's token endpoint\n\nDefaults to the `token_endpoint` provided through discovery",
"type": "string",