Allow banning registrations by IP address

This commit is contained in:
Quentin Gliech
2025-02-17 10:18:11 +01:00
parent fa85d60652
commit b1b7bf5725
17 changed files with 190 additions and 42 deletions

View File

@@ -398,6 +398,7 @@ impl UserEmailMutations {
let state = ctx.state();
let id = NodeType::User.extract_ulid(&input.user_id)?;
let requester = ctx.requester();
let requester_fingerprint = ctx.requester_fingerprint();
let clock = state.clock();
let mut rng = state.rng();
@@ -427,6 +428,7 @@ impl UserEmailMutations {
let res = policy
.evaluate_email(mas_policy::EmailInput {
email: &input.email,
requester: requester_fingerprint.into(),
})
.await?;
if !res.valid() {
@@ -559,6 +561,7 @@ impl UserEmailMutations {
let mut rng = state.rng();
let clock = state.clock();
let requester = ctx.requester();
let requester_fingerprint = ctx.requester_fingerprint();
let limiter = state.limiter();
// Only allow calling this if the requester is a browser session
@@ -617,6 +620,7 @@ impl UserEmailMutations {
let res = policy
.evaluate_email(mas_policy::EmailInput {
email: &input.email,
requester: requester_fingerprint.into(),
})
.await?;
if !res.valid() {

View File

@@ -231,6 +231,9 @@ pub(crate) async fn complete(
client,
scope: &grant.scope,
grant_type: mas_policy::GrantType::AuthorizationCode,
requester: mas_policy::Requester {
ip_address: activity_tracker.ip(),
},
})
.await?;

View File

@@ -116,6 +116,9 @@ pub(crate) async fn get(
client: &client,
scope: &grant.scope,
grant_type: mas_policy::GrantType::AuthorizationCode,
requester: mas_policy::Requester {
ip_address: activity_tracker.ip(),
},
})
.await?;
@@ -195,6 +198,9 @@ pub(crate) async fn post(
client: &client,
scope: &grant.scope,
grant_type: mas_policy::GrantType::AuthorizationCode,
requester: mas_policy::Requester {
ip_address: activity_tracker.ip(),
},
})
.await?;

View File

@@ -87,6 +87,9 @@ pub(crate) async fn get(
client: &client,
scope: &grant.scope,
user: Some(&session.user),
requester: mas_policy::Requester {
ip_address: activity_tracker.ip(),
},
})
.await?;
if !res.valid() {
@@ -167,6 +170,9 @@ pub(crate) async fn post(
client: &client,
scope: &grant.scope,
user: Some(&session.user),
requester: mas_policy::Requester {
ip_address: activity_tracker.ip(),
},
})
.await?;
if !res.valid() {

View File

@@ -25,7 +25,7 @@ use thiserror::Error;
use tracing::info;
use url::Url;
use crate::impl_from_error_for_route;
use crate::{impl_from_error_for_route, BoundActivityTracker};
#[derive(Debug, Error)]
pub(crate) enum RouteError {
@@ -195,6 +195,7 @@ pub(crate) async fn post(
clock: BoxClock,
mut repo: BoxRepository,
mut policy: Policy,
activity_tracker: BoundActivityTracker,
State(encrypter): State<Encrypter>,
body: Result<Json<ClientMetadata>, axum::extract::rejection::JsonRejection>,
) -> Result<impl IntoResponse, RouteError> {
@@ -247,6 +248,9 @@ pub(crate) async fn post(
let res = policy
.evaluate_client_registration(mas_policy::ClientRegistrationInput {
client_metadata: &metadata,
requester: mas_policy::Requester {
ip_address: activity_tracker.ip(),
},
})
.await?;
if !res.valid() {

View File

@@ -681,6 +681,9 @@ async fn client_credentials_grant(
client,
scope: &scope,
grant_type: mas_policy::GrantType::ClientCredentials,
requester: mas_policy::Requester {
ip_address: activity_tracker.ip(),
},
})
.await?;
if !res.valid() {

View File

@@ -53,6 +53,12 @@ pub struct RequesterFingerprint {
ip: Option<IpAddr>,
}
impl From<RequesterFingerprint> for mas_policy::Requester {
fn from(val: RequesterFingerprint) -> Self {
mas_policy::Requester { ip_address: val.ip }
}
}
impl std::fmt::Display for RequesterFingerprint {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
if let Some(ip) = self.ip {

View File

@@ -43,7 +43,8 @@ use super::{
UpstreamSessionsCookie,
};
use crate::{
impl_from_error_for_route, views::shared::OptionalPostAuthAction, PreferredLanguage, SiteConfig,
impl_from_error_for_route, views::shared::OptionalPostAuthAction, BoundActivityTracker,
PreferredLanguage, SiteConfig,
};
const DEFAULT_LOCALPART_TEMPLATE: &str = "{{ user.preferred_username }}";
@@ -199,6 +200,7 @@ pub(crate) async fn get(
State(url_builder): State<UrlBuilder>,
State(homeserver): State<BoxHomeserverConnection>,
cookie_jar: CookieJar,
activity_tracker: BoundActivityTracker,
user_agent: Option<TypedHeader<headers::UserAgent>>,
Path(link_id): Path<Ulid>,
) -> Result<impl IntoResponse, RouteError> {
@@ -445,6 +447,9 @@ pub(crate) async fn get(
registration_method: mas_policy::RegistrationMethod::UpstreamOAuth2,
username: &localpart,
email: None,
requester: mas_policy::Requester {
ip_address: activity_tracker.ip(),
},
})
.await?;
@@ -502,6 +507,7 @@ pub(crate) async fn post(
user_agent: Option<TypedHeader<headers::UserAgent>>,
mut policy: Policy,
PreferredLanguage(locale): PreferredLanguage,
activity_tracker: BoundActivityTracker,
State(templates): State<Templates>,
State(homeserver): State<BoxHomeserverConnection>,
State(url_builder): State<UrlBuilder>,
@@ -760,6 +766,9 @@ pub(crate) async fn post(
registration_method: mas_policy::RegistrationMethod::UpstreamOAuth2,
username: &username,
email: email.as_deref(),
requester: mas_policy::Requester {
ip_address: activity_tracker.ip(),
},
})
.await?;

View File

@@ -237,6 +237,9 @@ pub(crate) async fn post(
registration_method: mas_policy::RegistrationMethod::Password,
username: &form.username,
email: Some(&form.email),
requester: mas_policy::Requester {
ip_address: activity_tracker.ip(),
},
})
.await?;

View File

@@ -307,6 +307,7 @@ mod tests {
registration_method: RegistrationMethod::Password,
username: "hello",
email: Some("hello@example.com"),
requester: Requester { ip_address: None },
})
.await
.unwrap();
@@ -317,6 +318,7 @@ mod tests {
registration_method: RegistrationMethod::Password,
username: "hello",
email: Some("hello@foo.element.io"),
requester: Requester { ip_address: None },
})
.await
.unwrap();
@@ -327,6 +329,7 @@ mod tests {
registration_method: RegistrationMethod::Password,
username: "hello",
email: Some("hello@staging.element.io"),
requester: Requester { ip_address: None },
})
.await
.unwrap();

View File

@@ -9,6 +9,8 @@
//! This is useful to generate JSON schemas for each input type, which can then
//! be type-checked by Open Policy Agent.
use std::net::IpAddr;
use mas_data_model::{Client, User};
use oauth2_types::{registration::VerifiedClientMetadata, scope::Scope};
use serde::{Deserialize, Serialize};
@@ -92,6 +94,15 @@ impl EvaluationResult {
}
}
/// Identity of the requester
#[derive(Serialize, Debug, Default)]
#[serde(rename_all = "snake_case")]
#[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))]
pub struct Requester {
/// IP address of the entity making the request
pub ip_address: Option<IpAddr>,
}
#[derive(Serialize, Debug)]
#[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))]
pub enum RegistrationMethod {
@@ -113,6 +124,8 @@ pub struct RegisterInput<'a> {
#[serde(skip_serializing_if = "Option::is_none")]
pub email: Option<&'a str>,
pub requester: Requester,
}
/// Input for the client registration policy.
@@ -125,6 +138,7 @@ pub struct ClientRegistrationInput<'a> {
schemars(with = "std::collections::HashMap<String, serde_json::Value>")
)]
pub client_metadata: &'a VerifiedClientMetadata,
pub requester: Requester,
}
#[derive(Serialize, Debug)]
@@ -158,6 +172,8 @@ pub struct AuthorizationGrantInput<'a> {
pub scope: &'a Scope,
pub grant_type: GrantType,
pub requester: Requester,
}
/// Input for the email add policy.
@@ -167,4 +183,5 @@ pub struct AuthorizationGrantInput<'a> {
pub struct EmailInput<'a> {
pub email: &'a str,
pub requester: Requester,
}

