Replace HTTP client in oidc-client with reqwest

This commit is contained in:
Quentin Gliech
2024-10-24 16:41:14 +02:00
parent 5b879bd4f4
commit 248e03ac93
41 changed files with 220 additions and 1958 deletions

View File

@@ -51,5 +51,3 @@ tokio.workspace = true
wiremock = "0.6.2"
http-body-util.workspace = true
rustls.workspace = true
mas-http = { workspace = true, features = ["client"] }

View File

@@ -6,11 +6,9 @@
//! The error types used in this crate.
use std::{str::Utf8Error, sync::Arc};
use headers::authorization::InvalidBearerToken;
use http::{header::ToStrError, StatusCode};
use mas_http::{catch_http_codes, form_urlencoded_request, json_request, json_response};
use http::StatusCode;
use mas_http::{catch_http_codes, form_urlencoded_request, json_response};
use mas_jose::{
claims::ClaimError,
jwa::InvalidAlgorithm,
@@ -33,9 +31,6 @@ pub enum Error {
/// An error occurred fetching the provider JWKS.
Jwks(#[from] JwksError),
/// An error occurred during client registration.
Registration(#[from] RegistrationError),
/// An error occurred building the authorization URL.
Authorization(#[from] AuthorizationError),
@@ -48,17 +43,11 @@ pub enum Error {
/// An error occurred refreshing an access token.
TokenRefresh(#[from] TokenRefreshError),
/// An error occurred revoking a token.
TokenRevoke(#[from] TokenRevokeError),
/// An error occurred requesting user info.
UserInfo(#[from] UserInfoError),
/// An error occurred introspecting a token.
Introspection(#[from] IntrospectionError),
/// An error occurred building the account management URL.
AccountManagement(#[from] AccountManagementError),
}
/// All possible errors when fetching provider metadata.
@@ -81,135 +70,6 @@ pub enum DiscoveryError {
Disabled,
}
/// All possible errors when registering the client.
#[derive(Debug, Error)]
pub enum RegistrationError {
/// An error occurred building the request.
#[error(transparent)]
IntoHttp(#[from] http::Error),
/// An error occurred serializing the request or deserializing the response.
#[error(transparent)]
Json(#[from] serde_json::Error),
/// The server returned an HTTP error status code.
#[error(transparent)]
Http(#[from] HttpError),
/// No client secret was received although one was expected because of the
/// authentication method.
#[error("missing client secret in response")]
MissingClientSecret,
/// An error occurred sending the request.
#[error(transparent)]
Service(BoxError),
}
impl<S> From<json_request::Error<S>> for RegistrationError
where
S: Into<RegistrationError>,
{
fn from(err: json_request::Error<S>) -> Self {
match err {
json_request::Error::Serialize { inner } => inner.into(),
json_request::Error::Service { inner } => inner.into(),
}
}
}
impl<S> From<json_response::Error<S>> for RegistrationError
where
S: Into<RegistrationError>,
{
fn from(err: json_response::Error<S>) -> Self {
match err {
json_response::Error::Deserialize { inner } => inner.into(),
json_response::Error::Service { inner } => inner.into(),
}
}
}
impl<S> From<catch_http_codes::Error<S, Option<ErrorBody>>> for RegistrationError
where
S: Into<BoxError>,
{
fn from(err: catch_http_codes::Error<S, Option<ErrorBody>>) -> Self {
match err {
catch_http_codes::Error::HttpError { status_code, inner } => {
HttpError::new(status_code, inner).into()
}
catch_http_codes::Error::Service { inner } => Self::Service(inner.into()),
}
}
}
/// All possible errors when making a pushed authorization request.
#[derive(Debug, Error)]
pub enum PushedAuthorizationError {
/// An error occurred serializing the request.
#[error(transparent)]
UrlEncoded(#[from] serde_urlencoded::ser::Error),
/// An error occurred building the request.
#[error(transparent)]
IntoHttp(#[from] http::Error),
/// An error occurred adding the client credentials to the request.
#[error(transparent)]
Credentials(#[from] CredentialsError),
/// The server returned an HTTP error status code.
#[error(transparent)]
Http(#[from] HttpError),
/// An error occurred deserializing the response.
#[error(transparent)]
Json(#[from] serde_json::Error),
/// An error occurred sending the request.
#[error(transparent)]
Service(BoxError),
}
impl<S> From<form_urlencoded_request::Error<S>> for PushedAuthorizationError
where
S: Into<PushedAuthorizationError>,
{
fn from(err: form_urlencoded_request::Error<S>) -> Self {
match err {
form_urlencoded_request::Error::Serialize { inner } => inner.into(),
form_urlencoded_request::Error::Service { inner } => inner.into(),
}
}
}
impl<S> From<json_response::Error<S>> for PushedAuthorizationError
where
S: Into<PushedAuthorizationError>,
{
fn from(err: json_response::Error<S>) -> Self {
match err {
json_response::Error::Deserialize { inner } => inner.into(),
json_response::Error::Service { inner } => inner.into(),
}
}
}
impl<S> From<catch_http_codes::Error<S, Option<ErrorBody>>> for PushedAuthorizationError
where
S: Into<BoxError>,
{
fn from(err: catch_http_codes::Error<S, Option<ErrorBody>>) -> Self {
match err {
catch_http_codes::Error::HttpError { status_code, inner } => {
HttpError::new(status_code, inner).into()
}
catch_http_codes::Error::Service { inner } => Self::Service(inner.into()),
}
}
}
/// All possible errors when authorizing the client.
#[derive(Debug, Error)]
pub enum AuthorizationError {
@@ -220,76 +80,18 @@ pub enum AuthorizationError {
/// An error occurred serializing the request.
#[error(transparent)]
UrlEncoded(#[from] serde_urlencoded::ser::Error),
/// An error occurred making the PAR request.
#[error(transparent)]
PushedAuthorization(#[from] PushedAuthorizationError),
}
/// All possible errors when requesting an access token.
#[derive(Debug, Error)]
pub enum TokenRequestError {
/// An error occurred building the request.
/// The HTTP client returned an error.
#[error(transparent)]
IntoHttp(#[from] http::Error),
Http(#[from] reqwest::Error),
/// An error occurred adding the client credentials to the request.
/// Error while injecting the client credentials into the request.
#[error(transparent)]
Credentials(#[from] CredentialsError),
/// An error occurred serializing the request.
#[error(transparent)]
UrlEncoded(#[from] serde_urlencoded::ser::Error),
/// The server returned an HTTP error status code.
#[error(transparent)]
Http(#[from] HttpError),
/// An error occurred deserializing the response.
#[error(transparent)]
Json(#[from] serde_json::Error),
/// An error occurred sending the request.
#[error(transparent)]
Service(BoxError),
}
impl<S> From<form_urlencoded_request::Error<S>> for TokenRequestError
where
S: Into<TokenRequestError>,
{
fn from(err: form_urlencoded_request::Error<S>) -> Self {
match err {
form_urlencoded_request::Error::Serialize { inner } => inner.into(),
form_urlencoded_request::Error::Service { inner } => inner.into(),
}
}
}
impl<S> From<json_response::Error<S>> for TokenRequestError
where
S: Into<TokenRequestError>,
{
fn from(err: json_response::Error<S>) -> Self {
match err {
json_response::Error::Deserialize { inner } => inner.into(),
json_response::Error::Service { inner } => inner.into(),
}
}
}
impl<S> From<catch_http_codes::Error<S, Option<ErrorBody>>> for TokenRequestError
where
S: Into<BoxError>,
{
fn from(err: catch_http_codes::Error<S, Option<ErrorBody>>) -> Self {
match err {
catch_http_codes::Error::HttpError { status_code, inner } => {
HttpError::new(status_code, inner).into()
}
catch_http_codes::Error::Service { inner } => Self::Service(inner.into()),
}
}
}
/// All possible errors when exchanging a code for an access token.
@@ -316,95 +118,13 @@ pub enum TokenRefreshError {
IdToken(#[from] IdTokenError),
}
/// All possible errors when revoking a token.
#[derive(Debug, Error)]
pub enum TokenRevokeError {
/// An error occurred building the request.
#[error(transparent)]
IntoHttp(#[from] http::Error),
/// An error occurred adding the client credentials to the request.
#[error(transparent)]
Credentials(#[from] CredentialsError),
/// An error occurred serializing the request.
#[error(transparent)]
UrlEncoded(#[from] serde_urlencoded::ser::Error),
/// An error occurred deserializing the error response.
#[error(transparent)]
Json(#[from] serde_json::Error),
/// The server returned an HTTP error status code.
#[error(transparent)]
Http(#[from] HttpError),
/// An error occurred sending the request.
#[error(transparent)]
Service(BoxError),
}
impl<S> From<form_urlencoded_request::Error<S>> for TokenRevokeError
where
S: Into<TokenRevokeError>,
{
fn from(err: form_urlencoded_request::Error<S>) -> Self {
match err {
form_urlencoded_request::Error::Serialize { inner } => inner.into(),
form_urlencoded_request::Error::Service { inner } => inner.into(),
}
}
}
impl<S> From<catch_http_codes::Error<S, Option<ErrorBody>>> for TokenRevokeError
where
S: Into<BoxError>,
{
fn from(err: catch_http_codes::Error<S, Option<ErrorBody>>) -> Self {
match err {
catch_http_codes::Error::HttpError { status_code, inner } => {
HttpError::new(status_code, inner).into()
}
catch_http_codes::Error::Service { inner } => Self::Service(inner.into()),
}
}
}
/// All possible errors when requesting user info.
#[derive(Debug, Error)]
pub enum UserInfoError {
/// An error occurred getting the provider metadata.
#[error(transparent)]
Discovery(#[from] Arc<DiscoveryError>),
/// The provider doesn't support requesting user info.
#[error("missing UserInfo support")]
MissingUserInfoSupport,
/// No token is available to get info from.
#[error("missing token")]
MissingToken,
/// No client metadata is available.
#[error("missing client metadata")]
MissingClientMetadata,
/// The access token is invalid.
#[error(transparent)]
Token(#[from] InvalidBearerToken),
/// An error occurred building the request.
#[error(transparent)]
IntoHttp(#[from] http::Error),
/// The content-type header is missing from the response.
#[error("missing response content-type")]
MissingResponseContentType,
/// The content-type header could not be decoded.
#[error("could not decoded response content-type: {0}")]
DecodeResponseContentType(#[from] ToStrError),
/// The content-type is not valid.
#[error("invalid response content-type")]
InvalidResponseContentTypeValue,
@@ -418,39 +138,13 @@ pub enum UserInfoError {
got: String,
},
/// An error occurred reading the response.
#[error(transparent)]
FromUtf8(#[from] Utf8Error),
/// An error occurred deserializing the JSON or error response.
#[error(transparent)]
Json(#[from] serde_json::Error),
/// An error occurred verifying the Id Token.
#[error(transparent)]
IdToken(#[from] IdTokenError),
/// The server returned an HTTP error status code.
#[error(transparent)]
Http(#[from] HttpError),
/// An error occurred sending the request.
#[error(transparent)]
Service(BoxError),
}
impl<S> From<catch_http_codes::Error<S, Option<ErrorBody>>> for UserInfoError
where
S: Into<BoxError>,
{
fn from(err: catch_http_codes::Error<S, Option<ErrorBody>>) -> Self {
match err {
catch_http_codes::Error::HttpError { status_code, inner } => {
HttpError::new(status_code, inner).into()
}
catch_http_codes::Error::Service { inner } => Self::Service(inner.into()),
}
}
Http(#[from] reqwest::Error),
}
/// All possible errors when introspecting a token.
@@ -644,11 +338,3 @@ pub enum CredentialsError {
#[error(transparent)]
Custom(BoxError),
}
/// All errors that can occur when building the account management URL.
#[derive(Debug, Error)]
pub enum AccountManagementError {
/// An error occurred serializing the parameters.
#[error(transparent)]
UrlEncoded(#[from] serde_urlencoded::ser::Error),
}

View File

@@ -52,7 +52,6 @@ pub mod error;
pub mod http_service;
pub mod requests;
pub mod types;
mod utils;
use std::fmt;

View File

@@ -1,121 +0,0 @@
// Copyright 2024 New Vector Ltd.
// Copyright 2024 Kévin Commaille.
//
// SPDX-License-Identifier: AGPL-3.0-only
// Please see LICENSE in the repository root for full details.
//! Methods related to the account management URL.
//!
//! This is a Matrix extension introduced in [MSC2965](https://github.com/matrix-org/matrix-spec-proposals/pull/2965).
use serde::Serialize;
use serde_with::skip_serializing_none;
use url::Url;
use crate::error::AccountManagementError;
/// An account management action that a user can take, including a device ID for
/// the actions that support it.
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
#[serde(tag = "action")]
#[non_exhaustive]
pub enum AccountManagementActionFull {
/// `org.matrix.profile`
///
/// The user wishes to view their profile (name, avatar, contact details).
#[serde(rename = "org.matrix.profile")]
Profile,
/// `org.matrix.sessions_list`
///
/// The user wishes to view a list of their sessions.
#[serde(rename = "org.matrix.sessions_list")]
SessionsList,
/// `org.matrix.session_view`
///
/// The user wishes to view the details of a specific session.
#[serde(rename = "org.matrix.session_view")]
SessionView {
/// The ID of the session to view the details of.
device_id: String,
},
/// `org.matrix.session_end`
///
/// The user wishes to end/log out of a specific session.
#[serde(rename = "org.matrix.session_end")]
SessionEnd {
/// The ID of the session to end.
device_id: String,
},
/// `org.matrix.account_deactivate`
///
/// The user wishes to deactivate their account.
#[serde(rename = "org.matrix.account_deactivate")]
AccountDeactivate,
/// `org.matrix.cross_signing_reset`
///
/// The user wishes to reset their cross-signing keys.
#[serde(rename = "org.matrix.cross_signing_reset")]
CrossSigningReset,
}
#[skip_serializing_none]
#[derive(Debug, Clone, Serialize)]
struct AccountManagementData {
#[serde(flatten)]
action: Option<AccountManagementActionFull>,
id_token_hint: Option<String>,
}
/// Build the URL for accessing the account management capabilities.
///
/// # Arguments
///
/// * `account_management_uri` - The URL to access the issuer's account
/// management capabilities.
///
/// * `action` - The action that the user wishes to take.
///
/// * `id_token_hint` - An ID Token that was previously issued to the client,
/// used as a hint for which user is requesting to manage their account.
///
/// # Returns
///
/// A URL to be opened in a web browser where the end-user will be able to
/// access the account management capabilities of the issuer.
///
/// # Errors
///
/// Returns an error if serializing the URL fails.
pub fn build_account_management_url(
mut account_management_uri: Url,
action: Option<AccountManagementActionFull>,
id_token_hint: Option<String>,
) -> Result<Url, AccountManagementError> {
let data = AccountManagementData {
action,
id_token_hint,
};
let extra_query = serde_urlencoded::to_string(data)?;
if !extra_query.is_empty() {
// Add our parameters to the query, because the URL might already have one.
let mut full_query = account_management_uri
.query()
.map(ToOwned::to_owned)
.unwrap_or_default();
if !full_query.is_empty() {
full_query.push('&');
}
full_query.push_str(&extra_query);
account_management_uri.set_query(Some(&full_query));
}
Ok(account_management_uri)
}

View File

@@ -12,9 +12,7 @@ use std::{collections::HashSet, num::NonZeroU32};
use base64ct::{Base64UrlUnpadded, Encoding};
use chrono::{DateTime, Utc};
use http::header::CONTENT_TYPE;
use language_tags::LanguageTag;
use mas_http::{CatchHttpCodesLayer, FormUrlencodedRequestLayer, JsonResponseLayer};
use mas_iana::oauth::{OAuthAuthorizationEndpointResponseType, PkceCodeChallengeMethod};
use mas_jose::claims::{self, TokenHash};
use oauth2_types::{
@@ -22,7 +20,7 @@ use oauth2_types::{
prelude::CodeChallengeMethodExt,
requests::{
AccessTokenRequest, AccessTokenResponse, AuthorizationCodeGrant, AuthorizationRequest,
Display, Prompt, PushedAuthorizationResponse,
Display, Prompt,
},
scope::Scope,
};
@@ -32,22 +30,17 @@ use rand::{
};
use serde::Serialize;
use serde_with::skip_serializing_none;
use tower::{Layer, Service, ServiceExt};
use url::Url;
use super::jose::JwtVerificationData;
use crate::{
error::{
AuthorizationError, IdTokenError, PushedAuthorizationError, TokenAuthorizationCodeError,
},
http_service::HttpService,
error::{AuthorizationError, IdTokenError, TokenAuthorizationCodeError},
requests::{jose::verify_id_token, token::request_access_token},
types::{
client_credentials::ClientCredentials,
scope::{ScopeExt, ScopeToken},
IdToken,
},
utils::{http_all_error_status_codes, http_error_mapper},
};
/// The data necessary to build an authorization request.
@@ -320,7 +313,6 @@ fn build_authorization_request(
///
/// [`VerifiedClientMetadata`]: oauth2_types::registration::VerifiedClientMetadata
/// [`ClientErrorCode`]: oauth2_types::errors::ClientErrorCode
#[allow(clippy::too_many_lines)]
pub fn build_authorization_url(
authorization_endpoint: Url,
authorization_data: AuthorizationRequestData,
@@ -353,115 +345,6 @@ pub fn build_authorization_url(
Ok((authorization_url, validation_data))
}
/// Make a [Pushed Authorization Request] and build the URL for authenticating
/// at the Authorization endpoint.
///
/// # Arguments
///
/// * `http_service` - The service to use for making HTTP requests.
///
/// * `client_credentials` - The credentials obtained when registering the
/// client.
///
/// * `par_endpoint` - The URL of the issuer's Pushed Authorization Request
/// endpoint.
///
/// * `authorization_endpoint` - The URL of the issuer's Authorization endpoint.
///
/// * `authorization_data` - The data necessary to build the authorization
/// request.
///
/// * `now` - The current time.
///
/// * `rng` - A random number generator.
///
/// # Returns
///
/// A URL to be opened in a web browser where the end-user will be able to
/// authorize the given scope, and the [`AuthorizationValidationData`] to
/// validate this request.
///
/// The redirect URI will receive parameters in its query:
///
/// * A successful response will receive a `code` and a `state`.
///
/// * If the authorization fails, it should receive an `error` parameter with a
/// [`ClientErrorCode`] and optionally an `error_description`.
///
/// # Errors
///
/// Returns an error if the request fails, the response is invalid or building
/// the URL fails.
///
/// [Pushed Authorization Request]: https://oauth.net/2/pushed-authorization-requests/
/// [`ClientErrorCode`]: oauth2_types::errors::ClientErrorCode
#[tracing::instrument(skip_all, fields(par_endpoint))]
pub async fn build_par_authorization_url(
http_service: &HttpService,
client_credentials: ClientCredentials,
par_endpoint: &Url,
authorization_endpoint: Url,
authorization_data: AuthorizationRequestData,
now: DateTime<Utc>,
rng: &mut impl Rng,
) -> Result<(Url, AuthorizationValidationData), AuthorizationError> {
tracing::debug!(
scope = ?authorization_data.scope,
"Authorizing with a PAR..."
);
let client_id = client_credentials.client_id().to_owned();
let (authorization_request, validation_data) =
build_authorization_request(authorization_data, rng)?;
let par_request = http::Request::post(par_endpoint.as_str())
.header(CONTENT_TYPE, mime::APPLICATION_WWW_FORM_URLENCODED.as_ref())
.body(authorization_request)
.map_err(PushedAuthorizationError::from)?;
let par_request = client_credentials
.apply_to_request(par_request, now, rng)
.map_err(PushedAuthorizationError::from)?;
let service = (
FormUrlencodedRequestLayer::default(),
JsonResponseLayer::<PushedAuthorizationResponse>::default(),
CatchHttpCodesLayer::new(http_all_error_status_codes(), http_error_mapper),
)
.layer(http_service.clone());
let par_response = service
.ready_oneshot()
.await
.map_err(PushedAuthorizationError::from)?
.call(par_request)
.await
.map_err(PushedAuthorizationError::from)?
.into_body();
let authorization_query = serde_urlencoded::to_string([
("request_uri", par_response.request_uri.as_str()),
("client_id", &client_id),
])?;
let mut authorization_url = authorization_endpoint;
// Add our parameters to the query, because the URL might already have one.
let mut full_query = authorization_url
.query()
.map(ToOwned::to_owned)
.unwrap_or_default();
if !full_query.is_empty() {
full_query.push('&');
}
full_query.push_str(&authorization_query);
authorization_url.set_query(Some(&full_query));
Ok((authorization_url, validation_data))
}
/// Exchange an authorization code for an access token.
///
/// This should be used as the first step for logging in, and to request a

View File

@@ -17,7 +17,7 @@ use rand::Rng;
use url::Url;
use crate::{
error::TokenRequestError, http_service::HttpService, requests::token::request_access_token,
error::TokenRequestError, requests::token::request_access_token,
types::client_credentials::ClientCredentials,
};
@@ -46,7 +46,7 @@ use crate::{
/// Returns an error if the request fails or the response is invalid.
#[tracing::instrument(skip_all, fields(token_endpoint))]
pub async fn access_token_with_client_credentials(
http_service: &HttpService,
http_client: &reqwest::Client,
client_credentials: ClientCredentials,
token_endpoint: &Url,
scope: Option<Scope>,
@@ -56,7 +56,7 @@ pub async fn access_token_with_client_credentials(
tracing::debug!("Requesting access token with client credentials...");
request_access_token(
http_service,
http_client,
client_credentials,
token_endpoint,
AccessTokenRequest::ClientCredentials(ClientCredentialsGrant { scope }),

View File

@@ -8,6 +8,7 @@
//!
//! [Discovery]: https://openid.net/specs/openid-connect-discovery-1_0.html
use mas_http::RequestBuilderExt;
use oauth2_types::oidc::{ProviderMetadata, VerifiedProviderMetadata};
use url::Url;
@@ -34,7 +35,7 @@ async fn discover_inner(
let response = client
.get(config_url.as_str())
.send()
.send_traced()
.await?
.error_for_status()?
.json()

View File

@@ -1,145 +0,0 @@
// Copyright 2024 New Vector Ltd.
// Copyright 2022-2024 Kévin Commaille.
//
// SPDX-License-Identifier: AGPL-3.0-only
// Please see LICENSE in the repository root for full details.
//! Requests for [Token Introspection].
//!
//! [Token Introspection]: https://www.rfc-editor.org/rfc/rfc7662
use chrono::{DateTime, Utc};
use headers::{Authorization, HeaderMapExt};
use http::Request;
use mas_http::{CatchHttpCodesLayer, FormUrlencodedRequestLayer, JsonResponseLayer};
use mas_iana::oauth::OAuthTokenTypeHint;
use oauth2_types::requests::{IntrospectionRequest, IntrospectionResponse};
use rand::Rng;
use serde::Serialize;
use tower::{Layer, Service, ServiceExt};
use url::Url;
use crate::{
error::IntrospectionError,
http_service::HttpService,
types::client_credentials::{ClientCredentials, RequestWithClientCredentials},
utils::{http_all_error_status_codes, http_error_mapper},
};
/// The method used to authenticate at the introspection endpoint.
pub enum IntrospectionAuthentication<'a> {
/// Using client authentication.
Credentials(ClientCredentials),
/// Using a bearer token.
BearerToken(&'a str),
}
impl<'a> IntrospectionAuthentication<'a> {
/// Constructs an `IntrospectionAuthentication` from the given client
/// credentials.
#[must_use]
pub fn with_client_credentials(credentials: ClientCredentials) -> Self {
Self::Credentials(credentials)
}
/// Constructs an `IntrospectionAuthentication` from the given bearer token.
#[must_use]
pub fn with_bearer_token(token: &'a str) -> Self {
Self::BearerToken(token)
}
fn apply_to_request<T: Serialize>(
self,
request: Request<T>,
now: DateTime<Utc>,
rng: &mut impl Rng,
) -> Result<Request<RequestWithClientCredentials<T>>, IntrospectionError> {
let res = match self {
IntrospectionAuthentication::Credentials(client_credentials) => {
client_credentials.apply_to_request(request, now, rng)?
}
IntrospectionAuthentication::BearerToken(access_token) => {
let (mut parts, body) = request.into_parts();
parts
.headers
.typed_insert(Authorization::bearer(access_token)?);
let body = RequestWithClientCredentials {
body,
credentials: None,
};
http::Request::from_parts(parts, body)
}
};
Ok(res)
}
}
impl<'a> From<ClientCredentials> for IntrospectionAuthentication<'a> {
fn from(credentials: ClientCredentials) -> Self {
Self::with_client_credentials(credentials)
}
}
/// Obtain information about a token.
///
/// # Arguments
///
/// * `http_service` - The service to use for making HTTP requests.
///
/// * `authentication` - The method used to authenticate the request.
///
/// * `revocation_endpoint` - The URL of the issuer's Revocation endpoint.
///
/// * `token` - The token to introspect.
///
/// * `token_type_hint` - Hint about the type of the token.
///
/// * `now` - The current time.
///
/// * `rng` - A random number generator.
///
/// # Errors
///
/// Returns an error if the request fails or the response is invalid.
#[tracing::instrument(skip_all, fields(introspection_endpoint))]
pub async fn introspect_token(
http_service: &HttpService,
authentication: IntrospectionAuthentication<'_>,
introspection_endpoint: &Url,
token: String,
token_type_hint: Option<OAuthTokenTypeHint>,
now: DateTime<Utc>,
rng: &mut impl Rng,
) -> Result<IntrospectionResponse, IntrospectionError> {
tracing::debug!("Introspecting token…");
let introspection_request = IntrospectionRequest {
token,
token_type_hint,
};
let introspection_request =
http::Request::post(introspection_endpoint.as_str()).body(introspection_request)?;
let introspection_request = authentication.apply_to_request(introspection_request, now, rng)?;
let service = (
FormUrlencodedRequestLayer::default(),
JsonResponseLayer::<IntrospectionResponse>::default(),
CatchHttpCodesLayer::new(http_all_error_status_codes(), http_error_mapper),
)
.layer(http_service.clone());
let introspection_response = service
.ready_oneshot()
.await?
.call(introspection_request)
.await?
.into_body();
Ok(introspection_response)
}

View File

@@ -9,6 +9,7 @@
use std::collections::HashMap;
use chrono::{DateTime, Utc};
use mas_http::RequestBuilderExt;
use mas_iana::jose::JsonWebSignatureAlg;
use mas_jose::{
claims::{self, TimeOptions},
@@ -43,7 +44,7 @@ pub async fn fetch_jwks(
let response: PublicJsonWebKeySet = client
.get(jwks_uri.as_str())
.send()
.send_traced()
.await?
.error_for_status()?
.json()

View File

@@ -6,15 +6,11 @@
//! Methods to interact with OpenID Connect and OAuth2.0 endpoints.
pub mod account_management;
pub mod authorization_code;
pub mod client_credentials;
pub mod discovery;
pub mod introspection;
pub mod jose;
pub mod refresh_token;
pub mod registration;
pub mod revocation;
pub mod rp_initiated_logout;
pub mod token;
pub mod userinfo;

View File

@@ -20,7 +20,6 @@ use url::Url;
use super::jose::JwtVerificationData;
use crate::{
error::{IdTokenError, TokenRefreshError},
http_service::HttpService,
requests::{jose::verify_id_token, token::request_access_token},
types::{client_credentials::ClientCredentials, IdToken},
};
@@ -68,7 +67,7 @@ use crate::{
#[allow(clippy::too_many_arguments)]
#[tracing::instrument(skip_all, fields(token_endpoint))]
pub async fn refresh_access_token(
http_service: &HttpService,
http_client: &reqwest::Client,
client_credentials: ClientCredentials,
token_endpoint: &Url,
refresh_token: String,
@@ -81,7 +80,7 @@ pub async fn refresh_access_token(
tracing::debug!("Refreshing access token…");
let token_response = request_access_token(
http_service,
http_client,
client_credentials,
token_endpoint,
AccessTokenRequest::RefreshToken(RefreshTokenGrant {

View File

@@ -1,91 +0,0 @@
// Copyright 2024 New Vector Ltd.
// Copyright 2022-2024 Kévin Commaille.
//
// SPDX-License-Identifier: AGPL-3.0-only
// Please see LICENSE in the repository root for full details.
//! Requests for [Dynamic Registration].
//!
//! [Dynamic Registration]: https://openid.net/specs/openid-connect-registration-1_0.html
use mas_http::{CatchHttpCodesLayer, JsonRequestLayer, JsonResponseLayer};
use mas_iana::oauth::OAuthClientAuthenticationMethod;
use oauth2_types::registration::{ClientRegistrationResponse, VerifiedClientMetadata};
use serde::Serialize;
use serde_with::skip_serializing_none;
use tower::{Layer, Service, ServiceExt};
use url::Url;
use crate::{
error::RegistrationError,
http_service::HttpService,
utils::{http_all_error_status_codes, http_error_mapper},
};
#[skip_serializing_none]
#[derive(Serialize)]
struct RegistrationRequest {
#[serde(flatten)]
client_metadata: VerifiedClientMetadata,
software_statement: Option<String>,
}
/// Register a client with an OpenID Provider.
///
/// # Arguments
///
/// * `http_service` - The service to use for making HTTP requests.
///
/// * `registration_endpoint` - The URL of the issuer's Registration endpoint.
///
/// * `client_metadata` - The metadata to register with the issuer.
///
/// * `software_statement` - A JWT that asserts metadata values about the client
/// software that should be signed.
///
/// # Errors
///
/// Returns an error if the request fails or the response is invalid.
#[tracing::instrument(skip_all, fields(registration_endpoint))]
pub async fn register_client(
http_service: &HttpService,
registration_endpoint: &Url,
client_metadata: VerifiedClientMetadata,
software_statement: Option<String>,
) -> Result<ClientRegistrationResponse, RegistrationError> {
tracing::debug!("Registering client...");
let should_receive_secret = matches!(
client_metadata.token_endpoint_auth_method(),
OAuthClientAuthenticationMethod::ClientSecretPost
| OAuthClientAuthenticationMethod::ClientSecretBasic
| OAuthClientAuthenticationMethod::ClientSecretJwt
);
let body = RegistrationRequest {
client_metadata,
software_statement,
};
let registration_req = http::Request::post(registration_endpoint.as_str()).body(body)?;
let service = (
JsonRequestLayer::default(),
JsonResponseLayer::<ClientRegistrationResponse>::default(),
CatchHttpCodesLayer::new(http_all_error_status_codes(), http_error_mapper),
)
.layer(http_service.clone());
let response = service
.ready_oneshot()
.await?
.call(registration_req)
.await?
.into_body();
if should_receive_secret && response.client_secret.is_none() {
return Err(RegistrationError::MissingClientSecret);
}
Ok(response)
}

View File

@@ -1,82 +0,0 @@
// Copyright 2024 New Vector Ltd.
// Copyright 2022-2024 Kévin Commaille.
//
// SPDX-License-Identifier: AGPL-3.0-only
// Please see LICENSE in the repository root for full details.
//! Requests for [Token Revocation].
//!
//! [Token Revocation]: https://www.rfc-editor.org/rfc/rfc7009.html
use chrono::{DateTime, Utc};
use mas_http::{CatchHttpCodesLayer, FormUrlencodedRequestLayer};
use mas_iana::oauth::OAuthTokenTypeHint;
use oauth2_types::requests::IntrospectionRequest;
use rand::Rng;
use tower::{Layer, Service, ServiceExt};
use url::Url;
use crate::{
error::TokenRevokeError,
http_service::HttpService,
types::client_credentials::ClientCredentials,
utils::{http_all_error_status_codes, http_error_mapper},
};
/// Revoke a token.
///
/// # Arguments
///
/// * `http_service` - The service to use for making HTTP requests.
///
/// * `client_credentials` - The credentials obtained when registering the
/// client.
///
/// * `revocation_endpoint` - The URL of the issuer's Revocation endpoint.
///
/// * `token` - The token to revoke.
///
/// * `token_type_hint` - Hint about the type of the token.
///
/// * `now` - The current time.
///
/// * `rng` - A random number generator.
///
/// # Errors
///
/// Returns an error if the request fails or the response is invalid.
#[tracing::instrument(skip_all, fields(revocation_endpoint))]
pub async fn revoke_token(
http_service: &HttpService,
client_credentials: ClientCredentials,
revocation_endpoint: &Url,
token: String,
token_type_hint: Option<OAuthTokenTypeHint>,
now: DateTime<Utc>,
rng: &mut impl Rng,
) -> Result<(), TokenRevokeError> {
tracing::debug!("Revoking token…");
let request = IntrospectionRequest {
token,
token_type_hint,
};
let revocation_request = http::Request::post(revocation_endpoint.as_str()).body(request)?;
let revocation_request = client_credentials.apply_to_request(revocation_request, now, rng)?;
let service = (
FormUrlencodedRequestLayer::default(),
CatchHttpCodesLayer::new(http_all_error_status_codes(), http_error_mapper),
)
.layer(http_service.clone());
service
.ready_oneshot()
.await?
.call(revocation_request)
.await?;
Ok(())
}

View File

@@ -7,18 +7,12 @@
//! Requests for the Token endpoint.
use chrono::{DateTime, Utc};
use mas_http::{CatchHttpCodesLayer, FormUrlencodedRequestLayer, JsonResponseLayer};
use mas_http::RequestBuilderExt;
use oauth2_types::requests::{AccessTokenRequest, AccessTokenResponse};
use rand::Rng;
use tower::{Layer, Service, ServiceExt};
use url::Url;
use crate::{
error::TokenRequestError,
http_service::HttpService,
types::client_credentials::ClientCredentials,
utils::{http_all_error_status_codes, http_error_mapper},
};
use crate::{error::TokenRequestError, types::client_credentials::ClientCredentials};
/// Request an access token.
///
@@ -51,13 +45,15 @@ pub async fn request_access_token(
) -> Result<AccessTokenResponse, TokenRequestError> {
tracing::debug!(?request, "Requesting access token...");
let token_request = http_client.post(token_endpoint.as_str()).form(&request);
let token_request = http_client.post(token_endpoint.as_str());
let token_request = client_credentials.apply_to_request(token_request, now, rng)?;
let res = service.ready_oneshot().await?.call(token_request).await?;
let token_response = res.into_body();
let token_response = client_credentials
.authenticated_form(token_request, &request, now, rng)?
.send_traced()
.await?
.error_for_status()?
.json()
.await?;
Ok(token_response)
}

View File

@@ -10,23 +10,19 @@
use std::collections::HashMap;
use bytes::Bytes;
use headers::{Authorization, ContentType, HeaderMapExt, HeaderValue};
use headers::{ContentType, HeaderMapExt, HeaderValue};
use http::header::ACCEPT;
use mas_http::CatchHttpCodesLayer;
use mas_http::RequestBuilderExt;
use mas_jose::claims;
use mime::Mime;
use serde_json::Value;
use tower::{Layer, Service, ServiceExt};
use url::Url;
use super::jose::JwtVerificationData;
use crate::{
error::{IdTokenError, UserInfoError},
http_service::HttpService,
requests::jose::verify_signed_jwt,
types::IdToken,
utils::{http_all_error_status_codes, http_error_mapper},
};
/// Obtain information about an authenticated end-user.
@@ -59,7 +55,7 @@ use crate::{
/// [`Claim`]: mas_jose::claims::Claim
#[tracing::instrument(skip_all, fields(userinfo_endpoint))]
pub async fn fetch_userinfo(
http_service: &HttpService,
http_client: &reqwest::Client,
userinfo_endpoint: &Url,
access_token: &str,
jwt_verification_data: Option<JwtVerificationData<'_>>,
@@ -67,29 +63,18 @@ pub async fn fetch_userinfo(
) -> Result<HashMap<String, Value>, UserInfoError> {
tracing::debug!("Obtaining user info…");
let mut userinfo_request = http::Request::get(userinfo_endpoint.as_str());
let expected_content_type = if jwt_verification_data.is_some() {
"application/jwt"
} else {
mime::APPLICATION_JSON.as_ref()
};
if let Some(headers) = userinfo_request.headers_mut() {
headers.typed_insert(Authorization::bearer(access_token)?);
headers.insert(ACCEPT, HeaderValue::from_static(expected_content_type));
}
let userinfo_request = http_client
.get(userinfo_endpoint.as_str())
.bearer_auth(access_token)
.header(ACCEPT, HeaderValue::from_static(expected_content_type));
let userinfo_request = userinfo_request.body(Bytes::new())?;
let service = CatchHttpCodesLayer::new(http_all_error_status_codes(), http_error_mapper)
.layer(http_service.clone());
let userinfo_response = service
.ready_oneshot()
.await?
.call(userinfo_request)
.await?;
let userinfo_response = userinfo_request.send_traced().await?.error_for_status()?;
let content_type: Mime = userinfo_response
.headers()
@@ -105,15 +90,14 @@ pub async fn fetch_userinfo(
});
}
let response_body = std::str::from_utf8(userinfo_response.body())?;
let mut claims = if let Some(verification_data) = jwt_verification_data {
verify_signed_jwt(response_body, verification_data)
let response_body = userinfo_response.text().await?;
verify_signed_jwt(&response_body, verification_data)
.map_err(IdTokenError::from)?
.into_parts()
.1
} else {
serde_json::from_str(response_body)?
userinfo_response.json().await?
};
let mut auth_claims = auth_id_token.payload().clone();

View File

@@ -10,8 +10,6 @@ use std::{collections::HashMap, fmt, sync::Arc};
use base64ct::{Base64UrlUnpadded, Encoding};
use chrono::{DateTime, Duration, Utc};
use headers::{Authorization, HeaderMapExt};
use http::Request;
use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
#[cfg(feature = "keystore")]
use mas_jose::constraints::Constrainable;
@@ -25,7 +23,6 @@ use mas_keystore::Keystore;
use rand::Rng;
use serde::Serialize;
use serde_json::Value;
use serde_with::skip_serializing_none;
use tower::BoxError;
use url::Url;
@@ -175,45 +172,113 @@ impl ClientCredentials {
}
}
/// Apply these `ClientCredentials` to the given request.
pub(crate) fn apply_to_request<T: Serialize>(
self,
/// Apply these [`ClientCredentials`] to the given request with the given
/// form.
pub(crate) fn authenticated_form<T: Serialize>(
&self,
request: reqwest::RequestBuilder,
form: &T,
now: DateTime<Utc>,
rng: &mut impl Rng,
) -> Result<reqwest::RequestBuilder, CredentialsError> {
// TODO: get the form in params, augment it and serialize
let credentials = RequestClientCredentials::try_from_credentials(self, now, rng)?;
let request = match self {
ClientCredentials::None { client_id } => request.form(&RequestWithClientCredentials {
body: form,
client_id,
client_secret: None,
client_assertion: None,
client_assertion_type: None,
}),
let (parts, body) = request.into_parts();
let mut body = RequestWithClientCredentials {
body,
credentials: None,
};
let request = match credentials {
RequestClientCredentials::Body(credentials) => {
body.credentials = Some(credentials);
Request::from_parts(parts, body)
}
RequestClientCredentials::Header(credentials) => {
let HeaderClientCredentials {
ClientCredentials::ClientSecretBasic {
client_id,
client_secret,
} => request.basic_auth(client_id, Some(client_secret)).form(
&RequestWithClientCredentials {
body: form,
client_id,
client_secret,
} = credentials;
client_secret: None,
client_assertion: None,
client_assertion_type: None,
},
),
let mut request = Request::from_parts(parts, body);
ClientCredentials::ClientSecretPost {
client_id,
client_secret,
} => request.form(&RequestWithClientCredentials {
body: form,
client_id,
client_secret: Some(client_secret),
client_assertion: None,
client_assertion_type: None,
}),
// Encode the values with `application/x-www-form-urlencoded`.
let client_id =
form_urlencoded::byte_serialize(client_id.as_bytes()).collect::<String>();
let client_secret =
form_urlencoded::byte_serialize(client_secret.as_bytes()).collect::<String>();
ClientCredentials::ClientSecretJwt {
client_id,
client_secret,
signing_algorithm,
token_endpoint,
} => {
let claims =
prepare_claims(client_id.clone(), token_endpoint.to_string(), now, rng)?;
let key = SymmetricKey::new_for_alg(
client_secret.as_bytes().to_vec(),
signing_algorithm,
)?;
let header = JsonWebSignatureHeader::new(signing_algorithm.clone());
let auth = Authorization::basic(&client_id, &client_secret);
request.headers_mut().typed_insert(auth);
let jwt = Jwt::sign(header, claims, &key)?;
request
request.form(&RequestWithClientCredentials {
body: form,
client_id,
client_secret: None,
client_assertion: Some(jwt.as_str()),
client_assertion_type: Some(JwtBearerClientAssertionType),
})
}
ClientCredentials::PrivateKeyJwt {
client_id,
jwt_signing_method,
signing_algorithm,
token_endpoint,
} => {
let claims =
prepare_claims(client_id.clone(), token_endpoint.to_string(), now, rng)?;
let client_assertion = match jwt_signing_method {
#[cfg(feature = "keystore")]
JwtSigningMethod::Keystore(keystore) => {
let key = keystore
.signing_key_for_algorithm(signing_algorithm)
.ok_or(CredentialsError::NoPrivateKeyFound)?;
let signer = key
.params()
.signing_key_for_alg(signing_algorithm)
.map_err(|_| CredentialsError::JwtWrongAlgorithm)?;
let mut header = JsonWebSignatureHeader::new(signing_algorithm.clone());
if let Some(kid) = key.kid() {
header = header.with_kid(kid);
}
Jwt::sign(header, claims, &signer)?.to_string()
}
JwtSigningMethod::Custom(jwt_signing_fn) => {
jwt_signing_fn(claims, signing_algorithm.clone())
.map_err(CredentialsError::Custom)?
}
};
request.form(&RequestWithClientCredentials {
body: form,
client_id,
client_secret: None,
client_assertion: Some(&client_assertion),
client_assertion_type: Some(JwtBearerClientAssertionType),
})
}
};
@@ -264,123 +329,7 @@ impl fmt::Debug for ClientCredentials {
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
#[serde(rename = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer")]
pub(crate) struct JwtBearerClientAssertionType;
enum RequestClientCredentials {
Body(BodyClientCredentials),
Header(HeaderClientCredentials),
}
impl RequestClientCredentials {
fn try_from_credentials(
credentials: ClientCredentials,
now: DateTime<Utc>,
rng: &mut impl Rng,
) -> Result<Self, CredentialsError> {
let res = match credentials {
ClientCredentials::None { client_id } => Self::Body(BodyClientCredentials {
client_id,
client_secret: None,
client_assertion: None,
client_assertion_type: None,
}),
ClientCredentials::ClientSecretBasic {
client_id,
client_secret,
} => Self::Header(HeaderClientCredentials {
client_id,
client_secret,
}),
ClientCredentials::ClientSecretPost {
client_id,
client_secret,
} => Self::Body(BodyClientCredentials {
client_id,
client_secret: Some(client_secret),
client_assertion: None,
client_assertion_type: None,
}),
ClientCredentials::ClientSecretJwt {
client_id,
client_secret,
signing_algorithm,
token_endpoint,
} => {
let claims =
prepare_claims(client_id.clone(), token_endpoint.to_string(), now, rng)?;
let key = SymmetricKey::new_for_alg(client_secret.into(), &signing_algorithm)?;
let header = JsonWebSignatureHeader::new(signing_algorithm);
let jwt = Jwt::sign(header, claims, &key)?;
Self::Body(BodyClientCredentials {
client_id,
client_secret: None,
client_assertion: Some(jwt.to_string()),
client_assertion_type: Some(JwtBearerClientAssertionType),
})
}
ClientCredentials::PrivateKeyJwt {
client_id,
jwt_signing_method,
signing_algorithm,
token_endpoint,
} => {
let claims =
prepare_claims(client_id.clone(), token_endpoint.to_string(), now, rng)?;
let client_assertion = match jwt_signing_method {
#[cfg(feature = "keystore")]
JwtSigningMethod::Keystore(keystore) => {
let key = keystore
.signing_key_for_algorithm(&signing_algorithm)
.ok_or(CredentialsError::NoPrivateKeyFound)?;
let signer = key
.params()
.signing_key_for_alg(&signing_algorithm)
.map_err(|_| CredentialsError::JwtWrongAlgorithm)?;
let mut header = JsonWebSignatureHeader::new(signing_algorithm);
if let Some(kid) = key.kid() {
header = header.with_kid(kid);
}
Jwt::sign(header, claims, &signer)?.to_string()
}
JwtSigningMethod::Custom(jwt_signing_fn) => {
jwt_signing_fn(claims, signing_algorithm)
.map_err(CredentialsError::Custom)?
}
};
Self::Body(BodyClientCredentials {
client_id,
client_secret: None,
client_assertion: Some(client_assertion),
client_assertion_type: Some(JwtBearerClientAssertionType),
})
}
};
Ok(res)
}
}
#[allow(clippy::struct_field_names)] // All the fields start with `client_`
#[skip_serializing_none]
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub(crate) struct BodyClientCredentials {
client_id: String,
client_secret: Option<String>,
client_assertion: Option<String>,
client_assertion_type: Option<JwtBearerClientAssertionType>,
}
#[derive(Debug, Clone)]
struct HeaderClientCredentials {
client_id: String,
client_secret: String,
}
struct JwtBearerClientAssertionType;
fn prepare_claims(
iss: String,
@@ -409,14 +358,20 @@ fn prepare_claims(
/// A request with client credentials added to it.
#[derive(Clone, Serialize)]
#[skip_serializing_none]
pub struct RequestWithClientCredentials<T> {
struct RequestWithClientCredentials<'a, T> {
#[serde(flatten)]
pub(crate) body: T,
#[serde(flatten)]
pub(crate) credentials: Option<BodyClientCredentials>,
body: T,
client_id: &'a str,
#[serde(skip_serializing_if = "Option::is_none")]
client_secret: Option<&'a str>,
#[serde(skip_serializing_if = "Option::is_none")]
client_assertion: Option<&'a str>,
#[serde(skip_serializing_if = "Option::is_none")]
client_assertion_type: Option<JwtBearerClientAssertionType>,
}
/*
#[cfg(test)]
mod test {
use assert_matches::assert_matches;
@@ -442,91 +397,6 @@ mod test {
Utc::now()
}
#[test]
fn serialize_credentials() {
assert_eq!(
serde_urlencoded::to_string(BodyClientCredentials {
client_id: CLIENT_ID.to_owned(),
client_secret: None,
client_assertion: None,
client_assertion_type: None,
})
.unwrap(),
"client_id=abcd%24%2B%2B"
);
assert_eq!(
serde_urlencoded::to_string(BodyClientCredentials {
client_id: CLIENT_ID.to_owned(),
client_secret: Some(CLIENT_SECRET.to_owned()),
client_assertion: None,
client_assertion_type: None,
})
.unwrap(),
"client_id=abcd%24%2B%2B&client_secret=xyz%21%3B%3F"
);
assert_eq!(
serde_urlencoded::to_string(BodyClientCredentials {
client_id: CLIENT_ID.to_owned(),
client_secret: None,
client_assertion: Some(CLIENT_SECRET.to_owned()),
client_assertion_type: Some(JwtBearerClientAssertionType)
})
.unwrap(),
"client_id=abcd%24%2B%2B&client_assertion=xyz%21%3B%3F&client_assertion_type=urn%3Aietf%3Aparams%3Aoauth%3Aclient-assertion-type%3Ajwt-bearer"
);
}
#[test]
fn serialize_request_with_credentials() {
let req = RequestWithClientCredentials {
body: Body { body: REQUEST_BODY },
credentials: None,
};
assert_eq!(serde_urlencoded::to_string(req).unwrap(), "body=some_body");
let req = RequestWithClientCredentials {
body: Body { body: REQUEST_BODY },
credentials: Some(BodyClientCredentials {
client_id: CLIENT_ID.to_owned(),
client_secret: None,
client_assertion: None,
client_assertion_type: None,
}),
};
assert_eq!(
serde_urlencoded::to_string(req).unwrap(),
"body=some_body&client_id=abcd%24%2B%2B"
);
let req = RequestWithClientCredentials {
body: Body { body: REQUEST_BODY },
credentials: Some(BodyClientCredentials {
client_id: CLIENT_ID.to_owned(),
client_secret: Some(CLIENT_SECRET.to_owned()),
client_assertion: None,
client_assertion_type: None,
}),
};
assert_eq!(
serde_urlencoded::to_string(req).unwrap(),
"body=some_body&client_id=abcd%24%2B%2B&client_secret=xyz%21%3B%3F"
);
let req = RequestWithClientCredentials {
body: Body { body: REQUEST_BODY },
credentials: Some(BodyClientCredentials {
client_id: CLIENT_ID.to_owned(),
client_secret: None,
client_assertion: Some(CLIENT_SECRET.to_owned()),
client_assertion_type: Some(JwtBearerClientAssertionType),
}),
};
assert_eq!(
serde_urlencoded::to_string(req).unwrap(),
"body=some_body&client_id=abcd%24%2B%2B&client_assertion=xyz%21%3B%3F&client_assertion_type=urn%3Aietf%3Aparams%3Aoauth%3Aclient-assertion-type%3Ajwt-bearer"
);
}
#[tokio::test]
async fn build_request_none() {
let credentials = ClientCredentials::None {
@@ -677,3 +547,5 @@ mod test {
credentials.client_assertion_type.unwrap();
}
}
*/

View File

@@ -1,31 +0,0 @@
// Copyright 2024 New Vector Ltd.
// Copyright 2022-2024 Kévin Commaille.
//
// SPDX-License-Identifier: AGPL-3.0-only
// Please see LICENSE in the repository root for full details.
use std::ops::RangeBounds;
use bytes::Buf;
use http::{Response, StatusCode};
use crate::error::ErrorBody;
pub fn http_error_mapper<T>(response: Response<T>) -> Option<ErrorBody>
where
T: Buf,
{
let body = response.into_body();
serde_json::from_reader(body.reader()).ok()
}
pub fn http_all_error_status_codes() -> impl RangeBounds<StatusCode> {
let Ok(client_errors_start_code) = StatusCode::from_u16(400) else {
unreachable!()
};
let Ok(server_errors_end_code) = StatusCode::from_u16(599) else {
unreachable!()
};
client_errors_start_code..=server_errors_end_code
}

View File

@@ -7,8 +7,6 @@
use std::collections::HashMap;
use chrono::{DateTime, Duration, Utc};
use http_body_util::Full;
use mas_http::{BodyToBytesResponseLayer, BoxCloneSyncService};
use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
use mas_jose::{
claims::{self, hash_token},
@@ -17,21 +15,14 @@ use mas_jose::{
jwt::{JsonWebSignatureHeader, Jwt},
};
use mas_keystore::{JsonWebKey, JsonWebKeySet, Keystore, PrivateKey};
use mas_oidc_client::{
http_service::HttpService,
types::{
client_credentials::{ClientCredentials, JwtSigningFn, JwtSigningMethod},
IdToken,
},
use mas_oidc_client::types::{
client_credentials::{ClientCredentials, JwtSigningFn, JwtSigningMethod},
IdToken,
};
use rand::{
distributions::{Alphanumeric, DistString},
SeedableRng,
};
use tower::{
util::{MapErrLayer, MapRequestLayer},
BoxError, Layer,
};
use url::Url;
use wiremock::MockServer;
@@ -41,7 +32,6 @@ mod types;
const REDIRECT_URI: &str = "http://localhost/";
const CLIENT_ID: &str = "client!+ID";
const CLIENT_SECRET: &str = "SECRET?%Gclient";
const REQUEST_URI: &str = "REQUESTur1";
const AUTHORIZATION_CODE: &str = "authC0D3";
const CODE_VERIFIER: &str = "cODEv3R1f1ER";
const NONCE: &str = "No0o0o0once";

View File

@@ -1,127 +0,0 @@
// Copyright 2024 New Vector Ltd.
// Copyright 2024 Kévin Commaille.
//
// SPDX-License-Identifier: AGPL-3.0-only
// Please see LICENSE in the repository root for full details.
use std::collections::HashMap;
use mas_oidc_client::requests::account_management::{
build_account_management_url, AccountManagementActionFull,
};
use url::Url;
#[test]
fn build_url() {
let account_management_uri = Url::parse("http://localhost/account_management/").unwrap();
// No params
let url = build_account_management_url(account_management_uri.clone(), None, None).unwrap();
assert_eq!(url.query(), None);
// Action without device ID.
let url = build_account_management_url(
account_management_uri.clone(),
Some(AccountManagementActionFull::Profile),
None,
)
.unwrap();
let query_pairs = url.query_pairs().collect::<HashMap<_, _>>();
assert_eq!(query_pairs.len(), 1);
assert_eq!(query_pairs.get("action").unwrap(), "org.matrix.profile");
// Action with device ID.
let url = build_account_management_url(
account_management_uri.clone(),
Some(AccountManagementActionFull::SessionEnd {
device_id: "mydevice".to_owned(),
}),
None,
)
.unwrap();
let query_pairs = url.query_pairs().collect::<HashMap<_, _>>();
assert_eq!(query_pairs.len(), 2);
assert_eq!(query_pairs.get("action").unwrap(), "org.matrix.session_end");
assert_eq!(query_pairs.get("device_id").unwrap(), "mydevice");
// ID Token hint.
let url = build_account_management_url(
account_management_uri.clone(),
None,
Some("anidtokenthat.might.looksomethinglikethis".to_owned()),
)
.unwrap();
let query_pairs = url.query_pairs().collect::<HashMap<_, _>>();
assert_eq!(query_pairs.len(), 1);
assert_eq!(
query_pairs.get("id_token_hint").unwrap(),
"anidtokenthat.might.looksomethinglikethis"
);
// Action without device ID and ID Token hint.
let url = build_account_management_url(
account_management_uri.clone(),
Some(AccountManagementActionFull::AccountDeactivate),
Some("anotheridtokenthat.might.looksomethinglikethis".to_owned()),
)
.unwrap();
let query_pairs = url.query_pairs().collect::<HashMap<_, _>>();
assert_eq!(query_pairs.len(), 2);
assert_eq!(
query_pairs.get("action").unwrap(),
"org.matrix.account_deactivate"
);
assert_eq!(
query_pairs.get("id_token_hint").unwrap(),
"anotheridtokenthat.might.looksomethinglikethis"
);
// Action with device ID and ID Token hint.
let url = build_account_management_url(
account_management_uri,
Some(AccountManagementActionFull::SessionView {
device_id: "myseconddevice".to_owned(),
}),
Some("athirdidtokenthat.might.looksomethinglikethis".to_owned()),
)
.unwrap();
let query_pairs = url.query_pairs().collect::<HashMap<_, _>>();
assert_eq!(query_pairs.len(), 3);
assert_eq!(
query_pairs.get("action").unwrap(),
"org.matrix.session_view"
);
assert_eq!(query_pairs.get("device_id").unwrap(), "myseconddevice");
assert_eq!(
query_pairs.get("id_token_hint").unwrap(),
"athirdidtokenthat.might.looksomethinglikethis"
);
// Account management URI with a query already.
let account_management_uri_with_query =
Url::parse("http://localhost/account_management?param=value").unwrap();
let url = build_account_management_url(
account_management_uri_with_query,
Some(AccountManagementActionFull::SessionsList),
Some("afinalidtokenthat.might.looksomethinglikethis".to_owned()),
)
.unwrap();
let query_pairs = url.query_pairs().collect::<HashMap<_, _>>();
assert_eq!(query_pairs.len(), 3);
assert_eq!(
query_pairs.get("action").unwrap(),
"org.matrix.sessions_list"
);
assert_eq!(
query_pairs.get("id_token_hint").unwrap(),
"afinalidtokenthat.might.looksomethinglikethis"
);
}

View File

@@ -4,34 +4,26 @@
// SPDX-License-Identifier: AGPL-3.0-only
// Please see LICENSE in the repository root for full details.
use std::{
collections::HashMap,
num::NonZeroU32,
sync::{Arc, Mutex},
};
use std::{collections::HashMap, num::NonZeroU32};
use assert_matches::assert_matches;
use chrono::Duration;
use mas_iana::oauth::{
OAuthAccessTokenType, OAuthClientAuthenticationMethod, PkceCodeChallengeMethod,
};
use mas_jose::{claims::ClaimError, jwk::PublicJsonWebKeySet};
use mas_oidc_client::{
error::{
AuthorizationError, IdTokenError, PushedAuthorizationError, TokenAuthorizationCodeError,
},
error::{IdTokenError, TokenAuthorizationCodeError},
requests::{
authorization_code::{
access_token_with_authorization_code, build_authorization_url,
build_par_authorization_url, AuthorizationRequestData, AuthorizationValidationData,
AuthorizationRequestData, AuthorizationValidationData,
},
jose::JwtVerificationData,
},
types::scope::{ScopeExt, ScopeToken},
};
use oauth2_types::requests::{AccessTokenResponse, Display, Prompt, PushedAuthorizationResponse};
use oauth2_types::requests::{AccessTokenResponse, Display, Prompt};
use rand::SeedableRng;
use tokio::sync::oneshot;
use url::Url;
use wiremock::{
matchers::{method, path},
@@ -40,7 +32,7 @@ use wiremock::{
use crate::{
client_credentials, id_token, init_test, now, ACCESS_TOKEN, AUTHORIZATION_CODE, CLIENT_ID,
CODE_VERIFIER, ID_TOKEN_SIGNING_ALG, NONCE, REDIRECT_URI, REQUEST_URI,
CODE_VERIFIER, ID_TOKEN_SIGNING_ALG, NONCE, REDIRECT_URI,
};
#[test]
@@ -137,115 +129,6 @@ fn pass_full_authorization_url() {
assert_eq!(query_pairs.get("code_challenge_method"), None);
}
#[tokio::test]
async fn pass_pushed_authorization_request() {
let (http_service, mock_server, issuer) = init_test().await;
let client_credentials =
client_credentials(&OAuthClientAuthenticationMethod::None, &issuer, None);
let authorization_endpoint = issuer.join("authorize").unwrap();
let par_endpoint = issuer.join("par").unwrap();
let redirect_uri = Url::parse(REDIRECT_URI).unwrap();
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42);
let (sender, receiver) = oneshot::channel();
let sender_mutex = Arc::new(Mutex::new(Some(sender)));
Mock::given(method("POST"))
.and(path("/par"))
.and(move |req: &Request| {
let body = form_urlencoded::parse(&req.body)
.into_owned()
.collect::<HashMap<_, _>>();
if let Some(sender) = sender_mutex.lock().unwrap().take() {
sender.send(body).unwrap();
true
} else {
false
}
})
.respond_with(
ResponseTemplate::new(200).set_body_json(PushedAuthorizationResponse {
request_uri: REQUEST_URI.to_owned(),
expires_in: Duration::microseconds(30 * 1000 * 1000),
}),
)
.mount(&mock_server)
.await;
let (url, validation_data) = build_par_authorization_url(
&http_service,
client_credentials,
&par_endpoint,
authorization_endpoint,
AuthorizationRequestData::new(
CLIENT_ID.to_owned(),
[ScopeToken::Openid].into_iter().collect(),
redirect_uri,
)
.with_code_challenge_methods_supported(vec![PkceCodeChallengeMethod::S256]),
now(),
&mut rng,
)
.await
.unwrap();
assert_eq!(validation_data.state, "OrJ8xbWovSpJUTKz");
assert_eq!(
validation_data.code_challenge_verifier.unwrap(),
"TSgZ_hr3TJPjhq4aDp34K_8ksjLwaa1xDcPiRGBcjhM"
);
let request_pairs = receiver.await.unwrap();
assert_eq!(url.path(), "/authorize");
let query_pairs = url.query_pairs().collect::<HashMap<_, _>>();
assert_eq!(query_pairs.get("request_uri").unwrap(), REQUEST_URI,);
assert_eq!(query_pairs.get("client_id").unwrap(), CLIENT_ID);
assert_eq!(request_pairs.get("scope").unwrap(), "openid");
assert_eq!(request_pairs.get("response_type").unwrap(), "code");
assert_eq!(request_pairs.get("client_id").unwrap(), CLIENT_ID);
assert_eq!(request_pairs.get("redirect_uri").unwrap(), REDIRECT_URI);
assert_eq!(*request_pairs.get("state").unwrap(), validation_data.state);
assert_eq!(request_pairs.get("nonce").unwrap(), "ox0PigY5l9xl5uTL");
let code_challenge = request_pairs.get("code_challenge").unwrap();
assert!(code_challenge.len() >= 43);
assert_eq!(request_pairs.get("code_challenge_method").unwrap(), "S256");
}
#[tokio::test]
async fn fail_pushed_authorization_request_404() {
let (http_service, _, issuer) = init_test().await;
let client_credentials =
client_credentials(&OAuthClientAuthenticationMethod::None, &issuer, None);
let authorization_endpoint = issuer.join("authorize").unwrap();
let par_endpoint = issuer.join("par").unwrap();
let redirect_uri = Url::parse(REDIRECT_URI).unwrap();
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42);
let error = build_par_authorization_url(
&http_service,
client_credentials,
&par_endpoint,
authorization_endpoint,
AuthorizationRequestData::new(
CLIENT_ID.to_owned(),
[ScopeToken::Openid].into_iter().collect(),
redirect_uri,
)
.with_code_challenge_methods_supported(vec![PkceCodeChallengeMethod::S256]),
now(),
&mut rng,
)
.await
.unwrap_err();
assert_matches!(
error,
AuthorizationError::PushedAuthorization(PushedAuthorizationError::Http(_))
);
}
/// Check if the given request to the token endpoint is valid.
fn is_valid_token_endpoint_request(req: &Request) -> bool {
let body = form_urlencoded::parse(&req.body).collect::<HashMap<_, _>>();

View File

@@ -36,7 +36,7 @@ fn provider_metadata(issuer: &Url) -> ProviderMetadata {
#[tokio::test]
async fn pass_discover() {
let (http_service, mock_server, issuer) = init_test().await;
let (http_client, mock_server, issuer) = init_test().await;
Mock::given(method("GET"))
.and(path("/.well-known/openid-configuration"))
@@ -44,7 +44,9 @@ async fn pass_discover() {
.mount(&mock_server)
.await;
let provider_metadata = insecure_discover(&client, issuer.as_str()).await.unwrap();
let provider_metadata = insecure_discover(&http_client, issuer.as_str())
.await
.unwrap();
assert_eq!(provider_metadata.issuer(), issuer.as_str());
}
@@ -70,7 +72,7 @@ async fn fail_discover_not_json() {
let error = discover(&http_service, issuer.as_str()).await.unwrap_err();
assert_matches!(error, DiscoveryError::FromJson(_));
assert_matches!(error, DiscoveryError::Http(_));
}
#[tokio::test]

View File

@@ -1,100 +0,0 @@
// Copyright 2024 New Vector Ltd.
// Copyright 2022-2024 Kévin Commaille.
//
// SPDX-License-Identifier: AGPL-3.0-only
// Please see LICENSE in the repository root for full details.
use std::collections::HashMap;
use mas_iana::oauth::{OAuthClientAuthenticationMethod, OAuthTokenTypeHint};
use mas_oidc_client::{
requests::introspection::introspect_token,
types::scope::{ScopeExt, ScopeToken},
};
use oauth2_types::requests::IntrospectionResponse;
use rand::SeedableRng;
use wiremock::{
matchers::{method, path},
Mock, Request, ResponseTemplate,
};
use crate::{client_credentials, init_test, now, ACCESS_TOKEN, CLIENT_ID, SUBJECT_IDENTIFIER};
#[tokio::test]
async fn pass_introspect_token() {
let (http_service, mock_server, issuer) = init_test().await;
let client_credentials =
client_credentials(&OAuthClientAuthenticationMethod::None, &issuer, None);
let introspection_endpoint = issuer.join("introspect").unwrap();
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42);
Mock::given(method("POST"))
.and(path("/introspect"))
.and(|req: &Request| {
let query_pairs = form_urlencoded::parse(&req.body).collect::<HashMap<_, _>>();
if query_pairs
.get("token")
.filter(|s| *s == ACCESS_TOKEN)
.is_none()
{
println!("Wrong or missing token");
return false;
}
if query_pairs
.get("token_type_hint")
.filter(|s| *s == "access_token")
.is_none()
{
println!("Wrong or missing token type hint");
return false;
}
if query_pairs
.get("client_id")
.filter(|s| *s == CLIENT_ID)
.is_none()
{
println!("Wrong or missing client ID");
return false;
}
true
})
.respond_with(
ResponseTemplate::new(200).set_body_json(IntrospectionResponse {
active: true,
scope: Some([ScopeToken::Profile].into_iter().collect()),
client_id: Some(CLIENT_ID.to_owned()),
username: None,
token_type: Some(OAuthTokenTypeHint::AccessToken),
exp: None,
iat: None,
nbf: None,
sub: Some(SUBJECT_IDENTIFIER.to_owned()),
aud: Some(CLIENT_ID.to_owned()),
iss: Some(issuer.to_string()),
jti: None,
}),
)
.mount(&mock_server)
.await;
let response = introspect_token(
&http_service,
client_credentials.into(),
&introspection_endpoint,
ACCESS_TOKEN.to_owned(),
Some(OAuthTokenTypeHint::AccessToken),
now(),
&mut rng,
)
.await
.unwrap();
assert!(response.active);
assert_eq!(response.aud.unwrap(), CLIENT_ID);
assert!(response.scope.unwrap().contains_token(&ScopeToken::Profile));
assert_eq!(response.client_id.unwrap(), CLIENT_ID);
assert_eq!(response.iss.unwrap(), issuer.as_str());
assert_eq!(response.sub.unwrap(), SUBJECT_IDENTIFIER);
}

View File

@@ -4,14 +4,10 @@
// SPDX-License-Identifier: AGPL-3.0-only
// Please see LICENSE in the repository root for full details.
mod account_management;
mod authorization_code;
mod client_credentials;
mod discovery;
mod introspection;
mod jose;
mod refresh_token;
mod registration;
mod revocation;
mod rp_initiated_logout;
mod userinfo;

View File

@@ -1,250 +0,0 @@
// Copyright 2024 New Vector Ltd.
// Copyright 2022-2024 Kévin Commaille.
//
// SPDX-License-Identifier: AGPL-3.0-only
// Please see LICENSE in the repository root for full details.
use assert_matches::assert_matches;
use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
use mas_jose::jwk::PublicJsonWebKeySet;
use mas_oidc_client::{error::RegistrationError, requests::registration::register_client};
use oauth2_types::{
oidc::ApplicationType,
registration::{ClientMetadata, ClientRegistrationResponse, VerifiedClientMetadata},
};
use serde_json::json;
use url::Url;
use wiremock::{
matchers::{body_partial_json, method, path},
Mock, Request, ResponseTemplate,
};
use crate::{init_test, CLIENT_ID, CLIENT_SECRET, REDIRECT_URI};
/// Generate valid client metadata for the given authentication method.
fn client_metadata(auth_method: OAuthClientAuthenticationMethod) -> VerifiedClientMetadata {
let (signing_alg, jwks) = match &auth_method {
OAuthClientAuthenticationMethod::ClientSecretJwt => {
(Some(JsonWebSignatureAlg::Hs256), None)
}
OAuthClientAuthenticationMethod::PrivateKeyJwt => (
Some(JsonWebSignatureAlg::Es256),
Some(PublicJsonWebKeySet::default()),
),
_ => (None, None),
};
ClientMetadata {
redirect_uris: Some(vec![Url::parse(REDIRECT_URI).expect("Couldn't parse URL")]),
application_type: Some(ApplicationType::Native),
token_endpoint_auth_method: Some(auth_method),
token_endpoint_auth_signing_alg: signing_alg,
jwks,
..Default::default()
}
.validate()
.unwrap()
}
#[tokio::test]
async fn pass_register_client_none() {
let (http_service, mock_server, issuer) = init_test().await;
let client_metadata = client_metadata(OAuthClientAuthenticationMethod::None);
let registration_endpoint = issuer.join("register").unwrap();
Mock::given(method("POST"))
.and(path("/register"))
.and(body_partial_json(json!({
"redirect_uris": [REDIRECT_URI],
"token_endpoint_auth_method": "none",
})))
.respond_with(
ResponseTemplate::new(200).set_body_json(ClientRegistrationResponse {
client_id: CLIENT_ID.to_owned(),
client_secret: None,
client_id_issued_at: None,
client_secret_expires_at: None,
}),
)
.mount(&mock_server)
.await;
let response = register_client(&http_service, &registration_endpoint, client_metadata, None)
.await
.unwrap();
assert_eq!(response.client_id, CLIENT_ID);
assert_eq!(response.client_secret, None);
}
#[tokio::test]
async fn pass_register_client_client_secret_basic() {
let (http_service, mock_server, issuer) = init_test().await;
let client_metadata = client_metadata(OAuthClientAuthenticationMethod::ClientSecretBasic);
let registration_endpoint = issuer.join("register").unwrap();
Mock::given(method("POST"))
.and(path("/register"))
.and(body_partial_json(json!({
"redirect_uris": [REDIRECT_URI],
"token_endpoint_auth_method": "client_secret_basic",
})))
.respond_with(
ResponseTemplate::new(200).set_body_json(ClientRegistrationResponse {
client_id: CLIENT_ID.to_owned(),
client_secret: Some(CLIENT_SECRET.to_owned()),
client_id_issued_at: None,
client_secret_expires_at: None,
}),
)
.mount(&mock_server)
.await;
let response = register_client(&http_service, &registration_endpoint, client_metadata, None)
.await
.unwrap();
assert_eq!(response.client_id, CLIENT_ID);
assert_eq!(response.client_secret.unwrap(), CLIENT_SECRET);
}
#[tokio::test]
async fn pass_register_client_client_secret_post() {
let (http_service, mock_server, issuer) = init_test().await;
let client_metadata = client_metadata(OAuthClientAuthenticationMethod::ClientSecretPost);
let registration_endpoint = issuer.join("register").unwrap();
Mock::given(method("POST"))
.and(path("/register"))
.and(body_partial_json(json!({
"redirect_uris": [REDIRECT_URI],
"token_endpoint_auth_method": "client_secret_post",
})))
.respond_with(
ResponseTemplate::new(200).set_body_json(ClientRegistrationResponse {
client_id: CLIENT_ID.to_owned(),
client_secret: Some(CLIENT_SECRET.to_owned()),
client_id_issued_at: None,
client_secret_expires_at: None,
}),
)
.mount(&mock_server)
.await;
let response = register_client(&http_service, &registration_endpoint, client_metadata, None)
.await
.unwrap();
assert_eq!(response.client_id, CLIENT_ID);
assert_eq!(response.client_secret.unwrap(), CLIENT_SECRET);
}
#[tokio::test]
async fn pass_register_client_client_secret_jwt() {
let (http_service, mock_server, issuer) = init_test().await;
let client_metadata = client_metadata(OAuthClientAuthenticationMethod::ClientSecretJwt);
let registration_endpoint = issuer.join("register").unwrap();
Mock::given(method("POST"))
.and(path("/register"))
.and(body_partial_json(json!({
"redirect_uris": [REDIRECT_URI],
"token_endpoint_auth_method": "client_secret_jwt",
"token_endpoint_auth_signing_alg": "HS256",
})))
.respond_with(
ResponseTemplate::new(200).set_body_json(ClientRegistrationResponse {
client_id: CLIENT_ID.to_owned(),
client_secret: Some(CLIENT_SECRET.to_owned()),
client_id_issued_at: None,
client_secret_expires_at: None,
}),
)
.mount(&mock_server)
.await;
let response = register_client(&http_service, &registration_endpoint, client_metadata, None)
.await
.unwrap();
assert_eq!(response.client_id, CLIENT_ID);
assert_eq!(response.client_secret.unwrap(), CLIENT_SECRET);
}
#[tokio::test]
async fn pass_register_client_private_key_jwt() {
let (http_service, mock_server, issuer) = init_test().await;
let client_metadata = client_metadata(OAuthClientAuthenticationMethod::PrivateKeyJwt);
let registration_endpoint = issuer.join("register").unwrap();
Mock::given(method("POST"))
.and(path("/register"))
.and(|req: &Request| {
let Ok(metadata) = req.body_json::<ClientMetadata>() else {
return false;
};
*metadata.token_endpoint_auth_method() == OAuthClientAuthenticationMethod::PrivateKeyJwt
&& metadata.token_endpoint_auth_signing_alg == Some(JsonWebSignatureAlg::Es256)
&& metadata.jwks.is_some()
})
.respond_with(
ResponseTemplate::new(200).set_body_json(ClientRegistrationResponse {
client_id: CLIENT_ID.to_owned(),
client_secret: None,
client_id_issued_at: None,
client_secret_expires_at: None,
}),
)
.mount(&mock_server)
.await;
let response = register_client(&http_service, &registration_endpoint, client_metadata, None)
.await
.unwrap();
assert_eq!(response.client_id, CLIENT_ID);
assert_eq!(response.client_secret, None);
}
#[tokio::test]
async fn fail_register_client_404() {
let (http_service, _, issuer) = init_test().await;
let client_metadata = client_metadata(OAuthClientAuthenticationMethod::None);
let registration_endpoint = issuer.join("register").unwrap();
let error = register_client(&http_service, &registration_endpoint, client_metadata, None)
.await
.unwrap_err();
assert_matches!(error, RegistrationError::Http(_));
}
#[tokio::test]
async fn fail_register_client_missing_secret() {
let (http_service, mock_server, issuer) = init_test().await;
let client_metadata = client_metadata(OAuthClientAuthenticationMethod::ClientSecretBasic);
let registration_endpoint = issuer.join("register").unwrap();
Mock::given(method("POST"))
.and(path("/register"))
.and(body_partial_json(json!({
"token_endpoint_auth_method": "client_secret_basic",
})))
.respond_with(
ResponseTemplate::new(200).set_body_json(ClientRegistrationResponse {
client_id: CLIENT_ID.to_owned(),
client_secret: None,
client_id_issued_at: None,
client_secret_expires_at: None,
}),
)
.mount(&mock_server)
.await;
let error = register_client(&http_service, &registration_endpoint, client_metadata, None)
.await
.unwrap_err();
assert_matches!(error, RegistrationError::MissingClientSecret);
}

View File

@@ -1,74 +0,0 @@
// Copyright 2024 New Vector Ltd.
// Copyright 2022-2024 Kévin Commaille.
//
// SPDX-License-Identifier: AGPL-3.0-only
// Please see LICENSE in the repository root for full details.
use std::collections::HashMap;
use mas_iana::oauth::{OAuthClientAuthenticationMethod, OAuthTokenTypeHint};
use mas_oidc_client::requests::revocation::revoke_token;
use rand::SeedableRng;
use wiremock::{
matchers::{method, path},
Mock, Request, ResponseTemplate,
};
use crate::{client_credentials, init_test, ACCESS_TOKEN, CLIENT_ID};
#[tokio::test]
async fn pass_revoke_token() {
let (http_service, mock_server, issuer) = init_test().await;
let client_credentials =
client_credentials(&OAuthClientAuthenticationMethod::None, &issuer, None);
let revocation_endpoint = issuer.join("revoke").unwrap();
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42);
Mock::given(method("POST"))
.and(path("/revoke"))
.and(|req: &Request| {
let query_pairs = form_urlencoded::parse(&req.body).collect::<HashMap<_, _>>();
if query_pairs
.get("token")
.filter(|s| *s == ACCESS_TOKEN)
.is_none()
{
println!("Wrong or missing refresh token");
return false;
}
if query_pairs
.get("token_type_hint")
.filter(|s| *s == "access_token")
.is_none()
{
println!("Wrong or missing token type hint");
return false;
}
if query_pairs
.get("client_id")
.filter(|s| *s == CLIENT_ID)
.is_none()
{
println!("Wrong or missing client ID");
return false;
}
true
})
.respond_with(ResponseTemplate::new(200))
.mount(&mock_server)
.await;
revoke_token(
&http_service,
client_credentials,
&revocation_endpoint,
ACCESS_TOKEN.to_owned(),
Some(OAuthTokenTypeHint::AccessToken),
crate::now(),
&mut rng,
)
.await
.unwrap();
}

View File

@@ -8,6 +8,7 @@ use std::collections::HashMap;
use assert_matches::assert_matches;
use base64ct::Encoding;
use http::header::AUTHORIZATION;
use mas_iana::oauth::{OAuthAccessTokenType, OAuthClientAuthenticationMethod};
use mas_jose::{
claims::{self, TimeOptions},
@@ -31,7 +32,7 @@ use crate::{client_credentials, init_test, now, ACCESS_TOKEN, CLIENT_ID, CLIENT_
#[tokio::test]
async fn pass_none() {
let (http_service, mock_server, issuer) = init_test().await;
let (http_client, mock_server, issuer) = init_test().await;
let client_credentials =
client_credentials(&OAuthClientAuthenticationMethod::None, &issuer, None);
let token_endpoint = issuer.join("token").unwrap();
@@ -67,7 +68,7 @@ async fn pass_none() {
.await;
access_token_with_client_credentials(
&http_service,
&http_client,
client_credentials,
&token_endpoint,
None,
@@ -80,7 +81,7 @@ async fn pass_none() {
#[tokio::test]
async fn pass_client_secret_basic() {
let (http_service, mock_server, issuer) = init_test().await;
let (http_client, mock_server, issuer) = init_test().await;
let client_credentials = client_credentials(
&OAuthClientAuthenticationMethod::ClientSecretBasic,
&issuer,
@@ -94,10 +95,15 @@ async fn pass_client_secret_basic() {
let enc_user_pass =
base64ct::Base64::encode_string(format!("{username}:{password}").as_bytes());
let authorization_header = format!("Basic {enc_user_pass}");
eprintln!("{authorization_header}");
Mock::given(method("POST"))
.and(path("/token"))
.and(header("authorization", authorization_header.as_str()))
.and(|req: &Request| {
println!("{req:?}");
true
})
.and(header(AUTHORIZATION, authorization_header.as_str()))
.respond_with(
ResponseTemplate::new(200).set_body_json(AccessTokenResponse {
access_token: ACCESS_TOKEN.to_owned(),
@@ -112,7 +118,7 @@ async fn pass_client_secret_basic() {
.await;
access_token_with_client_credentials(
&http_service,
&http_client,
client_credentials,
&token_endpoint,
None,
@@ -125,7 +131,7 @@ async fn pass_client_secret_basic() {
#[tokio::test]
async fn pass_client_secret_post() {
let (http_service, mock_server, issuer) = init_test().await;
let (http_client, mock_server, issuer) = init_test().await;
let client_credentials = client_credentials(
&OAuthClientAuthenticationMethod::ClientSecretPost,
&issuer,
@@ -172,7 +178,7 @@ async fn pass_client_secret_post() {
.await;
access_token_with_client_credentials(
&http_service,
&http_client,
client_credentials,
&token_endpoint,
None,
@@ -185,7 +191,7 @@ async fn pass_client_secret_post() {
#[tokio::test]
async fn pass_client_secret_jwt() {
let (http_service, mock_server, issuer) = init_test().await;
let (http_client, mock_server, issuer) = init_test().await;
let client_credentials = client_credentials(
&OAuthClientAuthenticationMethod::ClientSecretJwt,
&issuer,
@@ -253,7 +259,7 @@ async fn pass_client_secret_jwt() {
.await;
access_token_with_client_credentials(
&http_service,
&http_client,
client_credentials,
&token_endpoint,
None,
@@ -266,7 +272,7 @@ async fn pass_client_secret_jwt() {
#[tokio::test]
async fn pass_private_key_jwt_with_keystore() {
let (http_service, mock_server, issuer) = init_test().await;
let (http_client, mock_server, issuer) = init_test().await;
let client_credentials = client_credentials(
&OAuthClientAuthenticationMethod::PrivateKeyJwt,
&issuer,
@@ -341,7 +347,7 @@ async fn pass_private_key_jwt_with_keystore() {
.await;
access_token_with_client_credentials(
&http_service,
&http_client,
client_credentials,
&token_endpoint,
None,
@@ -354,7 +360,7 @@ async fn pass_private_key_jwt_with_keystore() {
#[tokio::test]
async fn pass_private_key_jwt_with_custom_signing() {
let (http_service, mock_server, issuer) = init_test().await;
let (http_client, mock_server, issuer) = init_test().await;
let client_credentials = client_credentials(
&OAuthClientAuthenticationMethod::PrivateKeyJwt,
&issuer,
@@ -410,7 +416,7 @@ async fn pass_private_key_jwt_with_custom_signing() {
.await;
access_token_with_client_credentials(
&http_service,
&http_client,
client_credentials,
&token_endpoint,
None,
@@ -423,7 +429,7 @@ async fn pass_private_key_jwt_with_custom_signing() {
#[tokio::test]
async fn fail_private_key_jwt_with_custom_signing() {
let (http_service, _, issuer) = init_test().await;
let (http_client, _, issuer) = init_test().await;
let client_credentials = client_credentials(
&OAuthClientAuthenticationMethod::PrivateKeyJwt,
&issuer,
@@ -433,7 +439,7 @@ async fn fail_private_key_jwt_with_custom_signing() {
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42);
let error = access_token_with_client_credentials(
&http_service,
&http_client,
client_credentials,
&token_endpoint,
None,