Replace HTTP client in oidc-client with reqwest
This commit is contained in:
@@ -51,5 +51,3 @@ tokio.workspace = true
|
||||
wiremock = "0.6.2"
|
||||
http-body-util.workspace = true
|
||||
rustls.workspace = true
|
||||
|
||||
mas-http = { workspace = true, features = ["client"] }
|
||||
|
||||
@@ -6,11 +6,9 @@
|
||||
|
||||
//! The error types used in this crate.
|
||||
|
||||
use std::{str::Utf8Error, sync::Arc};
|
||||
|
||||
use headers::authorization::InvalidBearerToken;
|
||||
use http::{header::ToStrError, StatusCode};
|
||||
use mas_http::{catch_http_codes, form_urlencoded_request, json_request, json_response};
|
||||
use http::StatusCode;
|
||||
use mas_http::{catch_http_codes, form_urlencoded_request, json_response};
|
||||
use mas_jose::{
|
||||
claims::ClaimError,
|
||||
jwa::InvalidAlgorithm,
|
||||
@@ -33,9 +31,6 @@ pub enum Error {
|
||||
/// An error occurred fetching the provider JWKS.
|
||||
Jwks(#[from] JwksError),
|
||||
|
||||
/// An error occurred during client registration.
|
||||
Registration(#[from] RegistrationError),
|
||||
|
||||
/// An error occurred building the authorization URL.
|
||||
Authorization(#[from] AuthorizationError),
|
||||
|
||||
@@ -48,17 +43,11 @@ pub enum Error {
|
||||
/// An error occurred refreshing an access token.
|
||||
TokenRefresh(#[from] TokenRefreshError),
|
||||
|
||||
/// An error occurred revoking a token.
|
||||
TokenRevoke(#[from] TokenRevokeError),
|
||||
|
||||
/// An error occurred requesting user info.
|
||||
UserInfo(#[from] UserInfoError),
|
||||
|
||||
/// An error occurred introspecting a token.
|
||||
Introspection(#[from] IntrospectionError),
|
||||
|
||||
/// An error occurred building the account management URL.
|
||||
AccountManagement(#[from] AccountManagementError),
|
||||
}
|
||||
|
||||
/// All possible errors when fetching provider metadata.
|
||||
@@ -81,135 +70,6 @@ pub enum DiscoveryError {
|
||||
Disabled,
|
||||
}
|
||||
|
||||
/// All possible errors when registering the client.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum RegistrationError {
|
||||
/// An error occurred building the request.
|
||||
#[error(transparent)]
|
||||
IntoHttp(#[from] http::Error),
|
||||
|
||||
/// An error occurred serializing the request or deserializing the response.
|
||||
#[error(transparent)]
|
||||
Json(#[from] serde_json::Error),
|
||||
|
||||
/// The server returned an HTTP error status code.
|
||||
#[error(transparent)]
|
||||
Http(#[from] HttpError),
|
||||
|
||||
/// No client secret was received although one was expected because of the
|
||||
/// authentication method.
|
||||
#[error("missing client secret in response")]
|
||||
MissingClientSecret,
|
||||
|
||||
/// An error occurred sending the request.
|
||||
#[error(transparent)]
|
||||
Service(BoxError),
|
||||
}
|
||||
|
||||
impl<S> From<json_request::Error<S>> for RegistrationError
|
||||
where
|
||||
S: Into<RegistrationError>,
|
||||
{
|
||||
fn from(err: json_request::Error<S>) -> Self {
|
||||
match err {
|
||||
json_request::Error::Serialize { inner } => inner.into(),
|
||||
json_request::Error::Service { inner } => inner.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> From<json_response::Error<S>> for RegistrationError
|
||||
where
|
||||
S: Into<RegistrationError>,
|
||||
{
|
||||
fn from(err: json_response::Error<S>) -> Self {
|
||||
match err {
|
||||
json_response::Error::Deserialize { inner } => inner.into(),
|
||||
json_response::Error::Service { inner } => inner.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> From<catch_http_codes::Error<S, Option<ErrorBody>>> for RegistrationError
|
||||
where
|
||||
S: Into<BoxError>,
|
||||
{
|
||||
fn from(err: catch_http_codes::Error<S, Option<ErrorBody>>) -> Self {
|
||||
match err {
|
||||
catch_http_codes::Error::HttpError { status_code, inner } => {
|
||||
HttpError::new(status_code, inner).into()
|
||||
}
|
||||
catch_http_codes::Error::Service { inner } => Self::Service(inner.into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// All possible errors when making a pushed authorization request.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum PushedAuthorizationError {
|
||||
/// An error occurred serializing the request.
|
||||
#[error(transparent)]
|
||||
UrlEncoded(#[from] serde_urlencoded::ser::Error),
|
||||
|
||||
/// An error occurred building the request.
|
||||
#[error(transparent)]
|
||||
IntoHttp(#[from] http::Error),
|
||||
|
||||
/// An error occurred adding the client credentials to the request.
|
||||
#[error(transparent)]
|
||||
Credentials(#[from] CredentialsError),
|
||||
|
||||
/// The server returned an HTTP error status code.
|
||||
#[error(transparent)]
|
||||
Http(#[from] HttpError),
|
||||
|
||||
/// An error occurred deserializing the response.
|
||||
#[error(transparent)]
|
||||
Json(#[from] serde_json::Error),
|
||||
|
||||
/// An error occurred sending the request.
|
||||
#[error(transparent)]
|
||||
Service(BoxError),
|
||||
}
|
||||
|
||||
impl<S> From<form_urlencoded_request::Error<S>> for PushedAuthorizationError
|
||||
where
|
||||
S: Into<PushedAuthorizationError>,
|
||||
{
|
||||
fn from(err: form_urlencoded_request::Error<S>) -> Self {
|
||||
match err {
|
||||
form_urlencoded_request::Error::Serialize { inner } => inner.into(),
|
||||
form_urlencoded_request::Error::Service { inner } => inner.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> From<json_response::Error<S>> for PushedAuthorizationError
|
||||
where
|
||||
S: Into<PushedAuthorizationError>,
|
||||
{
|
||||
fn from(err: json_response::Error<S>) -> Self {
|
||||
match err {
|
||||
json_response::Error::Deserialize { inner } => inner.into(),
|
||||
json_response::Error::Service { inner } => inner.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> From<catch_http_codes::Error<S, Option<ErrorBody>>> for PushedAuthorizationError
|
||||
where
|
||||
S: Into<BoxError>,
|
||||
{
|
||||
fn from(err: catch_http_codes::Error<S, Option<ErrorBody>>) -> Self {
|
||||
match err {
|
||||
catch_http_codes::Error::HttpError { status_code, inner } => {
|
||||
HttpError::new(status_code, inner).into()
|
||||
}
|
||||
catch_http_codes::Error::Service { inner } => Self::Service(inner.into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// All possible errors when authorizing the client.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum AuthorizationError {
|
||||
@@ -220,76 +80,18 @@ pub enum AuthorizationError {
|
||||
/// An error occurred serializing the request.
|
||||
#[error(transparent)]
|
||||
UrlEncoded(#[from] serde_urlencoded::ser::Error),
|
||||
|
||||
/// An error occurred making the PAR request.
|
||||
#[error(transparent)]
|
||||
PushedAuthorization(#[from] PushedAuthorizationError),
|
||||
}
|
||||
|
||||
/// All possible errors when requesting an access token.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum TokenRequestError {
|
||||
/// An error occurred building the request.
|
||||
/// The HTTP client returned an error.
|
||||
#[error(transparent)]
|
||||
IntoHttp(#[from] http::Error),
|
||||
Http(#[from] reqwest::Error),
|
||||
|
||||
/// An error occurred adding the client credentials to the request.
|
||||
/// Error while injecting the client credentials into the request.
|
||||
#[error(transparent)]
|
||||
Credentials(#[from] CredentialsError),
|
||||
|
||||
/// An error occurred serializing the request.
|
||||
#[error(transparent)]
|
||||
UrlEncoded(#[from] serde_urlencoded::ser::Error),
|
||||
|
||||
/// The server returned an HTTP error status code.
|
||||
#[error(transparent)]
|
||||
Http(#[from] HttpError),
|
||||
|
||||
/// An error occurred deserializing the response.
|
||||
#[error(transparent)]
|
||||
Json(#[from] serde_json::Error),
|
||||
|
||||
/// An error occurred sending the request.
|
||||
#[error(transparent)]
|
||||
Service(BoxError),
|
||||
}
|
||||
|
||||
impl<S> From<form_urlencoded_request::Error<S>> for TokenRequestError
|
||||
where
|
||||
S: Into<TokenRequestError>,
|
||||
{
|
||||
fn from(err: form_urlencoded_request::Error<S>) -> Self {
|
||||
match err {
|
||||
form_urlencoded_request::Error::Serialize { inner } => inner.into(),
|
||||
form_urlencoded_request::Error::Service { inner } => inner.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> From<json_response::Error<S>> for TokenRequestError
|
||||
where
|
||||
S: Into<TokenRequestError>,
|
||||
{
|
||||
fn from(err: json_response::Error<S>) -> Self {
|
||||
match err {
|
||||
json_response::Error::Deserialize { inner } => inner.into(),
|
||||
json_response::Error::Service { inner } => inner.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> From<catch_http_codes::Error<S, Option<ErrorBody>>> for TokenRequestError
|
||||
where
|
||||
S: Into<BoxError>,
|
||||
{
|
||||
fn from(err: catch_http_codes::Error<S, Option<ErrorBody>>) -> Self {
|
||||
match err {
|
||||
catch_http_codes::Error::HttpError { status_code, inner } => {
|
||||
HttpError::new(status_code, inner).into()
|
||||
}
|
||||
catch_http_codes::Error::Service { inner } => Self::Service(inner.into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// All possible errors when exchanging a code for an access token.
|
||||
@@ -316,95 +118,13 @@ pub enum TokenRefreshError {
|
||||
IdToken(#[from] IdTokenError),
|
||||
}
|
||||
|
||||
/// All possible errors when revoking a token.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum TokenRevokeError {
|
||||
/// An error occurred building the request.
|
||||
#[error(transparent)]
|
||||
IntoHttp(#[from] http::Error),
|
||||
|
||||
/// An error occurred adding the client credentials to the request.
|
||||
#[error(transparent)]
|
||||
Credentials(#[from] CredentialsError),
|
||||
|
||||
/// An error occurred serializing the request.
|
||||
#[error(transparent)]
|
||||
UrlEncoded(#[from] serde_urlencoded::ser::Error),
|
||||
|
||||
/// An error occurred deserializing the error response.
|
||||
#[error(transparent)]
|
||||
Json(#[from] serde_json::Error),
|
||||
|
||||
/// The server returned an HTTP error status code.
|
||||
#[error(transparent)]
|
||||
Http(#[from] HttpError),
|
||||
|
||||
/// An error occurred sending the request.
|
||||
#[error(transparent)]
|
||||
Service(BoxError),
|
||||
}
|
||||
|
||||
impl<S> From<form_urlencoded_request::Error<S>> for TokenRevokeError
|
||||
where
|
||||
S: Into<TokenRevokeError>,
|
||||
{
|
||||
fn from(err: form_urlencoded_request::Error<S>) -> Self {
|
||||
match err {
|
||||
form_urlencoded_request::Error::Serialize { inner } => inner.into(),
|
||||
form_urlencoded_request::Error::Service { inner } => inner.into(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<S> From<catch_http_codes::Error<S, Option<ErrorBody>>> for TokenRevokeError
|
||||
where
|
||||
S: Into<BoxError>,
|
||||
{
|
||||
fn from(err: catch_http_codes::Error<S, Option<ErrorBody>>) -> Self {
|
||||
match err {
|
||||
catch_http_codes::Error::HttpError { status_code, inner } => {
|
||||
HttpError::new(status_code, inner).into()
|
||||
}
|
||||
catch_http_codes::Error::Service { inner } => Self::Service(inner.into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// All possible errors when requesting user info.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum UserInfoError {
|
||||
/// An error occurred getting the provider metadata.
|
||||
#[error(transparent)]
|
||||
Discovery(#[from] Arc<DiscoveryError>),
|
||||
|
||||
/// The provider doesn't support requesting user info.
|
||||
#[error("missing UserInfo support")]
|
||||
MissingUserInfoSupport,
|
||||
|
||||
/// No token is available to get info from.
|
||||
#[error("missing token")]
|
||||
MissingToken,
|
||||
|
||||
/// No client metadata is available.
|
||||
#[error("missing client metadata")]
|
||||
MissingClientMetadata,
|
||||
|
||||
/// The access token is invalid.
|
||||
#[error(transparent)]
|
||||
Token(#[from] InvalidBearerToken),
|
||||
|
||||
/// An error occurred building the request.
|
||||
#[error(transparent)]
|
||||
IntoHttp(#[from] http::Error),
|
||||
|
||||
/// The content-type header is missing from the response.
|
||||
#[error("missing response content-type")]
|
||||
MissingResponseContentType,
|
||||
|
||||
/// The content-type header could not be decoded.
|
||||
#[error("could not decoded response content-type: {0}")]
|
||||
DecodeResponseContentType(#[from] ToStrError),
|
||||
|
||||
/// The content-type is not valid.
|
||||
#[error("invalid response content-type")]
|
||||
InvalidResponseContentTypeValue,
|
||||
@@ -418,39 +138,13 @@ pub enum UserInfoError {
|
||||
got: String,
|
||||
},
|
||||
|
||||
/// An error occurred reading the response.
|
||||
#[error(transparent)]
|
||||
FromUtf8(#[from] Utf8Error),
|
||||
|
||||
/// An error occurred deserializing the JSON or error response.
|
||||
#[error(transparent)]
|
||||
Json(#[from] serde_json::Error),
|
||||
|
||||
/// An error occurred verifying the Id Token.
|
||||
#[error(transparent)]
|
||||
IdToken(#[from] IdTokenError),
|
||||
|
||||
/// The server returned an HTTP error status code.
|
||||
#[error(transparent)]
|
||||
Http(#[from] HttpError),
|
||||
|
||||
/// An error occurred sending the request.
|
||||
#[error(transparent)]
|
||||
Service(BoxError),
|
||||
}
|
||||
|
||||
impl<S> From<catch_http_codes::Error<S, Option<ErrorBody>>> for UserInfoError
|
||||
where
|
||||
S: Into<BoxError>,
|
||||
{
|
||||
fn from(err: catch_http_codes::Error<S, Option<ErrorBody>>) -> Self {
|
||||
match err {
|
||||
catch_http_codes::Error::HttpError { status_code, inner } => {
|
||||
HttpError::new(status_code, inner).into()
|
||||
}
|
||||
catch_http_codes::Error::Service { inner } => Self::Service(inner.into()),
|
||||
}
|
||||
}
|
||||
Http(#[from] reqwest::Error),
|
||||
}
|
||||
|
||||
/// All possible errors when introspecting a token.
|
||||
@@ -644,11 +338,3 @@ pub enum CredentialsError {
|
||||
#[error(transparent)]
|
||||
Custom(BoxError),
|
||||
}
|
||||
|
||||
/// All errors that can occur when building the account management URL.
|
||||
#[derive(Debug, Error)]
|
||||
pub enum AccountManagementError {
|
||||
/// An error occurred serializing the parameters.
|
||||
#[error(transparent)]
|
||||
UrlEncoded(#[from] serde_urlencoded::ser::Error),
|
||||
}
|
||||
|
||||
@@ -52,7 +52,6 @@ pub mod error;
|
||||
pub mod http_service;
|
||||
pub mod requests;
|
||||
pub mod types;
|
||||
mod utils;
|
||||
|
||||
use std::fmt;
|
||||
|
||||
|
||||
@@ -1,121 +0,0 @@
|
||||
// Copyright 2024 New Vector Ltd.
|
||||
// Copyright 2024 Kévin Commaille.
|
||||
//
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
// Please see LICENSE in the repository root for full details.
|
||||
|
||||
//! Methods related to the account management URL.
|
||||
//!
|
||||
//! This is a Matrix extension introduced in [MSC2965](https://github.com/matrix-org/matrix-spec-proposals/pull/2965).
|
||||
|
||||
use serde::Serialize;
|
||||
use serde_with::skip_serializing_none;
|
||||
use url::Url;
|
||||
|
||||
use crate::error::AccountManagementError;
|
||||
|
||||
/// An account management action that a user can take, including a device ID for
|
||||
/// the actions that support it.
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
|
||||
#[serde(tag = "action")]
|
||||
#[non_exhaustive]
|
||||
pub enum AccountManagementActionFull {
|
||||
/// `org.matrix.profile`
|
||||
///
|
||||
/// The user wishes to view their profile (name, avatar, contact details).
|
||||
#[serde(rename = "org.matrix.profile")]
|
||||
Profile,
|
||||
|
||||
/// `org.matrix.sessions_list`
|
||||
///
|
||||
/// The user wishes to view a list of their sessions.
|
||||
#[serde(rename = "org.matrix.sessions_list")]
|
||||
SessionsList,
|
||||
|
||||
/// `org.matrix.session_view`
|
||||
///
|
||||
/// The user wishes to view the details of a specific session.
|
||||
#[serde(rename = "org.matrix.session_view")]
|
||||
SessionView {
|
||||
/// The ID of the session to view the details of.
|
||||
device_id: String,
|
||||
},
|
||||
|
||||
/// `org.matrix.session_end`
|
||||
///
|
||||
/// The user wishes to end/log out of a specific session.
|
||||
#[serde(rename = "org.matrix.session_end")]
|
||||
SessionEnd {
|
||||
/// The ID of the session to end.
|
||||
device_id: String,
|
||||
},
|
||||
|
||||
/// `org.matrix.account_deactivate`
|
||||
///
|
||||
/// The user wishes to deactivate their account.
|
||||
#[serde(rename = "org.matrix.account_deactivate")]
|
||||
AccountDeactivate,
|
||||
|
||||
/// `org.matrix.cross_signing_reset`
|
||||
///
|
||||
/// The user wishes to reset their cross-signing keys.
|
||||
#[serde(rename = "org.matrix.cross_signing_reset")]
|
||||
CrossSigningReset,
|
||||
}
|
||||
|
||||
#[skip_serializing_none]
|
||||
#[derive(Debug, Clone, Serialize)]
|
||||
struct AccountManagementData {
|
||||
#[serde(flatten)]
|
||||
action: Option<AccountManagementActionFull>,
|
||||
id_token_hint: Option<String>,
|
||||
}
|
||||
|
||||
/// Build the URL for accessing the account management capabilities.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `account_management_uri` - The URL to access the issuer's account
|
||||
/// management capabilities.
|
||||
///
|
||||
/// * `action` - The action that the user wishes to take.
|
||||
///
|
||||
/// * `id_token_hint` - An ID Token that was previously issued to the client,
|
||||
/// used as a hint for which user is requesting to manage their account.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A URL to be opened in a web browser where the end-user will be able to
|
||||
/// access the account management capabilities of the issuer.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if serializing the URL fails.
|
||||
pub fn build_account_management_url(
|
||||
mut account_management_uri: Url,
|
||||
action: Option<AccountManagementActionFull>,
|
||||
id_token_hint: Option<String>,
|
||||
) -> Result<Url, AccountManagementError> {
|
||||
let data = AccountManagementData {
|
||||
action,
|
||||
id_token_hint,
|
||||
};
|
||||
let extra_query = serde_urlencoded::to_string(data)?;
|
||||
|
||||
if !extra_query.is_empty() {
|
||||
// Add our parameters to the query, because the URL might already have one.
|
||||
let mut full_query = account_management_uri
|
||||
.query()
|
||||
.map(ToOwned::to_owned)
|
||||
.unwrap_or_default();
|
||||
|
||||
if !full_query.is_empty() {
|
||||
full_query.push('&');
|
||||
}
|
||||
full_query.push_str(&extra_query);
|
||||
|
||||
account_management_uri.set_query(Some(&full_query));
|
||||
}
|
||||
|
||||
Ok(account_management_uri)
|
||||
}
|
||||
@@ -12,9 +12,7 @@ use std::{collections::HashSet, num::NonZeroU32};
|
||||
|
||||
use base64ct::{Base64UrlUnpadded, Encoding};
|
||||
use chrono::{DateTime, Utc};
|
||||
use http::header::CONTENT_TYPE;
|
||||
use language_tags::LanguageTag;
|
||||
use mas_http::{CatchHttpCodesLayer, FormUrlencodedRequestLayer, JsonResponseLayer};
|
||||
use mas_iana::oauth::{OAuthAuthorizationEndpointResponseType, PkceCodeChallengeMethod};
|
||||
use mas_jose::claims::{self, TokenHash};
|
||||
use oauth2_types::{
|
||||
@@ -22,7 +20,7 @@ use oauth2_types::{
|
||||
prelude::CodeChallengeMethodExt,
|
||||
requests::{
|
||||
AccessTokenRequest, AccessTokenResponse, AuthorizationCodeGrant, AuthorizationRequest,
|
||||
Display, Prompt, PushedAuthorizationResponse,
|
||||
Display, Prompt,
|
||||
},
|
||||
scope::Scope,
|
||||
};
|
||||
@@ -32,22 +30,17 @@ use rand::{
|
||||
};
|
||||
use serde::Serialize;
|
||||
use serde_with::skip_serializing_none;
|
||||
use tower::{Layer, Service, ServiceExt};
|
||||
use url::Url;
|
||||
|
||||
use super::jose::JwtVerificationData;
|
||||
use crate::{
|
||||
error::{
|
||||
AuthorizationError, IdTokenError, PushedAuthorizationError, TokenAuthorizationCodeError,
|
||||
},
|
||||
http_service::HttpService,
|
||||
error::{AuthorizationError, IdTokenError, TokenAuthorizationCodeError},
|
||||
requests::{jose::verify_id_token, token::request_access_token},
|
||||
types::{
|
||||
client_credentials::ClientCredentials,
|
||||
scope::{ScopeExt, ScopeToken},
|
||||
IdToken,
|
||||
},
|
||||
utils::{http_all_error_status_codes, http_error_mapper},
|
||||
};
|
||||
|
||||
/// The data necessary to build an authorization request.
|
||||
@@ -320,7 +313,6 @@ fn build_authorization_request(
|
||||
///
|
||||
/// [`VerifiedClientMetadata`]: oauth2_types::registration::VerifiedClientMetadata
|
||||
/// [`ClientErrorCode`]: oauth2_types::errors::ClientErrorCode
|
||||
#[allow(clippy::too_many_lines)]
|
||||
pub fn build_authorization_url(
|
||||
authorization_endpoint: Url,
|
||||
authorization_data: AuthorizationRequestData,
|
||||
@@ -353,115 +345,6 @@ pub fn build_authorization_url(
|
||||
Ok((authorization_url, validation_data))
|
||||
}
|
||||
|
||||
/// Make a [Pushed Authorization Request] and build the URL for authenticating
|
||||
/// at the Authorization endpoint.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `http_service` - The service to use for making HTTP requests.
|
||||
///
|
||||
/// * `client_credentials` - The credentials obtained when registering the
|
||||
/// client.
|
||||
///
|
||||
/// * `par_endpoint` - The URL of the issuer's Pushed Authorization Request
|
||||
/// endpoint.
|
||||
///
|
||||
/// * `authorization_endpoint` - The URL of the issuer's Authorization endpoint.
|
||||
///
|
||||
/// * `authorization_data` - The data necessary to build the authorization
|
||||
/// request.
|
||||
///
|
||||
/// * `now` - The current time.
|
||||
///
|
||||
/// * `rng` - A random number generator.
|
||||
///
|
||||
/// # Returns
|
||||
///
|
||||
/// A URL to be opened in a web browser where the end-user will be able to
|
||||
/// authorize the given scope, and the [`AuthorizationValidationData`] to
|
||||
/// validate this request.
|
||||
///
|
||||
/// The redirect URI will receive parameters in its query:
|
||||
///
|
||||
/// * A successful response will receive a `code` and a `state`.
|
||||
///
|
||||
/// * If the authorization fails, it should receive an `error` parameter with a
|
||||
/// [`ClientErrorCode`] and optionally an `error_description`.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the request fails, the response is invalid or building
|
||||
/// the URL fails.
|
||||
///
|
||||
/// [Pushed Authorization Request]: https://oauth.net/2/pushed-authorization-requests/
|
||||
/// [`ClientErrorCode`]: oauth2_types::errors::ClientErrorCode
|
||||
#[tracing::instrument(skip_all, fields(par_endpoint))]
|
||||
pub async fn build_par_authorization_url(
|
||||
http_service: &HttpService,
|
||||
client_credentials: ClientCredentials,
|
||||
par_endpoint: &Url,
|
||||
authorization_endpoint: Url,
|
||||
authorization_data: AuthorizationRequestData,
|
||||
now: DateTime<Utc>,
|
||||
rng: &mut impl Rng,
|
||||
) -> Result<(Url, AuthorizationValidationData), AuthorizationError> {
|
||||
tracing::debug!(
|
||||
scope = ?authorization_data.scope,
|
||||
"Authorizing with a PAR..."
|
||||
);
|
||||
|
||||
let client_id = client_credentials.client_id().to_owned();
|
||||
|
||||
let (authorization_request, validation_data) =
|
||||
build_authorization_request(authorization_data, rng)?;
|
||||
|
||||
let par_request = http::Request::post(par_endpoint.as_str())
|
||||
.header(CONTENT_TYPE, mime::APPLICATION_WWW_FORM_URLENCODED.as_ref())
|
||||
.body(authorization_request)
|
||||
.map_err(PushedAuthorizationError::from)?;
|
||||
|
||||
let par_request = client_credentials
|
||||
.apply_to_request(par_request, now, rng)
|
||||
.map_err(PushedAuthorizationError::from)?;
|
||||
|
||||
let service = (
|
||||
FormUrlencodedRequestLayer::default(),
|
||||
JsonResponseLayer::<PushedAuthorizationResponse>::default(),
|
||||
CatchHttpCodesLayer::new(http_all_error_status_codes(), http_error_mapper),
|
||||
)
|
||||
.layer(http_service.clone());
|
||||
|
||||
let par_response = service
|
||||
.ready_oneshot()
|
||||
.await
|
||||
.map_err(PushedAuthorizationError::from)?
|
||||
.call(par_request)
|
||||
.await
|
||||
.map_err(PushedAuthorizationError::from)?
|
||||
.into_body();
|
||||
|
||||
let authorization_query = serde_urlencoded::to_string([
|
||||
("request_uri", par_response.request_uri.as_str()),
|
||||
("client_id", &client_id),
|
||||
])?;
|
||||
|
||||
let mut authorization_url = authorization_endpoint;
|
||||
|
||||
// Add our parameters to the query, because the URL might already have one.
|
||||
let mut full_query = authorization_url
|
||||
.query()
|
||||
.map(ToOwned::to_owned)
|
||||
.unwrap_or_default();
|
||||
if !full_query.is_empty() {
|
||||
full_query.push('&');
|
||||
}
|
||||
full_query.push_str(&authorization_query);
|
||||
|
||||
authorization_url.set_query(Some(&full_query));
|
||||
|
||||
Ok((authorization_url, validation_data))
|
||||
}
|
||||
|
||||
/// Exchange an authorization code for an access token.
|
||||
///
|
||||
/// This should be used as the first step for logging in, and to request a
|
||||
|
||||
@@ -17,7 +17,7 @@ use rand::Rng;
|
||||
use url::Url;
|
||||
|
||||
use crate::{
|
||||
error::TokenRequestError, http_service::HttpService, requests::token::request_access_token,
|
||||
error::TokenRequestError, requests::token::request_access_token,
|
||||
types::client_credentials::ClientCredentials,
|
||||
};
|
||||
|
||||
@@ -46,7 +46,7 @@ use crate::{
|
||||
/// Returns an error if the request fails or the response is invalid.
|
||||
#[tracing::instrument(skip_all, fields(token_endpoint))]
|
||||
pub async fn access_token_with_client_credentials(
|
||||
http_service: &HttpService,
|
||||
http_client: &reqwest::Client,
|
||||
client_credentials: ClientCredentials,
|
||||
token_endpoint: &Url,
|
||||
scope: Option<Scope>,
|
||||
@@ -56,7 +56,7 @@ pub async fn access_token_with_client_credentials(
|
||||
tracing::debug!("Requesting access token with client credentials...");
|
||||
|
||||
request_access_token(
|
||||
http_service,
|
||||
http_client,
|
||||
client_credentials,
|
||||
token_endpoint,
|
||||
AccessTokenRequest::ClientCredentials(ClientCredentialsGrant { scope }),
|
||||
|
||||
@@ -8,6 +8,7 @@
|
||||
//!
|
||||
//! [Discovery]: https://openid.net/specs/openid-connect-discovery-1_0.html
|
||||
|
||||
use mas_http::RequestBuilderExt;
|
||||
use oauth2_types::oidc::{ProviderMetadata, VerifiedProviderMetadata};
|
||||
use url::Url;
|
||||
|
||||
@@ -34,7 +35,7 @@ async fn discover_inner(
|
||||
|
||||
let response = client
|
||||
.get(config_url.as_str())
|
||||
.send()
|
||||
.send_traced()
|
||||
.await?
|
||||
.error_for_status()?
|
||||
.json()
|
||||
|
||||
@@ -1,145 +0,0 @@
|
||||
// Copyright 2024 New Vector Ltd.
|
||||
// Copyright 2022-2024 Kévin Commaille.
|
||||
//
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
// Please see LICENSE in the repository root for full details.
|
||||
|
||||
//! Requests for [Token Introspection].
|
||||
//!
|
||||
//! [Token Introspection]: https://www.rfc-editor.org/rfc/rfc7662
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use headers::{Authorization, HeaderMapExt};
|
||||
use http::Request;
|
||||
use mas_http::{CatchHttpCodesLayer, FormUrlencodedRequestLayer, JsonResponseLayer};
|
||||
use mas_iana::oauth::OAuthTokenTypeHint;
|
||||
use oauth2_types::requests::{IntrospectionRequest, IntrospectionResponse};
|
||||
use rand::Rng;
|
||||
use serde::Serialize;
|
||||
use tower::{Layer, Service, ServiceExt};
|
||||
use url::Url;
|
||||
|
||||
use crate::{
|
||||
error::IntrospectionError,
|
||||
http_service::HttpService,
|
||||
types::client_credentials::{ClientCredentials, RequestWithClientCredentials},
|
||||
utils::{http_all_error_status_codes, http_error_mapper},
|
||||
};
|
||||
|
||||
/// The method used to authenticate at the introspection endpoint.
|
||||
pub enum IntrospectionAuthentication<'a> {
|
||||
/// Using client authentication.
|
||||
Credentials(ClientCredentials),
|
||||
|
||||
/// Using a bearer token.
|
||||
BearerToken(&'a str),
|
||||
}
|
||||
|
||||
impl<'a> IntrospectionAuthentication<'a> {
|
||||
/// Constructs an `IntrospectionAuthentication` from the given client
|
||||
/// credentials.
|
||||
#[must_use]
|
||||
pub fn with_client_credentials(credentials: ClientCredentials) -> Self {
|
||||
Self::Credentials(credentials)
|
||||
}
|
||||
|
||||
/// Constructs an `IntrospectionAuthentication` from the given bearer token.
|
||||
#[must_use]
|
||||
pub fn with_bearer_token(token: &'a str) -> Self {
|
||||
Self::BearerToken(token)
|
||||
}
|
||||
|
||||
fn apply_to_request<T: Serialize>(
|
||||
self,
|
||||
request: Request<T>,
|
||||
now: DateTime<Utc>,
|
||||
rng: &mut impl Rng,
|
||||
) -> Result<Request<RequestWithClientCredentials<T>>, IntrospectionError> {
|
||||
let res = match self {
|
||||
IntrospectionAuthentication::Credentials(client_credentials) => {
|
||||
client_credentials.apply_to_request(request, now, rng)?
|
||||
}
|
||||
IntrospectionAuthentication::BearerToken(access_token) => {
|
||||
let (mut parts, body) = request.into_parts();
|
||||
|
||||
parts
|
||||
.headers
|
||||
.typed_insert(Authorization::bearer(access_token)?);
|
||||
|
||||
let body = RequestWithClientCredentials {
|
||||
body,
|
||||
credentials: None,
|
||||
};
|
||||
|
||||
http::Request::from_parts(parts, body)
|
||||
}
|
||||
};
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> From<ClientCredentials> for IntrospectionAuthentication<'a> {
|
||||
fn from(credentials: ClientCredentials) -> Self {
|
||||
Self::with_client_credentials(credentials)
|
||||
}
|
||||
}
|
||||
|
||||
/// Obtain information about a token.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `http_service` - The service to use for making HTTP requests.
|
||||
///
|
||||
/// * `authentication` - The method used to authenticate the request.
|
||||
///
|
||||
/// * `revocation_endpoint` - The URL of the issuer's Revocation endpoint.
|
||||
///
|
||||
/// * `token` - The token to introspect.
|
||||
///
|
||||
/// * `token_type_hint` - Hint about the type of the token.
|
||||
///
|
||||
/// * `now` - The current time.
|
||||
///
|
||||
/// * `rng` - A random number generator.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the request fails or the response is invalid.
|
||||
#[tracing::instrument(skip_all, fields(introspection_endpoint))]
|
||||
pub async fn introspect_token(
|
||||
http_service: &HttpService,
|
||||
authentication: IntrospectionAuthentication<'_>,
|
||||
introspection_endpoint: &Url,
|
||||
token: String,
|
||||
token_type_hint: Option<OAuthTokenTypeHint>,
|
||||
now: DateTime<Utc>,
|
||||
rng: &mut impl Rng,
|
||||
) -> Result<IntrospectionResponse, IntrospectionError> {
|
||||
tracing::debug!("Introspecting token…");
|
||||
|
||||
let introspection_request = IntrospectionRequest {
|
||||
token,
|
||||
token_type_hint,
|
||||
};
|
||||
let introspection_request =
|
||||
http::Request::post(introspection_endpoint.as_str()).body(introspection_request)?;
|
||||
|
||||
let introspection_request = authentication.apply_to_request(introspection_request, now, rng)?;
|
||||
|
||||
let service = (
|
||||
FormUrlencodedRequestLayer::default(),
|
||||
JsonResponseLayer::<IntrospectionResponse>::default(),
|
||||
CatchHttpCodesLayer::new(http_all_error_status_codes(), http_error_mapper),
|
||||
)
|
||||
.layer(http_service.clone());
|
||||
|
||||
let introspection_response = service
|
||||
.ready_oneshot()
|
||||
.await?
|
||||
.call(introspection_request)
|
||||
.await?
|
||||
.into_body();
|
||||
|
||||
Ok(introspection_response)
|
||||
}
|
||||
@@ -9,6 +9,7 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use mas_http::RequestBuilderExt;
|
||||
use mas_iana::jose::JsonWebSignatureAlg;
|
||||
use mas_jose::{
|
||||
claims::{self, TimeOptions},
|
||||
@@ -43,7 +44,7 @@ pub async fn fetch_jwks(
|
||||
|
||||
let response: PublicJsonWebKeySet = client
|
||||
.get(jwks_uri.as_str())
|
||||
.send()
|
||||
.send_traced()
|
||||
.await?
|
||||
.error_for_status()?
|
||||
.json()
|
||||
|
||||
@@ -6,15 +6,11 @@
|
||||
|
||||
//! Methods to interact with OpenID Connect and OAuth2.0 endpoints.
|
||||
|
||||
pub mod account_management;
|
||||
pub mod authorization_code;
|
||||
pub mod client_credentials;
|
||||
pub mod discovery;
|
||||
pub mod introspection;
|
||||
pub mod jose;
|
||||
pub mod refresh_token;
|
||||
pub mod registration;
|
||||
pub mod revocation;
|
||||
pub mod rp_initiated_logout;
|
||||
pub mod token;
|
||||
pub mod userinfo;
|
||||
|
||||
@@ -20,7 +20,6 @@ use url::Url;
|
||||
use super::jose::JwtVerificationData;
|
||||
use crate::{
|
||||
error::{IdTokenError, TokenRefreshError},
|
||||
http_service::HttpService,
|
||||
requests::{jose::verify_id_token, token::request_access_token},
|
||||
types::{client_credentials::ClientCredentials, IdToken},
|
||||
};
|
||||
@@ -68,7 +67,7 @@ use crate::{
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
#[tracing::instrument(skip_all, fields(token_endpoint))]
|
||||
pub async fn refresh_access_token(
|
||||
http_service: &HttpService,
|
||||
http_client: &reqwest::Client,
|
||||
client_credentials: ClientCredentials,
|
||||
token_endpoint: &Url,
|
||||
refresh_token: String,
|
||||
@@ -81,7 +80,7 @@ pub async fn refresh_access_token(
|
||||
tracing::debug!("Refreshing access token…");
|
||||
|
||||
let token_response = request_access_token(
|
||||
http_service,
|
||||
http_client,
|
||||
client_credentials,
|
||||
token_endpoint,
|
||||
AccessTokenRequest::RefreshToken(RefreshTokenGrant {
|
||||
|
||||
@@ -1,91 +0,0 @@
|
||||
// Copyright 2024 New Vector Ltd.
|
||||
// Copyright 2022-2024 Kévin Commaille.
|
||||
//
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
// Please see LICENSE in the repository root for full details.
|
||||
|
||||
//! Requests for [Dynamic Registration].
|
||||
//!
|
||||
//! [Dynamic Registration]: https://openid.net/specs/openid-connect-registration-1_0.html
|
||||
|
||||
use mas_http::{CatchHttpCodesLayer, JsonRequestLayer, JsonResponseLayer};
|
||||
use mas_iana::oauth::OAuthClientAuthenticationMethod;
|
||||
use oauth2_types::registration::{ClientRegistrationResponse, VerifiedClientMetadata};
|
||||
use serde::Serialize;
|
||||
use serde_with::skip_serializing_none;
|
||||
use tower::{Layer, Service, ServiceExt};
|
||||
use url::Url;
|
||||
|
||||
use crate::{
|
||||
error::RegistrationError,
|
||||
http_service::HttpService,
|
||||
utils::{http_all_error_status_codes, http_error_mapper},
|
||||
};
|
||||
|
||||
#[skip_serializing_none]
|
||||
#[derive(Serialize)]
|
||||
struct RegistrationRequest {
|
||||
#[serde(flatten)]
|
||||
client_metadata: VerifiedClientMetadata,
|
||||
software_statement: Option<String>,
|
||||
}
|
||||
|
||||
/// Register a client with an OpenID Provider.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `http_service` - The service to use for making HTTP requests.
|
||||
///
|
||||
/// * `registration_endpoint` - The URL of the issuer's Registration endpoint.
|
||||
///
|
||||
/// * `client_metadata` - The metadata to register with the issuer.
|
||||
///
|
||||
/// * `software_statement` - A JWT that asserts metadata values about the client
|
||||
/// software that should be signed.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the request fails or the response is invalid.
|
||||
#[tracing::instrument(skip_all, fields(registration_endpoint))]
|
||||
pub async fn register_client(
|
||||
http_service: &HttpService,
|
||||
registration_endpoint: &Url,
|
||||
client_metadata: VerifiedClientMetadata,
|
||||
software_statement: Option<String>,
|
||||
) -> Result<ClientRegistrationResponse, RegistrationError> {
|
||||
tracing::debug!("Registering client...");
|
||||
|
||||
let should_receive_secret = matches!(
|
||||
client_metadata.token_endpoint_auth_method(),
|
||||
OAuthClientAuthenticationMethod::ClientSecretPost
|
||||
| OAuthClientAuthenticationMethod::ClientSecretBasic
|
||||
| OAuthClientAuthenticationMethod::ClientSecretJwt
|
||||
);
|
||||
|
||||
let body = RegistrationRequest {
|
||||
client_metadata,
|
||||
software_statement,
|
||||
};
|
||||
|
||||
let registration_req = http::Request::post(registration_endpoint.as_str()).body(body)?;
|
||||
|
||||
let service = (
|
||||
JsonRequestLayer::default(),
|
||||
JsonResponseLayer::<ClientRegistrationResponse>::default(),
|
||||
CatchHttpCodesLayer::new(http_all_error_status_codes(), http_error_mapper),
|
||||
)
|
||||
.layer(http_service.clone());
|
||||
|
||||
let response = service
|
||||
.ready_oneshot()
|
||||
.await?
|
||||
.call(registration_req)
|
||||
.await?
|
||||
.into_body();
|
||||
|
||||
if should_receive_secret && response.client_secret.is_none() {
|
||||
return Err(RegistrationError::MissingClientSecret);
|
||||
}
|
||||
|
||||
Ok(response)
|
||||
}
|
||||
@@ -1,82 +0,0 @@
|
||||
// Copyright 2024 New Vector Ltd.
|
||||
// Copyright 2022-2024 Kévin Commaille.
|
||||
//
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
// Please see LICENSE in the repository root for full details.
|
||||
|
||||
//! Requests for [Token Revocation].
|
||||
//!
|
||||
//! [Token Revocation]: https://www.rfc-editor.org/rfc/rfc7009.html
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use mas_http::{CatchHttpCodesLayer, FormUrlencodedRequestLayer};
|
||||
use mas_iana::oauth::OAuthTokenTypeHint;
|
||||
use oauth2_types::requests::IntrospectionRequest;
|
||||
use rand::Rng;
|
||||
use tower::{Layer, Service, ServiceExt};
|
||||
use url::Url;
|
||||
|
||||
use crate::{
|
||||
error::TokenRevokeError,
|
||||
http_service::HttpService,
|
||||
types::client_credentials::ClientCredentials,
|
||||
utils::{http_all_error_status_codes, http_error_mapper},
|
||||
};
|
||||
|
||||
/// Revoke a token.
|
||||
///
|
||||
/// # Arguments
|
||||
///
|
||||
/// * `http_service` - The service to use for making HTTP requests.
|
||||
///
|
||||
/// * `client_credentials` - The credentials obtained when registering the
|
||||
/// client.
|
||||
///
|
||||
/// * `revocation_endpoint` - The URL of the issuer's Revocation endpoint.
|
||||
///
|
||||
/// * `token` - The token to revoke.
|
||||
///
|
||||
/// * `token_type_hint` - Hint about the type of the token.
|
||||
///
|
||||
/// * `now` - The current time.
|
||||
///
|
||||
/// * `rng` - A random number generator.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the request fails or the response is invalid.
|
||||
#[tracing::instrument(skip_all, fields(revocation_endpoint))]
|
||||
pub async fn revoke_token(
|
||||
http_service: &HttpService,
|
||||
client_credentials: ClientCredentials,
|
||||
revocation_endpoint: &Url,
|
||||
token: String,
|
||||
token_type_hint: Option<OAuthTokenTypeHint>,
|
||||
now: DateTime<Utc>,
|
||||
rng: &mut impl Rng,
|
||||
) -> Result<(), TokenRevokeError> {
|
||||
tracing::debug!("Revoking token…");
|
||||
|
||||
let request = IntrospectionRequest {
|
||||
token,
|
||||
token_type_hint,
|
||||
};
|
||||
|
||||
let revocation_request = http::Request::post(revocation_endpoint.as_str()).body(request)?;
|
||||
|
||||
let revocation_request = client_credentials.apply_to_request(revocation_request, now, rng)?;
|
||||
|
||||
let service = (
|
||||
FormUrlencodedRequestLayer::default(),
|
||||
CatchHttpCodesLayer::new(http_all_error_status_codes(), http_error_mapper),
|
||||
)
|
||||
.layer(http_service.clone());
|
||||
|
||||
service
|
||||
.ready_oneshot()
|
||||
.await?
|
||||
.call(revocation_request)
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -7,18 +7,12 @@
|
||||
//! Requests for the Token endpoint.
|
||||
|
||||
use chrono::{DateTime, Utc};
|
||||
use mas_http::{CatchHttpCodesLayer, FormUrlencodedRequestLayer, JsonResponseLayer};
|
||||
use mas_http::RequestBuilderExt;
|
||||
use oauth2_types::requests::{AccessTokenRequest, AccessTokenResponse};
|
||||
use rand::Rng;
|
||||
use tower::{Layer, Service, ServiceExt};
|
||||
use url::Url;
|
||||
|
||||
use crate::{
|
||||
error::TokenRequestError,
|
||||
http_service::HttpService,
|
||||
types::client_credentials::ClientCredentials,
|
||||
utils::{http_all_error_status_codes, http_error_mapper},
|
||||
};
|
||||
use crate::{error::TokenRequestError, types::client_credentials::ClientCredentials};
|
||||
|
||||
/// Request an access token.
|
||||
///
|
||||
@@ -51,13 +45,15 @@ pub async fn request_access_token(
|
||||
) -> Result<AccessTokenResponse, TokenRequestError> {
|
||||
tracing::debug!(?request, "Requesting access token...");
|
||||
|
||||
let token_request = http_client.post(token_endpoint.as_str()).form(&request);
|
||||
let token_request = http_client.post(token_endpoint.as_str());
|
||||
|
||||
let token_request = client_credentials.apply_to_request(token_request, now, rng)?;
|
||||
|
||||
let res = service.ready_oneshot().await?.call(token_request).await?;
|
||||
|
||||
let token_response = res.into_body();
|
||||
let token_response = client_credentials
|
||||
.authenticated_form(token_request, &request, now, rng)?
|
||||
.send_traced()
|
||||
.await?
|
||||
.error_for_status()?
|
||||
.json()
|
||||
.await?;
|
||||
|
||||
Ok(token_response)
|
||||
}
|
||||
|
||||
@@ -10,23 +10,19 @@
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use bytes::Bytes;
|
||||
use headers::{Authorization, ContentType, HeaderMapExt, HeaderValue};
|
||||
use headers::{ContentType, HeaderMapExt, HeaderValue};
|
||||
use http::header::ACCEPT;
|
||||
use mas_http::CatchHttpCodesLayer;
|
||||
use mas_http::RequestBuilderExt;
|
||||
use mas_jose::claims;
|
||||
use mime::Mime;
|
||||
use serde_json::Value;
|
||||
use tower::{Layer, Service, ServiceExt};
|
||||
use url::Url;
|
||||
|
||||
use super::jose::JwtVerificationData;
|
||||
use crate::{
|
||||
error::{IdTokenError, UserInfoError},
|
||||
http_service::HttpService,
|
||||
requests::jose::verify_signed_jwt,
|
||||
types::IdToken,
|
||||
utils::{http_all_error_status_codes, http_error_mapper},
|
||||
};
|
||||
|
||||
/// Obtain information about an authenticated end-user.
|
||||
@@ -59,7 +55,7 @@ use crate::{
|
||||
/// [`Claim`]: mas_jose::claims::Claim
|
||||
#[tracing::instrument(skip_all, fields(userinfo_endpoint))]
|
||||
pub async fn fetch_userinfo(
|
||||
http_service: &HttpService,
|
||||
http_client: &reqwest::Client,
|
||||
userinfo_endpoint: &Url,
|
||||
access_token: &str,
|
||||
jwt_verification_data: Option<JwtVerificationData<'_>>,
|
||||
@@ -67,29 +63,18 @@ pub async fn fetch_userinfo(
|
||||
) -> Result<HashMap<String, Value>, UserInfoError> {
|
||||
tracing::debug!("Obtaining user info…");
|
||||
|
||||
let mut userinfo_request = http::Request::get(userinfo_endpoint.as_str());
|
||||
|
||||
let expected_content_type = if jwt_verification_data.is_some() {
|
||||
"application/jwt"
|
||||
} else {
|
||||
mime::APPLICATION_JSON.as_ref()
|
||||
};
|
||||
|
||||
if let Some(headers) = userinfo_request.headers_mut() {
|
||||
headers.typed_insert(Authorization::bearer(access_token)?);
|
||||
headers.insert(ACCEPT, HeaderValue::from_static(expected_content_type));
|
||||
}
|
||||
let userinfo_request = http_client
|
||||
.get(userinfo_endpoint.as_str())
|
||||
.bearer_auth(access_token)
|
||||
.header(ACCEPT, HeaderValue::from_static(expected_content_type));
|
||||
|
||||
let userinfo_request = userinfo_request.body(Bytes::new())?;
|
||||
|
||||
let service = CatchHttpCodesLayer::new(http_all_error_status_codes(), http_error_mapper)
|
||||
.layer(http_service.clone());
|
||||
|
||||
let userinfo_response = service
|
||||
.ready_oneshot()
|
||||
.await?
|
||||
.call(userinfo_request)
|
||||
.await?;
|
||||
let userinfo_response = userinfo_request.send_traced().await?.error_for_status()?;
|
||||
|
||||
let content_type: Mime = userinfo_response
|
||||
.headers()
|
||||
@@ -105,15 +90,14 @@ pub async fn fetch_userinfo(
|
||||
});
|
||||
}
|
||||
|
||||
let response_body = std::str::from_utf8(userinfo_response.body())?;
|
||||
|
||||
let mut claims = if let Some(verification_data) = jwt_verification_data {
|
||||
verify_signed_jwt(response_body, verification_data)
|
||||
let response_body = userinfo_response.text().await?;
|
||||
verify_signed_jwt(&response_body, verification_data)
|
||||
.map_err(IdTokenError::from)?
|
||||
.into_parts()
|
||||
.1
|
||||
} else {
|
||||
serde_json::from_str(response_body)?
|
||||
userinfo_response.json().await?
|
||||
};
|
||||
|
||||
let mut auth_claims = auth_id_token.payload().clone();
|
||||
|
||||
@@ -10,8 +10,6 @@ use std::{collections::HashMap, fmt, sync::Arc};
|
||||
|
||||
use base64ct::{Base64UrlUnpadded, Encoding};
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use headers::{Authorization, HeaderMapExt};
|
||||
use http::Request;
|
||||
use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
|
||||
#[cfg(feature = "keystore")]
|
||||
use mas_jose::constraints::Constrainable;
|
||||
@@ -25,7 +23,6 @@ use mas_keystore::Keystore;
|
||||
use rand::Rng;
|
||||
use serde::Serialize;
|
||||
use serde_json::Value;
|
||||
use serde_with::skip_serializing_none;
|
||||
use tower::BoxError;
|
||||
use url::Url;
|
||||
|
||||
@@ -175,45 +172,113 @@ impl ClientCredentials {
|
||||
}
|
||||
}
|
||||
|
||||
/// Apply these `ClientCredentials` to the given request.
|
||||
pub(crate) fn apply_to_request<T: Serialize>(
|
||||
self,
|
||||
/// Apply these [`ClientCredentials`] to the given request with the given
|
||||
/// form.
|
||||
pub(crate) fn authenticated_form<T: Serialize>(
|
||||
&self,
|
||||
request: reqwest::RequestBuilder,
|
||||
form: &T,
|
||||
now: DateTime<Utc>,
|
||||
rng: &mut impl Rng,
|
||||
) -> Result<reqwest::RequestBuilder, CredentialsError> {
|
||||
// TODO: get the form in params, augment it and serialize
|
||||
let credentials = RequestClientCredentials::try_from_credentials(self, now, rng)?;
|
||||
let request = match self {
|
||||
ClientCredentials::None { client_id } => request.form(&RequestWithClientCredentials {
|
||||
body: form,
|
||||
client_id,
|
||||
client_secret: None,
|
||||
client_assertion: None,
|
||||
client_assertion_type: None,
|
||||
}),
|
||||
|
||||
let (parts, body) = request.into_parts();
|
||||
let mut body = RequestWithClientCredentials {
|
||||
body,
|
||||
credentials: None,
|
||||
};
|
||||
|
||||
let request = match credentials {
|
||||
RequestClientCredentials::Body(credentials) => {
|
||||
body.credentials = Some(credentials);
|
||||
Request::from_parts(parts, body)
|
||||
}
|
||||
RequestClientCredentials::Header(credentials) => {
|
||||
let HeaderClientCredentials {
|
||||
ClientCredentials::ClientSecretBasic {
|
||||
client_id,
|
||||
client_secret,
|
||||
} => request.basic_auth(client_id, Some(client_secret)).form(
|
||||
&RequestWithClientCredentials {
|
||||
body: form,
|
||||
client_id,
|
||||
client_secret,
|
||||
} = credentials;
|
||||
client_secret: None,
|
||||
client_assertion: None,
|
||||
client_assertion_type: None,
|
||||
},
|
||||
),
|
||||
|
||||
let mut request = Request::from_parts(parts, body);
|
||||
ClientCredentials::ClientSecretPost {
|
||||
client_id,
|
||||
client_secret,
|
||||
} => request.form(&RequestWithClientCredentials {
|
||||
body: form,
|
||||
client_id,
|
||||
client_secret: Some(client_secret),
|
||||
client_assertion: None,
|
||||
client_assertion_type: None,
|
||||
}),
|
||||
|
||||
// Encode the values with `application/x-www-form-urlencoded`.
|
||||
let client_id =
|
||||
form_urlencoded::byte_serialize(client_id.as_bytes()).collect::<String>();
|
||||
let client_secret =
|
||||
form_urlencoded::byte_serialize(client_secret.as_bytes()).collect::<String>();
|
||||
ClientCredentials::ClientSecretJwt {
|
||||
client_id,
|
||||
client_secret,
|
||||
signing_algorithm,
|
||||
token_endpoint,
|
||||
} => {
|
||||
let claims =
|
||||
prepare_claims(client_id.clone(), token_endpoint.to_string(), now, rng)?;
|
||||
let key = SymmetricKey::new_for_alg(
|
||||
client_secret.as_bytes().to_vec(),
|
||||
signing_algorithm,
|
||||
)?;
|
||||
let header = JsonWebSignatureHeader::new(signing_algorithm.clone());
|
||||
|
||||
let auth = Authorization::basic(&client_id, &client_secret);
|
||||
request.headers_mut().typed_insert(auth);
|
||||
let jwt = Jwt::sign(header, claims, &key)?;
|
||||
|
||||
request
|
||||
request.form(&RequestWithClientCredentials {
|
||||
body: form,
|
||||
client_id,
|
||||
client_secret: None,
|
||||
client_assertion: Some(jwt.as_str()),
|
||||
client_assertion_type: Some(JwtBearerClientAssertionType),
|
||||
})
|
||||
}
|
||||
|
||||
ClientCredentials::PrivateKeyJwt {
|
||||
client_id,
|
||||
jwt_signing_method,
|
||||
signing_algorithm,
|
||||
token_endpoint,
|
||||
} => {
|
||||
let claims =
|
||||
prepare_claims(client_id.clone(), token_endpoint.to_string(), now, rng)?;
|
||||
|
||||
let client_assertion = match jwt_signing_method {
|
||||
#[cfg(feature = "keystore")]
|
||||
JwtSigningMethod::Keystore(keystore) => {
|
||||
let key = keystore
|
||||
.signing_key_for_algorithm(signing_algorithm)
|
||||
.ok_or(CredentialsError::NoPrivateKeyFound)?;
|
||||
let signer = key
|
||||
.params()
|
||||
.signing_key_for_alg(signing_algorithm)
|
||||
.map_err(|_| CredentialsError::JwtWrongAlgorithm)?;
|
||||
let mut header = JsonWebSignatureHeader::new(signing_algorithm.clone());
|
||||
|
||||
if let Some(kid) = key.kid() {
|
||||
header = header.with_kid(kid);
|
||||
}
|
||||
|
||||
Jwt::sign(header, claims, &signer)?.to_string()
|
||||
}
|
||||
JwtSigningMethod::Custom(jwt_signing_fn) => {
|
||||
jwt_signing_fn(claims, signing_algorithm.clone())
|
||||
.map_err(CredentialsError::Custom)?
|
||||
}
|
||||
};
|
||||
|
||||
request.form(&RequestWithClientCredentials {
|
||||
body: form,
|
||||
client_id,
|
||||
client_secret: None,
|
||||
client_assertion: Some(&client_assertion),
|
||||
client_assertion_type: Some(JwtBearerClientAssertionType),
|
||||
})
|
||||
}
|
||||
};
|
||||
|
||||
@@ -264,123 +329,7 @@ impl fmt::Debug for ClientCredentials {
|
||||
|
||||
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize)]
|
||||
#[serde(rename = "urn:ietf:params:oauth:client-assertion-type:jwt-bearer")]
|
||||
pub(crate) struct JwtBearerClientAssertionType;
|
||||
|
||||
enum RequestClientCredentials {
|
||||
Body(BodyClientCredentials),
|
||||
Header(HeaderClientCredentials),
|
||||
}
|
||||
|
||||
impl RequestClientCredentials {
|
||||
fn try_from_credentials(
|
||||
credentials: ClientCredentials,
|
||||
now: DateTime<Utc>,
|
||||
rng: &mut impl Rng,
|
||||
) -> Result<Self, CredentialsError> {
|
||||
let res = match credentials {
|
||||
ClientCredentials::None { client_id } => Self::Body(BodyClientCredentials {
|
||||
client_id,
|
||||
client_secret: None,
|
||||
client_assertion: None,
|
||||
client_assertion_type: None,
|
||||
}),
|
||||
ClientCredentials::ClientSecretBasic {
|
||||
client_id,
|
||||
client_secret,
|
||||
} => Self::Header(HeaderClientCredentials {
|
||||
client_id,
|
||||
client_secret,
|
||||
}),
|
||||
ClientCredentials::ClientSecretPost {
|
||||
client_id,
|
||||
client_secret,
|
||||
} => Self::Body(BodyClientCredentials {
|
||||
client_id,
|
||||
client_secret: Some(client_secret),
|
||||
client_assertion: None,
|
||||
client_assertion_type: None,
|
||||
}),
|
||||
ClientCredentials::ClientSecretJwt {
|
||||
client_id,
|
||||
client_secret,
|
||||
signing_algorithm,
|
||||
token_endpoint,
|
||||
} => {
|
||||
let claims =
|
||||
prepare_claims(client_id.clone(), token_endpoint.to_string(), now, rng)?;
|
||||
let key = SymmetricKey::new_for_alg(client_secret.into(), &signing_algorithm)?;
|
||||
let header = JsonWebSignatureHeader::new(signing_algorithm);
|
||||
|
||||
let jwt = Jwt::sign(header, claims, &key)?;
|
||||
|
||||
Self::Body(BodyClientCredentials {
|
||||
client_id,
|
||||
client_secret: None,
|
||||
client_assertion: Some(jwt.to_string()),
|
||||
client_assertion_type: Some(JwtBearerClientAssertionType),
|
||||
})
|
||||
}
|
||||
ClientCredentials::PrivateKeyJwt {
|
||||
client_id,
|
||||
jwt_signing_method,
|
||||
signing_algorithm,
|
||||
token_endpoint,
|
||||
} => {
|
||||
let claims =
|
||||
prepare_claims(client_id.clone(), token_endpoint.to_string(), now, rng)?;
|
||||
|
||||
let client_assertion = match jwt_signing_method {
|
||||
#[cfg(feature = "keystore")]
|
||||
JwtSigningMethod::Keystore(keystore) => {
|
||||
let key = keystore
|
||||
.signing_key_for_algorithm(&signing_algorithm)
|
||||
.ok_or(CredentialsError::NoPrivateKeyFound)?;
|
||||
let signer = key
|
||||
.params()
|
||||
.signing_key_for_alg(&signing_algorithm)
|
||||
.map_err(|_| CredentialsError::JwtWrongAlgorithm)?;
|
||||
let mut header = JsonWebSignatureHeader::new(signing_algorithm);
|
||||
|
||||
if let Some(kid) = key.kid() {
|
||||
header = header.with_kid(kid);
|
||||
}
|
||||
|
||||
Jwt::sign(header, claims, &signer)?.to_string()
|
||||
}
|
||||
JwtSigningMethod::Custom(jwt_signing_fn) => {
|
||||
jwt_signing_fn(claims, signing_algorithm)
|
||||
.map_err(CredentialsError::Custom)?
|
||||
}
|
||||
};
|
||||
|
||||
Self::Body(BodyClientCredentials {
|
||||
client_id,
|
||||
client_secret: None,
|
||||
client_assertion: Some(client_assertion),
|
||||
client_assertion_type: Some(JwtBearerClientAssertionType),
|
||||
})
|
||||
}
|
||||
};
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::struct_field_names)] // All the fields start with `client_`
|
||||
#[skip_serializing_none]
|
||||
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
|
||||
pub(crate) struct BodyClientCredentials {
|
||||
client_id: String,
|
||||
client_secret: Option<String>,
|
||||
client_assertion: Option<String>,
|
||||
client_assertion_type: Option<JwtBearerClientAssertionType>,
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
struct HeaderClientCredentials {
|
||||
client_id: String,
|
||||
client_secret: String,
|
||||
}
|
||||
struct JwtBearerClientAssertionType;
|
||||
|
||||
fn prepare_claims(
|
||||
iss: String,
|
||||
@@ -409,14 +358,20 @@ fn prepare_claims(
|
||||
|
||||
/// A request with client credentials added to it.
|
||||
#[derive(Clone, Serialize)]
|
||||
#[skip_serializing_none]
|
||||
pub struct RequestWithClientCredentials<T> {
|
||||
struct RequestWithClientCredentials<'a, T> {
|
||||
#[serde(flatten)]
|
||||
pub(crate) body: T,
|
||||
#[serde(flatten)]
|
||||
pub(crate) credentials: Option<BodyClientCredentials>,
|
||||
body: T,
|
||||
|
||||
client_id: &'a str,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
client_secret: Option<&'a str>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
client_assertion: Option<&'a str>,
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
client_assertion_type: Option<JwtBearerClientAssertionType>,
|
||||
}
|
||||
|
||||
/*
|
||||
#[cfg(test)]
|
||||
mod test {
|
||||
use assert_matches::assert_matches;
|
||||
@@ -442,91 +397,6 @@ mod test {
|
||||
Utc::now()
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serialize_credentials() {
|
||||
assert_eq!(
|
||||
serde_urlencoded::to_string(BodyClientCredentials {
|
||||
client_id: CLIENT_ID.to_owned(),
|
||||
client_secret: None,
|
||||
client_assertion: None,
|
||||
client_assertion_type: None,
|
||||
})
|
||||
.unwrap(),
|
||||
"client_id=abcd%24%2B%2B"
|
||||
);
|
||||
assert_eq!(
|
||||
serde_urlencoded::to_string(BodyClientCredentials {
|
||||
client_id: CLIENT_ID.to_owned(),
|
||||
client_secret: Some(CLIENT_SECRET.to_owned()),
|
||||
client_assertion: None,
|
||||
client_assertion_type: None,
|
||||
})
|
||||
.unwrap(),
|
||||
"client_id=abcd%24%2B%2B&client_secret=xyz%21%3B%3F"
|
||||
);
|
||||
assert_eq!(
|
||||
serde_urlencoded::to_string(BodyClientCredentials {
|
||||
client_id: CLIENT_ID.to_owned(),
|
||||
client_secret: None,
|
||||
client_assertion: Some(CLIENT_SECRET.to_owned()),
|
||||
client_assertion_type: Some(JwtBearerClientAssertionType)
|
||||
})
|
||||
.unwrap(),
|
||||
"client_id=abcd%24%2B%2B&client_assertion=xyz%21%3B%3F&client_assertion_type=urn%3Aietf%3Aparams%3Aoauth%3Aclient-assertion-type%3Ajwt-bearer"
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn serialize_request_with_credentials() {
|
||||
let req = RequestWithClientCredentials {
|
||||
body: Body { body: REQUEST_BODY },
|
||||
credentials: None,
|
||||
};
|
||||
assert_eq!(serde_urlencoded::to_string(req).unwrap(), "body=some_body");
|
||||
|
||||
let req = RequestWithClientCredentials {
|
||||
body: Body { body: REQUEST_BODY },
|
||||
credentials: Some(BodyClientCredentials {
|
||||
client_id: CLIENT_ID.to_owned(),
|
||||
client_secret: None,
|
||||
client_assertion: None,
|
||||
client_assertion_type: None,
|
||||
}),
|
||||
};
|
||||
assert_eq!(
|
||||
serde_urlencoded::to_string(req).unwrap(),
|
||||
"body=some_body&client_id=abcd%24%2B%2B"
|
||||
);
|
||||
|
||||
let req = RequestWithClientCredentials {
|
||||
body: Body { body: REQUEST_BODY },
|
||||
credentials: Some(BodyClientCredentials {
|
||||
client_id: CLIENT_ID.to_owned(),
|
||||
client_secret: Some(CLIENT_SECRET.to_owned()),
|
||||
client_assertion: None,
|
||||
client_assertion_type: None,
|
||||
}),
|
||||
};
|
||||
assert_eq!(
|
||||
serde_urlencoded::to_string(req).unwrap(),
|
||||
"body=some_body&client_id=abcd%24%2B%2B&client_secret=xyz%21%3B%3F"
|
||||
);
|
||||
|
||||
let req = RequestWithClientCredentials {
|
||||
body: Body { body: REQUEST_BODY },
|
||||
credentials: Some(BodyClientCredentials {
|
||||
client_id: CLIENT_ID.to_owned(),
|
||||
client_secret: None,
|
||||
client_assertion: Some(CLIENT_SECRET.to_owned()),
|
||||
client_assertion_type: Some(JwtBearerClientAssertionType),
|
||||
}),
|
||||
};
|
||||
assert_eq!(
|
||||
serde_urlencoded::to_string(req).unwrap(),
|
||||
"body=some_body&client_id=abcd%24%2B%2B&client_assertion=xyz%21%3B%3F&client_assertion_type=urn%3Aietf%3Aparams%3Aoauth%3Aclient-assertion-type%3Ajwt-bearer"
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn build_request_none() {
|
||||
let credentials = ClientCredentials::None {
|
||||
@@ -677,3 +547,5 @@ mod test {
|
||||
credentials.client_assertion_type.unwrap();
|
||||
}
|
||||
}
|
||||
|
||||
*/
|
||||
|
||||
@@ -1,31 +0,0 @@
|
||||
// Copyright 2024 New Vector Ltd.
|
||||
// Copyright 2022-2024 Kévin Commaille.
|
||||
//
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
// Please see LICENSE in the repository root for full details.
|
||||
|
||||
use std::ops::RangeBounds;
|
||||
|
||||
use bytes::Buf;
|
||||
use http::{Response, StatusCode};
|
||||
|
||||
use crate::error::ErrorBody;
|
||||
|
||||
pub fn http_error_mapper<T>(response: Response<T>) -> Option<ErrorBody>
|
||||
where
|
||||
T: Buf,
|
||||
{
|
||||
let body = response.into_body();
|
||||
serde_json::from_reader(body.reader()).ok()
|
||||
}
|
||||
|
||||
pub fn http_all_error_status_codes() -> impl RangeBounds<StatusCode> {
|
||||
let Ok(client_errors_start_code) = StatusCode::from_u16(400) else {
|
||||
unreachable!()
|
||||
};
|
||||
let Ok(server_errors_end_code) = StatusCode::from_u16(599) else {
|
||||
unreachable!()
|
||||
};
|
||||
|
||||
client_errors_start_code..=server_errors_end_code
|
||||
}
|
||||
@@ -7,8 +7,6 @@
|
||||
use std::collections::HashMap;
|
||||
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use http_body_util::Full;
|
||||
use mas_http::{BodyToBytesResponseLayer, BoxCloneSyncService};
|
||||
use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
|
||||
use mas_jose::{
|
||||
claims::{self, hash_token},
|
||||
@@ -17,21 +15,14 @@ use mas_jose::{
|
||||
jwt::{JsonWebSignatureHeader, Jwt},
|
||||
};
|
||||
use mas_keystore::{JsonWebKey, JsonWebKeySet, Keystore, PrivateKey};
|
||||
use mas_oidc_client::{
|
||||
http_service::HttpService,
|
||||
types::{
|
||||
client_credentials::{ClientCredentials, JwtSigningFn, JwtSigningMethod},
|
||||
IdToken,
|
||||
},
|
||||
use mas_oidc_client::types::{
|
||||
client_credentials::{ClientCredentials, JwtSigningFn, JwtSigningMethod},
|
||||
IdToken,
|
||||
};
|
||||
use rand::{
|
||||
distributions::{Alphanumeric, DistString},
|
||||
SeedableRng,
|
||||
};
|
||||
use tower::{
|
||||
util::{MapErrLayer, MapRequestLayer},
|
||||
BoxError, Layer,
|
||||
};
|
||||
use url::Url;
|
||||
use wiremock::MockServer;
|
||||
|
||||
@@ -41,7 +32,6 @@ mod types;
|
||||
const REDIRECT_URI: &str = "http://localhost/";
|
||||
const CLIENT_ID: &str = "client!+ID";
|
||||
const CLIENT_SECRET: &str = "SECRET?%Gclient";
|
||||
const REQUEST_URI: &str = "REQUESTur1";
|
||||
const AUTHORIZATION_CODE: &str = "authC0D3";
|
||||
const CODE_VERIFIER: &str = "cODEv3R1f1ER";
|
||||
const NONCE: &str = "No0o0o0once";
|
||||
|
||||
@@ -1,127 +0,0 @@
|
||||
// Copyright 2024 New Vector Ltd.
|
||||
// Copyright 2024 Kévin Commaille.
|
||||
//
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
// Please see LICENSE in the repository root for full details.
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use mas_oidc_client::requests::account_management::{
|
||||
build_account_management_url, AccountManagementActionFull,
|
||||
};
|
||||
use url::Url;
|
||||
|
||||
#[test]
|
||||
fn build_url() {
|
||||
let account_management_uri = Url::parse("http://localhost/account_management/").unwrap();
|
||||
|
||||
// No params
|
||||
let url = build_account_management_url(account_management_uri.clone(), None, None).unwrap();
|
||||
|
||||
assert_eq!(url.query(), None);
|
||||
|
||||
// Action without device ID.
|
||||
let url = build_account_management_url(
|
||||
account_management_uri.clone(),
|
||||
Some(AccountManagementActionFull::Profile),
|
||||
None,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let query_pairs = url.query_pairs().collect::<HashMap<_, _>>();
|
||||
assert_eq!(query_pairs.len(), 1);
|
||||
assert_eq!(query_pairs.get("action").unwrap(), "org.matrix.profile");
|
||||
|
||||
// Action with device ID.
|
||||
let url = build_account_management_url(
|
||||
account_management_uri.clone(),
|
||||
Some(AccountManagementActionFull::SessionEnd {
|
||||
device_id: "mydevice".to_owned(),
|
||||
}),
|
||||
None,
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let query_pairs = url.query_pairs().collect::<HashMap<_, _>>();
|
||||
assert_eq!(query_pairs.len(), 2);
|
||||
assert_eq!(query_pairs.get("action").unwrap(), "org.matrix.session_end");
|
||||
assert_eq!(query_pairs.get("device_id").unwrap(), "mydevice");
|
||||
|
||||
// ID Token hint.
|
||||
let url = build_account_management_url(
|
||||
account_management_uri.clone(),
|
||||
None,
|
||||
Some("anidtokenthat.might.looksomethinglikethis".to_owned()),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let query_pairs = url.query_pairs().collect::<HashMap<_, _>>();
|
||||
assert_eq!(query_pairs.len(), 1);
|
||||
assert_eq!(
|
||||
query_pairs.get("id_token_hint").unwrap(),
|
||||
"anidtokenthat.might.looksomethinglikethis"
|
||||
);
|
||||
|
||||
// Action without device ID and ID Token hint.
|
||||
let url = build_account_management_url(
|
||||
account_management_uri.clone(),
|
||||
Some(AccountManagementActionFull::AccountDeactivate),
|
||||
Some("anotheridtokenthat.might.looksomethinglikethis".to_owned()),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let query_pairs = url.query_pairs().collect::<HashMap<_, _>>();
|
||||
assert_eq!(query_pairs.len(), 2);
|
||||
assert_eq!(
|
||||
query_pairs.get("action").unwrap(),
|
||||
"org.matrix.account_deactivate"
|
||||
);
|
||||
assert_eq!(
|
||||
query_pairs.get("id_token_hint").unwrap(),
|
||||
"anotheridtokenthat.might.looksomethinglikethis"
|
||||
);
|
||||
|
||||
// Action with device ID and ID Token hint.
|
||||
let url = build_account_management_url(
|
||||
account_management_uri,
|
||||
Some(AccountManagementActionFull::SessionView {
|
||||
device_id: "myseconddevice".to_owned(),
|
||||
}),
|
||||
Some("athirdidtokenthat.might.looksomethinglikethis".to_owned()),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let query_pairs = url.query_pairs().collect::<HashMap<_, _>>();
|
||||
assert_eq!(query_pairs.len(), 3);
|
||||
assert_eq!(
|
||||
query_pairs.get("action").unwrap(),
|
||||
"org.matrix.session_view"
|
||||
);
|
||||
assert_eq!(query_pairs.get("device_id").unwrap(), "myseconddevice");
|
||||
assert_eq!(
|
||||
query_pairs.get("id_token_hint").unwrap(),
|
||||
"athirdidtokenthat.might.looksomethinglikethis"
|
||||
);
|
||||
|
||||
// Account management URI with a query already.
|
||||
let account_management_uri_with_query =
|
||||
Url::parse("http://localhost/account_management?param=value").unwrap();
|
||||
|
||||
let url = build_account_management_url(
|
||||
account_management_uri_with_query,
|
||||
Some(AccountManagementActionFull::SessionsList),
|
||||
Some("afinalidtokenthat.might.looksomethinglikethis".to_owned()),
|
||||
)
|
||||
.unwrap();
|
||||
|
||||
let query_pairs = url.query_pairs().collect::<HashMap<_, _>>();
|
||||
assert_eq!(query_pairs.len(), 3);
|
||||
assert_eq!(
|
||||
query_pairs.get("action").unwrap(),
|
||||
"org.matrix.sessions_list"
|
||||
);
|
||||
assert_eq!(
|
||||
query_pairs.get("id_token_hint").unwrap(),
|
||||
"afinalidtokenthat.might.looksomethinglikethis"
|
||||
);
|
||||
}
|
||||
@@ -4,34 +4,26 @@
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
// Please see LICENSE in the repository root for full details.
|
||||
|
||||
use std::{
|
||||
collections::HashMap,
|
||||
num::NonZeroU32,
|
||||
sync::{Arc, Mutex},
|
||||
};
|
||||
use std::{collections::HashMap, num::NonZeroU32};
|
||||
|
||||
use assert_matches::assert_matches;
|
||||
use chrono::Duration;
|
||||
use mas_iana::oauth::{
|
||||
OAuthAccessTokenType, OAuthClientAuthenticationMethod, PkceCodeChallengeMethod,
|
||||
};
|
||||
use mas_jose::{claims::ClaimError, jwk::PublicJsonWebKeySet};
|
||||
use mas_oidc_client::{
|
||||
error::{
|
||||
AuthorizationError, IdTokenError, PushedAuthorizationError, TokenAuthorizationCodeError,
|
||||
},
|
||||
error::{IdTokenError, TokenAuthorizationCodeError},
|
||||
requests::{
|
||||
authorization_code::{
|
||||
access_token_with_authorization_code, build_authorization_url,
|
||||
build_par_authorization_url, AuthorizationRequestData, AuthorizationValidationData,
|
||||
AuthorizationRequestData, AuthorizationValidationData,
|
||||
},
|
||||
jose::JwtVerificationData,
|
||||
},
|
||||
types::scope::{ScopeExt, ScopeToken},
|
||||
};
|
||||
use oauth2_types::requests::{AccessTokenResponse, Display, Prompt, PushedAuthorizationResponse};
|
||||
use oauth2_types::requests::{AccessTokenResponse, Display, Prompt};
|
||||
use rand::SeedableRng;
|
||||
use tokio::sync::oneshot;
|
||||
use url::Url;
|
||||
use wiremock::{
|
||||
matchers::{method, path},
|
||||
@@ -40,7 +32,7 @@ use wiremock::{
|
||||
|
||||
use crate::{
|
||||
client_credentials, id_token, init_test, now, ACCESS_TOKEN, AUTHORIZATION_CODE, CLIENT_ID,
|
||||
CODE_VERIFIER, ID_TOKEN_SIGNING_ALG, NONCE, REDIRECT_URI, REQUEST_URI,
|
||||
CODE_VERIFIER, ID_TOKEN_SIGNING_ALG, NONCE, REDIRECT_URI,
|
||||
};
|
||||
|
||||
#[test]
|
||||
@@ -137,115 +129,6 @@ fn pass_full_authorization_url() {
|
||||
assert_eq!(query_pairs.get("code_challenge_method"), None);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn pass_pushed_authorization_request() {
|
||||
let (http_service, mock_server, issuer) = init_test().await;
|
||||
let client_credentials =
|
||||
client_credentials(&OAuthClientAuthenticationMethod::None, &issuer, None);
|
||||
let authorization_endpoint = issuer.join("authorize").unwrap();
|
||||
let par_endpoint = issuer.join("par").unwrap();
|
||||
let redirect_uri = Url::parse(REDIRECT_URI).unwrap();
|
||||
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42);
|
||||
|
||||
let (sender, receiver) = oneshot::channel();
|
||||
let sender_mutex = Arc::new(Mutex::new(Some(sender)));
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/par"))
|
||||
.and(move |req: &Request| {
|
||||
let body = form_urlencoded::parse(&req.body)
|
||||
.into_owned()
|
||||
.collect::<HashMap<_, _>>();
|
||||
if let Some(sender) = sender_mutex.lock().unwrap().take() {
|
||||
sender.send(body).unwrap();
|
||||
true
|
||||
} else {
|
||||
false
|
||||
}
|
||||
})
|
||||
.respond_with(
|
||||
ResponseTemplate::new(200).set_body_json(PushedAuthorizationResponse {
|
||||
request_uri: REQUEST_URI.to_owned(),
|
||||
expires_in: Duration::microseconds(30 * 1000 * 1000),
|
||||
}),
|
||||
)
|
||||
.mount(&mock_server)
|
||||
.await;
|
||||
|
||||
let (url, validation_data) = build_par_authorization_url(
|
||||
&http_service,
|
||||
client_credentials,
|
||||
&par_endpoint,
|
||||
authorization_endpoint,
|
||||
AuthorizationRequestData::new(
|
||||
CLIENT_ID.to_owned(),
|
||||
[ScopeToken::Openid].into_iter().collect(),
|
||||
redirect_uri,
|
||||
)
|
||||
.with_code_challenge_methods_supported(vec![PkceCodeChallengeMethod::S256]),
|
||||
now(),
|
||||
&mut rng,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(validation_data.state, "OrJ8xbWovSpJUTKz");
|
||||
assert_eq!(
|
||||
validation_data.code_challenge_verifier.unwrap(),
|
||||
"TSgZ_hr3TJPjhq4aDp34K_8ksjLwaa1xDcPiRGBcjhM"
|
||||
);
|
||||
|
||||
let request_pairs = receiver.await.unwrap();
|
||||
|
||||
assert_eq!(url.path(), "/authorize");
|
||||
let query_pairs = url.query_pairs().collect::<HashMap<_, _>>();
|
||||
assert_eq!(query_pairs.get("request_uri").unwrap(), REQUEST_URI,);
|
||||
assert_eq!(query_pairs.get("client_id").unwrap(), CLIENT_ID);
|
||||
|
||||
assert_eq!(request_pairs.get("scope").unwrap(), "openid");
|
||||
assert_eq!(request_pairs.get("response_type").unwrap(), "code");
|
||||
assert_eq!(request_pairs.get("client_id").unwrap(), CLIENT_ID);
|
||||
assert_eq!(request_pairs.get("redirect_uri").unwrap(), REDIRECT_URI);
|
||||
assert_eq!(*request_pairs.get("state").unwrap(), validation_data.state);
|
||||
assert_eq!(request_pairs.get("nonce").unwrap(), "ox0PigY5l9xl5uTL");
|
||||
let code_challenge = request_pairs.get("code_challenge").unwrap();
|
||||
assert!(code_challenge.len() >= 43);
|
||||
assert_eq!(request_pairs.get("code_challenge_method").unwrap(), "S256");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn fail_pushed_authorization_request_404() {
|
||||
let (http_service, _, issuer) = init_test().await;
|
||||
let client_credentials =
|
||||
client_credentials(&OAuthClientAuthenticationMethod::None, &issuer, None);
|
||||
let authorization_endpoint = issuer.join("authorize").unwrap();
|
||||
let par_endpoint = issuer.join("par").unwrap();
|
||||
let redirect_uri = Url::parse(REDIRECT_URI).unwrap();
|
||||
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42);
|
||||
|
||||
let error = build_par_authorization_url(
|
||||
&http_service,
|
||||
client_credentials,
|
||||
&par_endpoint,
|
||||
authorization_endpoint,
|
||||
AuthorizationRequestData::new(
|
||||
CLIENT_ID.to_owned(),
|
||||
[ScopeToken::Openid].into_iter().collect(),
|
||||
redirect_uri,
|
||||
)
|
||||
.with_code_challenge_methods_supported(vec![PkceCodeChallengeMethod::S256]),
|
||||
now(),
|
||||
&mut rng,
|
||||
)
|
||||
.await
|
||||
.unwrap_err();
|
||||
|
||||
assert_matches!(
|
||||
error,
|
||||
AuthorizationError::PushedAuthorization(PushedAuthorizationError::Http(_))
|
||||
);
|
||||
}
|
||||
|
||||
/// Check if the given request to the token endpoint is valid.
|
||||
fn is_valid_token_endpoint_request(req: &Request) -> bool {
|
||||
let body = form_urlencoded::parse(&req.body).collect::<HashMap<_, _>>();
|
||||
|
||||
@@ -36,7 +36,7 @@ fn provider_metadata(issuer: &Url) -> ProviderMetadata {
|
||||
|
||||
#[tokio::test]
|
||||
async fn pass_discover() {
|
||||
let (http_service, mock_server, issuer) = init_test().await;
|
||||
let (http_client, mock_server, issuer) = init_test().await;
|
||||
|
||||
Mock::given(method("GET"))
|
||||
.and(path("/.well-known/openid-configuration"))
|
||||
@@ -44,7 +44,9 @@ async fn pass_discover() {
|
||||
.mount(&mock_server)
|
||||
.await;
|
||||
|
||||
let provider_metadata = insecure_discover(&client, issuer.as_str()).await.unwrap();
|
||||
let provider_metadata = insecure_discover(&http_client, issuer.as_str())
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(provider_metadata.issuer(), issuer.as_str());
|
||||
}
|
||||
@@ -70,7 +72,7 @@ async fn fail_discover_not_json() {
|
||||
|
||||
let error = discover(&http_service, issuer.as_str()).await.unwrap_err();
|
||||
|
||||
assert_matches!(error, DiscoveryError::FromJson(_));
|
||||
assert_matches!(error, DiscoveryError::Http(_));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
|
||||
@@ -1,100 +0,0 @@
|
||||
// Copyright 2024 New Vector Ltd.
|
||||
// Copyright 2022-2024 Kévin Commaille.
|
||||
//
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
// Please see LICENSE in the repository root for full details.
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use mas_iana::oauth::{OAuthClientAuthenticationMethod, OAuthTokenTypeHint};
|
||||
use mas_oidc_client::{
|
||||
requests::introspection::introspect_token,
|
||||
types::scope::{ScopeExt, ScopeToken},
|
||||
};
|
||||
use oauth2_types::requests::IntrospectionResponse;
|
||||
use rand::SeedableRng;
|
||||
use wiremock::{
|
||||
matchers::{method, path},
|
||||
Mock, Request, ResponseTemplate,
|
||||
};
|
||||
|
||||
use crate::{client_credentials, init_test, now, ACCESS_TOKEN, CLIENT_ID, SUBJECT_IDENTIFIER};
|
||||
|
||||
#[tokio::test]
|
||||
async fn pass_introspect_token() {
|
||||
let (http_service, mock_server, issuer) = init_test().await;
|
||||
let client_credentials =
|
||||
client_credentials(&OAuthClientAuthenticationMethod::None, &issuer, None);
|
||||
let introspection_endpoint = issuer.join("introspect").unwrap();
|
||||
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42);
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/introspect"))
|
||||
.and(|req: &Request| {
|
||||
let query_pairs = form_urlencoded::parse(&req.body).collect::<HashMap<_, _>>();
|
||||
|
||||
if query_pairs
|
||||
.get("token")
|
||||
.filter(|s| *s == ACCESS_TOKEN)
|
||||
.is_none()
|
||||
{
|
||||
println!("Wrong or missing token");
|
||||
return false;
|
||||
}
|
||||
if query_pairs
|
||||
.get("token_type_hint")
|
||||
.filter(|s| *s == "access_token")
|
||||
.is_none()
|
||||
{
|
||||
println!("Wrong or missing token type hint");
|
||||
return false;
|
||||
}
|
||||
if query_pairs
|
||||
.get("client_id")
|
||||
.filter(|s| *s == CLIENT_ID)
|
||||
.is_none()
|
||||
{
|
||||
println!("Wrong or missing client ID");
|
||||
return false;
|
||||
}
|
||||
|
||||
true
|
||||
})
|
||||
.respond_with(
|
||||
ResponseTemplate::new(200).set_body_json(IntrospectionResponse {
|
||||
active: true,
|
||||
scope: Some([ScopeToken::Profile].into_iter().collect()),
|
||||
client_id: Some(CLIENT_ID.to_owned()),
|
||||
username: None,
|
||||
token_type: Some(OAuthTokenTypeHint::AccessToken),
|
||||
exp: None,
|
||||
iat: None,
|
||||
nbf: None,
|
||||
sub: Some(SUBJECT_IDENTIFIER.to_owned()),
|
||||
aud: Some(CLIENT_ID.to_owned()),
|
||||
iss: Some(issuer.to_string()),
|
||||
jti: None,
|
||||
}),
|
||||
)
|
||||
.mount(&mock_server)
|
||||
.await;
|
||||
|
||||
let response = introspect_token(
|
||||
&http_service,
|
||||
client_credentials.into(),
|
||||
&introspection_endpoint,
|
||||
ACCESS_TOKEN.to_owned(),
|
||||
Some(OAuthTokenTypeHint::AccessToken),
|
||||
now(),
|
||||
&mut rng,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert!(response.active);
|
||||
assert_eq!(response.aud.unwrap(), CLIENT_ID);
|
||||
assert!(response.scope.unwrap().contains_token(&ScopeToken::Profile));
|
||||
assert_eq!(response.client_id.unwrap(), CLIENT_ID);
|
||||
assert_eq!(response.iss.unwrap(), issuer.as_str());
|
||||
assert_eq!(response.sub.unwrap(), SUBJECT_IDENTIFIER);
|
||||
}
|
||||
@@ -4,14 +4,10 @@
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
// Please see LICENSE in the repository root for full details.
|
||||
|
||||
mod account_management;
|
||||
mod authorization_code;
|
||||
mod client_credentials;
|
||||
mod discovery;
|
||||
mod introspection;
|
||||
mod jose;
|
||||
mod refresh_token;
|
||||
mod registration;
|
||||
mod revocation;
|
||||
mod rp_initiated_logout;
|
||||
mod userinfo;
|
||||
|
||||
@@ -1,250 +0,0 @@
|
||||
// Copyright 2024 New Vector Ltd.
|
||||
// Copyright 2022-2024 Kévin Commaille.
|
||||
//
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
// Please see LICENSE in the repository root for full details.
|
||||
|
||||
use assert_matches::assert_matches;
|
||||
use mas_iana::{jose::JsonWebSignatureAlg, oauth::OAuthClientAuthenticationMethod};
|
||||
use mas_jose::jwk::PublicJsonWebKeySet;
|
||||
use mas_oidc_client::{error::RegistrationError, requests::registration::register_client};
|
||||
use oauth2_types::{
|
||||
oidc::ApplicationType,
|
||||
registration::{ClientMetadata, ClientRegistrationResponse, VerifiedClientMetadata},
|
||||
};
|
||||
use serde_json::json;
|
||||
use url::Url;
|
||||
use wiremock::{
|
||||
matchers::{body_partial_json, method, path},
|
||||
Mock, Request, ResponseTemplate,
|
||||
};
|
||||
|
||||
use crate::{init_test, CLIENT_ID, CLIENT_SECRET, REDIRECT_URI};
|
||||
|
||||
/// Generate valid client metadata for the given authentication method.
|
||||
fn client_metadata(auth_method: OAuthClientAuthenticationMethod) -> VerifiedClientMetadata {
|
||||
let (signing_alg, jwks) = match &auth_method {
|
||||
OAuthClientAuthenticationMethod::ClientSecretJwt => {
|
||||
(Some(JsonWebSignatureAlg::Hs256), None)
|
||||
}
|
||||
OAuthClientAuthenticationMethod::PrivateKeyJwt => (
|
||||
Some(JsonWebSignatureAlg::Es256),
|
||||
Some(PublicJsonWebKeySet::default()),
|
||||
),
|
||||
_ => (None, None),
|
||||
};
|
||||
|
||||
ClientMetadata {
|
||||
redirect_uris: Some(vec![Url::parse(REDIRECT_URI).expect("Couldn't parse URL")]),
|
||||
application_type: Some(ApplicationType::Native),
|
||||
token_endpoint_auth_method: Some(auth_method),
|
||||
token_endpoint_auth_signing_alg: signing_alg,
|
||||
jwks,
|
||||
..Default::default()
|
||||
}
|
||||
.validate()
|
||||
.unwrap()
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn pass_register_client_none() {
|
||||
let (http_service, mock_server, issuer) = init_test().await;
|
||||
let client_metadata = client_metadata(OAuthClientAuthenticationMethod::None);
|
||||
let registration_endpoint = issuer.join("register").unwrap();
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/register"))
|
||||
.and(body_partial_json(json!({
|
||||
"redirect_uris": [REDIRECT_URI],
|
||||
"token_endpoint_auth_method": "none",
|
||||
})))
|
||||
.respond_with(
|
||||
ResponseTemplate::new(200).set_body_json(ClientRegistrationResponse {
|
||||
client_id: CLIENT_ID.to_owned(),
|
||||
client_secret: None,
|
||||
client_id_issued_at: None,
|
||||
client_secret_expires_at: None,
|
||||
}),
|
||||
)
|
||||
.mount(&mock_server)
|
||||
.await;
|
||||
|
||||
let response = register_client(&http_service, ®istration_endpoint, client_metadata, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(response.client_id, CLIENT_ID);
|
||||
assert_eq!(response.client_secret, None);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn pass_register_client_client_secret_basic() {
|
||||
let (http_service, mock_server, issuer) = init_test().await;
|
||||
let client_metadata = client_metadata(OAuthClientAuthenticationMethod::ClientSecretBasic);
|
||||
let registration_endpoint = issuer.join("register").unwrap();
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/register"))
|
||||
.and(body_partial_json(json!({
|
||||
"redirect_uris": [REDIRECT_URI],
|
||||
"token_endpoint_auth_method": "client_secret_basic",
|
||||
})))
|
||||
.respond_with(
|
||||
ResponseTemplate::new(200).set_body_json(ClientRegistrationResponse {
|
||||
client_id: CLIENT_ID.to_owned(),
|
||||
client_secret: Some(CLIENT_SECRET.to_owned()),
|
||||
client_id_issued_at: None,
|
||||
client_secret_expires_at: None,
|
||||
}),
|
||||
)
|
||||
.mount(&mock_server)
|
||||
.await;
|
||||
|
||||
let response = register_client(&http_service, ®istration_endpoint, client_metadata, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(response.client_id, CLIENT_ID);
|
||||
assert_eq!(response.client_secret.unwrap(), CLIENT_SECRET);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn pass_register_client_client_secret_post() {
|
||||
let (http_service, mock_server, issuer) = init_test().await;
|
||||
let client_metadata = client_metadata(OAuthClientAuthenticationMethod::ClientSecretPost);
|
||||
let registration_endpoint = issuer.join("register").unwrap();
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/register"))
|
||||
.and(body_partial_json(json!({
|
||||
"redirect_uris": [REDIRECT_URI],
|
||||
"token_endpoint_auth_method": "client_secret_post",
|
||||
})))
|
||||
.respond_with(
|
||||
ResponseTemplate::new(200).set_body_json(ClientRegistrationResponse {
|
||||
client_id: CLIENT_ID.to_owned(),
|
||||
client_secret: Some(CLIENT_SECRET.to_owned()),
|
||||
client_id_issued_at: None,
|
||||
client_secret_expires_at: None,
|
||||
}),
|
||||
)
|
||||
.mount(&mock_server)
|
||||
.await;
|
||||
|
||||
let response = register_client(&http_service, ®istration_endpoint, client_metadata, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(response.client_id, CLIENT_ID);
|
||||
assert_eq!(response.client_secret.unwrap(), CLIENT_SECRET);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn pass_register_client_client_secret_jwt() {
|
||||
let (http_service, mock_server, issuer) = init_test().await;
|
||||
let client_metadata = client_metadata(OAuthClientAuthenticationMethod::ClientSecretJwt);
|
||||
let registration_endpoint = issuer.join("register").unwrap();
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/register"))
|
||||
.and(body_partial_json(json!({
|
||||
"redirect_uris": [REDIRECT_URI],
|
||||
"token_endpoint_auth_method": "client_secret_jwt",
|
||||
"token_endpoint_auth_signing_alg": "HS256",
|
||||
})))
|
||||
.respond_with(
|
||||
ResponseTemplate::new(200).set_body_json(ClientRegistrationResponse {
|
||||
client_id: CLIENT_ID.to_owned(),
|
||||
client_secret: Some(CLIENT_SECRET.to_owned()),
|
||||
client_id_issued_at: None,
|
||||
client_secret_expires_at: None,
|
||||
}),
|
||||
)
|
||||
.mount(&mock_server)
|
||||
.await;
|
||||
|
||||
let response = register_client(&http_service, ®istration_endpoint, client_metadata, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(response.client_id, CLIENT_ID);
|
||||
assert_eq!(response.client_secret.unwrap(), CLIENT_SECRET);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn pass_register_client_private_key_jwt() {
|
||||
let (http_service, mock_server, issuer) = init_test().await;
|
||||
let client_metadata = client_metadata(OAuthClientAuthenticationMethod::PrivateKeyJwt);
|
||||
let registration_endpoint = issuer.join("register").unwrap();
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/register"))
|
||||
.and(|req: &Request| {
|
||||
let Ok(metadata) = req.body_json::<ClientMetadata>() else {
|
||||
return false;
|
||||
};
|
||||
|
||||
*metadata.token_endpoint_auth_method() == OAuthClientAuthenticationMethod::PrivateKeyJwt
|
||||
&& metadata.token_endpoint_auth_signing_alg == Some(JsonWebSignatureAlg::Es256)
|
||||
&& metadata.jwks.is_some()
|
||||
})
|
||||
.respond_with(
|
||||
ResponseTemplate::new(200).set_body_json(ClientRegistrationResponse {
|
||||
client_id: CLIENT_ID.to_owned(),
|
||||
client_secret: None,
|
||||
client_id_issued_at: None,
|
||||
client_secret_expires_at: None,
|
||||
}),
|
||||
)
|
||||
.mount(&mock_server)
|
||||
.await;
|
||||
|
||||
let response = register_client(&http_service, ®istration_endpoint, client_metadata, None)
|
||||
.await
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(response.client_id, CLIENT_ID);
|
||||
assert_eq!(response.client_secret, None);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn fail_register_client_404() {
|
||||
let (http_service, _, issuer) = init_test().await;
|
||||
let client_metadata = client_metadata(OAuthClientAuthenticationMethod::None);
|
||||
let registration_endpoint = issuer.join("register").unwrap();
|
||||
|
||||
let error = register_client(&http_service, ®istration_endpoint, client_metadata, None)
|
||||
.await
|
||||
.unwrap_err();
|
||||
|
||||
assert_matches!(error, RegistrationError::Http(_));
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn fail_register_client_missing_secret() {
|
||||
let (http_service, mock_server, issuer) = init_test().await;
|
||||
let client_metadata = client_metadata(OAuthClientAuthenticationMethod::ClientSecretBasic);
|
||||
let registration_endpoint = issuer.join("register").unwrap();
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/register"))
|
||||
.and(body_partial_json(json!({
|
||||
"token_endpoint_auth_method": "client_secret_basic",
|
||||
})))
|
||||
.respond_with(
|
||||
ResponseTemplate::new(200).set_body_json(ClientRegistrationResponse {
|
||||
client_id: CLIENT_ID.to_owned(),
|
||||
client_secret: None,
|
||||
client_id_issued_at: None,
|
||||
client_secret_expires_at: None,
|
||||
}),
|
||||
)
|
||||
.mount(&mock_server)
|
||||
.await;
|
||||
|
||||
let error = register_client(&http_service, ®istration_endpoint, client_metadata, None)
|
||||
.await
|
||||
.unwrap_err();
|
||||
|
||||
assert_matches!(error, RegistrationError::MissingClientSecret);
|
||||
}
|
||||
@@ -1,74 +0,0 @@
|
||||
// Copyright 2024 New Vector Ltd.
|
||||
// Copyright 2022-2024 Kévin Commaille.
|
||||
//
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
// Please see LICENSE in the repository root for full details.
|
||||
|
||||
use std::collections::HashMap;
|
||||
|
||||
use mas_iana::oauth::{OAuthClientAuthenticationMethod, OAuthTokenTypeHint};
|
||||
use mas_oidc_client::requests::revocation::revoke_token;
|
||||
use rand::SeedableRng;
|
||||
use wiremock::{
|
||||
matchers::{method, path},
|
||||
Mock, Request, ResponseTemplate,
|
||||
};
|
||||
|
||||
use crate::{client_credentials, init_test, ACCESS_TOKEN, CLIENT_ID};
|
||||
|
||||
#[tokio::test]
|
||||
async fn pass_revoke_token() {
|
||||
let (http_service, mock_server, issuer) = init_test().await;
|
||||
let client_credentials =
|
||||
client_credentials(&OAuthClientAuthenticationMethod::None, &issuer, None);
|
||||
let revocation_endpoint = issuer.join("revoke").unwrap();
|
||||
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42);
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/revoke"))
|
||||
.and(|req: &Request| {
|
||||
let query_pairs = form_urlencoded::parse(&req.body).collect::<HashMap<_, _>>();
|
||||
|
||||
if query_pairs
|
||||
.get("token")
|
||||
.filter(|s| *s == ACCESS_TOKEN)
|
||||
.is_none()
|
||||
{
|
||||
println!("Wrong or missing refresh token");
|
||||
return false;
|
||||
}
|
||||
if query_pairs
|
||||
.get("token_type_hint")
|
||||
.filter(|s| *s == "access_token")
|
||||
.is_none()
|
||||
{
|
||||
println!("Wrong or missing token type hint");
|
||||
return false;
|
||||
}
|
||||
if query_pairs
|
||||
.get("client_id")
|
||||
.filter(|s| *s == CLIENT_ID)
|
||||
.is_none()
|
||||
{
|
||||
println!("Wrong or missing client ID");
|
||||
return false;
|
||||
}
|
||||
|
||||
true
|
||||
})
|
||||
.respond_with(ResponseTemplate::new(200))
|
||||
.mount(&mock_server)
|
||||
.await;
|
||||
|
||||
revoke_token(
|
||||
&http_service,
|
||||
client_credentials,
|
||||
&revocation_endpoint,
|
||||
ACCESS_TOKEN.to_owned(),
|
||||
Some(OAuthTokenTypeHint::AccessToken),
|
||||
crate::now(),
|
||||
&mut rng,
|
||||
)
|
||||
.await
|
||||
.unwrap();
|
||||
}
|
||||
@@ -8,6 +8,7 @@ use std::collections::HashMap;
|
||||
|
||||
use assert_matches::assert_matches;
|
||||
use base64ct::Encoding;
|
||||
use http::header::AUTHORIZATION;
|
||||
use mas_iana::oauth::{OAuthAccessTokenType, OAuthClientAuthenticationMethod};
|
||||
use mas_jose::{
|
||||
claims::{self, TimeOptions},
|
||||
@@ -31,7 +32,7 @@ use crate::{client_credentials, init_test, now, ACCESS_TOKEN, CLIENT_ID, CLIENT_
|
||||
|
||||
#[tokio::test]
|
||||
async fn pass_none() {
|
||||
let (http_service, mock_server, issuer) = init_test().await;
|
||||
let (http_client, mock_server, issuer) = init_test().await;
|
||||
let client_credentials =
|
||||
client_credentials(&OAuthClientAuthenticationMethod::None, &issuer, None);
|
||||
let token_endpoint = issuer.join("token").unwrap();
|
||||
@@ -67,7 +68,7 @@ async fn pass_none() {
|
||||
.await;
|
||||
|
||||
access_token_with_client_credentials(
|
||||
&http_service,
|
||||
&http_client,
|
||||
client_credentials,
|
||||
&token_endpoint,
|
||||
None,
|
||||
@@ -80,7 +81,7 @@ async fn pass_none() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn pass_client_secret_basic() {
|
||||
let (http_service, mock_server, issuer) = init_test().await;
|
||||
let (http_client, mock_server, issuer) = init_test().await;
|
||||
let client_credentials = client_credentials(
|
||||
&OAuthClientAuthenticationMethod::ClientSecretBasic,
|
||||
&issuer,
|
||||
@@ -94,10 +95,15 @@ async fn pass_client_secret_basic() {
|
||||
let enc_user_pass =
|
||||
base64ct::Base64::encode_string(format!("{username}:{password}").as_bytes());
|
||||
let authorization_header = format!("Basic {enc_user_pass}");
|
||||
eprintln!("{authorization_header}");
|
||||
|
||||
Mock::given(method("POST"))
|
||||
.and(path("/token"))
|
||||
.and(header("authorization", authorization_header.as_str()))
|
||||
.and(|req: &Request| {
|
||||
println!("{req:?}");
|
||||
true
|
||||
})
|
||||
.and(header(AUTHORIZATION, authorization_header.as_str()))
|
||||
.respond_with(
|
||||
ResponseTemplate::new(200).set_body_json(AccessTokenResponse {
|
||||
access_token: ACCESS_TOKEN.to_owned(),
|
||||
@@ -112,7 +118,7 @@ async fn pass_client_secret_basic() {
|
||||
.await;
|
||||
|
||||
access_token_with_client_credentials(
|
||||
&http_service,
|
||||
&http_client,
|
||||
client_credentials,
|
||||
&token_endpoint,
|
||||
None,
|
||||
@@ -125,7 +131,7 @@ async fn pass_client_secret_basic() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn pass_client_secret_post() {
|
||||
let (http_service, mock_server, issuer) = init_test().await;
|
||||
let (http_client, mock_server, issuer) = init_test().await;
|
||||
let client_credentials = client_credentials(
|
||||
&OAuthClientAuthenticationMethod::ClientSecretPost,
|
||||
&issuer,
|
||||
@@ -172,7 +178,7 @@ async fn pass_client_secret_post() {
|
||||
.await;
|
||||
|
||||
access_token_with_client_credentials(
|
||||
&http_service,
|
||||
&http_client,
|
||||
client_credentials,
|
||||
&token_endpoint,
|
||||
None,
|
||||
@@ -185,7 +191,7 @@ async fn pass_client_secret_post() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn pass_client_secret_jwt() {
|
||||
let (http_service, mock_server, issuer) = init_test().await;
|
||||
let (http_client, mock_server, issuer) = init_test().await;
|
||||
let client_credentials = client_credentials(
|
||||
&OAuthClientAuthenticationMethod::ClientSecretJwt,
|
||||
&issuer,
|
||||
@@ -253,7 +259,7 @@ async fn pass_client_secret_jwt() {
|
||||
.await;
|
||||
|
||||
access_token_with_client_credentials(
|
||||
&http_service,
|
||||
&http_client,
|
||||
client_credentials,
|
||||
&token_endpoint,
|
||||
None,
|
||||
@@ -266,7 +272,7 @@ async fn pass_client_secret_jwt() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn pass_private_key_jwt_with_keystore() {
|
||||
let (http_service, mock_server, issuer) = init_test().await;
|
||||
let (http_client, mock_server, issuer) = init_test().await;
|
||||
let client_credentials = client_credentials(
|
||||
&OAuthClientAuthenticationMethod::PrivateKeyJwt,
|
||||
&issuer,
|
||||
@@ -341,7 +347,7 @@ async fn pass_private_key_jwt_with_keystore() {
|
||||
.await;
|
||||
|
||||
access_token_with_client_credentials(
|
||||
&http_service,
|
||||
&http_client,
|
||||
client_credentials,
|
||||
&token_endpoint,
|
||||
None,
|
||||
@@ -354,7 +360,7 @@ async fn pass_private_key_jwt_with_keystore() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn pass_private_key_jwt_with_custom_signing() {
|
||||
let (http_service, mock_server, issuer) = init_test().await;
|
||||
let (http_client, mock_server, issuer) = init_test().await;
|
||||
let client_credentials = client_credentials(
|
||||
&OAuthClientAuthenticationMethod::PrivateKeyJwt,
|
||||
&issuer,
|
||||
@@ -410,7 +416,7 @@ async fn pass_private_key_jwt_with_custom_signing() {
|
||||
.await;
|
||||
|
||||
access_token_with_client_credentials(
|
||||
&http_service,
|
||||
&http_client,
|
||||
client_credentials,
|
||||
&token_endpoint,
|
||||
None,
|
||||
@@ -423,7 +429,7 @@ async fn pass_private_key_jwt_with_custom_signing() {
|
||||
|
||||
#[tokio::test]
|
||||
async fn fail_private_key_jwt_with_custom_signing() {
|
||||
let (http_service, _, issuer) = init_test().await;
|
||||
let (http_client, _, issuer) = init_test().await;
|
||||
let client_credentials = client_credentials(
|
||||
&OAuthClientAuthenticationMethod::PrivateKeyJwt,
|
||||
&issuer,
|
||||
@@ -433,7 +439,7 @@ async fn fail_private_key_jwt_with_custom_signing() {
|
||||
let mut rng = rand_chacha::ChaCha8Rng::seed_from_u64(42);
|
||||
|
||||
let error = access_token_with_client_credentials(
|
||||
&http_service,
|
||||
&http_client,
|
||||
client_credentials,
|
||||
&token_endpoint,
|
||||
None,
|
||||
|
||||
Reference in New Issue
Block a user