Add fetch_userinfo to upstream SSO provider (#3363)
This commit is contained in:
@@ -284,10 +284,12 @@ pub async fn config_sync(
|
||||
encrypted_client_secret,
|
||||
claims_imports: map_claims_imports(&provider.claims_imports),
|
||||
token_endpoint_override: provider.token_endpoint,
|
||||
userinfo_endpoint_override: provider.userinfo_endpoint,
|
||||
authorization_endpoint_override: provider.authorization_endpoint,
|
||||
jwks_uri_override: provider.jwks_uri,
|
||||
discovery_mode,
|
||||
pkce_mode,
|
||||
fetch_userinfo: provider.fetch_userinfo,
|
||||
response_mode,
|
||||
additional_authorization_parameters: provider
|
||||
.additional_authorization_parameters
|
||||
|
||||
@@ -465,12 +465,26 @@ pub struct Provider {
|
||||
#[serde(default, skip_serializing_if = "PkceMethod::is_default")]
|
||||
pub pkce_method: PkceMethod,
|
||||
|
||||
/// Whether to fetch the user profile from the userinfo endpoint,
|
||||
/// or to rely on the data returned in the `id_token` from the
|
||||
/// `token_endpoint`.
|
||||
///
|
||||
/// Defaults to `false`.
|
||||
#[serde(default)]
|
||||
pub fetch_userinfo: bool,
|
||||
|
||||
/// The URL to use for the provider's authorization endpoint
|
||||
///
|
||||
/// Defaults to the `authorization_endpoint` provided through discovery
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub authorization_endpoint: Option<Url>,
|
||||
|
||||
/// The URL to use for the provider's userinfo endpoint
|
||||
///
|
||||
/// Defaults to the `userinfo_endpoint` provided through discovery
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub userinfo_endpoint: Option<Url>,
|
||||
|
||||
/// The URL to use for the provider's token endpoint
|
||||
///
|
||||
/// Defaults to the `token_endpoint` provided through discovery
|
||||
|
||||
@@ -228,6 +228,8 @@ pub struct UpstreamOAuthProvider {
|
||||
pub authorization_endpoint_override: Option<Url>,
|
||||
pub scope: Scope,
|
||||
pub token_endpoint_override: Option<Url>,
|
||||
pub userinfo_endpoint_override: Option<Url>,
|
||||
pub fetch_userinfo: bool,
|
||||
pub client_id: String,
|
||||
pub encrypted_client_secret: Option<String>,
|
||||
pub token_endpoint_signing_alg: Option<JsonWebSignatureAlg>,
|
||||
|
||||
@@ -20,6 +20,7 @@ pub enum UpstreamOAuthAuthorizationSessionState {
|
||||
link_id: Ulid,
|
||||
id_token: Option<String>,
|
||||
extra_callback_parameters: Option<serde_json::Value>,
|
||||
userinfo: Option<serde_json::Value>,
|
||||
},
|
||||
Consumed {
|
||||
completed_at: DateTime<Utc>,
|
||||
@@ -27,6 +28,7 @@ pub enum UpstreamOAuthAuthorizationSessionState {
|
||||
link_id: Ulid,
|
||||
id_token: Option<String>,
|
||||
extra_callback_parameters: Option<serde_json::Value>,
|
||||
userinfo: Option<serde_json::Value>,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -45,6 +47,7 @@ impl UpstreamOAuthAuthorizationSessionState {
|
||||
link: &UpstreamOAuthLink,
|
||||
id_token: Option<String>,
|
||||
extra_callback_parameters: Option<serde_json::Value>,
|
||||
userinfo: Option<serde_json::Value>,
|
||||
) -> Result<Self, InvalidTransitionError> {
|
||||
match self {
|
||||
Self::Pending => Ok(Self::Completed {
|
||||
@@ -52,6 +55,7 @@ impl UpstreamOAuthAuthorizationSessionState {
|
||||
link_id: link.id,
|
||||
id_token,
|
||||
extra_callback_parameters,
|
||||
userinfo,
|
||||
}),
|
||||
Self::Completed { .. } | Self::Consumed { .. } => Err(InvalidTransitionError),
|
||||
}
|
||||
@@ -72,12 +76,14 @@ impl UpstreamOAuthAuthorizationSessionState {
|
||||
link_id,
|
||||
id_token,
|
||||
extra_callback_parameters,
|
||||
userinfo,
|
||||
} => Ok(Self::Consumed {
|
||||
completed_at,
|
||||
link_id,
|
||||
consumed_at,
|
||||
id_token,
|
||||
extra_callback_parameters,
|
||||
userinfo,
|
||||
}),
|
||||
Self::Pending | Self::Consumed { .. } => Err(InvalidTransitionError),
|
||||
}
|
||||
@@ -151,6 +157,14 @@ impl UpstreamOAuthAuthorizationSessionState {
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn userinfo(&self) -> Option<&serde_json::Value> {
|
||||
match self {
|
||||
Self::Pending => None,
|
||||
Self::Completed { userinfo, .. } | Self::Consumed { userinfo, .. } => userinfo.as_ref(),
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the time at which the upstream OAuth 2.0 authorization session was
|
||||
/// consumed.
|
||||
///
|
||||
@@ -229,10 +243,15 @@ impl UpstreamOAuthAuthorizationSession {
|
||||
link: &UpstreamOAuthLink,
|
||||
id_token: Option<String>,
|
||||
extra_callback_parameters: Option<serde_json::Value>,
|
||||
userinfo: Option<serde_json::Value>,
|
||||
) -> Result<Self, InvalidTransitionError> {
|
||||
self.state =
|
||||
self.state
|
||||
.complete(completed_at, link, id_token, extra_callback_parameters)?;
|
||||
self.state = self.state.complete(
|
||||
completed_at,
|
||||
link,
|
||||
id_token,
|
||||
extra_callback_parameters,
|
||||
userinfo,
|
||||
)?;
|
||||
Ok(self)
|
||||
}
|
||||
|
||||
|
||||
@@ -108,6 +108,18 @@ impl<'a> LazyProviderInfos<'a> {
|
||||
Ok(self.load().await?.token_endpoint())
|
||||
}
|
||||
|
||||
/// Get the userinfo endpoint for the provider.
|
||||
///
|
||||
/// Uses [`UpstreamOAuthProvider.userinfo_endpoint_override`] if set,
|
||||
/// otherwise uses the one from discovery.
|
||||
pub async fn userinfo_endpoint(&mut self) -> Result<&Url, DiscoveryError> {
|
||||
if let Some(userinfo_endpoint) = &self.provider.userinfo_endpoint_override {
|
||||
return Ok(userinfo_endpoint);
|
||||
}
|
||||
|
||||
Ok(self.load().await?.userinfo_endpoint())
|
||||
}
|
||||
|
||||
/// Get the PKCE methods supported by the provider.
|
||||
///
|
||||
/// If the mode is set to auto, it will use the ones from discovery,
|
||||
@@ -387,9 +399,11 @@ mod tests {
|
||||
brand_name: None,
|
||||
discovery_mode: UpstreamOAuthProviderDiscoveryMode::Insecure,
|
||||
pkce_mode: UpstreamOAuthProviderPkceMode::Auto,
|
||||
fetch_userinfo: false,
|
||||
jwks_uri_override: None,
|
||||
authorization_endpoint_override: None,
|
||||
scope: Scope::from_iter([OPENID]),
|
||||
userinfo_endpoint_override: None,
|
||||
token_endpoint_override: None,
|
||||
client_id: "client_id".to_owned(),
|
||||
encrypted_client_secret: None,
|
||||
|
||||
@@ -29,6 +29,7 @@ use mas_storage::{
|
||||
use mas_templates::{FormPostContext, Templates};
|
||||
use oauth2_types::errors::ClientErrorCode;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_json::json;
|
||||
use thiserror::Error;
|
||||
use ulid::Ulid;
|
||||
|
||||
@@ -117,7 +118,7 @@ pub(crate) enum RouteError {
|
||||
},
|
||||
|
||||
#[error(transparent)]
|
||||
Internal(Box<dyn std::error::Error>),
|
||||
Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
|
||||
}
|
||||
|
||||
impl_from_error_for_route!(mas_templates::TemplateError);
|
||||
@@ -125,6 +126,7 @@ impl_from_error_for_route!(mas_storage::RepositoryError);
|
||||
impl_from_error_for_route!(mas_oidc_client::error::DiscoveryError);
|
||||
impl_from_error_for_route!(mas_oidc_client::error::JwksError);
|
||||
impl_from_error_for_route!(mas_oidc_client::error::TokenAuthorizationCodeError);
|
||||
impl_from_error_for_route!(mas_oidc_client::error::UserInfoError);
|
||||
impl_from_error_for_route!(super::ProviderCredentialsError);
|
||||
impl_from_error_for_route!(super::cookie::UpstreamSessionNotFound);
|
||||
|
||||
@@ -274,7 +276,7 @@ pub(crate) async fn handler(
|
||||
redirect_uri,
|
||||
};
|
||||
|
||||
let id_token_verification_data = JwtVerificationData {
|
||||
let verification_data = JwtVerificationData {
|
||||
issuer: &provider.issuer,
|
||||
jwks: &jwks,
|
||||
// TODO: make that configurable
|
||||
@@ -282,25 +284,48 @@ pub(crate) async fn handler(
|
||||
client_id: &provider.client_id,
|
||||
};
|
||||
|
||||
let (response, id_token) =
|
||||
let (response, id_token_map) =
|
||||
mas_oidc_client::requests::authorization_code::access_token_with_authorization_code(
|
||||
&client,
|
||||
client_credentials,
|
||||
lazy_metadata.token_endpoint().await?,
|
||||
code,
|
||||
validation_data,
|
||||
Some(id_token_verification_data),
|
||||
Some(verification_data),
|
||||
clock.now(),
|
||||
&mut rng,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let (_header, id_token) = id_token.ok_or(RouteError::MissingIDToken)?.into_parts();
|
||||
let (_header, id_token) = id_token_map
|
||||
.clone()
|
||||
.ok_or(RouteError::MissingIDToken)?
|
||||
.into_parts();
|
||||
|
||||
let mut context = AttributeMappingContext::new().with_id_token_claims(id_token);
|
||||
if let Some(extra_callback_parameters) = extra_callback_parameters.clone() {
|
||||
context = context.with_extra_callback_parameters(extra_callback_parameters);
|
||||
}
|
||||
|
||||
let userinfo = if provider.fetch_userinfo {
|
||||
Some(json!(
|
||||
mas_oidc_client::requests::userinfo::fetch_userinfo(
|
||||
&client,
|
||||
lazy_metadata.userinfo_endpoint().await?,
|
||||
response.access_token.as_str(),
|
||||
Some(verification_data),
|
||||
&id_token_map.ok_or(RouteError::MissingIDToken)?,
|
||||
)
|
||||
.await?
|
||||
))
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
if let Some(userinfo) = userinfo.clone() {
|
||||
context = context.with_userinfo_claims(userinfo);
|
||||
}
|
||||
|
||||
let context = context.build();
|
||||
|
||||
let env = environment();
|
||||
@@ -341,6 +366,7 @@ pub(crate) async fn handler(
|
||||
&link,
|
||||
response.id_token,
|
||||
extra_callback_parameters,
|
||||
userinfo,
|
||||
)
|
||||
.await?;
|
||||
|
||||
|
||||
@@ -344,6 +344,9 @@ pub(crate) async fn get(
|
||||
if let Some(extra_callback_parameters) = upstream_session.extra_callback_parameters() {
|
||||
context = context.with_extra_callback_parameters(extra_callback_parameters.clone());
|
||||
}
|
||||
if let Some(userinfo) = upstream_session.userinfo() {
|
||||
context = context.with_userinfo_claims(userinfo.clone());
|
||||
}
|
||||
let context = context.build();
|
||||
|
||||
let ctx = if provider.claims_imports.displayname.ignore() {
|
||||
@@ -582,6 +585,9 @@ pub(crate) async fn post(
|
||||
if let Some(extra_callback_parameters) = upstream_session.extra_callback_parameters() {
|
||||
context = context.with_extra_callback_parameters(extra_callback_parameters.clone());
|
||||
}
|
||||
if let Some(userinfo) = upstream_session.userinfo() {
|
||||
context = context.with_userinfo_claims(userinfo.clone());
|
||||
}
|
||||
let context = context.build();
|
||||
|
||||
// Is the email verified according to the upstream provider?
|
||||
@@ -921,6 +927,8 @@ mod tests {
|
||||
claims_imports,
|
||||
authorization_endpoint_override: None,
|
||||
token_endpoint_override: None,
|
||||
userinfo_endpoint_override: None,
|
||||
fetch_userinfo: false,
|
||||
jwks_uri_override: None,
|
||||
discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
|
||||
pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
|
||||
@@ -958,6 +966,7 @@ mod tests {
|
||||
&link,
|
||||
Some(id_token.into_string()),
|
||||
None,
|
||||
None,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
@@ -23,6 +23,7 @@ use minijinja::{
|
||||
pub(crate) struct AttributeMappingContext {
|
||||
id_token_claims: Option<HashMap<String, serde_json::Value>>,
|
||||
extra_callback_parameters: Option<serde_json::Value>,
|
||||
userinfo_claims: Option<serde_json::Value>,
|
||||
}
|
||||
|
||||
impl AttributeMappingContext {
|
||||
@@ -46,6 +47,11 @@ impl AttributeMappingContext {
|
||||
self
|
||||
}
|
||||
|
||||
pub fn with_userinfo_claims(mut self, userinfo_claims: serde_json::Value) -> Self {
|
||||
self.userinfo_claims = Some(userinfo_claims);
|
||||
self
|
||||
}
|
||||
|
||||
pub fn build(self) -> Value {
|
||||
Value::from_object(self)
|
||||
}
|
||||
@@ -54,7 +60,25 @@ impl AttributeMappingContext {
|
||||
impl Object for AttributeMappingContext {
|
||||
fn get_value(self: &Arc<Self>, name: &Value) -> Option<Value> {
|
||||
match name.as_str()? {
|
||||
"user" | "id_token_claims" => self.id_token_claims.as_ref().map(Value::from_serialize),
|
||||
"user" => {
|
||||
if self.id_token_claims.is_none() && self.userinfo_claims.is_none() {
|
||||
return None;
|
||||
}
|
||||
let mut merged_user: HashMap<String, serde_json::Value> = HashMap::new();
|
||||
if let serde_json::Value::Object(userinfo) = self
|
||||
.userinfo_claims
|
||||
.clone()
|
||||
.unwrap_or(serde_json::Value::Null)
|
||||
{
|
||||
merged_user.extend(userinfo);
|
||||
}
|
||||
if let Some(id_token) = self.id_token_claims.clone() {
|
||||
merged_user.extend(id_token);
|
||||
}
|
||||
Some(Value::from_serialize(merged_user))
|
||||
}
|
||||
"id_token_claims" => self.id_token_claims.as_ref().map(Value::from_serialize),
|
||||
"userinfo_claims" => self.userinfo_claims.as_ref().map(Value::from_serialize),
|
||||
"extra_callback_parameters" => self
|
||||
.extra_callback_parameters
|
||||
.as_ref()
|
||||
@@ -64,17 +88,20 @@ impl Object for AttributeMappingContext {
|
||||
}
|
||||
|
||||
fn enumerate(self: &Arc<Self>) -> Enumerator {
|
||||
match (
|
||||
self.id_token_claims.is_some(),
|
||||
self.extra_callback_parameters.is_some(),
|
||||
) {
|
||||
(true, true) => {
|
||||
Enumerator::Str(&["user", "id_token_claims", "extra_callback_parameters"])
|
||||
}
|
||||
(true, false) => Enumerator::Str(&["user", "id_token_claims"]),
|
||||
(false, true) => Enumerator::Str(&["extra_callback_parameters"]),
|
||||
(false, false) => Enumerator::Str(&["user"]),
|
||||
let mut attrs = Vec::new();
|
||||
if self.id_token_claims.is_some() || self.userinfo_claims.is_none() {
|
||||
attrs.push(minijinja::Value::from("user"));
|
||||
}
|
||||
if self.id_token_claims.is_some() {
|
||||
attrs.push(minijinja::Value::from("id_token_claims"));
|
||||
}
|
||||
if self.userinfo_claims.is_some() {
|
||||
attrs.push(minijinja::Value::from("userinfo_claims"));
|
||||
}
|
||||
if self.extra_callback_parameters.is_some() {
|
||||
attrs.push(minijinja::Value::from("extra_callback_parameters"));
|
||||
}
|
||||
Enumerator::Values(attrs)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -403,11 +403,13 @@ mod test {
|
||||
scope: [OPENID].into_iter().collect(),
|
||||
token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
|
||||
token_endpoint_signing_alg: None,
|
||||
fetch_userinfo: false,
|
||||
client_id: "client".to_owned(),
|
||||
encrypted_client_secret: None,
|
||||
claims_imports: UpstreamOAuthProviderClaimsImports::default(),
|
||||
authorization_endpoint_override: None,
|
||||
token_endpoint_override: None,
|
||||
userinfo_endpoint_override: None,
|
||||
jwks_uri_override: None,
|
||||
discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
|
||||
pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
|
||||
@@ -439,11 +441,13 @@ mod test {
|
||||
scope: [OPENID].into_iter().collect(),
|
||||
token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
|
||||
token_endpoint_signing_alg: None,
|
||||
fetch_userinfo: false,
|
||||
client_id: "client".to_owned(),
|
||||
encrypted_client_secret: None,
|
||||
claims_imports: UpstreamOAuthProviderClaimsImports::default(),
|
||||
authorization_endpoint_override: None,
|
||||
token_endpoint_override: None,
|
||||
userinfo_endpoint_override: None,
|
||||
jwks_uri_override: None,
|
||||
discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
|
||||
pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
|
||||
|
||||
@@ -950,6 +950,15 @@ impl VerifiedProviderMetadata {
|
||||
}
|
||||
}
|
||||
|
||||
/// URL of the authorization server's userinfo endpoint.
|
||||
#[must_use]
|
||||
pub fn userinfo_endpoint(&self) -> &Url {
|
||||
match &self.userinfo_endpoint {
|
||||
Some(u) => u,
|
||||
None => unreachable!(),
|
||||
}
|
||||
}
|
||||
|
||||
/// URL of the authorization server's token endpoint.
|
||||
#[must_use]
|
||||
pub fn token_endpoint(&self) -> &Url {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n SELECT\n upstream_oauth_provider_id,\n issuer,\n human_name,\n brand_name,\n scope,\n client_id,\n encrypted_client_secret,\n token_endpoint_signing_alg,\n token_endpoint_auth_method,\n created_at,\n disabled_at,\n claims_imports as \"claims_imports: Json<UpstreamOAuthProviderClaimsImports>\",\n jwks_uri_override,\n authorization_endpoint_override,\n token_endpoint_override,\n discovery_mode,\n pkce_mode,\n response_mode,\n additional_parameters as \"additional_parameters: Json<Vec<(String, String)>>\"\n FROM upstream_oauth_providers\n WHERE disabled_at IS NULL\n ",
|
||||
"query": "\n SELECT\n upstream_oauth_provider_id,\n issuer,\n human_name,\n brand_name,\n scope,\n client_id,\n encrypted_client_secret,\n token_endpoint_signing_alg,\n token_endpoint_auth_method,\n fetch_userinfo,\n created_at,\n disabled_at,\n claims_imports as \"claims_imports: Json<UpstreamOAuthProviderClaimsImports>\",\n jwks_uri_override,\n authorization_endpoint_override,\n token_endpoint_override,\n userinfo_endpoint_override,\n discovery_mode,\n pkce_mode,\n response_mode,\n additional_parameters as \"additional_parameters: Json<Vec<(String, String)>>\"\n FROM upstream_oauth_providers\n WHERE disabled_at IS NULL\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
@@ -50,51 +50,61 @@
|
||||
},
|
||||
{
|
||||
"ordinal": 9,
|
||||
"name": "fetch_userinfo",
|
||||
"type_info": "Bool"
|
||||
},
|
||||
{
|
||||
"ordinal": 10,
|
||||
"name": "created_at",
|
||||
"type_info": "Timestamptz"
|
||||
},
|
||||
{
|
||||
"ordinal": 10,
|
||||
"ordinal": 11,
|
||||
"name": "disabled_at",
|
||||
"type_info": "Timestamptz"
|
||||
},
|
||||
{
|
||||
"ordinal": 11,
|
||||
"ordinal": 12,
|
||||
"name": "claims_imports: Json<UpstreamOAuthProviderClaimsImports>",
|
||||
"type_info": "Jsonb"
|
||||
},
|
||||
{
|
||||
"ordinal": 12,
|
||||
"ordinal": 13,
|
||||
"name": "jwks_uri_override",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 13,
|
||||
"ordinal": 14,
|
||||
"name": "authorization_endpoint_override",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 14,
|
||||
"ordinal": 15,
|
||||
"name": "token_endpoint_override",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 15,
|
||||
"name": "discovery_mode",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 16,
|
||||
"name": "pkce_mode",
|
||||
"name": "userinfo_endpoint_override",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 17,
|
||||
"name": "response_mode",
|
||||
"name": "discovery_mode",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 18,
|
||||
"name": "pkce_mode",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 19,
|
||||
"name": "response_mode",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 20,
|
||||
"name": "additional_parameters: Json<Vec<(String, String)>>",
|
||||
"type_info": "Jsonb"
|
||||
}
|
||||
@@ -113,16 +123,18 @@
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
true
|
||||
]
|
||||
},
|
||||
"hash": "6b133c3c6bfc3c80a21f6f72d0a6468f748ed59e88d8c904bb0a4bbfee43a67f"
|
||||
"hash": "39657c8064532745c8a8a944b73f650b468a4677eddf671c69c329d361edf00e"
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n UPDATE upstream_oauth_authorization_sessions\n SET upstream_oauth_link_id = $1,\n completed_at = $2,\n id_token = $3,\n extra_callback_parameters = $4\n WHERE upstream_oauth_authorization_session_id = $5\n ",
|
||||
"query": "\n UPDATE upstream_oauth_authorization_sessions\n SET upstream_oauth_link_id = $1,\n completed_at = $2,\n id_token = $3,\n extra_callback_parameters = $4,\n userinfo = $5\n WHERE upstream_oauth_authorization_session_id = $6\n ",
|
||||
"describe": {
|
||||
"columns": [],
|
||||
"parameters": {
|
||||
@@ -9,10 +9,11 @@
|
||||
"Timestamptz",
|
||||
"Text",
|
||||
"Jsonb",
|
||||
"Jsonb",
|
||||
"Uuid"
|
||||
]
|
||||
},
|
||||
"nullable": []
|
||||
},
|
||||
"hash": "5516235e0983fb64d18e82dbe3e34f966ed71a0ed59be0d48ec66fedf64e707d"
|
||||
"hash": "5f5245ace61b896f92be78ab4fef701b37c9e3c2f4a332f418b9fb2625a0fe3f"
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n SELECT\n upstream_oauth_provider_id,\n issuer,\n human_name,\n brand_name,\n scope,\n client_id,\n encrypted_client_secret,\n token_endpoint_signing_alg,\n token_endpoint_auth_method,\n created_at,\n disabled_at,\n claims_imports as \"claims_imports: Json<UpstreamOAuthProviderClaimsImports>\",\n jwks_uri_override,\n authorization_endpoint_override,\n token_endpoint_override,\n discovery_mode,\n pkce_mode,\n response_mode,\n additional_parameters as \"additional_parameters: Json<Vec<(String, String)>>\"\n FROM upstream_oauth_providers\n WHERE upstream_oauth_provider_id = $1\n ",
|
||||
"query": "\n SELECT\n upstream_oauth_provider_id,\n issuer,\n human_name,\n brand_name,\n scope,\n client_id,\n encrypted_client_secret,\n token_endpoint_signing_alg,\n token_endpoint_auth_method,\n fetch_userinfo,\n created_at,\n disabled_at,\n claims_imports as \"claims_imports: Json<UpstreamOAuthProviderClaimsImports>\",\n jwks_uri_override,\n authorization_endpoint_override,\n token_endpoint_override,\n userinfo_endpoint_override,\n discovery_mode,\n pkce_mode,\n response_mode,\n additional_parameters as \"additional_parameters: Json<Vec<(String, String)>>\"\n FROM upstream_oauth_providers\n WHERE upstream_oauth_provider_id = $1\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
@@ -50,51 +50,61 @@
|
||||
},
|
||||
{
|
||||
"ordinal": 9,
|
||||
"name": "fetch_userinfo",
|
||||
"type_info": "Bool"
|
||||
},
|
||||
{
|
||||
"ordinal": 10,
|
||||
"name": "created_at",
|
||||
"type_info": "Timestamptz"
|
||||
},
|
||||
{
|
||||
"ordinal": 10,
|
||||
"ordinal": 11,
|
||||
"name": "disabled_at",
|
||||
"type_info": "Timestamptz"
|
||||
},
|
||||
{
|
||||
"ordinal": 11,
|
||||
"ordinal": 12,
|
||||
"name": "claims_imports: Json<UpstreamOAuthProviderClaimsImports>",
|
||||
"type_info": "Jsonb"
|
||||
},
|
||||
{
|
||||
"ordinal": 12,
|
||||
"ordinal": 13,
|
||||
"name": "jwks_uri_override",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 13,
|
||||
"ordinal": 14,
|
||||
"name": "authorization_endpoint_override",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 14,
|
||||
"ordinal": 15,
|
||||
"name": "token_endpoint_override",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 15,
|
||||
"name": "discovery_mode",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 16,
|
||||
"name": "pkce_mode",
|
||||
"name": "userinfo_endpoint_override",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 17,
|
||||
"name": "response_mode",
|
||||
"name": "discovery_mode",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 18,
|
||||
"name": "pkce_mode",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 19,
|
||||
"name": "response_mode",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 20,
|
||||
"name": "additional_parameters: Json<Vec<(String, String)>>",
|
||||
"type_info": "Jsonb"
|
||||
}
|
||||
@@ -115,16 +125,18 @@
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
true
|
||||
]
|
||||
},
|
||||
"hash": "73f4e5a724a432f1328c6112185cdc9c7a1ae1de45a2a8c02e7a2b8020b41160"
|
||||
"hash": "887bd597132831c5caab2356f2d935c00a32274161ec5265da91d1c75ad0bb2b"
|
||||
}
|
||||
32
crates/storage-pg/.sqlx/query-8e1c0760c0b652cf62e47779f9d0aef89463cc60eeae2088d0fedf0aeb75718b.json
generated
Normal file
32
crates/storage-pg/.sqlx/query-8e1c0760c0b652cf62e47779f9d0aef89463cc60eeae2088d0fedf0aeb75718b.json
generated
Normal file
@@ -0,0 +1,32 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n INSERT INTO upstream_oauth_providers (\n upstream_oauth_provider_id,\n issuer,\n human_name,\n brand_name,\n scope,\n token_endpoint_auth_method,\n fetch_userinfo,\n token_endpoint_signing_alg,\n client_id,\n encrypted_client_secret,\n claims_imports,\n authorization_endpoint_override,\n token_endpoint_override,\n userinfo_endpoint_override,\n jwks_uri_override,\n discovery_mode,\n pkce_mode,\n response_mode,\n created_at\n ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10,\n $11, $12, $13, $14, $15, $16, $17, $18, $19)\n ",
|
||||
"describe": {
|
||||
"columns": [],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Uuid",
|
||||
"Text",
|
||||
"Text",
|
||||
"Text",
|
||||
"Text",
|
||||
"Text",
|
||||
"Bool",
|
||||
"Text",
|
||||
"Text",
|
||||
"Text",
|
||||
"Jsonb",
|
||||
"Text",
|
||||
"Text",
|
||||
"Text",
|
||||
"Text",
|
||||
"Text",
|
||||
"Text",
|
||||
"Text",
|
||||
"Timestamptz"
|
||||
]
|
||||
},
|
||||
"nullable": []
|
||||
},
|
||||
"hash": "8e1c0760c0b652cf62e47779f9d0aef89463cc60eeae2088d0fedf0aeb75718b"
|
||||
}
|
||||
41
crates/storage-pg/.sqlx/query-bf7747552fe6f5489dec3c91fe1cb13a737644b94871c28334a29c88977dd84c.json
generated
Normal file
41
crates/storage-pg/.sqlx/query-bf7747552fe6f5489dec3c91fe1cb13a737644b94871c28334a29c88977dd84c.json
generated
Normal file
@@ -0,0 +1,41 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n INSERT INTO upstream_oauth_providers (\n upstream_oauth_provider_id,\n issuer,\n human_name,\n brand_name,\n scope,\n token_endpoint_auth_method,\n fetch_userinfo,\n token_endpoint_signing_alg,\n client_id,\n encrypted_client_secret,\n claims_imports,\n authorization_endpoint_override,\n token_endpoint_override,\n userinfo_endpoint_override,\n jwks_uri_override,\n discovery_mode,\n pkce_mode,\n response_mode,\n additional_parameters,\n created_at\n ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11,\n $12, $13, $14, $15, $16, $17, $18, $19, $20)\n ON CONFLICT (upstream_oauth_provider_id)\n DO UPDATE\n SET\n issuer = EXCLUDED.issuer,\n human_name = EXCLUDED.human_name,\n brand_name = EXCLUDED.brand_name,\n scope = EXCLUDED.scope,\n token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method,\n fetch_userinfo = EXCLUDED.fetch_userinfo,\n token_endpoint_signing_alg = EXCLUDED.token_endpoint_signing_alg,\n disabled_at = NULL,\n client_id = EXCLUDED.client_id,\n encrypted_client_secret = EXCLUDED.encrypted_client_secret,\n claims_imports = EXCLUDED.claims_imports,\n authorization_endpoint_override = EXCLUDED.authorization_endpoint_override,\n token_endpoint_override = EXCLUDED.token_endpoint_override,\n userinfo_endpoint_override = EXCLUDED.userinfo_endpoint_override,\n jwks_uri_override = EXCLUDED.jwks_uri_override,\n discovery_mode = EXCLUDED.discovery_mode,\n pkce_mode = EXCLUDED.pkce_mode,\n response_mode = EXCLUDED.response_mode,\n additional_parameters = EXCLUDED.additional_parameters\n RETURNING created_at\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "created_at",
|
||||
"type_info": "Timestamptz"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Uuid",
|
||||
"Text",
|
||||
"Text",
|
||||
"Text",
|
||||
"Text",
|
||||
"Text",
|
||||
"Bool",
|
||||
"Text",
|
||||
"Text",
|
||||
"Text",
|
||||
"Jsonb",
|
||||
"Text",
|
||||
"Text",
|
||||
"Text",
|
||||
"Text",
|
||||
"Text",
|
||||
"Text",
|
||||
"Text",
|
||||
"Jsonb",
|
||||
"Timestamptz"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"hash": "bf7747552fe6f5489dec3c91fe1cb13a737644b94871c28334a29c88977dd84c"
|
||||
}
|
||||
@@ -1,39 +0,0 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n INSERT INTO upstream_oauth_providers (\n upstream_oauth_provider_id,\n issuer,\n human_name,\n brand_name,\n scope,\n token_endpoint_auth_method,\n token_endpoint_signing_alg,\n client_id,\n encrypted_client_secret,\n claims_imports,\n authorization_endpoint_override,\n token_endpoint_override,\n jwks_uri_override,\n discovery_mode,\n pkce_mode,\n response_mode,\n additional_parameters,\n created_at\n ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9,\n $10, $11, $12, $13, $14, $15, $16, $17, $18)\n ON CONFLICT (upstream_oauth_provider_id)\n DO UPDATE\n SET\n issuer = EXCLUDED.issuer,\n human_name = EXCLUDED.human_name,\n brand_name = EXCLUDED.brand_name,\n scope = EXCLUDED.scope,\n token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method,\n token_endpoint_signing_alg = EXCLUDED.token_endpoint_signing_alg,\n disabled_at = NULL,\n client_id = EXCLUDED.client_id,\n encrypted_client_secret = EXCLUDED.encrypted_client_secret,\n claims_imports = EXCLUDED.claims_imports,\n authorization_endpoint_override = EXCLUDED.authorization_endpoint_override,\n token_endpoint_override = EXCLUDED.token_endpoint_override,\n jwks_uri_override = EXCLUDED.jwks_uri_override,\n discovery_mode = EXCLUDED.discovery_mode,\n pkce_mode = EXCLUDED.pkce_mode,\n response_mode = EXCLUDED.response_mode,\n additional_parameters = EXCLUDED.additional_parameters\n RETURNING created_at\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "created_at",
|
||||
"type_info": "Timestamptz"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Uuid",
|
||||
"Text",
|
||||
"Text",
|
||||
"Text",
|
||||
"Text",
|
||||
"Text",
|
||||
"Text",
|
||||
"Text",
|
||||
"Text",
|
||||
"Jsonb",
|
||||
"Text",
|
||||
"Text",
|
||||
"Text",
|
||||
"Text",
|
||||
"Text",
|
||||
"Text",
|
||||
"Jsonb",
|
||||
"Timestamptz"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false
|
||||
]
|
||||
},
|
||||
"hash": "e36ed76d0176edf8c4a029f017b8f368a529b2d32a54c52f6a28b9e615716f4d"
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n SELECT\n upstream_oauth_authorization_session_id,\n upstream_oauth_provider_id,\n upstream_oauth_link_id,\n state,\n code_challenge_verifier,\n nonce,\n id_token,\n extra_callback_parameters,\n created_at,\n completed_at,\n consumed_at\n FROM upstream_oauth_authorization_sessions\n WHERE upstream_oauth_authorization_session_id = $1\n ",
|
||||
"query": "\n SELECT\n upstream_oauth_authorization_session_id,\n upstream_oauth_provider_id,\n upstream_oauth_link_id,\n state,\n code_challenge_verifier,\n nonce,\n id_token,\n extra_callback_parameters,\n userinfo,\n created_at,\n completed_at,\n consumed_at\n FROM upstream_oauth_authorization_sessions\n WHERE upstream_oauth_authorization_session_id = $1\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
@@ -45,16 +45,21 @@
|
||||
},
|
||||
{
|
||||
"ordinal": 8,
|
||||
"name": "userinfo",
|
||||
"type_info": "Jsonb"
|
||||
},
|
||||
{
|
||||
"ordinal": 9,
|
||||
"name": "created_at",
|
||||
"type_info": "Timestamptz"
|
||||
},
|
||||
{
|
||||
"ordinal": 9,
|
||||
"ordinal": 10,
|
||||
"name": "completed_at",
|
||||
"type_info": "Timestamptz"
|
||||
},
|
||||
{
|
||||
"ordinal": 10,
|
||||
"ordinal": 11,
|
||||
"name": "consumed_at",
|
||||
"type_info": "Timestamptz"
|
||||
}
|
||||
@@ -73,10 +78,11 @@
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
true,
|
||||
true
|
||||
]
|
||||
},
|
||||
"hash": "7d329e0c57f36b9ffe2aa7ddf4a21e293522c00009cca0222524b0c73f6eee30"
|
||||
"hash": "ea30b3809fd7c1d4e9983909c0219f343953a89f2a43f6b8c4ab4fbea7645ccc"
|
||||
}
|
||||
@@ -1,30 +0,0 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n INSERT INTO upstream_oauth_providers (\n upstream_oauth_provider_id,\n issuer,\n human_name,\n brand_name,\n scope,\n token_endpoint_auth_method,\n token_endpoint_signing_alg,\n client_id,\n encrypted_client_secret,\n claims_imports,\n authorization_endpoint_override,\n token_endpoint_override,\n jwks_uri_override,\n discovery_mode,\n pkce_mode,\n response_mode,\n created_at\n ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9,\n $10, $11, $12, $13, $14, $15, $16, $17)\n ",
|
||||
"describe": {
|
||||
"columns": [],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Uuid",
|
||||
"Text",
|
||||
"Text",
|
||||
"Text",
|
||||
"Text",
|
||||
"Text",
|
||||
"Text",
|
||||
"Text",
|
||||
"Text",
|
||||
"Jsonb",
|
||||
"Text",
|
||||
"Text",
|
||||
"Text",
|
||||
"Text",
|
||||
"Text",
|
||||
"Text",
|
||||
"Timestamptz"
|
||||
]
|
||||
},
|
||||
"nullable": []
|
||||
},
|
||||
"hash": "ebb1a78003293376a52de830f89f6f526ad1c5c823328463a6525d3c3d0d95c6"
|
||||
}
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n INSERT INTO upstream_oauth_authorization_sessions (\n upstream_oauth_authorization_session_id,\n upstream_oauth_provider_id,\n state,\n code_challenge_verifier,\n nonce,\n created_at,\n completed_at,\n consumed_at,\n id_token\n ) VALUES ($1, $2, $3, $4, $5, $6, NULL, NULL, NULL)\n ",
|
||||
"query": "\n INSERT INTO upstream_oauth_authorization_sessions (\n upstream_oauth_authorization_session_id,\n upstream_oauth_provider_id,\n state,\n code_challenge_verifier,\n nonce,\n created_at,\n completed_at,\n consumed_at,\n id_token,\n userinfo\n ) VALUES ($1, $2, $3, $4, $5, $6, NULL, NULL, NULL, NULL)\n ",
|
||||
"describe": {
|
||||
"columns": [],
|
||||
"parameters": {
|
||||
@@ -15,5 +15,5 @@
|
||||
},
|
||||
"nullable": []
|
||||
},
|
||||
"hash": "64e6ea47c2e877c1ebe4338d64d9ad8a6c1c777d1daea024b8ca2e7f0dd75b0f"
|
||||
"hash": "f5c2ec9b7038d7ed36091e670f9bf34f8aa9ea8ed50929731845e32dc3176e39"
|
||||
}
|
||||
@@ -0,0 +1,13 @@
|
||||
-- Copyright 2024 New Vector Ltd.
|
||||
--
|
||||
-- SPDX-License-Identifier: AGPL-3.0-only
|
||||
-- Please see LICENSE in the repository root for full details.
|
||||
|
||||
-- Add columms to upstream_oauth_providers and upstream_oauth_authorization_sessions
|
||||
-- table to handle userinfo endpoint.
|
||||
ALTER TABLE "upstream_oauth_providers"
|
||||
ADD COLUMN "fetch_userinfo" BOOLEAN NOT NULL DEFAULT FALSE,
|
||||
ADD COLUMN "userinfo_endpoint_override" TEXT;
|
||||
|
||||
ALTER TABLE "upstream_oauth_authorization_sessions"
|
||||
ADD COLUMN "userinfo" JSONB;
|
||||
@@ -98,6 +98,7 @@ pub enum UpstreamOAuthProviders {
|
||||
EncryptedClientSecret,
|
||||
TokenEndpointSigningAlg,
|
||||
TokenEndpointAuthMethod,
|
||||
FetchUserinfo,
|
||||
CreatedAt,
|
||||
DisabledAt,
|
||||
ClaimsImports,
|
||||
@@ -108,6 +109,7 @@ pub enum UpstreamOAuthProviders {
|
||||
JwksUriOverride,
|
||||
TokenEndpointOverride,
|
||||
AuthorizationEndpointOverride,
|
||||
UserinfoEndpointOverride,
|
||||
}
|
||||
|
||||
#[derive(sea_query::Iden)]
|
||||
|
||||
@@ -60,12 +60,14 @@ mod tests {
|
||||
brand_name: None,
|
||||
scope: Scope::from_iter([OPENID]),
|
||||
token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
|
||||
fetch_userinfo: false,
|
||||
token_endpoint_signing_alg: None,
|
||||
client_id: "client-id".to_owned(),
|
||||
encrypted_client_secret: None,
|
||||
claims_imports: UpstreamOAuthProviderClaimsImports::default(),
|
||||
token_endpoint_override: None,
|
||||
authorization_endpoint_override: None,
|
||||
userinfo_endpoint_override: None,
|
||||
jwks_uri_override: None,
|
||||
discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
|
||||
pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
|
||||
@@ -145,7 +147,7 @@ mod tests {
|
||||
|
||||
let session = repo
|
||||
.upstream_oauth_session()
|
||||
.complete_with_link(&clock, session, &link, None, None)
|
||||
.complete_with_link(&clock, session, &link, None, None, None)
|
||||
.await
|
||||
.unwrap();
|
||||
// Reload the session
|
||||
@@ -302,12 +304,14 @@ mod tests {
|
||||
brand_name: None,
|
||||
scope: scope.clone(),
|
||||
token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::None,
|
||||
fetch_userinfo: false,
|
||||
token_endpoint_signing_alg: None,
|
||||
client_id,
|
||||
encrypted_client_secret: None,
|
||||
claims_imports: UpstreamOAuthProviderClaimsImports::default(),
|
||||
token_endpoint_override: None,
|
||||
authorization_endpoint_override: None,
|
||||
userinfo_endpoint_override: None,
|
||||
jwks_uri_override: None,
|
||||
discovery_mode: mas_data_model::UpstreamOAuthProviderDiscoveryMode::Oidc,
|
||||
pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
|
||||
|
||||
@@ -56,12 +56,14 @@ struct ProviderLookup {
|
||||
encrypted_client_secret: Option<String>,
|
||||
token_endpoint_signing_alg: Option<String>,
|
||||
token_endpoint_auth_method: String,
|
||||
fetch_userinfo: bool,
|
||||
created_at: DateTime<Utc>,
|
||||
disabled_at: Option<DateTime<Utc>>,
|
||||
claims_imports: Json<UpstreamOAuthProviderClaimsImports>,
|
||||
jwks_uri_override: Option<String>,
|
||||
authorization_endpoint_override: Option<String>,
|
||||
token_endpoint_override: Option<String>,
|
||||
userinfo_endpoint_override: Option<String>,
|
||||
discovery_mode: String,
|
||||
pkce_mode: String,
|
||||
response_mode: String,
|
||||
@@ -70,6 +72,8 @@ struct ProviderLookup {
|
||||
|
||||
impl TryFrom<ProviderLookup> for UpstreamOAuthProvider {
|
||||
type Error = DatabaseInconsistencyError;
|
||||
|
||||
#[allow(clippy::too_many_lines)]
|
||||
fn try_from(value: ProviderLookup) -> Result<Self, Self::Error> {
|
||||
let id = value.upstream_oauth_provider_id.into();
|
||||
let scope = value.scope.parse().map_err(|e| {
|
||||
@@ -117,6 +121,17 @@ impl TryFrom<ProviderLookup> for UpstreamOAuthProvider {
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
let userinfo_endpoint_override = value
|
||||
.userinfo_endpoint_override
|
||||
.map(|x| x.parse())
|
||||
.transpose()
|
||||
.map_err(|e| {
|
||||
DatabaseInconsistencyError::on("upstream_oauth_providers")
|
||||
.column("userinfo_endpoint_override")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
let jwks_uri_override = value
|
||||
.jwks_uri_override
|
||||
.map(|x| x.parse())
|
||||
@@ -163,12 +178,14 @@ impl TryFrom<ProviderLookup> for UpstreamOAuthProvider {
|
||||
client_id: value.client_id,
|
||||
encrypted_client_secret: value.encrypted_client_secret,
|
||||
token_endpoint_auth_method,
|
||||
fetch_userinfo: value.fetch_userinfo,
|
||||
token_endpoint_signing_alg,
|
||||
created_at: value.created_at,
|
||||
disabled_at: value.disabled_at,
|
||||
claims_imports: value.claims_imports.0,
|
||||
authorization_endpoint_override,
|
||||
token_endpoint_override,
|
||||
userinfo_endpoint_override,
|
||||
jwks_uri_override,
|
||||
discovery_mode,
|
||||
pkce_mode,
|
||||
@@ -218,12 +235,14 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
|
||||
encrypted_client_secret,
|
||||
token_endpoint_signing_alg,
|
||||
token_endpoint_auth_method,
|
||||
fetch_userinfo,
|
||||
created_at,
|
||||
disabled_at,
|
||||
claims_imports as "claims_imports: Json<UpstreamOAuthProviderClaimsImports>",
|
||||
jwks_uri_override,
|
||||
authorization_endpoint_override,
|
||||
token_endpoint_override,
|
||||
userinfo_endpoint_override,
|
||||
discovery_mode,
|
||||
pkce_mode,
|
||||
response_mode,
|
||||
@@ -275,19 +294,21 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
|
||||
brand_name,
|
||||
scope,
|
||||
token_endpoint_auth_method,
|
||||
fetch_userinfo,
|
||||
token_endpoint_signing_alg,
|
||||
client_id,
|
||||
encrypted_client_secret,
|
||||
claims_imports,
|
||||
authorization_endpoint_override,
|
||||
token_endpoint_override,
|
||||
userinfo_endpoint_override,
|
||||
jwks_uri_override,
|
||||
discovery_mode,
|
||||
pkce_mode,
|
||||
response_mode,
|
||||
created_at
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9,
|
||||
$10, $11, $12, $13, $14, $15, $16, $17)
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10,
|
||||
$11, $12, $13, $14, $15, $16, $17, $18, $19)
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
¶ms.issuer,
|
||||
@@ -295,6 +316,7 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
|
||||
params.brand_name.as_deref(),
|
||||
params.scope.to_string(),
|
||||
params.token_endpoint_auth_method.to_string(),
|
||||
params.fetch_userinfo,
|
||||
params
|
||||
.token_endpoint_signing_alg
|
||||
.as_ref()
|
||||
@@ -310,6 +332,10 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
|
||||
.token_endpoint_override
|
||||
.as_ref()
|
||||
.map(ToString::to_string),
|
||||
params
|
||||
.userinfo_endpoint_override
|
||||
.as_ref()
|
||||
.map(ToString::to_string),
|
||||
params.jwks_uri_override.as_ref().map(ToString::to_string),
|
||||
params.discovery_mode.as_str(),
|
||||
params.pkce_mode.as_str(),
|
||||
@@ -330,11 +356,13 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
|
||||
encrypted_client_secret: params.encrypted_client_secret,
|
||||
token_endpoint_signing_alg: params.token_endpoint_signing_alg,
|
||||
token_endpoint_auth_method: params.token_endpoint_auth_method,
|
||||
fetch_userinfo: params.fetch_userinfo,
|
||||
created_at,
|
||||
disabled_at: None,
|
||||
claims_imports: params.claims_imports,
|
||||
authorization_endpoint_override: params.authorization_endpoint_override,
|
||||
token_endpoint_override: params.token_endpoint_override,
|
||||
userinfo_endpoint_override: params.userinfo_endpoint_override,
|
||||
jwks_uri_override: params.jwks_uri_override,
|
||||
discovery_mode: params.discovery_mode,
|
||||
pkce_mode: params.pkce_mode,
|
||||
@@ -437,20 +465,22 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
|
||||
brand_name,
|
||||
scope,
|
||||
token_endpoint_auth_method,
|
||||
fetch_userinfo,
|
||||
token_endpoint_signing_alg,
|
||||
client_id,
|
||||
encrypted_client_secret,
|
||||
claims_imports,
|
||||
authorization_endpoint_override,
|
||||
token_endpoint_override,
|
||||
userinfo_endpoint_override,
|
||||
jwks_uri_override,
|
||||
discovery_mode,
|
||||
pkce_mode,
|
||||
response_mode,
|
||||
additional_parameters,
|
||||
created_at
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9,
|
||||
$10, $11, $12, $13, $14, $15, $16, $17, $18)
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11,
|
||||
$12, $13, $14, $15, $16, $17, $18, $19, $20)
|
||||
ON CONFLICT (upstream_oauth_provider_id)
|
||||
DO UPDATE
|
||||
SET
|
||||
@@ -459,6 +489,7 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
|
||||
brand_name = EXCLUDED.brand_name,
|
||||
scope = EXCLUDED.scope,
|
||||
token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method,
|
||||
fetch_userinfo = EXCLUDED.fetch_userinfo,
|
||||
token_endpoint_signing_alg = EXCLUDED.token_endpoint_signing_alg,
|
||||
disabled_at = NULL,
|
||||
client_id = EXCLUDED.client_id,
|
||||
@@ -466,6 +497,7 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
|
||||
claims_imports = EXCLUDED.claims_imports,
|
||||
authorization_endpoint_override = EXCLUDED.authorization_endpoint_override,
|
||||
token_endpoint_override = EXCLUDED.token_endpoint_override,
|
||||
userinfo_endpoint_override = EXCLUDED.userinfo_endpoint_override,
|
||||
jwks_uri_override = EXCLUDED.jwks_uri_override,
|
||||
discovery_mode = EXCLUDED.discovery_mode,
|
||||
pkce_mode = EXCLUDED.pkce_mode,
|
||||
@@ -479,6 +511,7 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
|
||||
params.brand_name.as_deref(),
|
||||
params.scope.to_string(),
|
||||
params.token_endpoint_auth_method.to_string(),
|
||||
params.fetch_userinfo,
|
||||
params
|
||||
.token_endpoint_signing_alg
|
||||
.as_ref()
|
||||
@@ -494,6 +527,10 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
|
||||
.token_endpoint_override
|
||||
.as_ref()
|
||||
.map(ToString::to_string),
|
||||
params
|
||||
.userinfo_endpoint_override
|
||||
.as_ref()
|
||||
.map(ToString::to_string),
|
||||
params.jwks_uri_override.as_ref().map(ToString::to_string),
|
||||
params.discovery_mode.as_str(),
|
||||
params.pkce_mode.as_str(),
|
||||
@@ -515,11 +552,13 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
|
||||
encrypted_client_secret: params.encrypted_client_secret,
|
||||
token_endpoint_signing_alg: params.token_endpoint_signing_alg,
|
||||
token_endpoint_auth_method: params.token_endpoint_auth_method,
|
||||
fetch_userinfo: params.fetch_userinfo,
|
||||
created_at,
|
||||
disabled_at: None,
|
||||
claims_imports: params.claims_imports,
|
||||
authorization_endpoint_override: params.authorization_endpoint_override,
|
||||
token_endpoint_override: params.token_endpoint_override,
|
||||
userinfo_endpoint_override: params.userinfo_endpoint_override,
|
||||
jwks_uri_override: params.jwks_uri_override,
|
||||
discovery_mode: params.discovery_mode,
|
||||
pkce_mode: params.pkce_mode,
|
||||
@@ -644,6 +683,13 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
|
||||
)),
|
||||
ProviderLookupIden::CreatedAt,
|
||||
)
|
||||
.expr_as(
|
||||
Expr::col((
|
||||
UpstreamOAuthProviders::Table,
|
||||
UpstreamOAuthProviders::FetchUserinfo,
|
||||
)),
|
||||
ProviderLookupIden::FetchUserinfo,
|
||||
)
|
||||
.expr_as(
|
||||
Expr::col((
|
||||
UpstreamOAuthProviders::Table,
|
||||
@@ -679,6 +725,13 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
|
||||
)),
|
||||
ProviderLookupIden::AuthorizationEndpointOverride,
|
||||
)
|
||||
.expr_as(
|
||||
Expr::col((
|
||||
UpstreamOAuthProviders::Table,
|
||||
UpstreamOAuthProviders::UserinfoEndpointOverride,
|
||||
)),
|
||||
ProviderLookupIden::UserinfoEndpointOverride,
|
||||
)
|
||||
.expr_as(
|
||||
Expr::col((
|
||||
UpstreamOAuthProviders::Table,
|
||||
@@ -786,12 +839,14 @@ impl<'c> UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'
|
||||
encrypted_client_secret,
|
||||
token_endpoint_signing_alg,
|
||||
token_endpoint_auth_method,
|
||||
fetch_userinfo,
|
||||
created_at,
|
||||
disabled_at,
|
||||
claims_imports as "claims_imports: Json<UpstreamOAuthProviderClaimsImports>",
|
||||
jwks_uri_override,
|
||||
authorization_endpoint_override,
|
||||
token_endpoint_override,
|
||||
userinfo_endpoint_override,
|
||||
discovery_mode,
|
||||
pkce_mode,
|
||||
response_mode,
|
||||
|
||||
@@ -40,6 +40,7 @@ struct SessionLookup {
|
||||
code_challenge_verifier: Option<String>,
|
||||
nonce: String,
|
||||
id_token: Option<String>,
|
||||
userinfo: Option<serde_json::Value>,
|
||||
created_at: DateTime<Utc>,
|
||||
completed_at: Option<DateTime<Utc>>,
|
||||
consumed_at: Option<DateTime<Utc>>,
|
||||
@@ -55,22 +56,30 @@ impl TryFrom<SessionLookup> for UpstreamOAuthAuthorizationSession {
|
||||
value.upstream_oauth_link_id,
|
||||
value.id_token,
|
||||
value.extra_callback_parameters,
|
||||
value.userinfo,
|
||||
value.completed_at,
|
||||
value.consumed_at,
|
||||
) {
|
||||
(None, None, None, None, None) => UpstreamOAuthAuthorizationSessionState::Pending,
|
||||
(Some(link_id), id_token, extra_callback_parameters, Some(completed_at), None) => {
|
||||
UpstreamOAuthAuthorizationSessionState::Completed {
|
||||
completed_at,
|
||||
link_id: link_id.into(),
|
||||
id_token,
|
||||
extra_callback_parameters,
|
||||
}
|
||||
}
|
||||
(None, None, None, None, None, None) => UpstreamOAuthAuthorizationSessionState::Pending,
|
||||
(
|
||||
Some(link_id),
|
||||
id_token,
|
||||
extra_callback_parameters,
|
||||
userinfo,
|
||||
Some(completed_at),
|
||||
None,
|
||||
) => UpstreamOAuthAuthorizationSessionState::Completed {
|
||||
completed_at,
|
||||
link_id: link_id.into(),
|
||||
id_token,
|
||||
extra_callback_parameters,
|
||||
userinfo,
|
||||
},
|
||||
(
|
||||
Some(link_id),
|
||||
id_token,
|
||||
extra_callback_parameters,
|
||||
userinfo,
|
||||
Some(completed_at),
|
||||
Some(consumed_at),
|
||||
) => UpstreamOAuthAuthorizationSessionState::Consumed {
|
||||
@@ -78,6 +87,7 @@ impl TryFrom<SessionLookup> for UpstreamOAuthAuthorizationSession {
|
||||
link_id: link_id.into(),
|
||||
id_token,
|
||||
extra_callback_parameters,
|
||||
userinfo,
|
||||
consumed_at,
|
||||
},
|
||||
_ => {
|
||||
@@ -128,6 +138,7 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c>
|
||||
nonce,
|
||||
id_token,
|
||||
extra_callback_parameters,
|
||||
userinfo,
|
||||
created_at,
|
||||
completed_at,
|
||||
consumed_at
|
||||
@@ -184,8 +195,9 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c>
|
||||
created_at,
|
||||
completed_at,
|
||||
consumed_at,
|
||||
id_token
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, NULL, NULL, NULL)
|
||||
id_token,
|
||||
userinfo
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, NULL, NULL, NULL, NULL)
|
||||
"#,
|
||||
Uuid::from(id),
|
||||
Uuid::from(upstream_oauth_provider.id),
|
||||
@@ -226,6 +238,7 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c>
|
||||
upstream_oauth_link: &UpstreamOAuthLink,
|
||||
id_token: Option<String>,
|
||||
extra_callback_parameters: Option<serde_json::Value>,
|
||||
userinfo: Option<serde_json::Value>,
|
||||
) -> Result<UpstreamOAuthAuthorizationSession, Self::Error> {
|
||||
let completed_at = clock.now();
|
||||
|
||||
@@ -235,13 +248,15 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c>
|
||||
SET upstream_oauth_link_id = $1,
|
||||
completed_at = $2,
|
||||
id_token = $3,
|
||||
extra_callback_parameters = $4
|
||||
WHERE upstream_oauth_authorization_session_id = $5
|
||||
extra_callback_parameters = $4,
|
||||
userinfo = $5
|
||||
WHERE upstream_oauth_authorization_session_id = $6
|
||||
"#,
|
||||
Uuid::from(upstream_oauth_link.id),
|
||||
completed_at,
|
||||
id_token,
|
||||
extra_callback_parameters,
|
||||
userinfo,
|
||||
Uuid::from(upstream_oauth_authorization_session.id),
|
||||
)
|
||||
.traced()
|
||||
@@ -254,6 +269,7 @@ impl<'c> UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'c>
|
||||
upstream_oauth_link,
|
||||
id_token,
|
||||
extra_callback_parameters,
|
||||
userinfo,
|
||||
)
|
||||
.map_err(DatabaseError::to_invalid_operation)?;
|
||||
|
||||
|
||||
@@ -42,6 +42,11 @@ pub struct UpstreamOAuthProviderParams {
|
||||
/// `private_key_jwt` authentication methods are used
|
||||
pub token_endpoint_signing_alg: Option<JsonWebSignatureAlg>,
|
||||
|
||||
/// Whether to fetch the user profile from the userinfo endpoint,
|
||||
/// or to rely on the data returned in the `id_token` from the
|
||||
/// `token_endpoint`.
|
||||
pub fetch_userinfo: bool,
|
||||
|
||||
/// The client ID to use when authenticating to the upstream
|
||||
pub client_id: String,
|
||||
|
||||
@@ -59,6 +64,10 @@ pub struct UpstreamOAuthProviderParams {
|
||||
/// discovered
|
||||
pub token_endpoint_override: Option<Url>,
|
||||
|
||||
/// The URL to use as the userinfo endpoint. If `None`, the URL will be
|
||||
/// discovered
|
||||
pub userinfo_endpoint_override: Option<Url>,
|
||||
|
||||
/// The URL to use when fetching JWKS. If `None`, the URL will be discovered
|
||||
pub jwks_uri_override: Option<Url>,
|
||||
|
||||
|
||||
@@ -87,6 +87,7 @@ pub trait UpstreamOAuthSessionRepository: Send + Sync {
|
||||
upstream_oauth_link: &UpstreamOAuthLink,
|
||||
id_token: Option<String>,
|
||||
extra_callback_parameters: Option<serde_json::Value>,
|
||||
userinfo: Option<serde_json::Value>,
|
||||
) -> Result<UpstreamOAuthAuthorizationSession, Self::Error>;
|
||||
|
||||
/// Mark a session as consumed
|
||||
@@ -131,6 +132,7 @@ repository_impl!(UpstreamOAuthSessionRepository:
|
||||
upstream_oauth_link: &UpstreamOAuthLink,
|
||||
id_token: Option<String>,
|
||||
extra_callback_parameters: Option<serde_json::Value>,
|
||||
userinfo: Option<serde_json::Value>,
|
||||
) -> Result<UpstreamOAuthAuthorizationSession, Self::Error>;
|
||||
|
||||
async fn consume(
|
||||
|
||||
@@ -1895,11 +1895,21 @@
|
||||
}
|
||||
]
|
||||
},
|
||||
"fetch_userinfo": {
|
||||
"description": "Whether to fetch the user profile from the userinfo endpoint, or to rely on the data returned in the `id_token` from the `token_endpoint`.\n\nDefaults to `false`.",
|
||||
"default": false,
|
||||
"type": "boolean"
|
||||
},
|
||||
"authorization_endpoint": {
|
||||
"description": "The URL to use for the provider's authorization endpoint\n\nDefaults to the `authorization_endpoint` provided through discovery",
|
||||
"type": "string",
|
||||
"format": "uri"
|
||||
},
|
||||
"userinfo_endpoint": {
|
||||
"description": "The URL to use for the provider's userinfo endpoint\n\nDefaults to the `userinfo_endpoint` provided through discovery",
|
||||
"type": "string",
|
||||
"format": "uri"
|
||||
},
|
||||
"token_endpoint": {
|
||||
"description": "The URL to use for the provider's token endpoint\n\nDefaults to the `token_endpoint` provided through discovery",
|
||||
"type": "string",
|
||||
|
||||
Reference in New Issue
Block a user