diff --git a/crates/handlers/src/graphql/mutations/user_email.rs b/crates/handlers/src/graphql/mutations/user_email.rs index 4a0c51853..19ff39804 100644 --- a/crates/handlers/src/graphql/mutations/user_email.rs +++ b/crates/handlers/src/graphql/mutations/user_email.rs @@ -398,6 +398,7 @@ impl UserEmailMutations { let state = ctx.state(); let id = NodeType::User.extract_ulid(&input.user_id)?; let requester = ctx.requester(); + let requester_fingerprint = ctx.requester_fingerprint(); let clock = state.clock(); let mut rng = state.rng(); @@ -427,6 +428,7 @@ impl UserEmailMutations { let res = policy .evaluate_email(mas_policy::EmailInput { email: &input.email, + requester: requester_fingerprint.into(), }) .await?; if !res.valid() { @@ -559,6 +561,7 @@ impl UserEmailMutations { let mut rng = state.rng(); let clock = state.clock(); let requester = ctx.requester(); + let requester_fingerprint = ctx.requester_fingerprint(); let limiter = state.limiter(); // Only allow calling this if the requester is a browser session @@ -617,6 +620,7 @@ impl UserEmailMutations { let res = policy .evaluate_email(mas_policy::EmailInput { email: &input.email, + requester: requester_fingerprint.into(), }) .await?; if !res.valid() { diff --git a/crates/handlers/src/oauth2/authorization/complete.rs b/crates/handlers/src/oauth2/authorization/complete.rs index ba2786240..205809c19 100644 --- a/crates/handlers/src/oauth2/authorization/complete.rs +++ b/crates/handlers/src/oauth2/authorization/complete.rs @@ -231,6 +231,9 @@ pub(crate) async fn complete( client, scope: &grant.scope, grant_type: mas_policy::GrantType::AuthorizationCode, + requester: mas_policy::Requester { + ip_address: activity_tracker.ip(), + }, }) .await?; diff --git a/crates/handlers/src/oauth2/consent.rs b/crates/handlers/src/oauth2/consent.rs index b2b99b372..b549eaff5 100644 --- a/crates/handlers/src/oauth2/consent.rs +++ b/crates/handlers/src/oauth2/consent.rs @@ -116,6 +116,9 @@ pub(crate) async fn get( client: &client, scope: &grant.scope, grant_type: mas_policy::GrantType::AuthorizationCode, + requester: mas_policy::Requester { + ip_address: activity_tracker.ip(), + }, }) .await?; @@ -195,6 +198,9 @@ pub(crate) async fn post( client: &client, scope: &grant.scope, grant_type: mas_policy::GrantType::AuthorizationCode, + requester: mas_policy::Requester { + ip_address: activity_tracker.ip(), + }, }) .await?; diff --git a/crates/handlers/src/oauth2/device/consent.rs b/crates/handlers/src/oauth2/device/consent.rs index 9f0045ad9..7674e7a4e 100644 --- a/crates/handlers/src/oauth2/device/consent.rs +++ b/crates/handlers/src/oauth2/device/consent.rs @@ -87,6 +87,9 @@ pub(crate) async fn get( client: &client, scope: &grant.scope, user: Some(&session.user), + requester: mas_policy::Requester { + ip_address: activity_tracker.ip(), + }, }) .await?; if !res.valid() { @@ -167,6 +170,9 @@ pub(crate) async fn post( client: &client, scope: &grant.scope, user: Some(&session.user), + requester: mas_policy::Requester { + ip_address: activity_tracker.ip(), + }, }) .await?; if !res.valid() { diff --git a/crates/handlers/src/oauth2/registration.rs b/crates/handlers/src/oauth2/registration.rs index 1dcba9f65..a32c77caf 100644 --- a/crates/handlers/src/oauth2/registration.rs +++ b/crates/handlers/src/oauth2/registration.rs @@ -25,7 +25,7 @@ use thiserror::Error; use tracing::info; use url::Url; -use crate::impl_from_error_for_route; +use crate::{impl_from_error_for_route, BoundActivityTracker}; #[derive(Debug, Error)] pub(crate) enum RouteError { @@ -195,6 +195,7 @@ pub(crate) async fn post( clock: BoxClock, mut repo: BoxRepository, mut policy: Policy, + activity_tracker: BoundActivityTracker, State(encrypter): State, body: Result, axum::extract::rejection::JsonRejection>, ) -> Result { @@ -247,6 +248,9 @@ pub(crate) async fn post( let res = policy .evaluate_client_registration(mas_policy::ClientRegistrationInput { client_metadata: &metadata, + requester: mas_policy::Requester { + ip_address: activity_tracker.ip(), + }, }) .await?; if !res.valid() { diff --git a/crates/handlers/src/oauth2/token.rs b/crates/handlers/src/oauth2/token.rs index 70705293e..e89eb34d7 100644 --- a/crates/handlers/src/oauth2/token.rs +++ b/crates/handlers/src/oauth2/token.rs @@ -681,6 +681,9 @@ async fn client_credentials_grant( client, scope: &scope, grant_type: mas_policy::GrantType::ClientCredentials, + requester: mas_policy::Requester { + ip_address: activity_tracker.ip(), + }, }) .await?; if !res.valid() { diff --git a/crates/handlers/src/rate_limit.rs b/crates/handlers/src/rate_limit.rs index eff30d86f..673cee0a6 100644 --- a/crates/handlers/src/rate_limit.rs +++ b/crates/handlers/src/rate_limit.rs @@ -53,6 +53,12 @@ pub struct RequesterFingerprint { ip: Option, } +impl From for mas_policy::Requester { + fn from(val: RequesterFingerprint) -> Self { + mas_policy::Requester { ip_address: val.ip } + } +} + impl std::fmt::Display for RequesterFingerprint { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { if let Some(ip) = self.ip { diff --git a/crates/handlers/src/upstream_oauth2/link.rs b/crates/handlers/src/upstream_oauth2/link.rs index ec22c2c52..cccb3d961 100644 --- a/crates/handlers/src/upstream_oauth2/link.rs +++ b/crates/handlers/src/upstream_oauth2/link.rs @@ -43,7 +43,8 @@ use super::{ UpstreamSessionsCookie, }; use crate::{ - impl_from_error_for_route, views::shared::OptionalPostAuthAction, PreferredLanguage, SiteConfig, + impl_from_error_for_route, views::shared::OptionalPostAuthAction, BoundActivityTracker, + PreferredLanguage, SiteConfig, }; const DEFAULT_LOCALPART_TEMPLATE: &str = "{{ user.preferred_username }}"; @@ -199,6 +200,7 @@ pub(crate) async fn get( State(url_builder): State, State(homeserver): State, cookie_jar: CookieJar, + activity_tracker: BoundActivityTracker, user_agent: Option>, Path(link_id): Path, ) -> Result { @@ -445,6 +447,9 @@ pub(crate) async fn get( registration_method: mas_policy::RegistrationMethod::UpstreamOAuth2, username: &localpart, email: None, + requester: mas_policy::Requester { + ip_address: activity_tracker.ip(), + }, }) .await?; @@ -502,6 +507,7 @@ pub(crate) async fn post( user_agent: Option>, mut policy: Policy, PreferredLanguage(locale): PreferredLanguage, + activity_tracker: BoundActivityTracker, State(templates): State, State(homeserver): State, State(url_builder): State, @@ -760,6 +766,9 @@ pub(crate) async fn post( registration_method: mas_policy::RegistrationMethod::UpstreamOAuth2, username: &username, email: email.as_deref(), + requester: mas_policy::Requester { + ip_address: activity_tracker.ip(), + }, }) .await?; diff --git a/crates/handlers/src/views/register/password.rs b/crates/handlers/src/views/register/password.rs index 0cc59cbe3..8ce2677ed 100644 --- a/crates/handlers/src/views/register/password.rs +++ b/crates/handlers/src/views/register/password.rs @@ -237,6 +237,9 @@ pub(crate) async fn post( registration_method: mas_policy::RegistrationMethod::Password, username: &form.username, email: Some(&form.email), + requester: mas_policy::Requester { + ip_address: activity_tracker.ip(), + }, }) .await?; diff --git a/crates/policy/src/lib.rs b/crates/policy/src/lib.rs index bd338714f..44da50d38 100644 --- a/crates/policy/src/lib.rs +++ b/crates/policy/src/lib.rs @@ -307,6 +307,7 @@ mod tests { registration_method: RegistrationMethod::Password, username: "hello", email: Some("hello@example.com"), + requester: Requester { ip_address: None }, }) .await .unwrap(); @@ -317,6 +318,7 @@ mod tests { registration_method: RegistrationMethod::Password, username: "hello", email: Some("hello@foo.element.io"), + requester: Requester { ip_address: None }, }) .await .unwrap(); @@ -327,6 +329,7 @@ mod tests { registration_method: RegistrationMethod::Password, username: "hello", email: Some("hello@staging.element.io"), + requester: Requester { ip_address: None }, }) .await .unwrap(); diff --git a/crates/policy/src/model.rs b/crates/policy/src/model.rs index 4e9539e2b..d89b16779 100644 --- a/crates/policy/src/model.rs +++ b/crates/policy/src/model.rs @@ -9,6 +9,8 @@ //! This is useful to generate JSON schemas for each input type, which can then //! be type-checked by Open Policy Agent. +use std::net::IpAddr; + use mas_data_model::{Client, User}; use oauth2_types::{registration::VerifiedClientMetadata, scope::Scope}; use serde::{Deserialize, Serialize}; @@ -92,6 +94,15 @@ impl EvaluationResult { } } +/// Identity of the requester +#[derive(Serialize, Debug, Default)] +#[serde(rename_all = "snake_case")] +#[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))] +pub struct Requester { + /// IP address of the entity making the request + pub ip_address: Option, +} + #[derive(Serialize, Debug)] #[cfg_attr(feature = "jsonschema", derive(schemars::JsonSchema))] pub enum RegistrationMethod { @@ -113,6 +124,8 @@ pub struct RegisterInput<'a> { #[serde(skip_serializing_if = "Option::is_none")] pub email: Option<&'a str>, + + pub requester: Requester, } /// Input for the client registration policy. @@ -125,6 +138,7 @@ pub struct ClientRegistrationInput<'a> { schemars(with = "std::collections::HashMap") )] pub client_metadata: &'a VerifiedClientMetadata, + pub requester: Requester, } #[derive(Serialize, Debug)] @@ -158,6 +172,8 @@ pub struct AuthorizationGrantInput<'a> { pub scope: &'a Scope, pub grant_type: GrantType, + + pub requester: Requester, } /// Input for the email add policy. @@ -167,4 +183,5 @@ pub struct AuthorizationGrantInput<'a> { pub struct EmailInput<'a> { pub email: &'a str, + pub requester: Requester, } diff --git a/policies/register/register.rego b/policies/register/register.rego index 0fb36bf37..c3836d170 100644 --- a/policies/register/register.rego +++ b/policies/register/register.rego @@ -13,6 +13,26 @@ allow if { count(violation) == 0 } +# Normalize an IP address or CIDR to a CIDR +normalize_cidr(ip) := ip if contains(ip, "/") + +# If it's an IPv4, append /32 +normalize_cidr(ip) := sprintf("%s/32", [ip]) if { + not contains(ip, "/") + not contains(ip, ":") +} + +# If it's an IPv6, append /128 +normalize_cidr(ip) := sprintf("%s/128", [ip]) if { + not contains(ip, "/") + contains(ip, ":") +} + +is_ip_banned(ip) if { + some cidr in data.registration.banned_ips + net.cidr_contains(normalize_cidr(cidr), ip) +} + mxid(username, server_name) := sprintf("@%s:%s", [username, server_name]) # METADATA @@ -48,6 +68,10 @@ violation contains {"msg": "unknown registration method"} if { not input.registration_method in ["password", "upstream-oauth2"] } +violation contains {"msg": "IP address is banned"} if { + is_ip_banned(input.requester.ip_address) +} + # Check that we supplied an email for password registration violation contains {"field": "email", "msg": "email required for password-based registration"} if { input.registration_method == "password" diff --git a/policies/register/register_test.rego b/policies/register/register_test.rego index 26e119248..cf4324fa2 100644 --- a/policies/register/register_test.rego +++ b/policies/register/register_test.rego @@ -74,3 +74,19 @@ test_invalid_username if { test_numeric_username if { not register.allow with input as {"username": "1234", "registration_method": "upstream-oauth2"} } + +test_ip_ban if { + not register.allow with input as { + "username": "hello", + "registration_method": "upstream-oauth2", + "requester": {"ip_address": "1.1.1.1"}, + } + with data.registration.banned_ips as ["1.1.1.1"] + + not register.allow with input as { + "username": "hello", + "registration_method": "upstream-oauth2", + "requester": {"ip_address": "1.1.1.1"}, + } + with data.registration.banned_ips as ["1.0.0.0/8"] +} diff --git a/policies/schema/authorization_grant_input.json b/policies/schema/authorization_grant_input.json index 9b2f77403..5c431b431 100644 --- a/policies/schema/authorization_grant_input.json +++ b/policies/schema/authorization_grant_input.json @@ -6,6 +6,7 @@ "required": [ "client", "grant_type", + "requester", "scope" ], "properties": { @@ -22,6 +23,9 @@ }, "grant_type": { "$ref": "#/definitions/GrantType" + }, + "requester": { + "$ref": "#/definitions/Requester" } }, "definitions": { @@ -32,6 +36,17 @@ "client_credentials", "urn:ietf:params:oauth:grant-type:device_code" ] + }, + "Requester": { + "description": "Identity of the requester", + "type": "object", + "properties": { + "ip_address": { + "description": "IP address of the entity making the request", + "type": "string", + "format": "ip" + } + } } } } \ No newline at end of file diff --git a/policies/schema/client_registration_input.json b/policies/schema/client_registration_input.json index cc9957a85..096c56a0a 100644 --- a/policies/schema/client_registration_input.json +++ b/policies/schema/client_registration_input.json @@ -4,12 +4,29 @@ "description": "Input for the client registration policy.", "type": "object", "required": [ - "client_metadata" + "client_metadata", + "requester" ], "properties": { "client_metadata": { "type": "object", "additionalProperties": true + }, + "requester": { + "$ref": "#/definitions/Requester" + } + }, + "definitions": { + "Requester": { + "description": "Identity of the requester", + "type": "object", + "properties": { + "ip_address": { + "description": "IP address of the entity making the request", + "type": "string", + "format": "ip" + } + } } } } \ No newline at end of file diff --git a/policies/schema/email_input.json b/policies/schema/email_input.json index 19f4af523..384c359b6 100644 --- a/policies/schema/email_input.json +++ b/policies/schema/email_input.json @@ -4,11 +4,28 @@ "description": "Input for the email add policy.", "type": "object", "required": [ - "email" + "email", + "requester" ], "properties": { "email": { "type": "string" + }, + "requester": { + "$ref": "#/definitions/Requester" + } + }, + "definitions": { + "Requester": { + "description": "Identity of the requester", + "type": "object", + "properties": { + "ip_address": { + "description": "IP address of the entity making the request", + "type": "string", + "format": "ip" + } + } } } } \ No newline at end of file diff --git a/policies/schema/register_input.json b/policies/schema/register_input.json index e1d796ea1..27a19b78c 100644 --- a/policies/schema/register_input.json +++ b/policies/schema/register_input.json @@ -2,49 +2,44 @@ "$schema": "http://json-schema.org/draft-07/schema#", "title": "RegisterInput", "description": "Input for the user registration policy.", - "oneOf": [ - { - "type": "object", - "required": [ - "email", - "registration_method", - "username" - ], - "properties": { - "registration_method": { - "type": "string", - "enum": [ - "password" - ] - }, - "username": { - "type": "string" - }, - "email": { - "type": "string" - } - } + "type": "object", + "required": [ + "registration_method", + "requester", + "username" + ], + "properties": { + "registration_method": { + "$ref": "#/definitions/RegistrationMethod" }, - { + "username": { + "type": "string" + }, + "email": { + "type": "string" + }, + "requester": { + "$ref": "#/definitions/Requester" + } + }, + "definitions": { + "RegistrationMethod": { + "type": "string", + "enum": [ + "password", + "upstream-oauth2" + ] + }, + "Requester": { + "description": "Identity of the requester", "type": "object", - "required": [ - "registration_method", - "username" - ], "properties": { - "registration_method": { + "ip_address": { + "description": "IP address of the entity making the request", "type": "string", - "enum": [ - "upstream-oauth2" - ] - }, - "username": { - "type": "string" - }, - "email": { - "type": "string" + "format": "ip" } } } - ] + } } \ No newline at end of file