View File

@@ -13,6 +13,26 @@ allow if {
count(violation) == 0
}
# Normalize an IP address or CIDR to a CIDR
normalize_cidr(ip) := ip if contains(ip, "/")
# If it's an IPv4, append /32
normalize_cidr(ip) := sprintf("%s/32", [ip]) if {
not contains(ip, "/")
not contains(ip, ":")
}
# If it's an IPv6, append /128
normalize_cidr(ip) := sprintf("%s/128", [ip]) if {
not contains(ip, "/")
contains(ip, ":")
}
is_ip_banned(ip) if {
some cidr in data.registration.banned_ips
net.cidr_contains(normalize_cidr(cidr), ip)
}
mxid(username, server_name) := sprintf("@%s:%s", [username, server_name])
# METADATA
@@ -48,6 +68,10 @@ violation contains {"msg": "unknown registration method"} if {
not input.registration_method in ["password", "upstream-oauth2"]
}
violation contains {"msg": "IP address is banned"} if {
is_ip_banned(input.requester.ip_address)
}
# Check that we supplied an email for password registration
violation contains {"field": "email", "msg": "email required for password-based registration"} if {
input.registration_method == "password"

View File

@@ -74,3 +74,19 @@ test_invalid_username if {
test_numeric_username if {
not register.allow with input as {"username": "1234", "registration_method": "upstream-oauth2"}
}
test_ip_ban if {
not register.allow with input as {
"username": "hello",
"registration_method": "upstream-oauth2",
"requester": {"ip_address": "1.1.1.1"},
}
with data.registration.banned_ips as ["1.1.1.1"]
not register.allow with input as {
"username": "hello",
"registration_method": "upstream-oauth2",
"requester": {"ip_address": "1.1.1.1"},
}
with data.registration.banned_ips as ["1.0.0.0/8"]
}

View File

@@ -6,6 +6,7 @@
"required": [
"client",
"grant_type",
"requester",
"scope"
],
"properties": {
@@ -22,6 +23,9 @@
},
"grant_type": {
"$ref": "#/definitions/GrantType"
},
"requester": {
"$ref": "#/definitions/Requester"
}
},
"definitions": {
@@ -32,6 +36,17 @@
"client_credentials",
"urn:ietf:params:oauth:grant-type:device_code"
]
},
"Requester": {
"description": "Identity of the requester",
"type": "object",
"properties": {
"ip_address": {
"description": "IP address of the entity making the request",
"type": "string",
"format": "ip"
}
}
}
}
}

