Allow the homeserver to perform introspection using a shared secret

This commit is contained in:
Quentin Gliech
2025-07-22 11:18:08 +02:00
parent 0f4cd9d540
commit 01d5a2cca2
7 changed files with 205 additions and 45 deletions

View File

@@ -9,13 +9,12 @@ use std::collections::HashMap;
use axum::{
BoxError, Json,
extract::{
Form, FromRequest, FromRequestParts,
Form, FromRequest,
rejection::{FailedToDeserializeForm, FormRejection},
},
response::IntoResponse,
};
use axum_extra::typed_header::{TypedHeader, TypedHeaderRejectionReason};
use headers::{Authorization, authorization::Basic};
use headers::authorization::{Basic, Bearer, Credentials as _};
use http::{Request, StatusCode};
use mas_data_model::{Client, JwksOrJwksUri};
use mas_http::RequestBuilderExt;
@@ -60,17 +59,30 @@ pub enum Credentials {
client_id: String,
jwt: Box<Jwt<'static, HashMap<String, serde_json::Value>>>,
},
BearerToken {
token: String,
},
}
impl Credentials {
/// Get the `client_id` of the credentials
#[must_use]
pub fn client_id(&self) -> &str {
pub fn client_id(&self) -> Option<&str> {
match self {
Credentials::None { client_id }
| Credentials::ClientSecretBasic { client_id, .. }
| Credentials::ClientSecretPost { client_id, .. }
| Credentials::ClientAssertionJwtBearer { client_id, .. } => client_id,
| Credentials::ClientAssertionJwtBearer { client_id, .. } => Some(client_id),
Credentials::BearerToken { .. } => None,
}
}
/// Get the bearer token from the credentials.
#[must_use]
pub fn bearer_token(&self) -> Option<&str> {
match self {
Credentials::BearerToken { token } => Some(token),
_ => None,
}
}
@@ -89,6 +101,7 @@ impl Credentials {
| Credentials::ClientSecretBasic { client_id, .. }
| Credentials::ClientSecretPost { client_id, .. }
| Credentials::ClientAssertionJwtBearer { client_id, .. } => client_id,
Credentials::BearerToken { .. } => return Ok(None),
};
repo.oauth2_client().find_by_client_id(client_id).await
@@ -239,7 +252,7 @@ pub struct ClientAuthorization<F = ()> {
impl<F> ClientAuthorization<F> {
/// Get the `client_id` from the credentials.
#[must_use]
pub fn client_id(&self) -> &str {
pub fn client_id(&self) -> Option<&str> {
self.credentials.client_id()
}
}
@@ -360,26 +373,37 @@ where
req: Request<axum::body::Body>,
state: &S,
) -> Result<Self, Self::Rejection> {
// Split the request into parts so we can extract some headers
let (mut parts, body) = req.into_parts();
enum Authorization {
Basic(String, String),
Bearer(String),
}
let header =
TypedHeader::<Authorization<Basic>>::from_request_parts(&mut parts, state).await;
// Sadly, the typed-header 'Authorization' doesn't let us check for both
// Basic and Bearer at the same time, so we need to parse them manually
let authorization = if let Some(header) = req.headers().get(http::header::AUTHORIZATION) {
let bytes = header.as_bytes();
if bytes.len() >= 6 && bytes[..6].eq_ignore_ascii_case(b"Basic ") {
let Some(decoded) = Basic::decode(header) else {
return Err(ClientAuthorizationError::InvalidHeader);
};
// Take the Authorization header
let credentials_from_header = match header {
Ok(header) => Some((header.username().to_owned(), header.password().to_owned())),
Err(err) => match err.reason() {
// If it's missing it is fine
TypedHeaderRejectionReason::Missing => None,
// If the header could not be parsed, return the error
_ => return Err(ClientAuthorizationError::InvalidHeader),
},
Some(Authorization::Basic(
decoded.username().to_owned(),
decoded.password().to_owned(),
))
} else if bytes.len() >= 7 && bytes[..7].eq_ignore_ascii_case(b"Bearer ") {
let Some(decoded) = Bearer::decode(header) else {
return Err(ClientAuthorizationError::InvalidHeader);
};
Some(Authorization::Bearer(decoded.token().to_owned()))
} else {
return Err(ClientAuthorizationError::InvalidHeader);
}
} else {
None
};
// Reconstruct the request from the parts
let req = Request::from_parts(parts, body);
// Take the form value
let (
client_id_from_form,
@@ -407,13 +431,19 @@ where
// And now, figure out the actual auth method
let credentials = match (
credentials_from_header,
authorization,
client_id_from_form,
client_secret_from_form,
client_assertion_type,
client_assertion,
) {
(Some((client_id, client_secret)), client_id_from_form, None, None, None) => {
(
Some(Authorization::Basic(client_id, client_secret)),
client_id_from_form,
None,
None,
None,
) => {
if let Some(client_id_from_form) = client_id_from_form {
// If the client_id was in the body, verify it matches with the header
if client_id != client_id_from_form {
@@ -483,6 +513,11 @@ where
});
}
(Some(Authorization::Bearer(token)), None, None, None, None) => {
// Got a bearer token
Credentials::BearerToken { token }
}
(None, None, None, None, None) => {
// Special case when there are no credentials anywhere
return Err(ClientAuthorizationError::MissingCredentials);
@@ -677,4 +712,29 @@ mod tests {
jwt.verify_with_shared_secret(b"client-secret".to_vec())
.unwrap();
}
#[tokio::test]
async fn bearer_token_test() {
let req = Request::builder()
.method(Method::POST)
.header(
http::header::CONTENT_TYPE,
mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
)
.header(http::header::AUTHORIZATION, "Bearer token")
.body(Body::new("foo=bar".to_owned()))
.unwrap();
assert_eq!(
ClientAuthorization::<serde_json::Value>::from_request(req, &())
.await
.unwrap(),
ClientAuthorization {
credentials: Credentials::BearerToken {
token: "token".to_owned(),
},
form: Some(serde_json::json!({"foo": "bar"})),
}
);
}
}

View File

@@ -4,7 +4,7 @@
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
use std::sync::LazyLock;
use std::sync::{Arc, LazyLock};
use axum::{Json, extract::State, http::HeaderValue, response::IntoResponse};
use hyper::{HeaderMap, StatusCode};
@@ -15,6 +15,7 @@ use mas_axum_utils::{
use mas_data_model::{Device, TokenFormatError, TokenType};
use mas_iana::oauth::{OAuthClientAuthenticationMethod, OAuthTokenTypeHint};
use mas_keystore::Encrypter;
use mas_matrix::HomeserverConnection;
use mas_storage::{
BoxClock, BoxRepository, Clock,
compat::{CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository},
@@ -102,8 +103,14 @@ pub enum RouteError {
#[error("bad request")]
BadRequest,
#[error("failed to verify token")]
FailedToVerifyToken(#[source] anyhow::Error),
#[error(transparent)]
ClientCredentialsVerification(#[from] CredentialsVerificationError),
#[error("bearer token presented is invalid")]
InvalidBearerToken,
}
impl IntoResponse for RouteError {
@@ -114,13 +121,15 @@ impl IntoResponse for RouteError {
| Self::CantLoadCompatSession(_)
| Self::CantLoadOAuthSession(_)
| Self::CantLoadUser(_)
| Self::FailedToVerifyToken(_)
);
let response = match self {
e @ (Self::Internal(_)
| Self::CantLoadCompatSession(_)
| Self::CantLoadOAuthSession(_)
| Self::CantLoadUser(_)) => (
| Self::CantLoadUser(_)
| Self::FailedToVerifyToken(_)) => (
StatusCode::INTERNAL_SERVER_ERROR,
Json(
ClientError::from(ClientErrorCode::ServerError).with_description(e.to_string()),
@@ -140,6 +149,14 @@ impl IntoResponse for RouteError {
),
)
.into_response(),
e @ Self::InvalidBearerToken => (
StatusCode::UNAUTHORIZED,
Json(
ClientError::from(ClientErrorCode::AccessDenied)
.with_description(e.to_string()),
),
)
.into_response(),
Self::UnknownToken(_)
| Self::UnexpectedTokenType
@@ -195,7 +212,7 @@ const SYNAPSE_ADMIN_SCOPE: ScopeToken = ScopeToken::from_static("urn:synapse:adm
#[tracing::instrument(
name = "handlers.oauth2.introspection.post",
fields(client.id = client_authorization.client_id()),
fields(client.id = credentials.client_id()),
skip_all,
)]
#[allow(clippy::too_many_lines)]
@@ -205,28 +222,41 @@ pub(crate) async fn post(
mut repo: BoxRepository,
activity_tracker: ActivityTracker,
State(encrypter): State<Encrypter>,
State(homeserver): State<Arc<dyn HomeserverConnection>>,
headers: HeaderMap,
client_authorization: ClientAuthorization<IntrospectionRequest>,
ClientAuthorization { credentials, form }: ClientAuthorization<IntrospectionRequest>,
) -> Result<impl IntoResponse, RouteError> {
let client = client_authorization
.credentials
.fetch(&mut repo)
.await?
.ok_or(RouteError::ClientNotFound)?;
let method = match &client.token_endpoint_auth_method {
None | Some(OAuthClientAuthenticationMethod::None) => {
return Err(RouteError::NotAllowed(client.id));
if let Some(token) = credentials.bearer_token() {
// If the client presented a bearer token, we check with the homeserver
// connection if it is allowed to use the introspection endpoint
if !homeserver
.verify_token(token)
.await
.map_err(RouteError::FailedToVerifyToken)?
{
return Err(RouteError::InvalidBearerToken);
}
Some(c) => c,
};
} else {
// Otherwise, it presented regular client credentials, so we verify them
let client = credentials
.fetch(&mut repo)
.await?
.ok_or(RouteError::ClientNotFound)?;
client_authorization
.credentials
.verify(&http_client, &encrypter, method, &client)
.await?;
// Only confidential clients are allowed to introspect
let method = match &client.token_endpoint_auth_method {
None | Some(OAuthClientAuthenticationMethod::None) => {
return Err(RouteError::NotAllowed(client.id));
}
Some(c) => c,
};
let Some(form) = client_authorization.form else {
credentials
.verify(&http_client, &encrypter, method, &client)
.await?;
}
let Some(form) = form else {
return Err(RouteError::BadRequest);
};
@@ -578,10 +608,11 @@ mod tests {
use hyper::{Request, StatusCode};
use mas_data_model::{AccessToken, RefreshToken};
use mas_iana::oauth::OAuthTokenTypeHint;
use mas_matrix::{HomeserverConnection, ProvisionRequest};
use mas_matrix::{HomeserverConnection, MockHomeserverConnection, ProvisionRequest};
use mas_router::{OAuth2Introspection, OAuth2RegistrationEndpoint, SimpleRoute};
use mas_storage::Clock;
use oauth2_types::{
errors::{ClientError, ClientErrorCode},
registration::ClientRegistrationResponse,
requests::IntrospectionResponse,
scope::{OPENID, Scope},
@@ -984,4 +1015,29 @@ mod tests {
let response: IntrospectionResponse = response.json();
assert!(response.active);
}
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_introspect_with_bearer_token(pool: PgPool) {
setup();
let state = TestState::from_pool(pool).await.unwrap();
// Check that talking to the introspection endpoint with the bearer token from
// the MockHomeserverConnection doens't error out
let request = Request::post(OAuth2Introspection::PATH)
.bearer(MockHomeserverConnection::VALID_BEARER_TOKEN)
.form(json!({ "token": "some_token" }));
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let response: IntrospectionResponse = response.json();
assert!(!response.active);
// Check with another token, we should get a 401
let request = Request::post(OAuth2Introspection::PATH)
.bearer("another_token")
.form(json!({ "token": "some_token" }));
let response = state.request(request).await;
response.assert_status(StatusCode::UNAUTHORIZED);
let response: ClientError = response.json();
assert_eq!(response.error, ClientErrorCode::AccessDenied);
}
}

View File

@@ -160,6 +160,11 @@ impl HomeserverConnection for SynapseConnection {
&self.homeserver
}
#[tracing::instrument(name = "homeserver.verify_token", skip_all, err(Debug))]
async fn verify_token(&self, token: &str) -> Result<bool, anyhow::Error> {
Ok(self.access_token == token)
}
#[tracing::instrument(
name = "homeserver.query_user",
skip_all,

View File

@@ -66,6 +66,11 @@ impl HomeserverConnection for SynapseConnection {
&self.homeserver
}
#[tracing::instrument(name = "homeserver.verify_token", skip_all, err(Debug))]
async fn verify_token(&self, token: &str) -> Result<bool, anyhow::Error> {
Ok(self.access_token == token)
}
#[tracing::instrument(
name = "homeserver.query_user",
skip_all,

View File

@@ -207,6 +207,20 @@ pub trait HomeserverConnection: Send + Sync {
Some(mxid.localpart())
}
/// Verify a bearer token coming from the homeserver for homeserver to MAS
/// interactions
///
/// Returns `true` if the token is valid, `false` otherwise.
///
/// # Parameters
///
/// * `token` - The token to verify.
///
/// # Errors
///
/// Returns an error if the token failed to verify.
async fn verify_token(&self, token: &str) -> Result<bool, anyhow::Error>;
/// Query the state of a user on the homeserver.
///
/// # Parameters
@@ -384,6 +398,10 @@ impl<T: HomeserverConnection + Send + Sync + ?Sized> HomeserverConnection for &T
(**self).homeserver()
}
async fn verify_token(&self, token: &str) -> Result<bool, anyhow::Error> {
(**self).verify_token(token).await
}
async fn query_user(&self, localpart: &str) -> Result<MatrixUser, anyhow::Error> {
(**self).query_user(localpart).await
}
@@ -462,6 +480,10 @@ impl<T: HomeserverConnection + ?Sized> HomeserverConnection for Arc<T> {
(**self).homeserver()
}
async fn verify_token(&self, token: &str) -> Result<bool, anyhow::Error> {
(**self).verify_token(token).await
}
async fn query_user(&self, localpart: &str) -> Result<MatrixUser, anyhow::Error> {
(**self).query_user(localpart).await
}

View File

@@ -31,6 +31,10 @@ pub struct HomeserverConnection {
}
impl HomeserverConnection {
/// A valid bearer token that will be accepted by
/// [`crate::HomeserverConnection::verify_token`].
pub const VALID_BEARER_TOKEN: &str = "mock_homeserver_bearer_token";
/// Create a new mock connection.
pub fn new<H>(homeserver: H) -> Self
where
@@ -54,6 +58,10 @@ impl crate::HomeserverConnection for HomeserverConnection {
&self.homeserver
}
async fn verify_token(&self, token: &str) -> Result<bool, anyhow::Error> {
Ok(token == Self::VALID_BEARER_TOKEN)
}
async fn query_user(&self, localpart: &str) -> Result<MatrixUser, anyhow::Error> {
let mxid = self.mxid(localpart);
let users = self.users.read().await;

View File

@@ -28,6 +28,10 @@ impl<C: HomeserverConnection> HomeserverConnection for ReadOnlyHomeserverConnect
self.inner.homeserver()
}
async fn verify_token(&self, token: &str) -> Result<bool, anyhow::Error> {
self.inner.verify_token(token).await
}
async fn query_user(&self, localpart: &str) -> Result<MatrixUser, anyhow::Error> {
self.inner.query_user(localpart).await
}