Make the issue optional on upstream OAuth 2.0 providers

This commit is contained in:
Quentin Gliech
2024-12-13 19:35:06 +01:00
parent 75ee9a1e58
commit f563daf822
26 changed files with 85 additions and 58 deletions

View File

@@ -764,8 +764,10 @@ impl std::fmt::Display for HumanReadable<&UpstreamOAuthProvider> {
let provider = self.0;
if let Some(human_name) = &provider.human_name {
write!(f, "{} ({})", human_name, provider.id)
} else if let Some(issuer) = &provider.issuer {
write!(f, "{} ({})", issuer, provider.id)
} else {
write!(f, "{} ({})", provider.issuer, provider.id)
write!(f, "{}", provider.id)
}
}
}

View File

@@ -47,6 +47,14 @@ impl ConfigurationSection for UpstreamOAuth2Config {
Err(error)
};
if !matches!(provider.discovery_mode, DiscoveryMode::Disabled)
&& provider.issuer.is_none()
{
return annotate(figment::Error::custom(
"The `issuer` field is required when discovery is enabled",
));
}
match provider.token_endpoint_auth_method {
TokenAuthMethod::None
| TokenAuthMethod::PrivateKeyJwt
@@ -438,7 +446,10 @@ pub struct Provider {
pub id: Ulid,
/// The OIDC issuer URL
pub issuer: String,
///
/// This is required if OIDC discovery is enabled (which is the default)
#[serde(skip_serializing_if = "Option::is_none")]
pub issuer: Option<String>,
/// A human-readable name for the provider, that will be shown to users
#[serde(skip_serializing_if = "Option::is_none")]

View File

@@ -219,7 +219,7 @@ pub struct InvalidUpstreamOAuth2TokenAuthMethod(String);
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub struct UpstreamOAuthProvider {
pub id: Ulid,
pub issuer: String,
pub issuer: Option<String>,
pub human_name: Option<String>,
pub brand_name: Option<String>,
pub discovery_mode: DiscoveryMode,

View File

@@ -37,8 +37,8 @@ impl UpstreamOAuth2Provider {
}
/// OpenID Connect issuer URL.
pub async fn issuer(&self) -> &str {
&self.provider.issuer
pub async fn issuer(&self) -> Option<&str> {
self.provider.issuer.as_deref()
}
/// Client ID used for this provider.

View File

@@ -61,10 +61,11 @@ impl<'a> LazyProviderInfos<'a> {
}
};
let metadata = self
.cache
.get(self.client, &self.provider.issuer, verify)
.await?;
let Some(issuer) = &self.provider.issuer else {
return Err(DiscoveryError::MissingIssuer);
};
let metadata = self.cache.get(self.client, issuer, verify).await?;
self.loaded_metadata = Some(metadata);
}
@@ -179,8 +180,13 @@ impl MetadataCache {
UpstreamOAuthProviderDiscoveryMode::Disabled => continue,
};
if let Err(e) = self.fetch(client, &provider.issuer, verify).await {
tracing::error!(issuer = %provider.issuer, error = &e as &dyn std::error::Error, "Failed to fetch provider metadata");
let Some(issuer) = &provider.issuer else {
tracing::error!(%provider.id, "Provider doesn't have an issuer set, but discovery is enabled!");
continue;
};
if let Err(e) = self.fetch(client, issuer, verify).await {
tracing::error!(%issuer, error = &e as &dyn std::error::Error, "Failed to fetch provider metadata");
}
}
@@ -395,7 +401,7 @@ mod tests {
let clock = MockClock::default();
let provider = UpstreamOAuthProvider {
id: Ulid::nil(),
issuer: mock_server.uri(),
issuer: Some(mock_server.uri()),
human_name: Some("Example Ltd.".to_owned()),
brand_name: None,
discovery_mode: UpstreamOAuthProviderDiscoveryMode::Insecure,

View File

@@ -284,7 +284,7 @@ pub(crate) async fn handler(
);
let id_token_verification_data = JwtVerificationData {
issuer: &provider.issuer,
issuer: provider.issuer.as_deref(),
jwks: jwks.as_ref().unwrap(),
signing_algorithm: &provider.id_token_signed_response_alg,
client_id: &provider.client_id,
@@ -350,7 +350,7 @@ pub(crate) async fn handler(
lazy_metadata.userinfo_endpoint().await?,
token_response.access_token.as_str(),
Some(JwtVerificationData {
issuer: &provider.issuer,
issuer: provider.issuer.as_deref(),
jwks: &jwks,
signing_algorithm,
client_id: &provider.client_id,

View File

@@ -916,7 +916,7 @@ mod tests {
&mut rng,
&state.clock,
UpstreamOAuthProviderParams {
issuer: "https://example.com/".to_owned(),
issuer: Some("https://example.com/".to_owned()),
human_name: Some("Example Ltd.".to_owned()),
brand_name: None,
scope: Scope::from_iter([OPENID]),

View File

@@ -131,7 +131,6 @@ fn client_credentials_for_provider(
ClientCredentials::SignInWithApple {
client_id,
audience: provider.issuer.clone(),
key,
key_id: params.key_id,
team_id: params.team_id,

View File

@@ -398,7 +398,7 @@ mod test {
&mut rng,
&state.clock,
UpstreamOAuthProviderParams {
issuer: "https://first.com/".to_owned(),
issuer: Some("https://first.com/".to_owned()),
human_name: Some("First Ltd.".to_owned()),
brand_name: None,
scope: [OPENID].into_iter().collect(),
@@ -438,7 +438,7 @@ mod test {
&mut rng,
&state.clock,
UpstreamOAuthProviderParams {
issuer: "https://second.com/".to_owned(),
issuer: Some("https://second.com/".to_owned()),
human_name: None,
brand_name: None,
scope: [OPENID].into_iter().collect(),

View File

@@ -55,6 +55,11 @@ pub enum DiscoveryError {
/// An error occurred validating the metadata.
Validation(#[from] ProviderMetadataVerificationError),
/// The provider doesn't have an issuer set, which is required if discovery
/// is enabled.
#[error("Provider doesn't have an issuer set")]
MissingIssuer,
/// Discovery is disabled for this provider.
#[error("Discovery is disabled for this provider")]
Disabled,

View File

@@ -57,7 +57,7 @@ pub async fn fetch_jwks(
#[derive(Clone, Copy)]
pub struct JwtVerificationData<'a> {
/// The URL of the issuer that generated the ID Token.
pub issuer: &'a str,
pub issuer: Option<&'a str>,
/// The issuer's JWKS.
pub jwks: &'a PublicJsonWebKeySet,
@@ -76,7 +76,7 @@ pub struct JwtVerificationData<'a> {
///
/// * The signature is verified with the given JWKS.
///
/// * The `iss` claim must be present and match the issuer.
/// * The `iss` claim must be present and match the issuer, if present
///
/// * The `aud` claim must be present and match the client ID.
///
@@ -117,8 +117,10 @@ pub fn verify_signed_jwt<'a>(
let (header, mut claims) = jwt.clone().into_parts();
// Must have the proper issuer.
claims::ISS.extract_required_with_options(&mut claims, issuer)?;
if let Some(issuer) = issuer {
// Must have the proper issuer.
claims::ISS.extract_required_with_options(&mut claims, issuer)?;
}
// Must have the proper audience.
claims::AUD.extract_required_with_options(&mut claims, client_id)?;

View File

@@ -103,9 +103,6 @@ pub enum ClientCredentials {
/// The unique ID for the client.
client_id: String,
/// The audience to use. Usually `https://appleid.apple.com`
audience: String,
/// The ECDSA key used to sign
key: elliptic_curve::SecretKey<p256::NistP256>,
@@ -240,7 +237,6 @@ impl ClientCredentials {
ClientCredentials::SignInWithApple {
client_id,
audience,
key,
key_id,
team_id,
@@ -253,7 +249,7 @@ impl ClientCredentials {
claims::ISS.insert(&mut claims, team_id)?;
claims::SUB.insert(&mut claims, client_id)?;
claims::AUD.insert(&mut claims, audience.clone())?;
claims::AUD.insert(&mut claims, "https://appleid.apple.com".to_owned())?;
claims::IAT.insert(&mut claims, now)?;
claims::EXP.insert(&mut claims, now + Duration::microseconds(60 * 1000 * 1000))?;

View File

@@ -193,7 +193,7 @@ async fn pass_access_token_with_authorization_code() {
let (id_token, jwks) = id_token(issuer.as_str());
let id_token_verification_data = JwtVerificationData {
issuer: issuer.as_str(),
issuer: Some(issuer.as_str()),
jwks: &jwks,
client_id: &CLIENT_ID.to_owned(),
signing_algorithm: &ID_TOKEN_SIGNING_ALG,
@@ -251,7 +251,7 @@ async fn fail_access_token_with_authorization_code_wrong_nonce() {
let (id_token, jwks) = id_token(issuer.as_str());
let id_token_verification_data = JwtVerificationData {
issuer: issuer.as_str(),
issuer: Some(issuer.as_str()),
jwks: &jwks,
client_id: &CLIENT_ID.to_owned(),
signing_algorithm: &ID_TOKEN_SIGNING_ALG,
@@ -312,7 +312,7 @@ async fn fail_access_token_with_authorization_code_no_id_token() {
};
let id_token_verification_data = JwtVerificationData {
issuer: issuer.as_str(),
issuer: Some(issuer.as_str()),
jwks: &PublicJsonWebKeySet::default(),
client_id: &CLIENT_ID.to_owned(),
signing_algorithm: &ID_TOKEN_SIGNING_ALG,

View File

@@ -88,7 +88,7 @@ async fn pass_verify_id_token() {
let (id_token, jwks) = id_token(issuer, None, Some(now));
let verification_data = JwtVerificationData {
issuer,
issuer: Some(issuer),
jwks: &jwks,
client_id: &CLIENT_ID.to_owned(),
signing_algorithm: &ID_TOKEN_SIGNING_ALG,
@@ -111,7 +111,7 @@ async fn fail_verify_id_token_wrong_issuer() {
let now = now();
let verification_data = JwtVerificationData {
issuer: wrong_issuer,
issuer: Some(wrong_issuer),
jwks: &jwks,
client_id: &CLIENT_ID.to_owned(),
signing_algorithm: &ID_TOKEN_SIGNING_ALG,
@@ -135,7 +135,7 @@ async fn fail_verify_id_token_wrong_audience() {
let now = now();
let verification_data = JwtVerificationData {
issuer,
issuer: Some(issuer),
jwks: &jwks,
client_id: &"wrong_client_id".to_owned(),
signing_algorithm: &ID_TOKEN_SIGNING_ALG,
@@ -159,7 +159,7 @@ async fn fail_verify_id_token_wrong_signing_algorithm() {
let now = now();
let verification_data = JwtVerificationData {
issuer,
issuer: Some(issuer),
jwks: &jwks,
client_id: &CLIENT_ID.to_owned(),
signing_algorithm: &JsonWebSignatureAlg::Unknown("wrong_algorithm".to_owned()),
@@ -180,7 +180,7 @@ async fn fail_verify_id_token_wrong_expiration() {
let now = now();
let verification_data = JwtVerificationData {
issuer,
issuer: Some(issuer),
jwks: &jwks,
client_id: &CLIENT_ID.to_owned(),
signing_algorithm: &ID_TOKEN_SIGNING_ALG,
@@ -199,7 +199,7 @@ async fn fail_verify_id_token_wrong_subject() {
let (id_token, jwks) = id_token(issuer, Some(IdTokenFlag::WrongSubject), None);
let verification_data = JwtVerificationData {
issuer,
issuer: Some(issuer),
jwks: &jwks,
client_id: &CLIENT_ID.to_owned(),
signing_algorithm: &ID_TOKEN_SIGNING_ALG,
@@ -224,7 +224,7 @@ async fn fail_verify_id_token_wrong_auth_time() {
let (id_token, jwks) = id_token(issuer, None, Some(now + Duration::try_hours(1).unwrap()));
let verification_data = JwtVerificationData {
issuer,
issuer: Some(issuer),
jwks: &jwks,
client_id: &CLIENT_ID.to_owned(),
signing_algorithm: &ID_TOKEN_SIGNING_ALG,

View File

@@ -126,7 +126,7 @@
},
"nullable": [
false,
false,
true,
true,
true,
false,

View File

@@ -124,7 +124,7 @@
},
"nullable": [
false,
false,
true,
true,
true,
false,

View File

@@ -0,0 +1,8 @@
-- Copyright 2024 New Vector Ltd.
--
-- SPDX-License-Identifier: AGPL-3.0-only
-- Please see LICENSE in the repository root for full details.
-- Make the issuer field in the upstream_oauth_providers table optional
ALTER TABLE "upstream_oauth_providers"
ALTER COLUMN "issuer" DROP NOT NULL;

View File

@@ -148,7 +148,7 @@ impl UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'_> {
db.query.text,
upstream_oauth_link.subject = subject,
%upstream_oauth_provider.id,
%upstream_oauth_provider.issuer,
upstream_oauth_provider.issuer = upstream_oauth_provider.issuer,
%upstream_oauth_provider.client_id,
),
err,
@@ -192,7 +192,7 @@ impl UpstreamOAuthLinkRepository for PgUpstreamOAuthLinkRepository<'_> {
upstream_oauth_link.subject = subject,
upstream_oauth_link.human_account_name = human_account_name,
%upstream_oauth_provider.id,
%upstream_oauth_provider.issuer,
upstream_oauth_provider.issuer = upstream_oauth_provider.issuer,
%upstream_oauth_provider.client_id,
),
err,

View File

@@ -56,7 +56,7 @@ mod tests {
&mut rng,
&clock,
UpstreamOAuthProviderParams {
issuer: "https://example.com/".to_owned(),
issuer: Some("https://example.com/".to_owned()),
human_name: None,
brand_name: None,
scope: Scope::from_iter([OPENID]),
@@ -88,13 +88,13 @@ mod tests {
.await
.unwrap()
.expect("provider to be found in the database");
assert_eq!(provider.issuer, "https://example.com/");
assert_eq!(provider.issuer.as_deref(), Some("https://example.com/"));
assert_eq!(provider.client_id, "client-id");
// It should be in the list of all providers
let providers = repo.upstream_oauth_provider().all_enabled().await.unwrap();
assert_eq!(providers.len(), 1);
assert_eq!(providers[0].issuer, "https://example.com/");
assert_eq!(providers[0].issuer.as_deref(), Some("https://example.com/"));
assert_eq!(providers[0].client_id, "client-id");
// Start a session
@@ -277,7 +277,6 @@ mod tests {
/// provider repository
#[sqlx::test(migrator = "crate::MIGRATOR")]
async fn test_provider_repository_pagination(pool: PgPool) {
const ISSUER: &str = "https://example.com/";
let scope = Scope::from_iter([OPENID]);
let mut rng = rand_chacha::ChaChaRng::seed_from_u64(42);
@@ -302,7 +301,7 @@ mod tests {
&mut rng,
&clock,
UpstreamOAuthProviderParams {
issuer: ISSUER.to_owned(),
issuer: None,
human_name: None,
brand_name: None,
scope: scope.clone(),

View File

@@ -48,7 +48,7 @@ impl<'c> PgUpstreamOAuthProviderRepository<'c> {
#[enum_def]
struct ProviderLookup {
upstream_oauth_provider_id: Uuid,
issuer: String,
issuer: Option<String>,
human_name: Option<String>,
brand_name: Option<String>,
scope: String,
@@ -294,7 +294,7 @@ impl UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'_> {
fields(
db.query.text,
upstream_oauth_provider.id,
upstream_oauth_provider.issuer = %params.issuer,
upstream_oauth_provider.issuer = params.issuer,
upstream_oauth_provider.client_id = %params.client_id,
),
err,
@@ -337,7 +337,7 @@ impl UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'_> {
$11, $12, $13, $14, $15, $16, $17, $18, $19, $20, $21)
"#,
Uuid::from(id),
&params.issuer,
params.issuer.as_deref(),
params.human_name.as_deref(),
params.brand_name.as_deref(),
params.scope.to_string(),
@@ -476,7 +476,7 @@ impl UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'_> {
fields(
db.query.text,
upstream_oauth_provider.id = %id,
upstream_oauth_provider.issuer = %params.issuer,
upstream_oauth_provider.issuer = params.issuer,
upstream_oauth_provider.client_id = %params.client_id,
),
err,
@@ -543,7 +543,7 @@ impl UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'_> {
RETURNING created_at
"#,
Uuid::from(id),
&params.issuer,
params.issuer.as_deref(),
params.human_name.as_deref(),
params.brand_name.as_deref(),
params.scope.to_string(),

View File

@@ -162,7 +162,7 @@ impl UpstreamOAuthSessionRepository for PgUpstreamOAuthSessionRepository<'_> {
fields(
db.query.text,
%upstream_oauth_provider.id,
%upstream_oauth_provider.issuer,
upstream_oauth_provider.issuer = upstream_oauth_provider.issuer,
%upstream_oauth_provider.client_id,
upstream_oauth_authorization_session.id,
),

View File

@@ -24,7 +24,7 @@ use crate::{pagination::Page, repository_impl, Clock, Pagination};
/// OAuth 2.0 provider
pub struct UpstreamOAuthProviderParams {
/// The OIDC issuer of the provider
pub issuer: String,
pub issuer: Option<String>,
/// A human-readable name for the provider
pub human_name: Option<String>,

View File

@@ -1390,7 +1390,7 @@ impl TemplateContext for UpstreamRegister {
},
UpstreamOAuthProvider {
id: Ulid::nil(),
issuer: "https://example.com/".to_owned(),
issuer: Some("https://example.com/".to_owned()),
human_name: Some("Example Ltd.".to_owned()),
brand_name: None,
scope: Scope::from_iter([OPENID]),

View File

@@ -1817,7 +1817,6 @@
"required": [
"client_id",
"id",
"issuer",
"scope",
"token_endpoint_auth_method"
],
@@ -1832,7 +1831,7 @@
"pattern": "^[0123456789ABCDEFGHJKMNPQRSTVWXYZ]{26}$"
},
"issuer": {
"description": "The OIDC issuer URL",
"description": "The OIDC issuer URL\n\nThis is required if OIDC discovery is enabled (which is the default)",
"type": "string"
},
"human_name": {

View File

@@ -1589,7 +1589,7 @@ type UpstreamOAuth2Provider implements Node & CreationEvent {
"""
OpenID Connect issuer URL.
"""
issuer: String!
issuer: String
"""
Client ID used for this provider.
"""

View File

@@ -1156,7 +1156,7 @@ export type UpstreamOAuth2Provider = CreationEvent & Node & {
/** ID of the object. */
id: Scalars['ID']['output'];
/** OpenID Connect issuer URL. */
issuer: Scalars['String']['output'];
issuer?: Maybe<Scalars['String']['output']>;
};
export type UpstreamOAuth2ProviderConnection = {