View File

@@ -4,12 +4,29 @@
"description": "Input for the client registration policy.",
"type": "object",
"required": [
"client_metadata"
"client_metadata",
"requester"
],
"properties": {
"client_metadata": {
"type": "object",
"additionalProperties": true
},
"requester": {
"$ref": "#/definitions/Requester"
}
},
"definitions": {
"Requester": {
"description": "Identity of the requester",
"type": "object",
"properties": {
"ip_address": {
"description": "IP address of the entity making the request",
"type": "string",
"format": "ip"
}
}
}
}
}

View File

@@ -4,11 +4,28 @@
"description": "Input for the email add policy.",
"type": "object",
"required": [
"email"
"email",
"requester"
],
"properties": {
"email": {
"type": "string"
},
"requester": {
"$ref": "#/definitions/Requester"
}
},
"definitions": {
"Requester": {
"description": "Identity of the requester",
"type": "object",
"properties": {
"ip_address": {
"description": "IP address of the entity making the request",
"type": "string",
"format": "ip"
}
}
}
}
}

View File

@@ -2,49 +2,44 @@
"$schema": "http://json-schema.org/draft-07/schema#",
"title": "RegisterInput",
"description": "Input for the user registration policy.",
"oneOf": [
{
"type": "object",
"required": [
"email",
"registration_method",
"username"
],
"properties": {
"registration_method": {
"type": "string",
"enum": [
"password"
]
},
"username": {
"type": "string"
},
"email": {
"type": "string"
}
}
"type": "object",
"required": [
"registration_method",
"requester",
"username"
],
"properties": {
"registration_method": {
"$ref": "#/definitions/RegistrationMethod"
},
{
"username": {
"type": "string"
},
"email": {
"type": "string"
},
"requester": {
"$ref": "#/definitions/Requester"
}
},
"definitions": {
"RegistrationMethod": {
"type": "string",
"enum": [
"password",
"upstream-oauth2"
]
},
"Requester": {
"description": "Identity of the requester",
"type": "object",
"required": [
"registration_method",
"username"
],
"properties": {
"registration_method": {
"ip_address": {
"description": "IP address of the entity making the request",
"type": "string",
"enum": [
"upstream-oauth2"
]
},
"username": {
"type": "string"
},
"email": {
"type": "string"
"format": "ip"
}
}
}
]
}
}