diff --git a/crates/data-model/src/upstream_oauth2/provider.rs b/crates/data-model/src/upstream_oauth2/provider.rs index 4b6501950..328f9c308 100644 --- a/crates/data-model/src/upstream_oauth2/provider.rs +++ b/crates/data-model/src/upstream_oauth2/provider.rs @@ -388,7 +388,7 @@ pub enum ImportAction { impl ImportAction { #[must_use] - pub fn is_forced(&self) -> bool { + pub fn is_forced_or_required(&self) -> bool { matches!(self, Self::Force | Self::Require) } diff --git a/crates/handlers/src/upstream_oauth2/link.rs b/crates/handlers/src/upstream_oauth2/link.rs index 929f17539..56ae663bf 100644 --- a/crates/handlers/src/upstream_oauth2/link.rs +++ b/crates/handlers/src/upstream_oauth2/link.rs @@ -420,8 +420,10 @@ pub(crate) async fn get( &context, provider.claims_imports.displayname.is_required(), )? { - Some(value) => ctx - .with_display_name(value, provider.claims_imports.displayname.is_forced()), + Some(value) => ctx.with_display_name( + value, + provider.claims_imports.displayname.is_forced_or_required(), + ), None => ctx, } }; @@ -442,7 +444,9 @@ pub(crate) async fn get( &context, provider.claims_imports.email.is_required(), )? { - Some(value) => ctx.with_email(value, provider.claims_imports.email.is_forced()), + Some(value) => { + ctx.with_email(value, provider.claims_imports.email.is_forced_or_required()) + } None => ctx, } }; @@ -541,9 +545,9 @@ pub(crate) async fn get( // The username passes the policy check, add it to the context ctx.with_localpart( localpart, - provider.claims_imports.localpart.is_forced(), + provider.claims_imports.localpart.is_forced_or_required(), ) - } else if provider.claims_imports.localpart.is_forced() { + } else if provider.claims_imports.localpart.is_forced_or_required() { // If the username claim is 'forced' but doesn't pass the policy check, // we display an error message. // TODO: translate @@ -678,7 +682,7 @@ pub(crate) async fn post( let context = context.build(); // `is_forced` checks both if is it `force` or `require` - if !provider.claims_imports.localpart.is_forced() { + if !provider.claims_imports.localpart.is_forced_or_required() { //Claims import for `localpart` should be `require` or `force` at this stage return Err(RouteError::InvalidFormAction); } @@ -793,7 +797,7 @@ pub(crate) async fn post( let ctx = if let Some(ref display_name) = display_name { ctx.with_display_name( display_name.clone(), - provider.claims_imports.email.is_forced(), + provider.claims_imports.email.is_forced_or_required(), ) } else { ctx @@ -818,12 +822,15 @@ pub(crate) async fn post( }; let ctx = if let Some(ref email) = email { - ctx.with_email(email.clone(), provider.claims_imports.email.is_forced()) + ctx.with_email( + email.clone(), + provider.claims_imports.email.is_forced_or_required(), + ) } else { ctx }; - let username = if provider.claims_imports.localpart.is_forced() { + let username = if provider.claims_imports.localpart.is_forced_or_required() { let template = provider .claims_imports .localpart @@ -840,7 +847,7 @@ pub(crate) async fn post( let ctx = ctx.with_localpart( username.clone(), - provider.claims_imports.localpart.is_forced(), + provider.claims_imports.localpart.is_forced_or_required(), ); // Validate the form