Allow the homeserver to perform introspection using a shared secret
This commit is contained in:
@@ -9,13 +9,12 @@ use std::collections::HashMap;
|
||||
use axum::{
|
||||
BoxError, Json,
|
||||
extract::{
|
||||
Form, FromRequest, FromRequestParts,
|
||||
Form, FromRequest,
|
||||
rejection::{FailedToDeserializeForm, FormRejection},
|
||||
},
|
||||
response::IntoResponse,
|
||||
};
|
||||
use axum_extra::typed_header::{TypedHeader, TypedHeaderRejectionReason};
|
||||
use headers::{Authorization, authorization::Basic};
|
||||
use headers::authorization::{Basic, Bearer, Credentials as _};
|
||||
use http::{Request, StatusCode};
|
||||
use mas_data_model::{Client, JwksOrJwksUri};
|
||||
use mas_http::RequestBuilderExt;
|
||||
@@ -60,17 +59,30 @@ pub enum Credentials {
|
||||
client_id: String,
|
||||
jwt: Box<Jwt<'static, HashMap<String, serde_json::Value>>>,
|
||||
},
|
||||
BearerToken {
|
||||
token: String,
|
||||
},
|
||||
}
|
||||
|
||||
impl Credentials {
|
||||
/// Get the `client_id` of the credentials
|
||||
#[must_use]
|
||||
pub fn client_id(&self) -> &str {
|
||||
pub fn client_id(&self) -> Option<&str> {
|
||||
match self {
|
||||
Credentials::None { client_id }
|
||||
| Credentials::ClientSecretBasic { client_id, .. }
|
||||
| Credentials::ClientSecretPost { client_id, .. }
|
||||
| Credentials::ClientAssertionJwtBearer { client_id, .. } => client_id,
|
||||
| Credentials::ClientAssertionJwtBearer { client_id, .. } => Some(client_id),
|
||||
Credentials::BearerToken { .. } => None,
|
||||
}
|
||||
}
|
||||
|
||||
/// Get the bearer token from the credentials.
|
||||
#[must_use]
|
||||
pub fn bearer_token(&self) -> Option<&str> {
|
||||
match self {
|
||||
Credentials::BearerToken { token } => Some(token),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -89,6 +101,7 @@ impl Credentials {
|
||||
| Credentials::ClientSecretBasic { client_id, .. }
|
||||
| Credentials::ClientSecretPost { client_id, .. }
|
||||
| Credentials::ClientAssertionJwtBearer { client_id, .. } => client_id,
|
||||
Credentials::BearerToken { .. } => return Ok(None),
|
||||
};
|
||||
|
||||
repo.oauth2_client().find_by_client_id(client_id).await
|
||||
@@ -239,7 +252,7 @@ pub struct ClientAuthorization<F = ()> {
|
||||
impl<F> ClientAuthorization<F> {
|
||||
/// Get the `client_id` from the credentials.
|
||||
#[must_use]
|
||||
pub fn client_id(&self) -> &str {
|
||||
pub fn client_id(&self) -> Option<&str> {
|
||||
self.credentials.client_id()
|
||||
}
|
||||
}
|
||||
@@ -360,26 +373,37 @@ where
|
||||
req: Request<axum::body::Body>,
|
||||
state: &S,
|
||||
) -> Result<Self, Self::Rejection> {
|
||||
// Split the request into parts so we can extract some headers
|
||||
let (mut parts, body) = req.into_parts();
|
||||
enum Authorization {
|
||||
Basic(String, String),
|
||||
Bearer(String),
|
||||
}
|
||||
|
||||
let header =
|
||||
TypedHeader::<Authorization<Basic>>::from_request_parts(&mut parts, state).await;
|
||||
// Sadly, the typed-header 'Authorization' doesn't let us check for both
|
||||
// Basic and Bearer at the same time, so we need to parse them manually
|
||||
let authorization = if let Some(header) = req.headers().get(http::header::AUTHORIZATION) {
|
||||
let bytes = header.as_bytes();
|
||||
if bytes.len() >= 6 && bytes[..6].eq_ignore_ascii_case(b"Basic ") {
|
||||
let Some(decoded) = Basic::decode(header) else {
|
||||
return Err(ClientAuthorizationError::InvalidHeader);
|
||||
};
|
||||
|
||||
// Take the Authorization header
|
||||
let credentials_from_header = match header {
|
||||
Ok(header) => Some((header.username().to_owned(), header.password().to_owned())),
|
||||
Err(err) => match err.reason() {
|
||||
// If it's missing it is fine
|
||||
TypedHeaderRejectionReason::Missing => None,
|
||||
// If the header could not be parsed, return the error
|
||||
_ => return Err(ClientAuthorizationError::InvalidHeader),
|
||||
},
|
||||
Some(Authorization::Basic(
|
||||
decoded.username().to_owned(),
|
||||
decoded.password().to_owned(),
|
||||
))
|
||||
} else if bytes.len() >= 7 && bytes[..7].eq_ignore_ascii_case(b"Bearer ") {
|
||||
let Some(decoded) = Bearer::decode(header) else {
|
||||
return Err(ClientAuthorizationError::InvalidHeader);
|
||||
};
|
||||
|
||||
Some(Authorization::Bearer(decoded.token().to_owned()))
|
||||
} else {
|
||||
return Err(ClientAuthorizationError::InvalidHeader);
|
||||
}
|
||||
} else {
|
||||
None
|
||||
};
|
||||
|
||||
// Reconstruct the request from the parts
|
||||
let req = Request::from_parts(parts, body);
|
||||
|
||||
// Take the form value
|
||||
let (
|
||||
client_id_from_form,
|
||||
@@ -407,13 +431,19 @@ where
|
||||
|
||||
// And now, figure out the actual auth method
|
||||
let credentials = match (
|
||||
credentials_from_header,
|
||||
authorization,
|
||||
client_id_from_form,
|
||||
client_secret_from_form,
|
||||
client_assertion_type,
|
||||
client_assertion,
|
||||
) {
|
||||
(Some((client_id, client_secret)), client_id_from_form, None, None, None) => {
|
||||
(
|
||||
Some(Authorization::Basic(client_id, client_secret)),
|
||||
client_id_from_form,
|
||||
None,
|
||||
None,
|
||||
None,
|
||||
) => {
|
||||
if let Some(client_id_from_form) = client_id_from_form {
|
||||
// If the client_id was in the body, verify it matches with the header
|
||||
if client_id != client_id_from_form {
|
||||
@@ -483,6 +513,11 @@ where
|
||||
});
|
||||
}
|
||||
|
||||
(Some(Authorization::Bearer(token)), None, None, None, None) => {
|
||||
// Got a bearer token
|
||||
Credentials::BearerToken { token }
|
||||
}
|
||||
|
||||
(None, None, None, None, None) => {
|
||||
// Special case when there are no credentials anywhere
|
||||
return Err(ClientAuthorizationError::MissingCredentials);
|
||||
@@ -677,4 +712,29 @@ mod tests {
|
||||
jwt.verify_with_shared_secret(b"client-secret".to_vec())
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn bearer_token_test() {
|
||||
let req = Request::builder()
|
||||
.method(Method::POST)
|
||||
.header(
|
||||
http::header::CONTENT_TYPE,
|
||||
mime::APPLICATION_WWW_FORM_URLENCODED.as_ref(),
|
||||
)
|
||||
.header(http::header::AUTHORIZATION, "Bearer token")
|
||||
.body(Body::new("foo=bar".to_owned()))
|
||||
.unwrap();
|
||||
|
||||
assert_eq!(
|
||||
ClientAuthorization::<serde_json::Value>::from_request(req, &())
|
||||
.await
|
||||
.unwrap(),
|
||||
ClientAuthorization {
|
||||
credentials: Credentials::BearerToken {
|
||||
token: "token".to_owned(),
|
||||
},
|
||||
form: Some(serde_json::json!({"foo": "bar"})),
|
||||
}
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
|
||||
// Please see LICENSE files in the repository root for full details.
|
||||
|
||||
use std::sync::LazyLock;
|
||||
use std::sync::{Arc, LazyLock};
|
||||
|
||||
use axum::{Json, extract::State, http::HeaderValue, response::IntoResponse};
|
||||
use hyper::{HeaderMap, StatusCode};
|
||||
@@ -15,6 +15,7 @@ use mas_axum_utils::{
|
||||
use mas_data_model::{Device, TokenFormatError, TokenType};
|
||||
use mas_iana::oauth::{OAuthClientAuthenticationMethod, OAuthTokenTypeHint};
|
||||
use mas_keystore::Encrypter;
|
||||
use mas_matrix::HomeserverConnection;
|
||||
use mas_storage::{
|
||||
BoxClock, BoxRepository, Clock,
|
||||
compat::{CompatAccessTokenRepository, CompatRefreshTokenRepository, CompatSessionRepository},
|
||||
@@ -102,8 +103,14 @@ pub enum RouteError {
|
||||
#[error("bad request")]
|
||||
BadRequest,
|
||||
|
||||
#[error("failed to verify token")]
|
||||
FailedToVerifyToken(#[source] anyhow::Error),
|
||||
|
||||
#[error(transparent)]
|
||||
ClientCredentialsVerification(#[from] CredentialsVerificationError),
|
||||
|
||||
#[error("bearer token presented is invalid")]
|
||||
InvalidBearerToken,
|
||||
}
|
||||
|
||||
impl IntoResponse for RouteError {
|
||||
@@ -114,13 +121,15 @@ impl IntoResponse for RouteError {
|
||||
| Self::CantLoadCompatSession(_)
|
||||
| Self::CantLoadOAuthSession(_)
|
||||
| Self::CantLoadUser(_)
|
||||
| Self::FailedToVerifyToken(_)
|
||||
);
|
||||
|
||||
let response = match self {
|
||||
e @ (Self::Internal(_)
|
||||
| Self::CantLoadCompatSession(_)
|
||||
| Self::CantLoadOAuthSession(_)
|
||||
| Self::CantLoadUser(_)) => (
|
||||
| Self::CantLoadUser(_)
|
||||
| Self::FailedToVerifyToken(_)) => (
|
||||
StatusCode::INTERNAL_SERVER_ERROR,
|
||||
Json(
|
||||
ClientError::from(ClientErrorCode::ServerError).with_description(e.to_string()),
|
||||
@@ -140,6 +149,14 @@ impl IntoResponse for RouteError {
|
||||
),
|
||||
)
|
||||
.into_response(),
|
||||
e @ Self::InvalidBearerToken => (
|
||||
StatusCode::UNAUTHORIZED,
|
||||
Json(
|
||||
ClientError::from(ClientErrorCode::AccessDenied)
|
||||
.with_description(e.to_string()),
|
||||
),
|
||||
)
|
||||
.into_response(),
|
||||
|
||||
Self::UnknownToken(_)
|
||||
| Self::UnexpectedTokenType
|
||||
@@ -195,7 +212,7 @@ const SYNAPSE_ADMIN_SCOPE: ScopeToken = ScopeToken::from_static("urn:synapse:adm
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "handlers.oauth2.introspection.post",
|
||||
fields(client.id = client_authorization.client_id()),
|
||||
fields(client.id = credentials.client_id()),
|
||||
skip_all,
|
||||
)]
|
||||
#[allow(clippy::too_many_lines)]
|
||||
@@ -205,28 +222,41 @@ pub(crate) async fn post(
|
||||
mut repo: BoxRepository,
|
||||
activity_tracker: ActivityTracker,
|
||||
State(encrypter): State<Encrypter>,
|
||||
State(homeserver): State<Arc<dyn HomeserverConnection>>,
|
||||
headers: HeaderMap,
|
||||
client_authorization: ClientAuthorization<IntrospectionRequest>,
|
||||
ClientAuthorization { credentials, form }: ClientAuthorization<IntrospectionRequest>,
|
||||
) -> Result<impl IntoResponse, RouteError> {
|
||||
let client = client_authorization
|
||||
.credentials
|
||||
.fetch(&mut repo)
|
||||
.await?
|
||||
.ok_or(RouteError::ClientNotFound)?;
|
||||
|
||||
let method = match &client.token_endpoint_auth_method {
|
||||
None | Some(OAuthClientAuthenticationMethod::None) => {
|
||||
return Err(RouteError::NotAllowed(client.id));
|
||||
if let Some(token) = credentials.bearer_token() {
|
||||
// If the client presented a bearer token, we check with the homeserver
|
||||
// connection if it is allowed to use the introspection endpoint
|
||||
if !homeserver
|
||||
.verify_token(token)
|
||||
.await
|
||||
.map_err(RouteError::FailedToVerifyToken)?
|
||||
{
|
||||
return Err(RouteError::InvalidBearerToken);
|
||||
}
|
||||
Some(c) => c,
|
||||
};
|
||||
} else {
|
||||
// Otherwise, it presented regular client credentials, so we verify them
|
||||
let client = credentials
|
||||
.fetch(&mut repo)
|
||||
.await?
|
||||
.ok_or(RouteError::ClientNotFound)?;
|
||||
|
||||
client_authorization
|
||||
.credentials
|
||||
.verify(&http_client, &encrypter, method, &client)
|
||||
.await?;
|
||||
// Only confidential clients are allowed to introspect
|
||||
let method = match &client.token_endpoint_auth_method {
|
||||
None | Some(OAuthClientAuthenticationMethod::None) => {
|
||||
return Err(RouteError::NotAllowed(client.id));
|
||||
}
|
||||
Some(c) => c,
|
||||
};
|
||||
|
||||
let Some(form) = client_authorization.form else {
|
||||
credentials
|
||||
.verify(&http_client, &encrypter, method, &client)
|
||||
.await?;
|
||||
}
|
||||
|
||||
let Some(form) = form else {
|
||||
return Err(RouteError::BadRequest);
|
||||
};
|
||||
|
||||
@@ -578,10 +608,11 @@ mod tests {
|
||||
use hyper::{Request, StatusCode};
|
||||
use mas_data_model::{AccessToken, RefreshToken};
|
||||
use mas_iana::oauth::OAuthTokenTypeHint;
|
||||
use mas_matrix::{HomeserverConnection, ProvisionRequest};
|
||||
use mas_matrix::{HomeserverConnection, MockHomeserverConnection, ProvisionRequest};
|
||||
use mas_router::{OAuth2Introspection, OAuth2RegistrationEndpoint, SimpleRoute};
|
||||
use mas_storage::Clock;
|
||||
use oauth2_types::{
|
||||
errors::{ClientError, ClientErrorCode},
|
||||
registration::ClientRegistrationResponse,
|
||||
requests::IntrospectionResponse,
|
||||
scope::{OPENID, Scope},
|
||||
@@ -984,4 +1015,29 @@ mod tests {
|
||||
let response: IntrospectionResponse = response.json();
|
||||
assert!(response.active);
|
||||
}
|
||||
|
||||
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
|
||||
async fn test_introspect_with_bearer_token(pool: PgPool) {
|
||||
setup();
|
||||
let state = TestState::from_pool(pool).await.unwrap();
|
||||
|
||||
// Check that talking to the introspection endpoint with the bearer token from
|
||||
// the MockHomeserverConnection doens't error out
|
||||
let request = Request::post(OAuth2Introspection::PATH)
|
||||
.bearer(MockHomeserverConnection::VALID_BEARER_TOKEN)
|
||||
.form(json!({ "token": "some_token" }));
|
||||
let response = state.request(request).await;
|
||||
response.assert_status(StatusCode::OK);
|
||||
let response: IntrospectionResponse = response.json();
|
||||
assert!(!response.active);
|
||||
|
||||
// Check with another token, we should get a 401
|
||||
let request = Request::post(OAuth2Introspection::PATH)
|
||||
.bearer("another_token")
|
||||
.form(json!({ "token": "some_token" }));
|
||||
let response = state.request(request).await;
|
||||
response.assert_status(StatusCode::UNAUTHORIZED);
|
||||
let response: ClientError = response.json();
|
||||
assert_eq!(response.error, ClientErrorCode::AccessDenied);
|
||||
}
|
||||
}
|
||||
|
||||
@@ -160,6 +160,11 @@ impl HomeserverConnection for SynapseConnection {
|
||||
&self.homeserver
|
||||
}
|
||||
|
||||
#[tracing::instrument(name = "homeserver.verify_token", skip_all, err(Debug))]
|
||||
async fn verify_token(&self, token: &str) -> Result<bool, anyhow::Error> {
|
||||
Ok(self.access_token == token)
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "homeserver.query_user",
|
||||
skip_all,
|
||||
|
||||
@@ -66,6 +66,11 @@ impl HomeserverConnection for SynapseConnection {
|
||||
&self.homeserver
|
||||
}
|
||||
|
||||
#[tracing::instrument(name = "homeserver.verify_token", skip_all, err(Debug))]
|
||||
async fn verify_token(&self, token: &str) -> Result<bool, anyhow::Error> {
|
||||
Ok(self.access_token == token)
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "homeserver.query_user",
|
||||
skip_all,
|
||||
|
||||
@@ -207,6 +207,20 @@ pub trait HomeserverConnection: Send + Sync {
|
||||
Some(mxid.localpart())
|
||||
}
|
||||
|
||||
/// Verify a bearer token coming from the homeserver for homeserver to MAS
|
||||
/// interactions
|
||||
///
|
||||
/// Returns `true` if the token is valid, `false` otherwise.
|
||||
///
|
||||
/// # Parameters
|
||||
///
|
||||
/// * `token` - The token to verify.
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the token failed to verify.
|
||||
async fn verify_token(&self, token: &str) -> Result<bool, anyhow::Error>;
|
||||
|
||||
/// Query the state of a user on the homeserver.
|
||||
///
|
||||
/// # Parameters
|
||||
@@ -384,6 +398,10 @@ impl<T: HomeserverConnection + Send + Sync + ?Sized> HomeserverConnection for &T
|
||||
(**self).homeserver()
|
||||
}
|
||||
|
||||
async fn verify_token(&self, token: &str) -> Result<bool, anyhow::Error> {
|
||||
(**self).verify_token(token).await
|
||||
}
|
||||
|
||||
async fn query_user(&self, localpart: &str) -> Result<MatrixUser, anyhow::Error> {
|
||||
(**self).query_user(localpart).await
|
||||
}
|
||||
@@ -462,6 +480,10 @@ impl<T: HomeserverConnection + ?Sized> HomeserverConnection for Arc<T> {
|
||||
(**self).homeserver()
|
||||
}
|
||||
|
||||
async fn verify_token(&self, token: &str) -> Result<bool, anyhow::Error> {
|
||||
(**self).verify_token(token).await
|
||||
}
|
||||
|
||||
async fn query_user(&self, localpart: &str) -> Result<MatrixUser, anyhow::Error> {
|
||||
(**self).query_user(localpart).await
|
||||
}
|
||||
|
||||
@@ -31,6 +31,10 @@ pub struct HomeserverConnection {
|
||||
}
|
||||
|
||||
impl HomeserverConnection {
|
||||
/// A valid bearer token that will be accepted by
|
||||
/// [`crate::HomeserverConnection::verify_token`].
|
||||
pub const VALID_BEARER_TOKEN: &str = "mock_homeserver_bearer_token";
|
||||
|
||||
/// Create a new mock connection.
|
||||
pub fn new<H>(homeserver: H) -> Self
|
||||
where
|
||||
@@ -54,6 +58,10 @@ impl crate::HomeserverConnection for HomeserverConnection {
|
||||
&self.homeserver
|
||||
}
|
||||
|
||||
async fn verify_token(&self, token: &str) -> Result<bool, anyhow::Error> {
|
||||
Ok(token == Self::VALID_BEARER_TOKEN)
|
||||
}
|
||||
|
||||
async fn query_user(&self, localpart: &str) -> Result<MatrixUser, anyhow::Error> {
|
||||
let mxid = self.mxid(localpart);
|
||||
let users = self.users.read().await;
|
||||
|
||||
@@ -28,6 +28,10 @@ impl<C: HomeserverConnection> HomeserverConnection for ReadOnlyHomeserverConnect
|
||||
self.inner.homeserver()
|
||||
}
|
||||
|
||||
async fn verify_token(&self, token: &str) -> Result<bool, anyhow::Error> {
|
||||
self.inner.verify_token(token).await
|
||||
}
|
||||
|
||||
async fn query_user(&self, localpart: &str) -> Result<MatrixUser, anyhow::Error> {
|
||||
self.inner.query_user(localpart).await
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user