Merge branch 'main' into keys_dir

This commit is contained in:
Kai A. Hiller
2025-11-18 16:47:16 +01:00
250 changed files with 20306 additions and 8692 deletions

1
.github/CODEOWNERS vendored Normal file
View File

@@ -0,0 +1 @@
* @element-hq/mas-maintainers

View File

@@ -10,9 +10,9 @@ runs:
using: composite using: composite
steps: steps:
- name: Install Node - name: Install Node
uses: actions/setup-node@v4.2.0 uses: actions/setup-node@v6.0.0
with: with:
node-version: "22" node-version: "24"
- name: Install dependencies - name: Install dependencies
run: npm ci run: npm ci

View File

@@ -12,7 +12,8 @@ runs:
- name: Install Open Policy Agent - name: Install Open Policy Agent
uses: open-policy-agent/setup-opa@v2.2.0 uses: open-policy-agent/setup-opa@v2.2.0
with: with:
version: 1.1.0 # Keep in sync with the Dockerfile and policies/Makefile
version: 1.8.0
- name: Build the policies - name: Build the policies
run: make run: make

View File

@@ -84,7 +84,7 @@ jobs:
chmod -R u=rwX,go=rX assets-dist/ chmod -R u=rwX,go=rX assets-dist/
- name: Upload assets - name: Upload assets
uses: actions/upload-artifact@v4.6.2 uses: actions/upload-artifact@v5.0.0
with: with:
name: assets name: assets
path: assets-dist path: assets-dist
@@ -143,7 +143,7 @@ jobs:
-p mas-cli -p mas-cli
- name: Upload binary artifact - name: Upload binary artifact
uses: actions/upload-artifact@v4.6.2 uses: actions/upload-artifact@v5.0.0
with: with:
name: binary-${{ matrix.target }} name: binary-${{ matrix.target }}
path: target/${{ matrix.target }}/release/mas-cli path: target/${{ matrix.target }}/release/mas-cli
@@ -162,19 +162,19 @@ jobs:
steps: steps:
- name: Download assets - name: Download assets
uses: actions/download-artifact@v5 uses: actions/download-artifact@v6
with: with:
name: assets name: assets
path: assets-dist path: assets-dist
- name: Download binary x86_64 - name: Download binary x86_64
uses: actions/download-artifact@v5 uses: actions/download-artifact@v6
with: with:
name: binary-x86_64-unknown-linux-gnu name: binary-x86_64-unknown-linux-gnu
path: binary-x86_64 path: binary-x86_64
- name: Download binary aarch64 - name: Download binary aarch64
uses: actions/download-artifact@v5 uses: actions/download-artifact@v6
with: with:
name: binary-aarch64-unknown-linux-gnu name: binary-aarch64-unknown-linux-gnu
path: binary-aarch64 path: binary-aarch64
@@ -192,13 +192,13 @@ jobs:
done done
- name: Upload aarch64 archive - name: Upload aarch64 archive
uses: actions/upload-artifact@v4.6.2 uses: actions/upload-artifact@v5.0.0
with: with:
name: mas-cli-aarch64-linux name: mas-cli-aarch64-linux
path: mas-cli-aarch64-linux.tar.gz path: mas-cli-aarch64-linux.tar.gz
- name: Upload x86_64 archive - name: Upload x86_64 archive
uses: actions/upload-artifact@v4.6.2 uses: actions/upload-artifact@v5.0.0
with: with:
name: mas-cli-x86_64-linux name: mas-cli-x86_64-linux
path: mas-cli-x86_64-linux.tar.gz path: mas-cli-x86_64-linux.tar.gz
@@ -226,7 +226,7 @@ jobs:
steps: steps:
- name: Docker meta - name: Docker meta
id: meta id: meta
uses: docker/metadata-action@v5.8.0 uses: docker/metadata-action@v5.9.0
with: with:
images: "${{ env.IMAGE }}" images: "${{ env.IMAGE }}"
bake-target: docker-metadata-action bake-target: docker-metadata-action
@@ -242,7 +242,7 @@ jobs:
- name: Docker meta (debug variant) - name: Docker meta (debug variant)
id: meta-debug id: meta-debug
uses: docker/metadata-action@v5.8.0 uses: docker/metadata-action@v5.9.0
with: with:
images: "${{ env.IMAGE }}" images: "${{ env.IMAGE }}"
bake-target: docker-metadata-action-debug bake-target: docker-metadata-action-debug
@@ -258,7 +258,7 @@ jobs:
type=sha type=sha
- name: Setup Cosign - name: Setup Cosign
uses: sigstore/cosign-installer@v3.9.2 uses: sigstore/cosign-installer@v4.0.0
- name: Set up Docker Buildx - name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3.11.1 uses: docker/setup-buildx-action@v3.11.1
@@ -268,7 +268,7 @@ jobs:
mirrors = ["mirror.gcr.io"] mirrors = ["mirror.gcr.io"]
- name: Login to GitHub Container Registry - name: Login to GitHub Container Registry
uses: docker/login-action@v3.5.0 uses: docker/login-action@v3.6.0
with: with:
registry: ghcr.io registry: ghcr.io
username: ${{ github.repository_owner }} username: ${{ github.repository_owner }}
@@ -320,14 +320,14 @@ jobs:
- build-image - build-image
steps: steps:
- name: Download the artifacts from the previous job - name: Download the artifacts from the previous job
uses: actions/download-artifact@v5 uses: actions/download-artifact@v6
with: with:
pattern: mas-cli-* pattern: mas-cli-*
path: artifacts path: artifacts
merge-multiple: true merge-multiple: true
- name: Prepare a release - name: Prepare a release
uses: softprops/action-gh-release@v2.3.2 uses: softprops/action-gh-release@v2.4.2
with: with:
generate_release_notes: true generate_release_notes: true
body: | body: |
@@ -382,21 +382,21 @@ jobs:
.github/scripts .github/scripts
- name: Download the artifacts from the previous job - name: Download the artifacts from the previous job
uses: actions/download-artifact@v5 uses: actions/download-artifact@v6
with: with:
pattern: mas-cli-* pattern: mas-cli-*
path: artifacts path: artifacts
merge-multiple: true merge-multiple: true
- name: Update unstable git tag - name: Update unstable git tag
uses: actions/github-script@v7.0.1 uses: actions/github-script@v8.0.0
with: with:
script: | script: |
const script = require('./.github/scripts/update-unstable-tag.cjs'); const script = require('./.github/scripts/update-unstable-tag.cjs');
await script({ core, github, context }); await script({ core, github, context });
- name: Update unstable release - name: Update unstable release
uses: softprops/action-gh-release@v2.3.2 uses: softprops/action-gh-release@v2.4.2
with: with:
name: "Unstable build" name: "Unstable build"
tag_name: unstable tag_name: unstable
@@ -460,7 +460,7 @@ jobs:
.github/scripts .github/scripts
- name: Remove label and comment - name: Remove label and comment
uses: actions/github-script@v7.0.1 uses: actions/github-script@v8.0.0
env: env:
BUILD_IMAGE_MANIFEST: ${{ needs.build-image.outputs.metadata }} BUILD_IMAGE_MANIFEST: ${{ needs.build-image.outputs.metadata }}
with: with:

View File

@@ -41,7 +41,8 @@ jobs:
- name: Setup Regal - name: Setup Regal
uses: StyraInc/setup-regal@v1 uses: StyraInc/setup-regal@v1
with: with:
version: 0.29.2 # Keep in sync with policies/Makefile
version: 0.36.1
- name: Lint policies - name: Lint policies
working-directory: ./policies working-directory: ./policies
@@ -63,9 +64,9 @@ jobs:
uses: actions/checkout@v5 uses: actions/checkout@v5
- name: Install Node - name: Install Node
uses: actions/setup-node@v4.4.0 uses: actions/setup-node@v6.0.0
with: with:
node-version: 22 node-version: 24
- name: Install Node dependencies - name: Install Node dependencies
working-directory: ./frontend working-directory: ./frontend
@@ -87,9 +88,9 @@ jobs:
uses: actions/checkout@v5 uses: actions/checkout@v5
- name: Install Node - name: Install Node
uses: actions/setup-node@v4.4.0 uses: actions/setup-node@v6.0.0
with: with:
node-version: 22 node-version: 24
- name: Install Node dependencies - name: Install Node dependencies
working-directory: ./frontend working-directory: ./frontend
@@ -111,9 +112,9 @@ jobs:
uses: actions/checkout@v5 uses: actions/checkout@v5
- name: Install Node - name: Install Node
uses: actions/setup-node@v4.4.0 uses: actions/setup-node@v6.0.0
with: with:
node-version: 20 node-version: 24
- name: Install Node dependencies - name: Install Node dependencies
working-directory: ./frontend working-directory: ./frontend
@@ -256,7 +257,7 @@ jobs:
SQLX_OFFLINE: "1" SQLX_OFFLINE: "1"
- name: Upload archive to workflow - name: Upload archive to workflow
uses: actions/upload-artifact@v4.6.2 uses: actions/upload-artifact@v5.0.0
with: with:
name: nextest-archive name: nextest-archive
path: nextest-archive.tar.zst path: nextest-archive.tar.zst
@@ -304,7 +305,7 @@ jobs:
- uses: ./.github/actions/build-policies - uses: ./.github/actions/build-policies
- name: Download archive - name: Download archive
uses: actions/download-artifact@v5 uses: actions/download-artifact@v6
with: with:
name: nextest-archive name: nextest-archive

View File

@@ -38,7 +38,7 @@ jobs:
run: make coverage run: make coverage
- name: Upload to codecov.io - name: Upload to codecov.io
uses: codecov/codecov-action@v5.5.0 uses: codecov/codecov-action@v5.5.1
with: with:
token: ${{ secrets.CODECOV_TOKEN }} token: ${{ secrets.CODECOV_TOKEN }}
files: policies/coverage.json files: policies/coverage.json
@@ -65,7 +65,7 @@ jobs:
run: npm run coverage run: npm run coverage
- name: Upload to codecov.io - name: Upload to codecov.io
uses: codecov/codecov-action@v5.5.0 uses: codecov/codecov-action@v5.5.1
with: with:
token: ${{ secrets.CODECOV_TOKEN }} token: ${{ secrets.CODECOV_TOKEN }}
directory: frontend/coverage/ directory: frontend/coverage/
@@ -132,7 +132,7 @@ jobs:
grcov . --binary-path ./target/debug/deps/ -s . -t lcov --branch --ignore-not-existing --ignore '../*' --ignore "/*" -o target/coverage/tests.lcov grcov . --binary-path ./target/debug/deps/ -s . -t lcov --branch --ignore-not-existing --ignore '../*' --ignore "/*" -o target/coverage/tests.lcov
- name: Upload to codecov.io - name: Upload to codecov.io
uses: codecov/codecov-action@v5.5.0 uses: codecov/codecov-action@v5.5.1
with: with:
token: ${{ secrets.CODECOV_TOKEN }} token: ${{ secrets.CODECOV_TOKEN }}
files: target/coverage/*.lcov files: target/coverage/*.lcov

View File

@@ -39,9 +39,9 @@ jobs:
tool: mdbook tool: mdbook
- name: Install Node - name: Install Node
uses: actions/setup-node@v4.4.0 uses: actions/setup-node@v6.0.0
with: with:
node-version: 22 node-version: 24
- name: Build the documentation - name: Build the documentation
run: sh misc/build-docs.sh run: sh misc/build-docs.sh

View File

@@ -30,7 +30,7 @@ jobs:
.github/scripts .github/scripts
- name: Push branch and open a PR - name: Push branch and open a PR
uses: actions/github-script@v7.0.1 uses: actions/github-script@v8.0.0
env: env:
SHA: ${{ inputs.sha }} SHA: ${{ inputs.sha }}
with: with:

View File

@@ -64,9 +64,9 @@ jobs:
uses: actions/checkout@v5 uses: actions/checkout@v5
- name: Install Node - name: Install Node
uses: actions/setup-node@v4.4.0 uses: actions/setup-node@v6.0.0
with: with:
node-version: 22 node-version: 24
- name: Install Localazy CLI - name: Install Localazy CLI
run: npm install -g @localazy/cli run: npm install -g @localazy/cli
@@ -112,7 +112,7 @@ jobs:
.github/scripts .github/scripts
- name: Create a new release branch - name: Create a new release branch
uses: actions/github-script@v7.0.1 uses: actions/github-script@v8.0.0
env: env:
BRANCH: release/v${{ needs.compute-version.outputs.short }} BRANCH: release/v${{ needs.compute-version.outputs.short }}
SHA: ${{ needs.tag.outputs.sha }} SHA: ${{ needs.tag.outputs.sha }}

View File

@@ -82,7 +82,7 @@ jobs:
.github/scripts .github/scripts
- name: Update the release branch - name: Update the release branch
uses: actions/github-script@v7.0.1 uses: actions/github-script@v8.0.0
env: env:
BRANCH: "${{ github.ref_name }}" BRANCH: "${{ github.ref_name }}"
SHA: ${{ needs.tag.outputs.sha }} SHA: ${{ needs.tag.outputs.sha }}

View File

@@ -46,7 +46,7 @@ jobs:
run: cargo metadata --format-version 1 run: cargo metadata --format-version 1
- name: Commit and tag using the GitHub API - name: Commit and tag using the GitHub API
uses: actions/github-script@v7.0.1 uses: actions/github-script@v8.0.0
id: commit id: commit
env: env:
VERSION: ${{ inputs.version }} VERSION: ${{ inputs.version }}
@@ -58,7 +58,7 @@ jobs:
return await script({ core, github, context }); return await script({ core, github, context });
- name: Update the refs - name: Update the refs
uses: actions/github-script@v7.0.1 uses: actions/github-script@v8.0.0
env: env:
VERSION: ${{ inputs.version }} VERSION: ${{ inputs.version }}
TAG_SHA: ${{ fromJSON(steps.commit.outputs.result).tag }} TAG_SHA: ${{ fromJSON(steps.commit.outputs.result).tag }}

View File

@@ -22,9 +22,9 @@ jobs:
uses: actions/checkout@v5 uses: actions/checkout@v5
- name: Install Node - name: Install Node
uses: actions/setup-node@v4.4.0 uses: actions/setup-node@v6.0.0
with: with:
node-version: 22 node-version: 24
- name: Install Localazy CLI - name: Install Localazy CLI
run: npm install -g @localazy/cli run: npm install -g @localazy/cli

View File

@@ -21,9 +21,9 @@ jobs:
uses: actions/checkout@v5 uses: actions/checkout@v5
- name: Install Node - name: Install Node
uses: actions/setup-node@v4.4.0 uses: actions/setup-node@v6.0.0
with: with:
node-version: 22 node-version: 24
- name: Install Localazy CLI - name: Install Localazy CLI
run: npm install -g @localazy/cli run: npm install -g @localazy/cli

1315
Cargo.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -9,7 +9,7 @@ members = ["crates/*"]
resolver = "2" resolver = "2"
# Updated in the CI with a `sed` command # Updated in the CI with a `sed` command
package.version = "1.2.0-rc.0" package.version = "1.6.0"
package.license = "AGPL-3.0-only OR LicenseRef-Element-Commercial" package.license = "AGPL-3.0-only OR LicenseRef-Element-Commercial"
package.authors = ["Element Backend Team"] package.authors = ["Element Backend Team"]
package.edition = "2024" package.edition = "2024"
@@ -34,40 +34,40 @@ broken_intra_doc_links = "deny"
[workspace.dependencies] [workspace.dependencies]
# Workspace crates # Workspace crates
mas-axum-utils = { path = "./crates/axum-utils/", version = "=1.2.0-rc.0" } mas-axum-utils = { path = "./crates/axum-utils/", version = "=1.6.0" }
mas-cli = { path = "./crates/cli/", version = "=1.2.0-rc.0" } mas-cli = { path = "./crates/cli/", version = "=1.6.0" }
mas-config = { path = "./crates/config/", version = "=1.2.0-rc.0" } mas-config = { path = "./crates/config/", version = "=1.6.0" }
mas-context = { path = "./crates/context/", version = "=1.2.0-rc.0" } mas-context = { path = "./crates/context/", version = "=1.6.0" }
mas-data-model = { path = "./crates/data-model/", version = "=1.2.0-rc.0" } mas-data-model = { path = "./crates/data-model/", version = "=1.6.0" }
mas-email = { path = "./crates/email/", version = "=1.2.0-rc.0" } mas-email = { path = "./crates/email/", version = "=1.6.0" }
mas-graphql = { path = "./crates/graphql/", version = "=1.2.0-rc.0" } mas-graphql = { path = "./crates/graphql/", version = "=1.6.0" }
mas-handlers = { path = "./crates/handlers/", version = "=1.2.0-rc.0" } mas-handlers = { path = "./crates/handlers/", version = "=1.6.0" }
mas-http = { path = "./crates/http/", version = "=1.2.0-rc.0" } mas-http = { path = "./crates/http/", version = "=1.6.0" }
mas-i18n = { path = "./crates/i18n/", version = "=1.2.0-rc.0" } mas-i18n = { path = "./crates/i18n/", version = "=1.6.0" }
mas-i18n-scan = { path = "./crates/i18n-scan/", version = "=1.2.0-rc.0" } mas-i18n-scan = { path = "./crates/i18n-scan/", version = "=1.6.0" }
mas-iana = { path = "./crates/iana/", version = "=1.2.0-rc.0" } mas-iana = { path = "./crates/iana/", version = "=1.6.0" }
mas-iana-codegen = { path = "./crates/iana-codegen/", version = "=1.2.0-rc.0" } mas-iana-codegen = { path = "./crates/iana-codegen/", version = "=1.6.0" }
mas-jose = { path = "./crates/jose/", version = "=1.2.0-rc.0" } mas-jose = { path = "./crates/jose/", version = "=1.6.0" }
mas-keystore = { path = "./crates/keystore/", version = "=1.2.0-rc.0" } mas-keystore = { path = "./crates/keystore/", version = "=1.6.0" }
mas-listener = { path = "./crates/listener/", version = "=1.2.0-rc.0" } mas-listener = { path = "./crates/listener/", version = "=1.6.0" }
mas-matrix = { path = "./crates/matrix/", version = "=1.2.0-rc.0" } mas-matrix = { path = "./crates/matrix/", version = "=1.6.0" }
mas-matrix-synapse = { path = "./crates/matrix-synapse/", version = "=1.2.0-rc.0" } mas-matrix-synapse = { path = "./crates/matrix-synapse/", version = "=1.6.0" }
mas-oidc-client = { path = "./crates/oidc-client/", version = "=1.2.0-rc.0" } mas-oidc-client = { path = "./crates/oidc-client/", version = "=1.6.0" }
mas-policy = { path = "./crates/policy/", version = "=1.2.0-rc.0" } mas-policy = { path = "./crates/policy/", version = "=1.6.0" }
mas-router = { path = "./crates/router/", version = "=1.2.0-rc.0" } mas-router = { path = "./crates/router/", version = "=1.6.0" }
mas-spa = { path = "./crates/spa/", version = "=1.2.0-rc.0" } mas-spa = { path = "./crates/spa/", version = "=1.6.0" }
mas-storage = { path = "./crates/storage/", version = "=1.2.0-rc.0" } mas-storage = { path = "./crates/storage/", version = "=1.6.0" }
mas-storage-pg = { path = "./crates/storage-pg/", version = "=1.2.0-rc.0" } mas-storage-pg = { path = "./crates/storage-pg/", version = "=1.6.0" }
mas-tasks = { path = "./crates/tasks/", version = "=1.2.0-rc.0" } mas-tasks = { path = "./crates/tasks/", version = "=1.6.0" }
mas-templates = { path = "./crates/templates/", version = "=1.2.0-rc.0" } mas-templates = { path = "./crates/templates/", version = "=1.6.0" }
mas-tower = { path = "./crates/tower/", version = "=1.2.0-rc.0" } mas-tower = { path = "./crates/tower/", version = "=1.6.0" }
oauth2-types = { path = "./crates/oauth2-types/", version = "=1.2.0-rc.0" } oauth2-types = { path = "./crates/oauth2-types/", version = "=1.6.0" }
syn2mas = { path = "./crates/syn2mas", version = "=1.2.0-rc.0" } syn2mas = { path = "./crates/syn2mas", version = "=1.6.0" }
# OpenAPI schema generation and validation # OpenAPI schema generation and validation
[workspace.dependencies.aide] [workspace.dependencies.aide]
version = "0.14.2" version = "0.15.1"
features = ["axum", "axum-extra", "axum-json", "axum-query", "macros"] features = ["axum", "axum-extra", "axum-extra-query", "axum-json", "macros"]
# An `Arc` that can be atomically updated # An `Arc` that can be atomically updated
[workspace.dependencies.arc-swap] [workspace.dependencies.arc-swap]
@@ -88,7 +88,7 @@ version = "0.1.89"
# High-level error handling # High-level error handling
[workspace.dependencies.anyhow] [workspace.dependencies.anyhow]
version = "1.0.99" version = "1.0.100"
# Assert that a value matches a pattern # Assert that a value matches a pattern
[workspace.dependencies.assert_matches] [workspace.dependencies.assert_matches]
@@ -96,12 +96,12 @@ version = "1.5.0"
# HTTP router # HTTP router
[workspace.dependencies.axum] [workspace.dependencies.axum]
version = "0.8.4" version = "0.8.6"
# Extra utilities for Axum # Extra utilities for Axum
[workspace.dependencies.axum-extra] [workspace.dependencies.axum-extra]
version = "0.10.1" version = "0.10.3"
features = ["cookie-private", "cookie-key-expansion", "typed-header"] features = ["cookie-private", "cookie-key-expansion", "typed-header", "query"]
# Axum macros # Axum macros
[workspace.dependencies.axum-macros] [workspace.dependencies.axum-macros]
@@ -129,7 +129,7 @@ default-features = true
# Packed bitfields # Packed bitfields
[workspace.dependencies.bitflags] [workspace.dependencies.bitflags]
version = "2.9.3" version = "2.9.4"
# Bytes # Bytes
[workspace.dependencies.bytes] [workspace.dependencies.bytes]
@@ -137,7 +137,7 @@ version = "1.10.1"
# UTF-8 paths # UTF-8 paths
[workspace.dependencies.camino] [workspace.dependencies.camino]
version = "1.1.11" version = "1.2.1"
features = ["serde1"] features = ["serde1"]
# ChaCha20Poly1305 AEAD # ChaCha20Poly1305 AEAD
@@ -161,13 +161,13 @@ features = ["serde_json"]
# Time utilities # Time utilities
[workspace.dependencies.chrono] [workspace.dependencies.chrono]
version = "0.4.41" version = "0.4.42"
default-features = false default-features = false
features = ["serde", "clock"] features = ["serde", "clock"]
# CLI argument parsing # CLI argument parsing
[workspace.dependencies.clap] [workspace.dependencies.clap]
version = "4.5.46" version = "4.5.50"
features = ["derive"] features = ["derive"]
# Object Identifiers (OIDs) as constants # Object Identifiers (OIDs) as constants
@@ -189,7 +189,7 @@ version = "0.15.0"
# CSV parsing and writing # CSV parsing and writing
[workspace.dependencies.csv] [workspace.dependencies.csv]
version = "1.3.1" version = "1.4.0"
# DER encoding # DER encoding
[workspace.dependencies.der] [workspace.dependencies.der]
@@ -274,7 +274,7 @@ features = ["client", "server", "http1", "http2"]
# Additional Hyper utilties # Additional Hyper utilties
[workspace.dependencies.hyper-util] [workspace.dependencies.hyper-util]
version = "0.1.16" version = "0.1.17"
features = [ features = [
"client", "client",
"server", "server",
@@ -321,7 +321,7 @@ features = ["std"]
# HashMap which preserves insertion order # HashMap which preserves insertion order
[workspace.dependencies.indexmap] [workspace.dependencies.indexmap]
version = "2.11.0" version = "2.11.4"
features = ["serde"] features = ["serde"]
# Indented string literals # Indented string literals
@@ -330,13 +330,13 @@ version = "2.0.6"
# Snapshot testing # Snapshot testing
[workspace.dependencies.insta] [workspace.dependencies.insta]
version = "1.43.1" version = "1.43.2"
features = ["yaml", "json"] features = ["yaml", "json"]
# IP network address types # IP network address types
[workspace.dependencies.ipnetwork] [workspace.dependencies.ipnetwork]
version = "0.20.0" version = "0.20.0"
features = ["serde", "schemars"] features = ["serde"]
# Iterator utilities # Iterator utilities
[workspace.dependencies.itertools] [workspace.dependencies.itertools]
@@ -354,7 +354,7 @@ features = ["serde"]
# Email sending # Email sending
[workspace.dependencies.lettre] [workspace.dependencies.lettre]
version = "0.11.18" version = "0.11.19"
default-features = false default-features = false
features = [ features = [
"tokio1-rustls", "tokio1-rustls",
@@ -392,42 +392,40 @@ version = "0.3.0"
# Open Policy Agent support through WASM # Open Policy Agent support through WASM
[workspace.dependencies.opa-wasm] [workspace.dependencies.opa-wasm]
version = "0.1.7" version = "0.1.8"
# OpenTelemetry # OpenTelemetry
[workspace.dependencies.opentelemetry] [workspace.dependencies.opentelemetry]
version = "0.30.0" version = "0.31.0"
features = ["trace", "metrics"] features = ["trace", "metrics"]
[workspace.dependencies.opentelemetry-http] [workspace.dependencies.opentelemetry-http]
version = "0.30.0" version = "0.31.0"
features = ["reqwest"] features = ["reqwest"]
[workspace.dependencies.opentelemetry-jaeger-propagator] [workspace.dependencies.opentelemetry-jaeger-propagator]
version = "0.30.0" version = "0.31.0"
[workspace.dependencies.opentelemetry-otlp] [workspace.dependencies.opentelemetry-otlp]
version = "0.30.0" version = "0.31.0"
default-features = false default-features = false
features = ["trace", "metrics", "http-proto"] features = ["trace", "metrics", "http-proto"]
[workspace.dependencies.opentelemetry-prometheus] [workspace.dependencies.opentelemetry-prometheus-text-exporter]
# https://github.com/open-telemetry/opentelemetry-rust/pull/3076 version = "0.2.1"
git = "https://github.com/sandhose/opentelemetry-rust.git"
branch = "otel-prometheus-0.30"
[workspace.dependencies.opentelemetry-resource-detectors] [workspace.dependencies.opentelemetry-resource-detectors]
version = "0.9.0" version = "0.10.0"
[workspace.dependencies.opentelemetry-semantic-conventions] [workspace.dependencies.opentelemetry-semantic-conventions]
version = "0.30.0" version = "0.31.0"
features = ["semconv_experimental"] features = ["semconv_experimental"]
[workspace.dependencies.opentelemetry-stdout] [workspace.dependencies.opentelemetry-stdout]
version = "0.30.0" version = "0.31.0"
features = ["trace", "metrics"] features = ["trace", "metrics"]
[workspace.dependencies.opentelemetry_sdk] [workspace.dependencies.opentelemetry_sdk]
version = "0.30.0" version = "0.31.0"
features = [ features = [
"experimental_trace_batch_span_processor_with_async_runtime", "experimental_trace_batch_span_processor_with_async_runtime",
"experimental_metrics_periodicreader_with_async_runtime", "experimental_metrics_periodicreader_with_async_runtime",
"rt-tokio", "rt-tokio",
] ]
[workspace.dependencies.tracing-opentelemetry] [workspace.dependencies.tracing-opentelemetry]
version = "0.31.0" version = "0.32.0"
default-features = false default-features = false
# P256 elliptic curve # P256 elliptic curve
@@ -456,11 +454,11 @@ features = ["std"]
# Parser generator # Parser generator
[workspace.dependencies.pest] [workspace.dependencies.pest]
version = "2.8.1" version = "2.8.3"
# Pest derive macros # Pest derive macros
[workspace.dependencies.pest_derive] [workspace.dependencies.pest_derive]
version = "2.8.1" version = "2.8.3"
# Pin projection # Pin projection
[workspace.dependencies.pin-project-lite] [workspace.dependencies.pin-project-lite]
@@ -478,11 +476,7 @@ features = ["std", "pkcs5", "encryption"]
# Public Suffix List # Public Suffix List
[workspace.dependencies.psl] [workspace.dependencies.psl]
version = "2.1.136" version = "2.1.162"
# Prometheus metrics
[workspace.dependencies.prometheus]
version = "0.14.0"
# High-precision clock # High-precision clock
[workspace.dependencies.quanta] [workspace.dependencies.quanta]
@@ -498,11 +492,11 @@ version = "0.6.4"
# Regular expressions # Regular expressions
[workspace.dependencies.regex] [workspace.dependencies.regex]
version = "1.11.2" version = "1.12.2"
# High-level HTTP client # High-level HTTP client
[workspace.dependencies.reqwest] [workspace.dependencies.reqwest]
version = "0.12.23" version = "0.12.24"
default-features = false default-features = false
features = [ features = [
"http2", "http2",
@@ -523,11 +517,11 @@ version = "2.1.1"
# Matrix-related types # Matrix-related types
[workspace.dependencies.ruma-common] [workspace.dependencies.ruma-common]
version = "0.15.4" version = "0.16.0"
# TLS stack # TLS stack
[workspace.dependencies.rustls] [workspace.dependencies.rustls]
version = "0.23.31" version = "0.23.35"
# PEM parsing for rustls # PEM parsing for rustls
[workspace.dependencies.rustls-pemfile] [workspace.dependencies.rustls-pemfile]
@@ -535,7 +529,7 @@ version = "2.2.0"
# PKI types for rustls # PKI types for rustls
[workspace.dependencies.rustls-pki-types] [workspace.dependencies.rustls-pki-types]
version = "1.12.0" version = "1.13.0"
# Use platform-specific verifier for TLS # Use platform-specific verifier for TLS
[workspace.dependencies.rustls-platform-verifier] [workspace.dependencies.rustls-platform-verifier]
@@ -547,8 +541,8 @@ version = "0.4.5"
# JSON Schema generation # JSON Schema generation
[workspace.dependencies.schemars] [workspace.dependencies.schemars]
version = "0.8.22" version = "0.9.0"
features = ["url", "chrono", "preserve_order"] features = ["url2", "chrono04", "preserve_order"]
# SEC1 encoding format # SEC1 encoding format
[workspace.dependencies.sec1] [workspace.dependencies.sec1]
@@ -573,27 +567,27 @@ features = [
# Sentry error tracking # Sentry error tracking
[workspace.dependencies.sentry] [workspace.dependencies.sentry]
version = "0.42.0" version = "0.45.0"
default-features = false default-features = false
features = ["backtrace", "contexts", "panic", "tower", "reqwest"] features = ["backtrace", "contexts", "panic", "tower", "reqwest"]
# Sentry tower layer # Sentry tower layer
[workspace.dependencies.sentry-tower] [workspace.dependencies.sentry-tower]
version = "0.42.0" version = "0.45.0"
features = ["http", "axum-matched-path"] features = ["http", "axum-matched-path"]
# Sentry tracing integration # Sentry tracing integration
[workspace.dependencies.sentry-tracing] [workspace.dependencies.sentry-tracing]
version = "0.42.0" version = "0.45.0"
# Serialization and deserialization # Serialization and deserialization
[workspace.dependencies.serde] [workspace.dependencies.serde]
version = "1.0.219" version = "1.0.228"
features = ["derive"] # Most of the time, if we need serde, we need derive features = ["derive"] # Most of the time, if we need serde, we need derive
# JSON serialization and deserialization # JSON serialization and deserialization
[workspace.dependencies.serde_json] [workspace.dependencies.serde_json]
version = "1.0.143" version = "1.0.145"
features = ["preserve_order"] features = ["preserve_order"]
# URL encoded form serialization # URL encoded form serialization
@@ -620,7 +614,7 @@ version = "2.2.0"
# Low-level socket manipulation # Low-level socket manipulation
[workspace.dependencies.socket2] [workspace.dependencies.socket2]
version = "0.6.0" version = "0.6.1"
# Subject Public Key Info # Subject Public Key Info
[workspace.dependencies.spki] [workspace.dependencies.spki]
@@ -643,14 +637,14 @@ features = [
# Custom error types # Custom error types
[workspace.dependencies.thiserror] [workspace.dependencies.thiserror]
version = "2.0.16" version = "2.0.17"
[workspace.dependencies.thiserror-ext] [workspace.dependencies.thiserror-ext]
version = "0.3.0" version = "0.3.0"
# Async runtime # Async runtime
[workspace.dependencies.tokio] [workspace.dependencies.tokio]
version = "1.47.1" version = "1.48.0"
features = ["full"] features = ["full"]
[workspace.dependencies.tokio-stream] [workspace.dependencies.tokio-stream]
@@ -658,7 +652,7 @@ version = "0.1.17"
# Tokio rustls integration # Tokio rustls integration
[workspace.dependencies.tokio-rustls] [workspace.dependencies.tokio-rustls]
version = "0.26.2" version = "0.26.4"
# Tokio test utilities # Tokio test utilities
[workspace.dependencies.tokio-test] [workspace.dependencies.tokio-test]
@@ -712,7 +706,7 @@ features = ["serde", "uuid"]
# UUID support # UUID support
[workspace.dependencies.uuid] [workspace.dependencies.uuid]
version = "1.18.0" version = "1.18.1"
# HTML escaping # HTML escaping
[workspace.dependencies.v_htmlescape] [workspace.dependencies.v_htmlescape]
@@ -741,7 +735,7 @@ version = "0.5.5"
# Zero memory after use # Zero memory after use
[workspace.dependencies.zeroize] [workspace.dependencies.zeroize]
version = "1.8.1" version = "1.8.2"
# Password strength estimation # Password strength estimation
[workspace.dependencies.zxcvbn] [workspace.dependencies.zxcvbn]

View File

@@ -13,9 +13,10 @@
ARG DEBIAN_VERSION=12 ARG DEBIAN_VERSION=12
ARG DEBIAN_VERSION_NAME=bookworm ARG DEBIAN_VERSION_NAME=bookworm
ARG RUSTC_VERSION=1.89.0 ARG RUSTC_VERSION=1.89.0
ARG NODEJS_VERSION=20.15.0 ARG NODEJS_VERSION=24.11.0
ARG OPA_VERSION=1.1.0 # Keep in sync with .github/actions/build-policies/action.yml and policies/Makefile
ARG CARGO_AUDITABLE_VERSION=0.6.6 ARG OPA_VERSION=1.8.0
ARG CARGO_AUDITABLE_VERSION=0.7.0
########################################## ##########################################
## Build stage that builds the frontend ## ## Build stage that builds the frontend ##
@@ -24,7 +25,7 @@ FROM --platform=${BUILDPLATFORM} docker.io/library/node:${NODEJS_VERSION}-${DEBI
WORKDIR /app/frontend WORKDIR /app/frontend
COPY ./frontend/package.json ./frontend/package-lock.json /app/frontend/ COPY ./frontend/.npmrc ./frontend/package.json ./frontend/package-lock.json /app/frontend/
# Network access: to fetch dependencies # Network access: to fetch dependencies
RUN --network=default \ RUN --network=default \
npm ci npm ci

View File

@@ -1,5 +1,5 @@
{ {
"$schema": "https://biomejs.dev/schemas/2.0.6/schema.json", "$schema": "https://biomejs.dev/schemas/2.2.4/schema.json",
"assist": { "actions": { "source": { "organizeImports": "on" } } }, "assist": { "actions": { "source": { "organizeImports": "on" } } },
"vcs": { "vcs": {
"enabled": true, "enabled": true,
@@ -32,6 +32,12 @@
"enabled": true, "enabled": true,
"rules": { "rules": {
"recommended": true, "recommended": true,
"complexity": {
"noImportantStyles": "off"
},
"suspicious": {
"noUnknownAtRules": "off"
},
"correctness": { "correctness": {
"noUnusedImports": "warn", "noUnusedImports": "warn",
"noUnusedVariables": "warn" "noUnusedVariables": "warn"

View File

@@ -17,4 +17,6 @@ disallowed-methods = [
disallowed-types = [ disallowed-types = [
{ path = "std::path::PathBuf", reason = "use camino::Utf8PathBuf instead" }, { path = "std::path::PathBuf", reason = "use camino::Utf8PathBuf instead" },
{ path = "std::path::Path", reason = "use camino::Utf8Path instead" }, { path = "std::path::Path", reason = "use camino::Utf8Path instead" },
{ path = "axum::extract::Query", reason = "use axum_extra::extract::Query instead. The built-in version doesn't deserialise lists."},
{ path = "axum::extract::rejection::QueryRejection", reason = "use axum_extra::extract::QueryRejection instead"}
] ]

View File

@@ -59,12 +59,11 @@ opentelemetry.workspace = true
opentelemetry-http.workspace = true opentelemetry-http.workspace = true
opentelemetry-jaeger-propagator.workspace = true opentelemetry-jaeger-propagator.workspace = true
opentelemetry-otlp.workspace = true opentelemetry-otlp.workspace = true
opentelemetry-prometheus.workspace = true opentelemetry-prometheus-text-exporter.workspace = true
opentelemetry-resource-detectors.workspace = true opentelemetry-resource-detectors.workspace = true
opentelemetry-semantic-conventions.workspace = true opentelemetry-semantic-conventions.workspace = true
opentelemetry-stdout.workspace = true opentelemetry-stdout.workspace = true
opentelemetry_sdk.workspace = true opentelemetry_sdk.workspace = true
prometheus.workspace = true
sentry.workspace = true sentry.workspace = true
sentry-tracing.workspace = true sentry-tracing.workspace = true
sentry-tower.workspace = true sentry-tower.workspace = true

View File

@@ -9,7 +9,7 @@ use std::{convert::Infallible, net::IpAddr, sync::Arc};
use axum::extract::{FromRef, FromRequestParts}; use axum::extract::{FromRef, FromRequestParts};
use ipnetwork::IpNetwork; use ipnetwork::IpNetwork;
use mas_context::LogContext; use mas_context::LogContext;
use mas_data_model::{BoxClock, BoxRng, SiteConfig, SystemClock}; use mas_data_model::{AppVersion, BoxClock, BoxRng, SiteConfig, SystemClock};
use mas_handlers::{ use mas_handlers::{
ActivityTracker, BoundActivityTracker, CookieManager, ErrorWrapper, GraphQLSchema, Limiter, ActivityTracker, BoundActivityTracker, CookieManager, ErrorWrapper, GraphQLSchema, Limiter,
MetadataCache, RequesterFingerprint, passwords::PasswordManager, MetadataCache, RequesterFingerprint, passwords::PasswordManager,
@@ -27,7 +27,7 @@ use rand::SeedableRng;
use sqlx::PgPool; use sqlx::PgPool;
use tracing::Instrument; use tracing::Instrument;
use crate::telemetry::METER; use crate::{VERSION, telemetry::METER};
#[derive(Clone)] #[derive(Clone)]
pub struct AppState { pub struct AppState {
@@ -214,6 +214,12 @@ impl FromRef<AppState> for Arc<dyn HomeserverConnection> {
} }
} }
impl FromRef<AppState> for AppVersion {
fn from_ref(_input: &AppState) -> Self {
AppVersion(VERSION)
}
}
impl FromRequestParts<AppState> for BoxClock { impl FromRequestParts<AppState> for BoxClock {
type Rejection = Infallible; type Rejection = Infallible;

View File

@@ -19,14 +19,17 @@ use mas_data_model::{Clock, Device, SystemClock, TokenType, Ulid, UpstreamOAuthP
use mas_email::Address; use mas_email::Address;
use mas_matrix::HomeserverConnection; use mas_matrix::HomeserverConnection;
use mas_storage::{ use mas_storage::{
RepositoryAccess, Pagination, RepositoryAccess,
compat::{CompatAccessTokenRepository, CompatSessionFilter, CompatSessionRepository}, compat::{CompatAccessTokenRepository, CompatSessionFilter, CompatSessionRepository},
oauth2::OAuth2SessionFilter, oauth2::OAuth2SessionFilter,
queue::{ queue::{
DeactivateUserJob, ProvisionUserJob, QueueJobRepositoryExt as _, ReactivateUserJob, DeactivateUserJob, ProvisionUserJob, QueueJobRepositoryExt as _, ReactivateUserJob,
SyncDevicesJob, SyncDevicesJob,
}, },
user::{BrowserSessionFilter, UserEmailRepository, UserPasswordRepository, UserRepository}, user::{
BrowserSessionFilter, UserEmailRepository, UserFilter, UserPasswordRepository,
UserRepository,
},
}; };
use mas_storage_pg::{DatabaseError, PgRepository}; use mas_storage_pg::{DatabaseError, PgRepository};
use rand::{ use rand::{
@@ -85,6 +88,15 @@ enum Subcommand {
ignore_complexity: bool, ignore_complexity: bool,
}, },
/// Make a user admin
PromoteAdmin { username: String },
/// Make a user non-admin
DemoteAdmin { username: String },
/// List all users with admin privileges
ListAdminUsers,
/// Issue a compatibility token /// Issue a compatibility token
IssueCompatibilityToken { IssueCompatibilityToken {
/// User for which to issue the token /// User for which to issue the token
@@ -315,6 +327,83 @@ impl Options {
Ok(ExitCode::SUCCESS) Ok(ExitCode::SUCCESS)
} }
SC::PromoteAdmin { username } => {
let _span =
info_span!("cli.manage.promote_admin", user.username = username,).entered();
let database_config = DatabaseConfig::extract_or_default(figment)
.map_err(anyhow::Error::from_boxed)?;
let mut conn = database_connection_from_config(&database_config).await?;
let txn = conn.begin().await?;
let mut repo = PgRepository::from_conn(txn);
let user = repo
.user()
.find_by_username(&username)
.await?
.context("User not found")?;
let user = repo.user().set_can_request_admin(user, true).await?;
repo.into_inner().commit().await?;
info!(%user.id, %user.username, "User promoted to admin");
Ok(ExitCode::SUCCESS)
}
SC::DemoteAdmin { username } => {
let _span =
info_span!("cli.manage.demote_admin", user.username = username,).entered();
let database_config = DatabaseConfig::extract_or_default(figment)
.map_err(anyhow::Error::from_boxed)?;
let mut conn = database_connection_from_config(&database_config).await?;
let txn = conn.begin().await?;
let mut repo = PgRepository::from_conn(txn);
let user = repo
.user()
.find_by_username(&username)
.await?
.context("User not found")?;
let user = repo.user().set_can_request_admin(user, false).await?;
repo.into_inner().commit().await?;
info!(%user.id, %user.username, "User is no longer admin");
Ok(ExitCode::SUCCESS)
}
SC::ListAdminUsers => {
let _span = info_span!("cli.manage.list_admins").entered();
let database_config = DatabaseConfig::extract_or_default(figment)
.map_err(anyhow::Error::from_boxed)?;
let mut conn = database_connection_from_config(&database_config).await?;
let txn = conn.begin().await?;
let mut repo = PgRepository::from_conn(txn);
let mut cursor = Pagination::first(1000);
let filter = UserFilter::new().can_request_admin_only();
let total = repo.user().count(filter).await?;
info!("The following users can request admin privileges ({total} total):");
loop {
let page = repo.user().list(filter, cursor).await?;
for edge in page.edges {
let user = edge.node;
info!(%user.id, username = %user.username);
cursor = cursor.after(edge.cursor);
}
if !page.has_next_page {
break;
}
}
Ok(ExitCode::SUCCESS)
}
SC::IssueCompatibilityToken { SC::IssueCompatibilityToken {
username, username,
admin, admin,

View File

@@ -160,8 +160,14 @@ impl Options {
)?; )?;
// Load and compile the templates // Load and compile the templates
let templates = let templates = templates_from_config(
templates_from_config(&config.templates, &site_config, &url_builder).await?; &config.templates,
&site_config,
&url_builder,
// Don't use strict mode in production yet
false,
)
.await?;
shutdown.register_reloadable(&templates); shutdown.register_reloadable(&templates);
let http_client = mas_http::reqwest_client(); let http_client = mas_http::reqwest_client();

View File

@@ -4,8 +4,10 @@
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial // SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details. // Please see LICENSE files in the repository root for full details.
use std::process::ExitCode; use std::{fmt::Write, process::ExitCode};
use anyhow::{Context as _, bail};
use camino::Utf8PathBuf;
use clap::Parser; use clap::Parser;
use figment::Figment; use figment::Figment;
use mas_config::{ use mas_config::{
@@ -27,14 +29,19 @@ pub(super) struct Options {
#[derive(Parser, Debug)] #[derive(Parser, Debug)]
enum Subcommand { enum Subcommand {
/// Check that the templates specified in the config are valid /// Check that the templates specified in the config are valid
Check, Check {
/// If set, templates will be rendered to this directory.
/// The directory must either not exist or be empty.
#[arg(long = "out-dir")]
out_dir: Option<Utf8PathBuf>,
},
} }
impl Options { impl Options {
pub async fn run(self, figment: &Figment) -> anyhow::Result<ExitCode> { pub async fn run(self, figment: &Figment) -> anyhow::Result<ExitCode> {
use Subcommand as SC; use Subcommand as SC;
match self.subcommand { match self.subcommand {
SC::Check => { SC::Check { out_dir } => {
let _span = info_span!("cli.templates.check").entered(); let _span = info_span!("cli.templates.check").entered();
let template_config = TemplatesConfig::extract_or_default(figment) let template_config = TemplatesConfig::extract_or_default(figment)
@@ -65,9 +72,54 @@ impl Options {
&account_config, &account_config,
&captcha_config, &captcha_config,
)?; )?;
let templates = let templates = templates_from_config(
templates_from_config(&template_config, &site_config, &url_builder).await?; &template_config,
templates.check_render(clock.now(), &mut rng)?; &site_config,
&url_builder, // Use strict mode in template checks
true,
)
.await?;
let all_renders = templates.check_render(clock.now(), &mut rng)?;
if let Some(out_dir) = out_dir {
// Save renders to disk.
if out_dir.exists() {
let mut read_dir =
tokio::fs::read_dir(&out_dir).await.with_context(|| {
format!("could not read {out_dir} to check it's empty")
})?;
if read_dir.next_entry().await?.is_some() {
bail!("Render directory {out_dir} is not empty, refusing to write.");
}
} else {
tokio::fs::create_dir(&out_dir)
.await
.with_context(|| format!("could not create {out_dir}"))?;
}
for ((template, sample_identifier), template_render) in &all_renders {
let (template_filename_base, template_ext) =
template.rsplit_once('.').unwrap_or((template, "txt"));
let template_filename_base = template_filename_base.replace('/', "_");
// Make a string like `-index=0-browser-session=0-locale=fr`
let sample_suffix = {
let mut s = String::new();
for (k, v) in &sample_identifier.components {
write!(s, "-{k}={v}")?;
}
s
};
let render_path = out_dir.join(format!(
"{template_filename_base}{sample_suffix}.{template_ext}"
));
tokio::fs::write(&render_path, template_render.as_bytes())
.await
.with_context(|| format!("could not write render to {render_path}"))?;
}
}
Ok(ExitCode::SUCCESS) Ok(ExitCode::SUCCESS)
} }

View File

@@ -52,8 +52,14 @@ impl Options {
)?; )?;
// Load and compile the templates // Load and compile the templates
let templates = let templates = templates_from_config(
templates_from_config(&config.templates, &site_config, &url_builder).await?; &config.templates,
&site_config,
&url_builder,
// Don't use strict mode on task workers for now
false,
)
.await?;
let mailer = mailer_from_config(&config.email, &templates)?; let mailer = mailer_from_config(&config.email, &templates)?;
test_mailer_in_background(&mailer, Duration::from_secs(30)); test_mailer_in_background(&mailer, Duration::from_secs(30));

View File

@@ -149,12 +149,14 @@ async fn try_main() -> anyhow::Result<ExitCode> {
// Setup OpenTelemetry tracing and metrics // Setup OpenTelemetry tracing and metrics
self::telemetry::setup(&telemetry_config).context("failed to setup OpenTelemetry")?; self::telemetry::setup(&telemetry_config).context("failed to setup OpenTelemetry")?;
let telemetry_layer = self::telemetry::TRACER.get().map(|tracer| { let tracer = self::telemetry::TRACER
tracing_opentelemetry::layer() .get()
.with_tracer(tracer.clone()) .context("TRACER was not set")?;
.with_tracked_inactivity(false)
.with_filter(LevelFilter::INFO) let telemetry_layer = tracing_opentelemetry::layer()
}); .with_tracer(tracer.clone())
.with_tracked_inactivity(false)
.with_filter(LevelFilter::INFO);
let subscriber = Registry::default() let subscriber = Registry::default()
.with(suppress_layer) .with(suppress_layer)

View File

@@ -136,14 +136,24 @@ fn make_http_span<B>(req: &Request<B>) -> Span {
span.record(USER_AGENT_ORIGINAL, user_agent); span.record(USER_AGENT_ORIGINAL, user_agent);
} }
// Extract the parent span context from the request headers // In case the span is disabled by any of tracing layers, e.g. if `RUST_LOG`
let parent_context = opentelemetry::global::get_text_map_propagator(|propagator| { // is set to `warn`, `set_parent` will fail. So we only try to set the
let extractor = HeaderExtractor(req.headers()); // parent context if the span is not disabled.
let context = opentelemetry::Context::new(); if !span.is_disabled() {
propagator.extract_with_context(&context, &extractor) // Extract the parent span context from the request headers
}); let parent_context = opentelemetry::global::get_text_map_propagator(|propagator| {
let extractor = HeaderExtractor(req.headers());
let context = opentelemetry::Context::new();
propagator.extract_with_context(&context, &extractor)
});
span.set_parent(parent_context); if let Err(err) = span.set_parent(parent_context) {
tracing::error!(
error = &err as &dyn std::error::Error,
"Failed to set parent context on span"
);
}
}
span span
} }

View File

@@ -132,7 +132,8 @@ pub async fn config_sync(
let mut existing_enabled_ids = BTreeSet::new(); let mut existing_enabled_ids = BTreeSet::new();
let mut existing_disabled = BTreeMap::new(); let mut existing_disabled = BTreeMap::new();
// Process the existing providers // Process the existing providers
for provider in page.edges { for edge in page.edges {
let provider = edge.node;
if provider.enabled() { if provider.enabled() {
if config_ids.contains(&provider.id) { if config_ids.contains(&provider.id) {
existing_enabled_ids.insert(provider.id); existing_enabled_ids.insert(provider.id);
@@ -201,25 +202,24 @@ pub async fn config_sync(
continue; continue;
} }
let encrypted_client_secret = let encrypted_client_secret = if let Some(client_secret) = provider.client_secret {
if let Some(client_secret) = provider.client_secret.as_deref() { Some(encrypter.encrypt_to_string(client_secret.value().await?.as_bytes())?)
Some(encrypter.encrypt_to_string(client_secret.as_bytes())?) } else if let Some(mut siwa) = provider.sign_in_with_apple.clone() {
} else if let Some(mut siwa) = provider.sign_in_with_apple.clone() { // if private key file is defined and not private key (raw), we populate the
// if private key file is defined and not private key (raw), we populate the // private key to hold the content of the private key file.
// private key to hold the content of the private key file. // private key (raw) takes precedence so both can be defined
// private key (raw) takes precedence so both can be defined // without issues
// without issues if siwa.private_key.is_none()
if siwa.private_key.is_none() && let Some(private_key_file) = siwa.private_key_file.take()
&& let Some(private_key_file) = siwa.private_key_file.take() {
{ let key = tokio::fs::read_to_string(private_key_file).await?;
let key = tokio::fs::read_to_string(private_key_file).await?; siwa.private_key = Some(key);
siwa.private_key = Some(key); }
} let encoded = serde_json::to_vec(&siwa)?;
let encoded = serde_json::to_vec(&siwa)?; Some(encrypter.encrypt_to_string(&encoded)?)
Some(encrypter.encrypt_to_string(&encoded)?) } else {
} else { None
None };
};
let discovery_mode = match provider.discovery_mode { let discovery_mode = match provider.discovery_mode {
mas_config::UpstreamOAuth2DiscoveryMode::Oidc => { mas_config::UpstreamOAuth2DiscoveryMode::Oidc => {

View File

@@ -23,18 +23,17 @@ use opentelemetry::{
trace::TracerProvider as _, trace::TracerProvider as _,
}; };
use opentelemetry_otlp::{WithExportConfig, WithHttpConfig}; use opentelemetry_otlp::{WithExportConfig, WithHttpConfig};
use opentelemetry_prometheus::PrometheusExporter; use opentelemetry_prometheus_text_exporter::PrometheusExporter;
use opentelemetry_sdk::{ use opentelemetry_sdk::{
Resource, Resource,
metrics::{ManualReader, SdkMeterProvider, periodic_reader_with_async_runtime::PeriodicReader}, metrics::{ManualReader, SdkMeterProvider, periodic_reader_with_async_runtime::PeriodicReader},
propagation::{BaggagePropagator, TraceContextPropagator}, propagation::{BaggagePropagator, TraceContextPropagator},
trace::{ trace::{
Sampler, SdkTracerProvider, Tracer, span_processor_with_async_runtime::BatchSpanProcessor, IdGenerator, Sampler, SdkTracerProvider, Tracer,
span_processor_with_async_runtime::BatchSpanProcessor,
}, },
}; };
use opentelemetry_semantic_conventions as semcov; use opentelemetry_semantic_conventions as semcov;
use prometheus::Registry;
use url::Url;
static SCOPE: LazyLock<InstrumentationScope> = LazyLock::new(|| { static SCOPE: LazyLock<InstrumentationScope> = LazyLock::new(|| {
InstrumentationScope::builder(env!("CARGO_PKG_NAME")) InstrumentationScope::builder(env!("CARGO_PKG_NAME"))
@@ -49,7 +48,7 @@ pub static METER: LazyLock<Meter> =
pub static TRACER: OnceLock<Tracer> = OnceLock::new(); pub static TRACER: OnceLock<Tracer> = OnceLock::new();
static METER_PROVIDER: OnceLock<SdkMeterProvider> = OnceLock::new(); static METER_PROVIDER: OnceLock<SdkMeterProvider> = OnceLock::new();
static TRACER_PROVIDER: OnceLock<SdkTracerProvider> = OnceLock::new(); static TRACER_PROVIDER: OnceLock<SdkTracerProvider> = OnceLock::new();
static PROMETHEUS_REGISTRY: OnceLock<Registry> = OnceLock::new(); static PROMETHEUS_EXPORTER: OnceLock<PrometheusExporter> = OnceLock::new();
pub fn setup(config: &TelemetryConfig) -> anyhow::Result<()> { pub fn setup(config: &TelemetryConfig) -> anyhow::Result<()> {
let propagator = propagator(&config.tracing.propagators); let propagator = propagator(&config.tracing.propagators);
@@ -95,50 +94,65 @@ fn propagator(propagators: &[Propagator]) -> TextMapCompositePropagator {
TextMapCompositePropagator::new(propagators) TextMapCompositePropagator::new(propagators)
} }
fn stdout_tracer_provider() -> SdkTracerProvider { /// An [`IdGenerator`] which always returns an invalid trace ID and span ID
let exporter = opentelemetry_stdout::SpanExporter::default(); ///
SdkTracerProvider::builder() /// This is used when no exporter is being used, so that we don't log the trace
.with_simple_exporter(exporter) /// ID when we're not tracing.
.build() #[derive(Debug, Clone, Copy)]
struct InvalidIdGenerator;
impl IdGenerator for InvalidIdGenerator {
fn new_trace_id(&self) -> opentelemetry::TraceId {
opentelemetry::TraceId::INVALID
}
fn new_span_id(&self) -> opentelemetry::SpanId {
opentelemetry::SpanId::INVALID
}
} }
fn otlp_tracer_provider( fn init_tracer(config: &TracingConfig) -> anyhow::Result<()> {
endpoint: Option<&Url>, let sample_rate = config.sample_rate.unwrap_or(1.0);
sample_rate: f64,
) -> anyhow::Result<SdkTracerProvider> {
let mut exporter = opentelemetry_otlp::SpanExporter::builder()
.with_http()
.with_http_client(mas_http::reqwest_client());
if let Some(endpoint) = endpoint {
exporter = exporter.with_endpoint(endpoint.to_string());
}
let exporter = exporter
.build()
.context("Failed to configure OTLP trace exporter")?;
let batch_processor =
BatchSpanProcessor::builder(exporter, opentelemetry_sdk::runtime::Tokio).build();
// We sample traces based on the parent if we have one, and if not, we // We sample traces based on the parent if we have one, and if not, we
// sample a ratio based on the configured sample rate // sample a ratio based on the configured sample rate
let sampler = Sampler::ParentBased(Box::new(Sampler::TraceIdRatioBased(sample_rate))); let sampler = Sampler::ParentBased(Box::new(Sampler::TraceIdRatioBased(sample_rate)));
let tracer_provider = SdkTracerProvider::builder() let tracer_provider_builder = SdkTracerProvider::builder()
.with_span_processor(batch_processor)
.with_resource(resource()) .with_resource(resource())
.with_sampler(sampler) .with_sampler(sampler);
.build();
Ok(tracer_provider)
}
fn init_tracer(config: &TracingConfig) -> anyhow::Result<()> {
let sample_rate = config.sample_rate.unwrap_or(1.0);
let tracer_provider = match config.exporter { let tracer_provider = match config.exporter {
TracingExporterKind::None => return Ok(()), TracingExporterKind::None => tracer_provider_builder
TracingExporterKind::Stdout => stdout_tracer_provider(), .with_id_generator(InvalidIdGenerator)
TracingExporterKind::Otlp => otlp_tracer_provider(config.endpoint.as_ref(), sample_rate)?, .with_sampler(Sampler::AlwaysOff)
.build(),
TracingExporterKind::Stdout => {
let exporter = opentelemetry_stdout::SpanExporter::default();
tracer_provider_builder
.with_simple_exporter(exporter)
.build()
}
TracingExporterKind::Otlp => {
let mut exporter = opentelemetry_otlp::SpanExporter::builder()
.with_http()
.with_http_client(mas_http::reqwest_client());
if let Some(endpoint) = &config.endpoint {
exporter = exporter.with_endpoint(endpoint.as_str());
}
let exporter = exporter
.build()
.context("Failed to configure OTLP trace exporter")?;
let batch_processor =
BatchSpanProcessor::builder(exporter, opentelemetry_sdk::runtime::Tokio).build();
tracer_provider_builder
.with_span_processor(batch_processor)
.build()
}
}; };
TRACER_PROVIDER TRACER_PROVIDER
.set(tracer_provider.clone()) .set(tracer_provider.clone())
.map_err(|_| anyhow::anyhow!("TRACER_PROVIDER was set twice"))?; .map_err(|_| anyhow::anyhow!("TRACER_PROVIDER was set twice"))?;
@@ -180,21 +194,30 @@ type PromServiceFuture =
#[allow(clippy::needless_pass_by_value)] #[allow(clippy::needless_pass_by_value)]
fn prometheus_service_fn<T>(_req: T) -> PromServiceFuture { fn prometheus_service_fn<T>(_req: T) -> PromServiceFuture {
use prometheus::{Encoder, TextEncoder}; let response = if let Some(exporter) = PROMETHEUS_EXPORTER.get() {
// We'll need some space for this, so we preallocate a bit
let mut buffer = Vec::with_capacity(1024);
let response = if let Some(registry) = PROMETHEUS_REGISTRY.get() { if let Err(err) = exporter.export(&mut buffer) {
let mut buffer = Vec::new(); tracing::error!(
let encoder = TextEncoder::new(); error = &err as &dyn std::error::Error,
let metric_families = registry.gather(); "Failed to export Prometheus metrics"
);
// That shouldn't panic, unless we're constructing invalid labels Response::builder()
encoder.encode(&metric_families, &mut buffer).unwrap(); .status(500)
.header(CONTENT_TYPE, "text/plain")
Response::builder() .body(Full::new(Bytes::from_static(
.status(200) b"Failed to export Prometheus metrics, see logs for details",
.header(CONTENT_TYPE, encoder.format_type()) )))
.body(Full::new(Bytes::from(buffer))) .unwrap()
.unwrap() } else {
Response::builder()
.status(200)
.header(CONTENT_TYPE, "text/plain;version=1.0.0")
.body(Full::new(Bytes::from(buffer)))
.unwrap()
}
} else { } else {
Response::builder() Response::builder()
.status(500) .status(500)
@@ -209,7 +232,7 @@ fn prometheus_service_fn<T>(_req: T) -> PromServiceFuture {
} }
pub fn prometheus_service<T>() -> tower::util::ServiceFn<fn(T) -> PromServiceFuture> { pub fn prometheus_service<T>() -> tower::util::ServiceFn<fn(T) -> PromServiceFuture> {
if PROMETHEUS_REGISTRY.get().is_none() { if PROMETHEUS_EXPORTER.get().is_none() {
tracing::warn!( tracing::warn!(
"A Prometheus resource was mounted on a listener, but the Prometheus exporter was not setup in the config" "A Prometheus resource was mounted on a listener, but the Prometheus exporter was not setup in the config"
); );
@@ -219,16 +242,11 @@ pub fn prometheus_service<T>() -> tower::util::ServiceFn<fn(T) -> PromServiceFut
} }
fn prometheus_metric_reader() -> anyhow::Result<PrometheusExporter> { fn prometheus_metric_reader() -> anyhow::Result<PrometheusExporter> {
let registry = Registry::new(); let exporter = PrometheusExporter::builder().without_scope_info().build();
PROMETHEUS_REGISTRY PROMETHEUS_EXPORTER
.set(registry.clone()) .set(exporter.clone())
.map_err(|_| anyhow::anyhow!("PROMETHEUS_REGISTRY was set twice"))?; .map_err(|_| anyhow::anyhow!("PROMETHEUS_EXPORTER was set twice"))?;
let exporter = opentelemetry_prometheus::exporter()
.with_registry(registry)
.without_scope_info()
.build()?;
Ok(exporter) Ok(exporter)
} }

View File

@@ -211,6 +211,7 @@ pub fn site_config_from_config(
password_login_enabled: password_config.enabled(), password_login_enabled: password_config.enabled(),
password_registration_enabled: password_config.enabled() password_registration_enabled: password_config.enabled()
&& account_config.password_registration_enabled, && account_config.password_registration_enabled,
password_registration_email_required: account_config.password_registration_email_required,
registration_token_required: account_config.registration_token_required, registration_token_required: account_config.registration_token_required,
email_change_allowed: account_config.email_change_allowed, email_change_allowed: account_config.email_change_allowed,
displayname_change_allowed: account_config.displayname_change_allowed, displayname_change_allowed: account_config.displayname_change_allowed,
@@ -231,6 +232,7 @@ pub async fn templates_from_config(
config: &TemplatesConfig, config: &TemplatesConfig,
site_config: &SiteConfig, site_config: &SiteConfig,
url_builder: &UrlBuilder, url_builder: &UrlBuilder,
strict: bool,
) -> Result<Templates, anyhow::Error> { ) -> Result<Templates, anyhow::Error> {
Templates::load( Templates::load(
config.path.clone(), config.path.clone(),
@@ -239,6 +241,7 @@ pub async fn templates_from_config(
config.translations_path.clone(), config.translations_path.clone(),
site_config.templates_branding(), site_config.templates_branding(),
site_config.templates_features(), site_config.templates_features(),
strict,
) )
.await .await
.with_context(|| format!("Failed to load the templates at {}", config.path)) .with_context(|| format!("Failed to load the templates at {}", config.path))

View File

@@ -4,14 +4,10 @@
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial // SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details. // Please see LICENSE files in the repository root for full details.
use schemars::r#gen::SchemaSettings; use schemars::generate::SchemaSettings;
fn main() { fn main() {
let settings = SchemaSettings::draft07().with(|s| { let generator = SchemaSettings::draft07().into_generator();
s.option_nullable = false;
s.option_add_null_type = false;
});
let generator = settings.into_generator();
let schema = generator.into_root_schema_for::<mas_config::RootConfig>(); let schema = generator.into_root_schema_for::<mas_config::RootConfig>();
serde_json::to_writer_pretty(std::io::stdout(), &schema).expect("Failed to serialize schema"); serde_json::to_writer_pretty(std::io::stdout(), &schema).expect("Failed to serialize schema");

View File

@@ -6,29 +6,22 @@
//! Useful JSON Schema definitions //! Useful JSON Schema definitions
use schemars::{ use std::borrow::Cow;
JsonSchema,
r#gen::SchemaGenerator, use schemars::{JsonSchema, Schema, SchemaGenerator, json_schema};
schema::{InstanceType, Schema, SchemaObject},
};
/// A network hostname /// A network hostname
pub struct Hostname; pub struct Hostname;
impl JsonSchema for Hostname { impl JsonSchema for Hostname {
fn schema_name() -> String { fn schema_name() -> Cow<'static, str> {
"Hostname".to_string() Cow::Borrowed("Hostname")
} }
fn json_schema(generator: &mut SchemaGenerator) -> Schema { fn json_schema(_generator: &mut SchemaGenerator) -> Schema {
hostname(generator) json_schema!({
"type": "string",
"format": "hostname",
})
} }
} }
fn hostname(_gen: &mut SchemaGenerator) -> Schema {
Schema::Object(SchemaObject {
instance_type: Some(InstanceType::String.into()),
format: Some("hostname".to_owned()),
..SchemaObject::default()
})
}

View File

@@ -50,6 +50,13 @@ pub struct AccountConfig {
#[serde(default = "default_false", skip_serializing_if = "is_default_false")] #[serde(default = "default_false", skip_serializing_if = "is_default_false")]
pub password_registration_enabled: bool, pub password_registration_enabled: bool,
/// Whether self-service password registrations require a valid email.
/// Defaults to `true`.
///
/// This has no effect if password registration is disabled.
#[serde(default = "default_true", skip_serializing_if = "is_default_true")]
pub password_registration_email_required: bool,
/// Whether users are allowed to change their passwords. Defaults to `true`. /// Whether users are allowed to change their passwords. Defaults to `true`.
/// ///
/// This has no effect if password login is disabled. /// This has no effect if password login is disabled.
@@ -89,6 +96,7 @@ impl Default for AccountConfig {
email_change_allowed: default_true(), email_change_allowed: default_true(),
displayname_change_allowed: default_true(), displayname_change_allowed: default_true(),
password_registration_enabled: default_false(), password_registration_enabled: default_false(),
password_registration_email_required: default_true(),
password_change_allowed: default_true(), password_change_allowed: default_true(),
password_recovery_enabled: default_false(), password_recovery_enabled: default_false(),
account_deactivation_allowed: default_true(), account_deactivation_allowed: default_true(),

View File

@@ -6,8 +6,6 @@
use std::ops::Deref; use std::ops::Deref;
use anyhow::bail;
use camino::Utf8PathBuf;
use mas_iana::oauth::OAuthClientAuthenticationMethod; use mas_iana::oauth::OAuthClientAuthenticationMethod;
use mas_jose::jwk::PublicJsonWebKeySet; use mas_jose::jwk::PublicJsonWebKeySet;
use schemars::JsonSchema; use schemars::JsonSchema;
@@ -16,7 +14,7 @@ use serde_with::serde_as;
use ulid::Ulid; use ulid::Ulid;
use url::Url; use url::Url;
use super::ConfigurationSection; use super::{ClientSecret, ClientSecretRaw, ConfigurationSection};
#[derive(JsonSchema, Serialize, Deserialize, Clone, Debug)] #[derive(JsonSchema, Serialize, Deserialize, Clone, Debug)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
@@ -31,66 +29,6 @@ impl From<PublicJsonWebKeySet> for JwksOrJwksUri {
} }
} }
/// Client secret config option.
///
/// It either holds the client secret value directly or references a file where
/// the client secret is stored.
#[derive(Clone, Debug)]
pub enum ClientSecret {
File(Utf8PathBuf),
Value(String),
}
/// Client secret fields as serialized in JSON.
#[derive(JsonSchema, Serialize, Deserialize, Clone, Debug)]
struct ClientSecretRaw {
/// Path to the file containing the client secret. The client secret is used
/// by the `client_secret_basic`, `client_secret_post` and
/// `client_secret_jwt` authentication methods.
#[schemars(with = "Option<String>")]
#[serde(skip_serializing_if = "Option::is_none")]
client_secret_file: Option<Utf8PathBuf>,
/// Alternative to `client_secret_file`: Reads the client secret directly
/// from the config.
#[serde(skip_serializing_if = "Option::is_none")]
client_secret: Option<String>,
}
impl TryFrom<ClientSecretRaw> for Option<ClientSecret> {
type Error = anyhow::Error;
fn try_from(value: ClientSecretRaw) -> Result<Self, Self::Error> {
match (value.client_secret, value.client_secret_file) {
(None, None) => Ok(None),
(None, Some(path)) => Ok(Some(ClientSecret::File(path))),
(Some(client_secret), None) => Ok(Some(ClientSecret::Value(client_secret))),
(Some(_), Some(_)) => {
bail!("Cannot specify both `client_secret` and `client_secret_file`")
}
}
}
}
impl From<Option<ClientSecret>> for ClientSecretRaw {
fn from(value: Option<ClientSecret>) -> Self {
match value {
Some(ClientSecret::File(path)) => ClientSecretRaw {
client_secret_file: Some(path),
client_secret: None,
},
Some(ClientSecret::Value(client_secret)) => ClientSecretRaw {
client_secret_file: None,
client_secret: Some(client_secret),
},
None => ClientSecretRaw {
client_secret_file: None,
client_secret: None,
},
}
}
}
/// Authentication method used by clients /// Authentication method used by clients
#[derive(JsonSchema, Serialize, Deserialize, Copy, Clone, Debug)] #[derive(JsonSchema, Serialize, Deserialize, Copy, Clone, Debug)]
#[serde(rename_all = "snake_case")] #[serde(rename_all = "snake_case")]
@@ -273,8 +211,7 @@ impl ClientConfig {
/// Returns an error when the client secret could not be read from file. /// Returns an error when the client secret could not be read from file.
pub async fn client_secret(&self) -> anyhow::Result<Option<String>> { pub async fn client_secret(&self) -> anyhow::Result<Option<String>> {
Ok(match &self.client_secret { Ok(match &self.client_secret {
Some(ClientSecret::File(path)) => Some(tokio::fs::read_to_string(path).await?), Some(client_secret) => Some(client_secret.value().await?),
Some(ClientSecret::Value(client_secret)) => Some(client_secret.clone()),
None => None, None => None,
}) })
} }

View File

@@ -23,19 +23,6 @@ fn default_public_base() -> Url {
"http://[::]:8080".parse().unwrap() "http://[::]:8080".parse().unwrap()
} }
fn http_address_example_1() -> &'static str {
"[::1]:8080"
}
fn http_address_example_2() -> &'static str {
"[::]:8080"
}
fn http_address_example_3() -> &'static str {
"127.0.0.1:8080"
}
fn http_address_example_4() -> &'static str {
"0.0.0.0:8080"
}
#[cfg(not(any(feature = "docker", feature = "dist")))] #[cfg(not(any(feature = "docker", feature = "dist")))]
fn http_listener_assets_path_default() -> Utf8PathBuf { fn http_listener_assets_path_default() -> Utf8PathBuf {
"./frontend/dist/".into() "./frontend/dist/".into()
@@ -111,10 +98,10 @@ pub enum BindConfig {
Address { Address {
/// Host and port on which to listen /// Host and port on which to listen
#[schemars( #[schemars(
example = "http_address_example_1", example = &"[::1]:8080",
example = "http_address_example_2", example = &"[::]:8080",
example = "http_address_example_3", example = &"127.0.0.1:8080",
example = "http_address_example_4" example = &"0.0.0.0:8080",
)] )]
address: String, address: String,
}, },
@@ -354,6 +341,7 @@ pub struct HttpConfig {
/// List of trusted reverse proxies that can set the `X-Forwarded-For` /// List of trusted reverse proxies that can set the `X-Forwarded-For`
/// header /// header
#[serde(default = "default_trusted_proxies")] #[serde(default = "default_trusted_proxies")]
#[schemars(with = "Vec<String>", inner(ip))]
pub trusted_proxies: Vec<IpNetwork>, pub trusted_proxies: Vec<IpNetwork>,
/// Public URL base from where the authentication service is reachable /// Public URL base from where the authentication service is reachable

View File

@@ -131,7 +131,11 @@ impl MatrixConfig {
/// Returns an error when the shared secret could not be read from file. /// Returns an error when the shared secret could not be read from file.
pub async fn secret(&self) -> anyhow::Result<String> { pub async fn secret(&self) -> anyhow::Result<String> {
Ok(match &self.secret { Ok(match &self.secret {
Secret::File(path) => tokio::fs::read_to_string(path).await?, Secret::File(path) => {
let raw = tokio::fs::read_to_string(path).await?;
// Trim the secret when read from file to match Synapse's behaviour
raw.trim().to_string()
}
Secret::Value(secret) => secret.clone(), Secret::Value(secret) => secret.clone(),
}) })
} }

View File

@@ -4,6 +4,8 @@
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial // SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details. // Please see LICENSE files in the repository root for full details.
use anyhow::bail;
use camino::Utf8PathBuf;
use rand::Rng; use rand::Rng;
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
@@ -303,3 +305,82 @@ impl ConfigurationSection for SyncConfig {
Ok(()) Ok(())
} }
} }
/// Client secret config option.
///
/// It either holds the client secret value directly or references a file where
/// the client secret is stored.
#[derive(Clone, Debug)]
pub enum ClientSecret {
/// Path to the file containing the client secret.
File(Utf8PathBuf),
/// Client secret value.
Value(String),
}
/// Client secret fields as serialized in JSON.
#[derive(JsonSchema, Serialize, Deserialize, Clone, Debug)]
pub struct ClientSecretRaw {
/// Path to the file containing the client secret. The client secret is used
/// by the `client_secret_basic`, `client_secret_post` and
/// `client_secret_jwt` authentication methods.
#[schemars(with = "Option<String>")]
#[serde(skip_serializing_if = "Option::is_none")]
client_secret_file: Option<Utf8PathBuf>,
/// Alternative to `client_secret_file`: Reads the client secret directly
/// from the config.
#[serde(skip_serializing_if = "Option::is_none")]
client_secret: Option<String>,
}
impl ClientSecret {
/// Returns the client secret.
///
/// If `client_secret_file` was given, the secret is read from that file.
///
/// # Errors
///
/// Returns an error when the client secret could not be read from file.
pub async fn value(&self) -> anyhow::Result<String> {
Ok(match self {
ClientSecret::File(path) => tokio::fs::read_to_string(path).await?,
ClientSecret::Value(client_secret) => client_secret.clone(),
})
}
}
impl TryFrom<ClientSecretRaw> for Option<ClientSecret> {
type Error = anyhow::Error;
fn try_from(value: ClientSecretRaw) -> Result<Self, Self::Error> {
match (value.client_secret, value.client_secret_file) {
(None, None) => Ok(None),
(None, Some(path)) => Ok(Some(ClientSecret::File(path))),
(Some(client_secret), None) => Ok(Some(ClientSecret::Value(client_secret))),
(Some(_), Some(_)) => {
bail!("Cannot specify both `client_secret` and `client_secret_file`")
}
}
}
}
impl From<Option<ClientSecret>> for ClientSecretRaw {
fn from(value: Option<ClientSecret>) -> Self {
match value {
Some(ClientSecret::File(path)) => ClientSecretRaw {
client_secret_file: Some(path),
client_secret: None,
},
Some(ClientSecret::Value(client_secret)) => ClientSecretRaw {
client_secret_file: None,
client_secret: Some(client_secret),
},
None => ClientSecretRaw {
client_secret_file: None,
client_secret: None,
},
}
}
}

View File

@@ -20,10 +20,6 @@ use tracing::info;
use super::ConfigurationSection; use super::ConfigurationSection;
fn example_secret() -> &'static str {
"0000111122223333444455556666777788889999aaaabbbbccccddddeeeeffff"
}
/// Password config option. /// Password config option.
/// ///
/// It either holds the password value directly or references a file where the /// It either holds the password value directly or references a file where the
@@ -209,7 +205,7 @@ struct EncryptionRaw {
#[schemars( #[schemars(
with = "Option<String>", with = "Option<String>",
regex(pattern = r"[0-9a-fA-F]{64}"), regex(pattern = r"[0-9a-fA-F]{64}"),
example = "example_secret" example = &"0000111122223333444455556666777788889999aaaabbbbccccddddeeeeffff"
)] )]
#[serde_as(as = "Option<serde_with::hex::Hex>")] #[serde_as(as = "Option<serde_with::hex::Hex>")]
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
@@ -534,7 +530,10 @@ mod tests {
keys_dir: keys keys_dir: keys
"}, "},
)?; )?;
jail.create_file("encryption", example_secret())?; jail.create_file(
"encryption",
"0000111122223333444455556666777788889999aaaabbbbccccddddeeeeffff",
)?;
jail.create_dir("keys")?; jail.create_dir("keys")?;
jail.create_file( jail.create_file(
"keys/key1", "keys/key1",

View File

@@ -11,10 +11,6 @@ use url::Url;
use super::ConfigurationSection; use super::ConfigurationSection;
fn sample_rate_example() -> f64 {
0.5
}
/// Propagation format for incoming and outgoing requests /// Propagation format for incoming and outgoing requests
#[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq, JsonSchema)] #[derive(Clone, Copy, Debug, Serialize, Deserialize, PartialEq, Eq, JsonSchema)]
#[serde(rename_all = "lowercase")] #[serde(rename_all = "lowercase")]
@@ -70,7 +66,7 @@ pub struct TracingConfig {
/// ///
/// Defaults to `1.0` if not set. /// Defaults to `1.0` if not set.
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
#[schemars(example = "sample_rate_example", range(min = 0.0, max = 1.0))] #[schemars(example = 0.5, range(min = 0.0, max = 1.0))]
pub sample_rate: Option<f64>, pub sample_rate: Option<f64>,
} }
@@ -123,26 +119,18 @@ impl MetricsConfig {
} }
} }
fn sentry_dsn_example() -> &'static str {
"https://public@host:port/1"
}
fn sentry_environment_example() -> &'static str {
"production"
}
/// Configuration related to the Sentry integration /// Configuration related to the Sentry integration
#[derive(Clone, Debug, Default, Serialize, Deserialize, JsonSchema)] #[derive(Clone, Debug, Default, Serialize, Deserialize, JsonSchema)]
pub struct SentryConfig { pub struct SentryConfig {
/// Sentry DSN /// Sentry DSN
#[schemars(url, example = "sentry_dsn_example")] #[schemars(url, example = &"https://public@host:port/1")]
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub dsn: Option<String>, pub dsn: Option<String>,
/// Environment to use when sending events to Sentry /// Environment to use when sending events to Sentry
/// ///
/// Defaults to `production` if not set. /// Defaults to `production` if not set.
#[schemars(example = "sentry_environment_example")] #[schemars(example = &"production")]
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
pub environment: Option<String>, pub environment: Option<String>,
@@ -150,14 +138,14 @@ pub struct SentryConfig {
/// ///
/// Defaults to `1.0` if not set. /// Defaults to `1.0` if not set.
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
#[schemars(example = "sample_rate_example", range(min = 0.0, max = 1.0))] #[schemars(example = 0.5, range(min = 0.0, max = 1.0))]
pub sample_rate: Option<f32>, pub sample_rate: Option<f32>,
/// Sample rate for tracing transactions /// Sample rate for tracing transactions
/// ///
/// Defaults to `0.0` if not set. /// Defaults to `0.0` if not set.
#[serde(skip_serializing_if = "Option::is_none")] #[serde(skip_serializing_if = "Option::is_none")]
#[schemars(example = "sample_rate_example", range(min = 0.0, max = 1.0))] #[schemars(example = 0.5, range(min = 0.0, max = 1.0))]
pub traces_sample_rate: Option<f32>, pub traces_sample_rate: Option<f32>,
} }

View File

@@ -10,11 +10,11 @@ use camino::Utf8PathBuf;
use mas_iana::jose::JsonWebSignatureAlg; use mas_iana::jose::JsonWebSignatureAlg;
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::{Deserialize, Serialize, de::Error}; use serde::{Deserialize, Serialize, de::Error};
use serde_with::skip_serializing_none; use serde_with::{serde_as, skip_serializing_none};
use ulid::Ulid; use ulid::Ulid;
use url::Url; use url::Url;
use crate::ConfigurationSection; use crate::{ClientSecret, ClientSecretRaw, ConfigurationSection};
/// Upstream OAuth 2.0 providers configuration /// Upstream OAuth 2.0 providers configuration
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Default)] #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema, Default)]
@@ -475,6 +475,7 @@ impl OnBackchannelLogout {
} }
/// Configuration for one upstream OAuth 2 provider. /// Configuration for one upstream OAuth 2 provider.
#[serde_as]
#[skip_serializing_none] #[skip_serializing_none]
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)] #[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
pub struct Provider { pub struct Provider {
@@ -541,8 +542,10 @@ pub struct Provider {
/// ///
/// Used by the `client_secret_basic`, `client_secret_post`, and /// Used by the `client_secret_basic`, `client_secret_post`, and
/// `client_secret_jwt` methods /// `client_secret_jwt` methods
#[serde(skip_serializing_if = "Option::is_none")] #[schemars(with = "ClientSecretRaw")]
pub client_secret: Option<String>, #[serde_as(as = "serde_with::TryFromInto<ClientSecretRaw>")]
#[serde(flatten)]
pub client_secret: Option<ClientSecret>,
/// The method to authenticate the client with the provider /// The method to authenticate the client with the provider
pub token_endpoint_auth_method: TokenAuthMethod, pub token_endpoint_auth_method: TokenAuthMethod,
@@ -656,3 +659,110 @@ pub struct Provider {
#[serde(default, skip_serializing_if = "OnBackchannelLogout::is_default")] #[serde(default, skip_serializing_if = "OnBackchannelLogout::is_default")]
pub on_backchannel_logout: OnBackchannelLogout, pub on_backchannel_logout: OnBackchannelLogout,
} }
impl Provider {
/// Returns the client secret.
///
/// If `client_secret_file` was given, the secret is read from that file.
///
/// # Errors
///
/// Returns an error when the client secret could not be read from file.
pub async fn client_secret(&self) -> anyhow::Result<Option<String>> {
Ok(match &self.client_secret {
Some(client_secret) => Some(client_secret.value().await?),
None => None,
})
}
}
#[cfg(test)]
mod tests {
use std::str::FromStr;
use figment::{
Figment, Jail,
providers::{Format, Yaml},
};
use tokio::{runtime::Handle, task};
use super::*;
#[tokio::test]
async fn load_config() {
task::spawn_blocking(|| {
Jail::expect_with(|jail| {
jail.create_file(
"config.yaml",
r#"
upstream_oauth2:
providers:
- id: 01GFWR28C4KNE04WG3HKXB7C9R
client_id: upstream-oauth2
token_endpoint_auth_method: none
- id: 01GFWR32NCQ12B8Z0J8CPXRRB6
client_id: upstream-oauth2
client_secret_file: secret
token_endpoint_auth_method: client_secret_basic
- id: 01GFWR3WHR93Y5HK389H28VHZ9
client_id: upstream-oauth2
client_secret: c1!3n753c237
token_endpoint_auth_method: client_secret_post
- id: 01GFWR43R2ZZ8HX9CVBNW9TJWG
client_id: upstream-oauth2
client_secret_file: secret
token_endpoint_auth_method: client_secret_jwt
- id: 01GFWR4BNFDCC4QDG6AMSP1VRR
client_id: upstream-oauth2
token_endpoint_auth_method: private_key_jwt
jwks:
keys:
- kid: "03e84aed4ef4431014e8617567864c4efaaaede9"
kty: "RSA"
alg: "RS256"
use: "sig"
e: "AQAB"
n: "ma2uRyBeSEOatGuDpCiV9oIxlDWix_KypDYuhQfEzqi_BiF4fV266OWfyjcABbam59aJMNvOnKW3u_eZM-PhMCBij5MZ-vcBJ4GfxDJeKSn-GP_dJ09rpDcILh8HaWAnPmMoi4DC0nrfE241wPISvZaaZnGHkOrfN_EnA5DligLgVUbrA5rJhQ1aSEQO_gf1raEOW3DZ_ACU3qhtgO0ZBG3a5h7BPiRs2sXqb2UCmBBgwyvYLDebnpE7AotF6_xBIlR-Cykdap3GHVMXhrIpvU195HF30ZoBU4dMd-AeG6HgRt4Cqy1moGoDgMQfbmQ48Hlunv9_Vi2e2CLvYECcBw"
- kid: "d01c1abe249269f72ef7ca2613a86c9f05e59567"
kty: "RSA"
alg: "RS256"
use: "sig"
e: "AQAB"
n: "0hukqytPwrj1RbMYhYoepCi3CN5k7DwYkTe_Cmb7cP9_qv4ok78KdvFXt5AnQxCRwBD7-qTNkkfMWO2RxUMBdQD0ED6tsSb1n5dp0XY8dSWiBDCX8f6Hr-KolOpvMLZKRy01HdAWcM6RoL9ikbjYHUEW1C8IJnw3MzVHkpKFDL354aptdNLaAdTCBvKzU9WpXo10g-5ctzSlWWjQuecLMQ4G1mNdsR1LHhUENEnOvgT8cDkX0fJzLbEbyBYkdMgKggyVPEB1bg6evG4fTKawgnf0IDSPxIU-wdS9wdSP9ZCJJPLi5CEp-6t6rE_sb2dGcnzjCGlembC57VwpkUvyMw"
"#,
)?;
jail.create_file("secret", r"c1!3n753c237")?;
let config = Figment::new()
.merge(Yaml::file("config.yaml"))
.extract_inner::<UpstreamOAuth2Config>("upstream_oauth2")?;
assert_eq!(config.providers.len(), 5);
assert_eq!(
config.providers[1].id,
Ulid::from_str("01GFWR32NCQ12B8Z0J8CPXRRB6").unwrap()
);
assert!(config.providers[0].client_secret.is_none());
assert!(matches!(config.providers[1].client_secret, Some(ClientSecret::File(ref p)) if p == "secret"));
assert!(matches!(config.providers[2].client_secret, Some(ClientSecret::Value(ref v)) if v == "c1!3n753c237"));
assert!(matches!(config.providers[3].client_secret, Some(ClientSecret::File(ref p)) if p == "secret"));
assert!(config.providers[4].client_secret.is_none());
Handle::current().block_on(async move {
assert_eq!(config.providers[1].client_secret().await.unwrap().unwrap(), "c1!3n753c237");
assert_eq!(config.providers[2].client_secret().await.unwrap().unwrap(), "c1!3n753c237");
assert_eq!(config.providers[3].client_secret().await.unwrap().unwrap(), "c1!3n753c237");
});
Ok(())
});
}).await.unwrap();
}
}

View File

@@ -4,10 +4,7 @@
// Please see LICENSE files in the repository root for full details. // Please see LICENSE files in the repository root for full details.
use console::{Color, Style}; use console::{Color, Style};
use opentelemetry::{ use opentelemetry::TraceId;
TraceId,
trace::{SamplingDecision, TraceContextExt},
};
use tracing::{Level, Subscriber}; use tracing::{Level, Subscriber};
use tracing_opentelemetry::OtelData; use tracing_opentelemetry::OtelData;
use tracing_subscriber::{ use tracing_subscriber::{
@@ -21,7 +18,7 @@ use tracing_subscriber::{
use crate::LogContext; use crate::LogContext;
/// An event formatter usable by the [`tracing-subscriber`] crate, which /// An event formatter usable by the [`tracing_subscriber`] crate, which
/// includes the log context and the OTEL trace ID. /// includes the log context and the OTEL trace ID.
#[derive(Debug, Default)] #[derive(Debug, Default)]
pub struct EventFormatter; pub struct EventFormatter;
@@ -131,31 +128,14 @@ where
// If we have a OTEL span, we can add the trace ID to the end of the log line // If we have a OTEL span, we can add the trace ID to the end of the log line
if let Some(span) = ctx.lookup_current() if let Some(span) = ctx.lookup_current()
&& let Some(otel) = span.extensions().get::<OtelData>() && let Some(otel) = span.extensions().get::<OtelData>()
&& let Some(trace_id) = otel.trace_id()
&& trace_id != TraceId::INVALID
{ {
let parent_cx_span = otel.parent_cx.span(); let label = Style::new()
let sc = parent_cx_span.span_context(); .italic()
.force_styling(ansi)
// Check if the span is sampled, first from the span builder, .apply_to("trace.id");
// then from the parent context if nothing is set there write!(&mut writer, " {label}={trace_id}")?;
if otel
.builder
.sampling_result
.as_ref()
.map_or(sc.is_sampled(), |r| {
r.decision == SamplingDecision::RecordAndSample
})
{
// If it is the root span, the trace ID will be in the span builder. Else, it
// will be in the parent OTEL context
let trace_id = otel.builder.trace_id.unwrap_or(sc.trace_id());
if trace_id != TraceId::INVALID {
let label = Style::new()
.italic()
.force_styling(ansi)
.apply_to("trace.id");
write!(&mut writer, " {label}={trace_id}")?;
}
}
} }
writeln!(&mut writer) writeln!(&mut writer)

View File

@@ -11,6 +11,7 @@ use thiserror::Error;
pub mod clock; pub mod clock;
pub(crate) mod compat; pub(crate) mod compat;
pub mod oauth2; pub mod oauth2;
pub mod personal;
pub(crate) mod policy_data; pub(crate) mod policy_data;
mod site_config; mod site_config;
pub(crate) mod tokens; pub(crate) mod tokens;
@@ -18,6 +19,7 @@ pub(crate) mod upstream_oauth2;
pub(crate) mod user_agent; pub(crate) mod user_agent;
pub(crate) mod users; pub(crate) mod users;
mod utils; mod utils;
mod version;
/// Error when an invalid state transition is attempted. /// Error when an invalid state transition is attempted.
#[derive(Debug, Error)] #[derive(Debug, Error)]
@@ -57,4 +59,5 @@ pub use self::{
UserRecoveryTicket, UserRegistration, UserRegistrationPassword, UserRegistrationToken, UserRecoveryTicket, UserRegistration, UserRegistrationPassword, UserRegistrationToken,
}, },
utils::{BoxClock, BoxRng}, utils::{BoxClock, BoxRng},
version::AppVersion,
}; };

View File

@@ -0,0 +1,32 @@
// Copyright 2025 New Vector Ltd.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
pub mod session;
use chrono::{DateTime, Utc};
use ulid::Ulid;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PersonalAccessToken {
pub id: Ulid,
pub session_id: Ulid,
pub created_at: DateTime<Utc>,
pub expires_at: Option<DateTime<Utc>>,
pub revoked_at: Option<DateTime<Utc>>,
}
impl PersonalAccessToken {
#[must_use]
pub fn is_valid(&self, now: DateTime<Utc>) -> bool {
if self.revoked_at.is_some() {
return false;
}
if let Some(expires_at) = self.expires_at {
expires_at > now
} else {
true
}
}
}

View File

@@ -0,0 +1,141 @@
// Copyright 2025 New Vector Ltd.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
use std::net::IpAddr;
use chrono::{DateTime, Utc};
use oauth2_types::scope::Scope;
use serde::Serialize;
use ulid::Ulid;
use crate::{Client, Device, InvalidTransitionError, User};
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize)]
pub enum SessionState {
#[default]
Valid,
Revoked {
revoked_at: DateTime<Utc>,
},
}
impl SessionState {
/// Returns `true` if the session state is [`Valid`].
///
/// [`Valid`]: SessionState::Valid
#[must_use]
pub fn is_valid(&self) -> bool {
matches!(self, Self::Valid)
}
/// Returns `true` if the session state is [`Revoked`].
///
/// [`Revoked`]: SessionState::Revoked
#[must_use]
pub fn is_revoked(&self) -> bool {
matches!(self, Self::Revoked { .. })
}
/// Transitions the session state to [`Revoked`].
///
/// # Parameters
///
/// * `revoked_at` - The time at which the session was revoked.
///
/// # Errors
///
/// Returns an error if the session state is already [`Revoked`].
///
/// [`Revoked`]: SessionState::Revoked
pub fn revoke(self, revoked_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
match self {
Self::Valid => Ok(Self::Revoked { revoked_at }),
Self::Revoked { .. } => Err(InvalidTransitionError),
}
}
/// Returns the time the session was revoked, if any
///
/// Returns `None` if the session is still [`Valid`].
///
/// [`Valid`]: SessionState::Valid
#[must_use]
pub fn revoked_at(&self) -> Option<DateTime<Utc>> {
match self {
Self::Valid => None,
Self::Revoked { revoked_at } => Some(*revoked_at),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize)]
pub struct PersonalSession {
pub id: Ulid,
pub state: SessionState,
pub owner: PersonalSessionOwner,
pub actor_user_id: Ulid,
pub human_name: String,
/// The scope for the session, identical to OAuth 2 sessions.
/// May or may not include a device scope
/// (personal sessions can be deviceless).
pub scope: Scope,
pub created_at: DateTime<Utc>,
pub last_active_at: Option<DateTime<Utc>>,
pub last_active_ip: Option<IpAddr>,
}
#[derive(Debug, Copy, Clone, PartialEq, Eq, Serialize)]
pub enum PersonalSessionOwner {
/// The personal session is owned by the user with the given `user_id`.
User(Ulid),
/// The personal session is owned by the OAuth 2 Client with the given
/// `oauth2_client_id`.
OAuth2Client(Ulid),
}
impl<'a> From<&'a User> for PersonalSessionOwner {
fn from(value: &'a User) -> Self {
PersonalSessionOwner::User(value.id)
}
}
impl<'a> From<&'a Client> for PersonalSessionOwner {
fn from(value: &'a Client) -> Self {
PersonalSessionOwner::OAuth2Client(value.id)
}
}
impl std::ops::Deref for PersonalSession {
type Target = SessionState;
fn deref(&self) -> &Self::Target {
&self.state
}
}
impl PersonalSession {
/// Marks the session as revoked.
///
/// # Parameters
///
/// * `revoked_at` - The time at which the session was finished.
///
/// # Errors
///
/// Returns an error if the session is already finished.
pub fn finish(mut self, revoked_at: DateTime<Utc>) -> Result<Self, InvalidTransitionError> {
self.state = self.state.revoke(revoked_at)?;
Ok(self)
}
/// Returns whether the scope of this session contains a device scope;
/// in other words: whether this session has a device.
#[must_use]
pub fn has_device(&self) -> bool {
self.scope
.iter()
.any(|scope_token| Device::from_scope_token(scope_token).is_some())
}
}

View File

@@ -64,6 +64,9 @@ pub struct SiteConfig {
/// Whether password registration is enabled. /// Whether password registration is enabled.
pub password_registration_enabled: bool, pub password_registration_enabled: bool,
/// Whether a valid email address is required for password registrations.
pub password_registration_email_required: bool,
/// Whether registration tokens are required for password registrations. /// Whether registration tokens are required for password registrations.
pub registration_token_required: bool, pub registration_token_required: bool,

View File

@@ -240,6 +240,9 @@ pub enum TokenType {
/// A legacy refresh token /// A legacy refresh token
CompatRefreshToken, CompatRefreshToken,
/// A personal access token.
PersonalAccessToken,
} }
impl std::fmt::Display for TokenType { impl std::fmt::Display for TokenType {
@@ -249,6 +252,7 @@ impl std::fmt::Display for TokenType {
TokenType::RefreshToken => write!(f, "refresh token"), TokenType::RefreshToken => write!(f, "refresh token"),
TokenType::CompatAccessToken => write!(f, "compat access token"), TokenType::CompatAccessToken => write!(f, "compat access token"),
TokenType::CompatRefreshToken => write!(f, "compat refresh token"), TokenType::CompatRefreshToken => write!(f, "compat refresh token"),
TokenType::PersonalAccessToken => write!(f, "personal access token"),
} }
} }
} }
@@ -260,6 +264,7 @@ impl TokenType {
TokenType::RefreshToken => "mar", TokenType::RefreshToken => "mar",
TokenType::CompatAccessToken => "mct", TokenType::CompatAccessToken => "mct",
TokenType::CompatRefreshToken => "mcr", TokenType::CompatRefreshToken => "mcr",
TokenType::PersonalAccessToken => "mpt",
} }
} }
@@ -269,6 +274,7 @@ impl TokenType {
"mar" => Some(TokenType::RefreshToken), "mar" => Some(TokenType::RefreshToken),
"mct" | "syt" => Some(TokenType::CompatAccessToken), "mct" | "syt" => Some(TokenType::CompatAccessToken),
"mcr" | "syr" => Some(TokenType::CompatRefreshToken), "mcr" | "syr" => Some(TokenType::CompatRefreshToken),
"mpt" => Some(TokenType::PersonalAccessToken),
_ => None, _ => None,
} }
} }
@@ -335,7 +341,9 @@ impl PartialEq<OAuthTokenTypeHint> for TokenType {
matches!( matches!(
(self, other), (self, other),
( (
TokenType::AccessToken | TokenType::CompatAccessToken, TokenType::AccessToken
| TokenType::CompatAccessToken
| TokenType::PersonalAccessToken,
OAuthTokenTypeHint::AccessToken OAuthTokenTypeHint::AccessToken
) | ( ) | (
TokenType::RefreshToken | TokenType::CompatRefreshToken, TokenType::RefreshToken | TokenType::CompatRefreshToken,

View File

@@ -21,6 +21,7 @@ pub struct User {
pub locked_at: Option<DateTime<Utc>>, pub locked_at: Option<DateTime<Utc>>,
pub deactivated_at: Option<DateTime<Utc>>, pub deactivated_at: Option<DateTime<Utc>>,
pub can_request_admin: bool, pub can_request_admin: bool,
pub is_guest: bool,
} }
impl User { impl User {
@@ -29,6 +30,20 @@ impl User {
pub fn is_valid(&self) -> bool { pub fn is_valid(&self) -> bool {
self.locked_at.is_none() && self.deactivated_at.is_none() self.locked_at.is_none() && self.deactivated_at.is_none()
} }
/// Returns `true` if the user is a valid actor, for example
/// of a personal session.
///
/// Currently: this is `true` unless the user is deactivated.
///
/// This is a weaker form of validity: `is_valid` always implies
/// `is_valid_actor`, but some users (currently: locked users)
/// can be valid actors for personal sessions but aren't valid
/// except through administrative access.
#[must_use]
pub fn is_valid_actor(&self) -> bool {
self.deactivated_at.is_none()
}
} }
impl User { impl User {
@@ -43,6 +58,7 @@ impl User {
locked_at: None, locked_at: None,
deactivated_at: None, deactivated_at: None,
can_request_admin: false, can_request_admin: false,
is_guest: false,
}] }]
} }
} }

View File

@@ -0,0 +1,8 @@
// Copyright 2025 New Vector Ltd.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
/// A structure which holds information about the running version of the app
#[derive(Debug, Clone, Copy)]
pub struct AppVersion(pub &'static str);

View File

@@ -36,7 +36,9 @@ pub struct Transport {
inner: Arc<TransportInner>, inner: Arc<TransportInner>,
} }
#[derive(Default)]
enum TransportInner { enum TransportInner {
#[default]
Blackhole, Blackhole,
Smtp(AsyncSmtpTransport<Tokio1Executor>), Smtp(AsyncSmtpTransport<Tokio1Executor>),
Sendmail(AsyncSendmailTransport<Tokio1Executor>), Sendmail(AsyncSendmailTransport<Tokio1Executor>),
@@ -113,12 +115,6 @@ impl Transport {
} }
} }
impl Default for TransportInner {
fn default() -> Self {
Self::Blackhole
}
}
#[derive(Debug, Error)] #[derive(Debug, Error)]
#[error(transparent)] #[error(transparent)]
pub enum Error { pub enum Error {

View File

@@ -6,7 +6,9 @@
use std::net::IpAddr; use std::net::IpAddr;
use mas_data_model::{BrowserSession, Clock, CompatSession, Session}; use mas_data_model::{
BrowserSession, Clock, CompatSession, Session, personal::session::PersonalSession,
};
use crate::activity_tracker::ActivityTracker; use crate::activity_tracker::ActivityTracker;
@@ -37,6 +39,13 @@ impl Bound {
.await; .await;
} }
/// Record activity in a personal session.
pub async fn record_personal_session(&self, clock: &dyn Clock, session: &PersonalSession) {
self.tracker
.record_personal_session(clock, session, self.ip)
.await;
}
/// Record activity in a compatibility session. /// Record activity in a compatibility session.
pub async fn record_compat_session(&self, clock: &dyn Clock, session: &CompatSession) { pub async fn record_compat_session(&self, clock: &dyn Clock, session: &CompatSession) {
self.tracker self.tracker

View File

@@ -10,7 +10,9 @@ mod worker;
use std::net::IpAddr; use std::net::IpAddr;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use mas_data_model::{BrowserSession, Clock, CompatSession, Session}; use mas_data_model::{
BrowserSession, Clock, CompatSession, Session, personal::session::PersonalSession,
};
use mas_storage::BoxRepositoryFactory; use mas_storage::BoxRepositoryFactory;
use tokio_util::{sync::CancellationToken, task::TaskTracker}; use tokio_util::{sync::CancellationToken, task::TaskTracker};
use ulid::Ulid; use ulid::Ulid;
@@ -24,6 +26,8 @@ static MESSAGE_QUEUE_SIZE: usize = 1000;
enum SessionKind { enum SessionKind {
OAuth2, OAuth2,
Compat, Compat,
/// Session associated with personal access tokens
Personal,
Browser, Browser,
} }
@@ -32,6 +36,7 @@ impl SessionKind {
match self { match self {
SessionKind::OAuth2 => "oauth2", SessionKind::OAuth2 => "oauth2",
SessionKind::Compat => "compat", SessionKind::Compat => "compat",
SessionKind::Personal => "personal",
SessionKind::Browser => "browser", SessionKind::Browser => "browser",
} }
} }
@@ -108,6 +113,28 @@ impl ActivityTracker {
} }
} }
/// Record activity in a personal session.
pub async fn record_personal_session(
&self,
clock: &dyn Clock,
session: &PersonalSession,
ip: Option<IpAddr>,
) {
let res = self
.channel
.send(Message::Record {
kind: SessionKind::Personal,
id: session.id,
date_time: clock.now(),
ip,
})
.await;
if let Err(e) = res {
tracing::error!("Failed to record Personal session: {}", e);
}
}
/// Record activity in a compat session. /// Record activity in a compat session.
pub async fn record_compat_session( pub async fn record_compat_session(
&self, &self,

View File

@@ -224,6 +224,7 @@ impl Worker {
let mut browser_sessions = Vec::new(); let mut browser_sessions = Vec::new();
let mut oauth2_sessions = Vec::new(); let mut oauth2_sessions = Vec::new();
let mut compat_sessions = Vec::new(); let mut compat_sessions = Vec::new();
let mut personal_sessions = Vec::new();
for ((kind, id), record) in pending_records { for ((kind, id), record) in pending_records {
match kind { match kind {
@@ -236,6 +237,9 @@ impl Worker {
SessionKind::Compat => { SessionKind::Compat => {
compat_sessions.push((*id, record.end_time, record.ip)); compat_sessions.push((*id, record.end_time, record.ip));
} }
SessionKind::Personal => {
personal_sessions.push((*id, record.end_time, record.ip));
}
} }
} }
@@ -253,6 +257,9 @@ impl Worker {
repo.compat_session() repo.compat_session()
.record_batch_activity(compat_sessions) .record_batch_activity(compat_sessions)
.await?; .await?;
repo.personal_session()
.record_batch_activity(personal_sessions)
.await?;
repo.save().await?; repo.save().await?;
self.pending_records.clear(); self.pending_records.clear();

View File

@@ -16,8 +16,12 @@ use axum_extra::TypedHeader;
use headers::{Authorization, authorization::Bearer}; use headers::{Authorization, authorization::Bearer};
use hyper::StatusCode; use hyper::StatusCode;
use mas_axum_utils::record_error; use mas_axum_utils::record_error;
use mas_data_model::{BoxClock, Session, User}; use mas_data_model::{
BoxClock, Session, TokenFormatError, TokenType, User,
personal::session::{PersonalSession, PersonalSessionOwner},
};
use mas_storage::{BoxRepository, RepositoryError}; use mas_storage::{BoxRepository, RepositoryError};
use oauth2_types::scope::Scope;
use ulid::Ulid; use ulid::Ulid;
use super::response::ErrorResponse; use super::response::ErrorResponse;
@@ -41,6 +45,10 @@ pub enum Rejection {
#[error("Invalid repository operation")] #[error("Invalid repository operation")]
Repository(#[from] RepositoryError), Repository(#[from] RepositoryError),
/// The access token was not of the correct type for the Admin API
#[error("Invalid type of access token")]
InvalidAccessTokenType(#[from] Option<TokenFormatError>),
/// The access token could not be found in the database /// The access token could not be found in the database
#[error("Unknown access token")] #[error("Unknown access token")]
UnknownAccessToken, UnknownAccessToken,
@@ -90,7 +98,8 @@ impl IntoResponse for Rejection {
| Rejection::TokenExpired | Rejection::TokenExpired
| Rejection::SessionRevoked | Rejection::SessionRevoked
| Rejection::UserLocked | Rejection::UserLocked
| Rejection::MissingScope => StatusCode::UNAUTHORIZED, | Rejection::MissingScope
| Rejection::InvalidAccessTokenType(_) => StatusCode::UNAUTHORIZED,
Rejection::RepositorySetup(_) Rejection::RepositorySetup(_)
| Rejection::Repository(_) | Rejection::Repository(_)
@@ -113,7 +122,7 @@ pub struct CallContext {
pub repo: BoxRepository, pub repo: BoxRepository,
pub clock: BoxClock, pub clock: BoxClock,
pub user: Option<User>, pub user: Option<User>,
pub session: Session, pub session: CallerSession,
} }
impl<S> FromRequestParts<S> for CallContext impl<S> FromRequestParts<S> for CallContext
@@ -154,56 +163,126 @@ where
})?; })?;
let token = token.token(); let token = token.token();
let token_type = TokenType::check(token)?;
// Look for the access token in the database let session = match token_type {
let token = repo TokenType::AccessToken => {
.oauth2_access_token() // Look for the access token in the database
.find_by_token(token) let token = repo
.await? .oauth2_access_token()
.ok_or(Rejection::UnknownAccessToken)?; .find_by_token(token)
.await?
.ok_or(Rejection::UnknownAccessToken)?;
// Look for the associated session in the database // Look for the associated session in the database
let session = repo let session = repo
.oauth2_session() .oauth2_session()
.lookup(token.session_id) .lookup(token.session_id)
.await? .await?
.ok_or_else(|| Rejection::LoadSession(token.session_id))?; .ok_or_else(|| Rejection::LoadSession(token.session_id))?;
// Record the activity on the session if !session.is_valid() {
activity_tracker return Err(Rejection::SessionRevoked);
.record_oauth2_session(&clock, &session) }
.await;
if !token.is_valid(clock.now()) {
return Err(Rejection::TokenExpired);
}
// Record the activity on the session
activity_tracker
.record_oauth2_session(&clock, &session)
.await;
CallerSession::OAuth2Session(session)
}
TokenType::PersonalAccessToken => {
// Look for the access token in the database
let token = repo
.personal_access_token()
.find_by_token(token)
.await?
.ok_or(Rejection::UnknownAccessToken)?;
// Look for the associated session in the database
let session = repo
.personal_session()
.lookup(token.session_id)
.await?
.ok_or_else(|| Rejection::LoadSession(token.session_id))?;
if !session.is_valid() {
return Err(Rejection::SessionRevoked);
}
if !token.is_valid(clock.now()) {
return Err(Rejection::TokenExpired);
}
// Check the validity of the owner of the personal session
match session.owner {
PersonalSessionOwner::User(owner_user_id) => {
let owner_user = repo
.user()
.lookup(owner_user_id)
.await?
.ok_or_else(|| Rejection::LoadUser(owner_user_id))?;
if !owner_user.is_valid() {
return Err(Rejection::UserLocked);
}
}
PersonalSessionOwner::OAuth2Client(_) => {
// nop: Client owners are always valid
}
}
// Record the activity on the session
activity_tracker
.record_personal_session(&clock, &session)
.await;
CallerSession::PersonalSession(session)
}
_other => {
return Err(Rejection::InvalidAccessTokenType(None));
}
};
// Load the user if there is one // Load the user if there is one
let user = if let Some(user_id) = session.user_id { let user = if let Some(user_id) = session.user_id() {
let user = repo let user = repo
.user() .user()
.lookup(user_id) .lookup(user_id)
.await? .await?
.ok_or_else(|| Rejection::LoadUser(user_id))?; .ok_or_else(|| Rejection::LoadUser(user_id))?;
match session {
CallerSession::OAuth2Session(_) => {
// For OAuth2 sessions: check that the user is valid enough
// to be a user.
if !user.is_valid() {
return Err(Rejection::UserLocked);
}
}
CallerSession::PersonalSession(_) => {
// For personal sessions: check that the actor is valid enough
// to be an actor.
if !user.is_valid_actor() {
return Err(Rejection::UserLocked);
}
}
}
Some(user) Some(user)
} else { } else {
// Double check we're not using a PersonalSession
assert!(matches!(session, CallerSession::OAuth2Session(_)));
None None
}; };
// If there is a user for this session, check that it is not locked
if let Some(user) = &user
&& !user.is_valid()
{
return Err(Rejection::UserLocked);
}
if !session.is_valid() {
return Err(Rejection::SessionRevoked);
}
if !token.is_valid(clock.now()) {
return Err(Rejection::TokenExpired);
}
// For now, we only check that the session has the admin scope // For now, we only check that the session has the admin scope
// Later we might want to check other route-specific scopes // Later we might want to check other route-specific scopes
if !session.scope.contains("urn:mas:admin") { if !session.scope().contains("urn:mas:admin") {
return Err(Rejection::MissingScope); return Err(Rejection::MissingScope);
} }
@@ -215,3 +294,26 @@ where
}) })
} }
} }
/// The session representing the caller of the Admin API;
/// could either be an OAuth session or a personal session.
pub enum CallerSession {
OAuth2Session(Session),
PersonalSession(PersonalSession),
}
impl CallerSession {
pub fn scope(&self) -> &Scope {
match self {
CallerSession::OAuth2Session(session) => &session.scope,
CallerSession::PersonalSession(session) => &session.scope,
}
}
pub fn user_id(&self) -> Option<Ulid> {
match self {
CallerSession::OAuth2Session(session) => session.user_id,
CallerSession::PersonalSession(session) => Some(session.actor_user_id),
}
}
}

View File

@@ -20,7 +20,7 @@ use axum::{
use hyper::header::{ACCEPT, AUTHORIZATION, CONTENT_TYPE}; use hyper::header::{ACCEPT, AUTHORIZATION, CONTENT_TYPE};
use indexmap::IndexMap; use indexmap::IndexMap;
use mas_axum_utils::InternalError; use mas_axum_utils::InternalError;
use mas_data_model::BoxRng; use mas_data_model::{AppVersion, BoxRng, SiteConfig};
use mas_http::CorsLayerExt; use mas_http::CorsLayerExt;
use mas_matrix::HomeserverConnection; use mas_matrix::HomeserverConnection;
use mas_policy::PolicyFactory; use mas_policy::PolicyFactory;
@@ -29,6 +29,7 @@ use mas_router::{
UrlBuilder, UrlBuilder,
}; };
use mas_templates::{ApiDocContext, Templates}; use mas_templates::{ApiDocContext, Templates};
use schemars::transform::AddNullable;
use tower_http::cors::{Any, CorsLayer}; use tower_http::cors::{Any, CorsLayer};
mod call_context; mod call_context;
@@ -43,6 +44,11 @@ use crate::passwords::PasswordManager;
fn finish(t: TransformOpenApi) -> TransformOpenApi { fn finish(t: TransformOpenApi) -> TransformOpenApi {
t.title("Matrix Authentication Service admin API") t.title("Matrix Authentication Service admin API")
.tag(Tag {
name: "server".to_owned(),
description: Some("Information about the server".to_owned()),
..Tag::default()
})
.tag(Tag { .tag(Tag {
name: "compat-session".to_owned(), name: "compat-session".to_owned(),
description: Some("Manage compatibility sessions from legacy clients".to_owned()), description: Some("Manage compatibility sessions from legacy clients".to_owned()),
@@ -86,6 +92,11 @@ fn finish(t: TransformOpenApi) -> TransformOpenApi {
), ),
..Default::default() ..Default::default()
}) })
.tag(Tag {
name: "upstream-oauth-provider".to_owned(),
description: Some("Manage upstream OAuth 2.0 providers".to_owned()),
..Tag::default()
})
.security_scheme("oauth2", oauth_security_scheme(None)) .security_scheme("oauth2", oauth_security_scheme(None))
.security_scheme( .security_scheme(
"token", "token",
@@ -153,14 +164,24 @@ where
Templates: FromRef<S>, Templates: FromRef<S>,
UrlBuilder: FromRef<S>, UrlBuilder: FromRef<S>,
Arc<PolicyFactory>: FromRef<S>, Arc<PolicyFactory>: FromRef<S>,
SiteConfig: FromRef<S>,
AppVersion: FromRef<S>,
{ {
// We *always* want to explicitly set the possible responses, beacuse the // We *always* want to explicitly set the possible responses, beacuse the
// infered ones are not necessarily correct // infered ones are not necessarily correct
aide::generate::infer_responses(false); aide::generate::infer_responses(false);
aide::generate::in_context(|ctx| { aide::generate::in_context(|ctx| {
ctx.schema = ctx.schema = schemars::generate::SchemaGenerator::new(
schemars::r#gen::SchemaGenerator::new(schemars::r#gen::SchemaSettings::openapi3()); schemars::generate::SchemaSettings::openapi3().with(|settings| {
// Remove the transform which adds nullable fields, as it's not
// valid with OpenAPI 3.1. For some reason, aide/schemars output
// an OpenAPI 3.1 schema with this nullable transform.
settings
.transforms
.retain(|transform| !transform.is::<AddNullable>());
}),
);
}); });
let mut api = OpenApi::default(); let mut api = OpenApi::default();

View File

@@ -7,9 +7,16 @@
use std::net::IpAddr; use std::net::IpAddr;
use chrono::{DateTime, Utc}; use chrono::{DateTime, Utc};
use mas_data_model::Device; use mas_data_model::{
Device,
personal::{
PersonalAccessToken as DataModelPersonalAccessToken,
session::{PersonalSession as DataModelPersonalSession, PersonalSessionOwner},
},
};
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::Serialize; use serde::Serialize;
use thiserror::Error;
use ulid::Ulid; use ulid::Ulid;
use url::Url; use url::Url;
@@ -52,6 +59,9 @@ pub struct User {
/// Whether the user can request admin privileges. /// Whether the user can request admin privileges.
admin: bool, admin: bool,
/// Whether the user was a guest before migrating to MAS,
legacy_guest: bool,
} }
impl User { impl User {
@@ -65,6 +75,7 @@ impl User {
locked_at: None, locked_at: None,
deactivated_at: None, deactivated_at: None,
admin: false, admin: false,
legacy_guest: false,
}, },
Self { Self {
id: Ulid::from_bytes([0x02; 16]), id: Ulid::from_bytes([0x02; 16]),
@@ -73,6 +84,7 @@ impl User {
locked_at: None, locked_at: None,
deactivated_at: None, deactivated_at: None,
admin: true, admin: true,
legacy_guest: false,
}, },
Self { Self {
id: Ulid::from_bytes([0x03; 16]), id: Ulid::from_bytes([0x03; 16]),
@@ -81,6 +93,7 @@ impl User {
locked_at: Some(DateTime::default()), locked_at: Some(DateTime::default()),
deactivated_at: None, deactivated_at: None,
admin: false, admin: false,
legacy_guest: true,
}, },
] ]
} }
@@ -95,6 +108,7 @@ impl From<mas_data_model::User> for User {
locked_at: user.locked_at, locked_at: user.locked_at,
deactivated_at: user.deactivated_at, deactivated_at: user.deactivated_at,
admin: user.can_request_admin, admin: user.can_request_admin,
legacy_guest: user.is_guest,
} }
} }
} }
@@ -688,3 +702,255 @@ impl UserRegistrationToken {
] ]
} }
} }
/// An upstream OAuth 2.0 provider
#[derive(Serialize, JsonSchema)]
pub struct UpstreamOAuthProvider {
#[serde(skip)]
id: Ulid,
/// The OIDC issuer of the provider
issuer: Option<String>,
/// A human-readable name for the provider
human_name: Option<String>,
/// A brand identifier, e.g. "apple" or "google"
brand_name: Option<String>,
/// When the provider was created
created_at: DateTime<Utc>,
/// When the provider was disabled. If null, the provider is enabled.
disabled_at: Option<DateTime<Utc>>,
}
impl From<mas_data_model::UpstreamOAuthProvider> for UpstreamOAuthProvider {
fn from(provider: mas_data_model::UpstreamOAuthProvider) -> Self {
Self {
id: provider.id,
issuer: provider.issuer,
human_name: provider.human_name,
brand_name: provider.brand_name,
created_at: provider.created_at,
disabled_at: provider.disabled_at,
}
}
}
impl Resource for UpstreamOAuthProvider {
const KIND: &'static str = "upstream-oauth-provider";
const PATH: &'static str = "/api/admin/v1/upstream-oauth-providers";
fn id(&self) -> Ulid {
self.id
}
}
impl UpstreamOAuthProvider {
/// Samples of upstream OAuth 2.0 providers
pub fn samples() -> [Self; 3] {
[
Self {
id: Ulid::from_bytes([0x01; 16]),
issuer: Some("https://accounts.google.com".to_owned()),
human_name: Some("Google".to_owned()),
brand_name: Some("google".to_owned()),
created_at: DateTime::default(),
disabled_at: None,
},
Self {
id: Ulid::from_bytes([0x02; 16]),
issuer: Some("https://appleid.apple.com".to_owned()),
human_name: Some("Apple ID".to_owned()),
brand_name: Some("apple".to_owned()),
created_at: DateTime::default(),
disabled_at: Some(DateTime::default()),
},
Self {
id: Ulid::from_bytes([0x03; 16]),
issuer: None,
human_name: Some("Custom OAuth Provider".to_owned()),
brand_name: None,
created_at: DateTime::default(),
disabled_at: None,
},
]
}
}
/// An error that shouldn't happen in practice, but suggests database
/// inconsistency.
#[derive(Debug, Error)]
#[error(
"personal session {session_id} in inconsistent state: not revoked but no valid access token"
)]
pub struct InconsistentPersonalSession {
pub session_id: Ulid,
}
// Note: we don't expose a separate concept of personal access tokens to the
// admin API; we merge the relevant attributes into the personal session.
/// A personal session (session using personal access tokens)
#[derive(Serialize, JsonSchema)]
pub struct PersonalSession {
#[serde(skip)]
id: Ulid,
/// When the session was created
created_at: DateTime<Utc>,
/// When the session was revoked, if applicable
revoked_at: Option<DateTime<Utc>>,
/// The ID of the user who owns this session (if user-owned)
#[schemars(with = "Option<super::schema::Ulid>")]
owner_user_id: Option<Ulid>,
/// The ID of the `OAuth2` client that owns this session (if client-owned)
#[schemars(with = "Option<super::schema::Ulid>")]
owner_client_id: Option<Ulid>,
/// The ID of the user that the session acts on behalf of
#[schemars(with = "super::schema::Ulid")]
actor_user_id: Ulid,
/// Human-readable name for the session
human_name: String,
/// `OAuth2` scopes for this session
scope: String,
/// When the session was last active
last_active_at: Option<DateTime<Utc>>,
/// IP address of last activity
last_active_ip: Option<IpAddr>,
/// When the current token for this session expires.
/// The session will need to be regenerated, producing a new access token,
/// after this time.
/// None if the current token won't expire or if the session is revoked.
expires_at: Option<DateTime<Utc>>,
/// The actual access token (only returned on creation)
#[serde(skip_serializing_if = "Option::is_none")]
access_token: Option<String>,
}
impl
TryFrom<(
DataModelPersonalSession,
Option<DataModelPersonalAccessToken>,
)> for PersonalSession
{
type Error = InconsistentPersonalSession;
fn try_from(
(session, token): (
DataModelPersonalSession,
Option<DataModelPersonalAccessToken>,
),
) -> Result<Self, InconsistentPersonalSession> {
let expires_at = if let Some(token) = token {
token.expires_at
} else {
if !session.is_revoked() {
// No active token, but the session is not revoked.
return Err(InconsistentPersonalSession {
session_id: session.id,
});
}
None
};
let (owner_user_id, owner_client_id) = match session.owner {
PersonalSessionOwner::User(id) => (Some(id), None),
PersonalSessionOwner::OAuth2Client(id) => (None, Some(id)),
};
Ok(Self {
id: session.id,
created_at: session.created_at,
revoked_at: session.revoked_at(),
owner_user_id,
owner_client_id,
actor_user_id: session.actor_user_id,
human_name: session.human_name,
scope: session.scope.to_string(),
last_active_at: session.last_active_at,
last_active_ip: session.last_active_ip,
expires_at,
// If relevant, the caller will populate using `with_token` afterwards.
access_token: None,
})
}
}
impl Resource for PersonalSession {
const KIND: &'static str = "personal-session";
const PATH: &'static str = "/api/admin/v1/personal-sessions";
fn id(&self) -> Ulid {
self.id
}
}
impl PersonalSession {
/// Sample personal sessions for documentation/testing
pub fn samples() -> [Self; 3] {
[
Self {
id: Ulid::from_string("01FSHN9AG0AJ6AC5HQ9X6H4RP4").unwrap(),
created_at: DateTime::from_timestamp(1_642_338_000, 0).unwrap(), /* 2022-01-16T14:
* 40:00Z */
revoked_at: None,
owner_user_id: Some(Ulid::from_string("01FSHN9AG0MZAA6S4AF7CTV32E").unwrap()),
owner_client_id: None,
actor_user_id: Ulid::from_string("01FSHN9AG0MZAA6S4AF7CTV32E").unwrap(),
human_name: "Alice's Development Token".to_owned(),
scope: "openid urn:matrix:org.matrix.msc2967.client:api:*".to_owned(),
last_active_at: Some(DateTime::from_timestamp(1_642_347_000, 0).unwrap()), /* 2022-01-16T17:10:00Z */
last_active_ip: Some("192.168.1.100".parse().unwrap()),
expires_at: None,
access_token: None,
},
Self {
id: Ulid::from_string("01FSHN9AG0BJ6AC5HQ9X6H4RP5").unwrap(),
created_at: DateTime::from_timestamp(1_642_338_060, 0).unwrap(), /* 2022-01-16T14:
* 41:00Z */
revoked_at: Some(DateTime::from_timestamp(1_642_350_000, 0).unwrap()), /* 2022-01-16T18:00:00Z */
owner_user_id: Some(Ulid::from_string("01FSHN9AG0NZAA6S4AF7CTV32F").unwrap()),
owner_client_id: None,
actor_user_id: Ulid::from_string("01FSHN9AG0NZAA6S4AF7CTV32F").unwrap(),
human_name: "Bob's Mobile App".to_owned(),
scope: "openid".to_owned(),
last_active_at: Some(DateTime::from_timestamp(1_642_349_000, 0).unwrap()), /* 2022-01-16T17:43:20Z */
last_active_ip: Some("10.0.0.50".parse().unwrap()),
expires_at: None,
access_token: None,
},
Self {
id: Ulid::from_string("01FSHN9AG0CJ6AC5HQ9X6H4RP6").unwrap(),
created_at: DateTime::from_timestamp(1_642_338_120, 0).unwrap(), /* 2022-01-16T14:
* 42:00Z */
revoked_at: None,
owner_user_id: None,
owner_client_id: Some(Ulid::from_string("01FSHN9AG0DJ6AC5HQ9X6H4RP7").unwrap()),
actor_user_id: Ulid::from_string("01FSHN9AG0MZAA6S4AF7CTV32E").unwrap(),
human_name: "CI/CD Pipeline Token".to_owned(),
scope: "openid urn:mas:admin".to_owned(),
last_active_at: Some(DateTime::from_timestamp(1_642_348_000, 0).unwrap()), /* 2022-01-16T17:26:40Z */
last_active_ip: Some("203.0.113.10".parse().unwrap()),
expires_at: Some(DateTime::from_timestamp(1_642_999_000, 0).unwrap()),
access_token: None,
},
]
}
/// Add the actual token value (for use in creation responses)
pub fn with_token(mut self, access_token: String) -> Self {
self.access_token = Some(access_token);
self
}
}

View File

@@ -7,17 +7,15 @@
// Generated code from schemars violates this rule // Generated code from schemars violates this rule
#![allow(clippy::str_to_string)] #![allow(clippy::str_to_string)]
use std::num::NonZeroUsize; use std::{borrow::Cow, num::NonZeroUsize};
use aide::OperationIo; use aide::OperationIo;
use axum::{ use axum::{
Json, Json,
extract::{ extract::{FromRequestParts, Path, rejection::PathRejection},
FromRequestParts, Path, Query,
rejection::{PathRejection, QueryRejection},
},
response::IntoResponse, response::IntoResponse,
}; };
use axum_extra::extract::{Query, QueryRejection};
use axum_macros::FromRequestParts; use axum_macros::FromRequestParts;
use hyper::StatusCode; use hyper::StatusCode;
use mas_storage::pagination::PaginationDirection; use mas_storage::pagination::PaginationDirection;
@@ -64,6 +62,34 @@ impl std::ops::Deref for UlidPathParam {
/// The default page size if not specified /// The default page size if not specified
const DEFAULT_PAGE_SIZE: usize = 10; const DEFAULT_PAGE_SIZE: usize = 10;
#[derive(Deserialize, JsonSchema, Clone, Copy, Default, Debug)]
pub enum IncludeCount {
/// Include the total number of items (default)
#[default]
#[serde(rename = "true")]
True,
/// Do not include the total number of items
#[serde(rename = "false")]
False,
/// Only include the total number of items, skip the items themselves
#[serde(rename = "only")]
Only,
}
impl IncludeCount {
pub(crate) fn add_to_base(self, base: &str) -> Cow<'_, str> {
let separator = if base.contains('?') { '&' } else { '?' };
match self {
// This is the default, don't add anything
Self::True => Cow::Borrowed(base),
Self::False => format!("{base}{separator}count=false").into(),
Self::Only => format!("{base}{separator}count=only").into(),
}
}
}
#[derive(Deserialize, JsonSchema, Clone, Copy)] #[derive(Deserialize, JsonSchema, Clone, Copy)]
struct PaginationParams { struct PaginationParams {
/// Retrieve the items before the given ID /// Retrieve the items before the given ID
@@ -83,6 +109,10 @@ struct PaginationParams {
/// Retrieve the last N items /// Retrieve the last N items
#[serde(rename = "page[last]")] #[serde(rename = "page[last]")]
last: Option<NonZeroUsize>, last: Option<NonZeroUsize>,
/// Include the total number of items. Defaults to `true`.
#[serde(rename = "count")]
include_count: Option<IncludeCount>,
} }
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
@@ -107,7 +137,7 @@ impl IntoResponse for PaginationRejection {
/// An extractor for pagination parameters in the query string /// An extractor for pagination parameters in the query string
#[derive(OperationIo, Debug, Clone, Copy)] #[derive(OperationIo, Debug, Clone, Copy)]
#[aide(input_with = "Query<PaginationParams>")] #[aide(input_with = "Query<PaginationParams>")]
pub struct Pagination(pub mas_storage::Pagination); pub struct Pagination(pub mas_storage::Pagination, pub IncludeCount);
impl<S: Send + Sync> FromRequestParts<S> for Pagination { impl<S: Send + Sync> FromRequestParts<S> for Pagination {
type Rejection = PaginationRejection; type Rejection = PaginationRejection;
@@ -130,11 +160,14 @@ impl<S: Send + Sync> FromRequestParts<S> for Pagination {
(None, Some(last)) => (PaginationDirection::Backward, last.into()), (None, Some(last)) => (PaginationDirection::Backward, last.into()),
}; };
Ok(Self(mas_storage::Pagination { Ok(Self(
before: params.before, mas_storage::Pagination {
after: params.after, before: params.before,
direction, after: params.after,
count, direction,
})) count,
},
params.include_count.unwrap_or_default(),
))
} }
} }

View File

@@ -6,7 +6,7 @@
#![allow(clippy::module_name_repetitions)] #![allow(clippy::module_name_repetitions)]
use mas_storage::Pagination; use mas_storage::{Pagination, pagination::Edge};
use schemars::JsonSchema; use schemars::JsonSchema;
use serde::Serialize; use serde::Serialize;
use ulid::Ulid; use ulid::Ulid;
@@ -21,10 +21,12 @@ struct PaginationLinks {
self_: String, self_: String,
/// The link to the first page of results /// The link to the first page of results
first: String, #[serde(skip_serializing_if = "Option::is_none")]
first: Option<String>,
/// The link to the last page of results /// The link to the last page of results
last: String, #[serde(skip_serializing_if = "Option::is_none")]
last: Option<String>,
/// The link to the next page of results /// The link to the next page of results
/// ///
@@ -42,17 +44,27 @@ struct PaginationLinks {
#[derive(Serialize, JsonSchema)] #[derive(Serialize, JsonSchema)]
struct PaginationMeta { struct PaginationMeta {
/// The total number of results /// The total number of results
count: usize, #[serde(skip_serializing_if = "Option::is_none")]
count: Option<usize>,
}
impl PaginationMeta {
fn is_empty(&self) -> bool {
self.count.is_none()
}
} }
/// A top-level response with a page of resources /// A top-level response with a page of resources
#[derive(Serialize, JsonSchema)] #[derive(Serialize, JsonSchema)]
pub struct PaginatedResponse<T> { pub struct PaginatedResponse<T> {
/// Response metadata /// Response metadata
#[serde(skip_serializing_if = "PaginationMeta::is_empty")]
#[schemars(with = "Option<PaginationMeta>")]
meta: PaginationMeta, meta: PaginationMeta,
/// The list of resources /// The list of resources
data: Vec<SingleResource<T>>, #[serde(skip_serializing_if = "Option::is_none")]
data: Option<Vec<SingleResource<T>>>,
/// Related links /// Related links
links: PaginationLinks, links: PaginationLinks,
@@ -87,22 +99,28 @@ fn url_with_pagination(base: &str, pagination: Pagination) -> String {
} }
impl<T: Resource> PaginatedResponse<T> { impl<T: Resource> PaginatedResponse<T> {
pub fn new( pub fn for_page(
page: mas_storage::Page<T>, page: mas_storage::Page<T>,
current_pagination: Pagination, current_pagination: Pagination,
count: usize, count: Option<usize>,
base: &str, base: &str,
) -> Self { ) -> Self {
let links = PaginationLinks { let links = PaginationLinks {
self_: url_with_pagination(base, current_pagination), self_: url_with_pagination(base, current_pagination),
first: url_with_pagination(base, Pagination::first(current_pagination.count)), first: Some(url_with_pagination(
last: url_with_pagination(base, Pagination::last(current_pagination.count)), base,
Pagination::first(current_pagination.count),
)),
last: Some(url_with_pagination(
base,
Pagination::last(current_pagination.count),
)),
next: page.has_next_page.then(|| { next: page.has_next_page.then(|| {
url_with_pagination( url_with_pagination(
base, base,
current_pagination current_pagination
.clear_before() .clear_before()
.after(page.edges.last().unwrap().id()), .after(page.edges.last().unwrap().cursor),
) )
}), }),
prev: if page.has_previous_page { prev: if page.has_previous_page {
@@ -110,18 +128,38 @@ impl<T: Resource> PaginatedResponse<T> {
base, base,
current_pagination current_pagination
.clear_after() .clear_after()
.before(page.edges.first().unwrap().id()), .before(page.edges.first().unwrap().cursor),
)) ))
} else { } else {
None None
}, },
}; };
let data = page.edges.into_iter().map(SingleResource::new).collect(); let data = page
.edges
.into_iter()
.map(SingleResource::from_edge)
.collect();
Self { Self {
meta: PaginationMeta { count }, meta: PaginationMeta { count },
data, data: Some(data),
links,
}
}
pub fn for_count_only(count: usize, base: &str) -> Self {
let links = PaginationLinks {
self_: base.to_owned(),
first: None,
last: None,
next: None,
prev: None,
};
Self {
meta: PaginationMeta { count: Some(count) },
data: None,
links, links,
} }
} }
@@ -143,6 +181,32 @@ struct SingleResource<T> {
/// Related links /// Related links
links: SelfLinks, links: SelfLinks,
/// Metadata about the resource
#[serde(skip_serializing_if = "SingleResourceMeta::is_empty")]
#[schemars(with = "Option<SingleResourceMeta>")]
meta: SingleResourceMeta,
}
/// Metadata associated with a resource
#[derive(Serialize, JsonSchema)]
struct SingleResourceMeta {
/// Information about the pagination of the resource
#[serde(skip_serializing_if = "Option::is_none")]
page: Option<SingleResourceMetaPage>,
}
impl SingleResourceMeta {
fn is_empty(&self) -> bool {
self.page.is_none()
}
}
/// Pagination metadata for a resource
#[derive(Serialize, JsonSchema)]
struct SingleResourceMetaPage {
/// The cursor of this resource in the paginated result
cursor: String,
} }
impl<T: Resource> SingleResource<T> { impl<T: Resource> SingleResource<T> {
@@ -153,8 +217,16 @@ impl<T: Resource> SingleResource<T> {
id: resource.id(), id: resource.id(),
attributes: resource, attributes: resource,
links: SelfLinks { self_ }, links: SelfLinks { self_ },
meta: SingleResourceMeta { page: None },
} }
} }
fn from_edge<C: ToString>(edge: Edge<T, C>) -> Self {
let cursor = edge.cursor.to_string();
let mut resource = Self::new(edge.node);
resource.meta.page = Some(SingleResourceMetaPage { cursor });
resource
}
} }
/// Related links /// Related links

View File

@@ -6,11 +6,9 @@
//! Common schema definitions //! Common schema definitions
use schemars::{ use std::borrow::Cow;
JsonSchema,
r#gen::SchemaGenerator, use schemars::{JsonSchema, Schema, SchemaGenerator, json_schema};
schema::{InstanceType, Metadata, Schema, SchemaObject, StringValidation},
};
/// A type to use for schema definitions of ULIDs /// A type to use for schema definitions of ULIDs
/// ///
@@ -18,32 +16,21 @@ use schemars::{
pub struct Ulid; pub struct Ulid;
impl JsonSchema for Ulid { impl JsonSchema for Ulid {
fn schema_name() -> String { fn schema_name() -> Cow<'static, str> {
"ULID".to_owned() Cow::Borrowed("ULID")
} }
fn json_schema(_gen: &mut SchemaGenerator) -> Schema { fn json_schema(_gen: &mut SchemaGenerator) -> Schema {
SchemaObject { json_schema!({
instance_type: Some(InstanceType::String.into()), "type": "string",
"title": "ULID",
metadata: Some(Box::new(Metadata { "description": "A ULID as per https://github.com/ulid/spec",
title: Some("ULID".into()), "examples": [
description: Some("A ULID as per https://github.com/ulid/spec".into()), "01ARZ3NDEKTSV4RRFFQ69G5FAV",
examples: vec![ "01J41912SC8VGAQDD50F6APK91",
"01ARZ3NDEKTSV4RRFFQ69G5FAV".into(), ],
"01J41912SC8VGAQDD50F6APK91".into(), "pattern": "^[0123456789ABCDEFGHJKMNPQRSTVWXYZ]{26}$",
], })
..Metadata::default()
})),
string: Some(Box::new(StringValidation {
pattern: Some(r"^[0123456789ABCDEFGHJKMNPQRSTVWXYZ]{26}$".into()),
..StringValidation::default()
})),
..SchemaObject::default()
}
.into()
} }
} }
@@ -53,27 +40,20 @@ impl JsonSchema for Ulid {
pub struct Device; pub struct Device;
impl JsonSchema for Device { impl JsonSchema for Device {
fn schema_name() -> String { fn schema_name() -> Cow<'static, str> {
"DeviceID".to_owned() Cow::Borrowed("DeviceID")
} }
fn json_schema(_gen: &mut SchemaGenerator) -> Schema { fn json_schema(_gen: &mut SchemaGenerator) -> Schema {
SchemaObject { json_schema!({
instance_type: Some(InstanceType::String.into()), "type": "string",
"title": "Device ID",
metadata: Some(Box::new(Metadata { "description": "A device ID as per https://matrix.org/docs/spec/client_server/r0.6.0#device-ids",
title: Some("Device ID".into()), "examples": [
examples: vec!["AABBCCDDEE".into(), "FFGGHHIIJJ".into()], "AABBCCDDEE",
..Metadata::default() "FFGGHHIIJJ",
})), ],
"pattern": "^[A-Za-z0-9._~!$&'()*+,;=:&/-]+$",
string: Some(Box::new(StringValidation { })
pattern: Some(r"^[A-Za-z0-9._~!$&'()*+,;=:&/-]+$".into()),
..StringValidation::default()
})),
..SchemaObject::default()
}
.into()
} }
} }

View File

@@ -0,0 +1,243 @@
// Copyright 2025 New Vector Ltd.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
use aide::{NoApi, OperationIo, transform::TransformOperation};
use axum::{Json, response::IntoResponse};
use hyper::StatusCode;
use mas_axum_utils::record_error;
use mas_data_model::BoxRng;
use mas_storage::queue::{QueueJobRepositoryExt as _, SyncDevicesJob};
use ulid::Ulid;
use crate::{
admin::{
call_context::CallContext,
model::{CompatSession, Resource},
params::UlidPathParam,
response::{ErrorResponse, SingleResponse},
},
impl_from_error_for_route,
};
#[derive(Debug, thiserror::Error, OperationIo)]
#[aide(output_with = "Json<ErrorResponse>")]
pub enum RouteError {
#[error(transparent)]
Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("Compatibility session with ID {0} not found")]
NotFound(Ulid),
#[error("Compatibility session with ID {0} is already finished")]
AlreadyFinished(Ulid),
}
impl_from_error_for_route!(mas_storage::RepositoryError);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
let error = ErrorResponse::from_error(&self);
let sentry_event_id = record_error!(self, Self::Internal(_));
let status = match self {
Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
Self::NotFound(_) => StatusCode::NOT_FOUND,
Self::AlreadyFinished(_) => StatusCode::BAD_REQUEST,
};
(status, sentry_event_id, Json(error)).into_response()
}
}
pub fn doc(operation: TransformOperation) -> TransformOperation {
operation
.id("finishCompatSession")
.summary("Finish a compatibility session")
.description(
"Calling this endpoint will finish the compatibility session, preventing any further use. A job will be scheduled to sync the user's devices with the homeserver.",
)
.tag("compat-session")
.response_with::<200, Json<SingleResponse<CompatSession>>, _>(|t| {
// Get the finished session sample
let [_, finished_session, _] = CompatSession::samples();
let id = finished_session.id();
let response = SingleResponse::new(
finished_session,
format!("/api/admin/v1/compat-sessions/{id}/finish"),
);
t.description("Compatibility session was finished").example(response)
})
.response_with::<400, RouteError, _>(|t| {
let response = ErrorResponse::from_error(&RouteError::AlreadyFinished(Ulid::nil()));
t.description("Session is already finished")
.example(response)
})
.response_with::<404, RouteError, _>(|t| {
let response = ErrorResponse::from_error(&RouteError::NotFound(Ulid::nil()));
t.description("Compatibility session was not found")
.example(response)
})
}
#[tracing::instrument(name = "handler.admin.v1.compat_sessions.finish", skip_all)]
pub async fn handler(
CallContext {
mut repo, clock, ..
}: CallContext,
NoApi(mut rng): NoApi<BoxRng>,
id: UlidPathParam,
) -> Result<Json<SingleResponse<CompatSession>>, RouteError> {
let id = *id;
let session = repo
.compat_session()
.lookup(id)
.await?
.ok_or(RouteError::NotFound(id))?;
// Check if the session is already finished
if session.finished_at().is_some() {
return Err(RouteError::AlreadyFinished(id));
}
// Schedule a job to sync the devices of the user with the homeserver
tracing::info!(user.id = %session.user_id, "Scheduling device sync job for user");
repo.queue_job()
.schedule_job(
&mut rng,
&clock,
SyncDevicesJob::new_for_id(session.user_id),
)
.await?;
// Finish the session
let session = repo.compat_session().finish(&clock, session).await?;
// Get the SSO login info for the response
let sso_login = repo.compat_sso_login().find_for_session(&session).await?;
repo.save().await?;
Ok(Json(SingleResponse::new(
CompatSession::from((session, sso_login)),
format!("/api/admin/v1/compat-sessions/{id}/finish"),
)))
}
#[cfg(test)]
mod tests {
use chrono::Duration;
use hyper::{Request, StatusCode};
use mas_data_model::{Clock as _, Device};
use sqlx::PgPool;
use crate::test_utils::{RequestBuilderExt, ResponseExt, TestState, setup};
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_finish_session(pool: PgPool) {
setup();
let mut state = TestState::from_pool(pool).await.unwrap();
let token = state.token_with_scope("urn:mas:admin").await;
let mut rng = state.rng();
// Provision a user and a compat session
let mut repo = state.repository().await.unwrap();
let user = repo
.user()
.add(&mut rng, &state.clock, "alice".to_owned())
.await
.unwrap();
let device = Device::generate(&mut rng);
let session = repo
.compat_session()
.add(&mut rng, &state.clock, &user, device, None, false, None)
.await
.unwrap();
repo.save().await.unwrap();
let request = Request::post(format!(
"/api/admin/v1/compat-sessions/{}/finish",
session.id
))
.bearer(&token)
.empty();
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let body: serde_json::Value = response.json();
// The finished_at timestamp should be the same as the current time
assert_eq!(
body["data"]["attributes"]["finished_at"],
serde_json::json!(state.clock.now())
);
}
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_finish_already_finished_session(pool: PgPool) {
setup();
let mut state = TestState::from_pool(pool).await.unwrap();
let token = state.token_with_scope("urn:mas:admin").await;
let mut rng = state.rng();
// Provision a user and a compat session
let mut repo = state.repository().await.unwrap();
let user = repo
.user()
.add(&mut rng, &state.clock, "alice".to_owned())
.await
.unwrap();
let device = Device::generate(&mut rng);
let session = repo
.compat_session()
.add(&mut rng, &state.clock, &user, device, None, false, None)
.await
.unwrap();
// Finish the session first
let session = repo
.compat_session()
.finish(&state.clock, session)
.await
.unwrap();
repo.save().await.unwrap();
// Move the clock forward
state.clock.advance(Duration::try_minutes(1).unwrap());
let request = Request::post(format!(
"/api/admin/v1/compat-sessions/{}/finish",
session.id
))
.bearer(&token)
.empty();
let response = state.request(request).await;
response.assert_status(StatusCode::BAD_REQUEST);
let body: serde_json::Value = response.json();
assert_eq!(
body["errors"][0]["title"],
format!(
"Compatibility session with ID {} is already finished",
session.id
)
);
}
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_finish_unknown_session(pool: PgPool) {
setup();
let mut state = TestState::from_pool(pool).await.unwrap();
let token = state.token_with_scope("urn:mas:admin").await;
let request =
Request::post("/api/admin/v1/compat-sessions/01040G2081040G2081040G2081/finish")
.bearer(&token)
.empty();
let response = state.request(request).await;
response.assert_status(StatusCode::NOT_FOUND);
let body: serde_json::Value = response.json();
assert_eq!(
body["errors"][0]["title"],
"Compatibility session with ID 01040G2081040G2081040G2081 not found"
);
}
}

View File

@@ -4,11 +4,8 @@
// Please see LICENSE files in the repository root for full details. // Please see LICENSE files in the repository root for full details.
use aide::{OperationIo, transform::TransformOperation}; use aide::{OperationIo, transform::TransformOperation};
use axum::{ use axum::{Json, response::IntoResponse};
Json, use axum_extra::extract::{Query, QueryRejection};
extract::{Query, rejection::QueryRejection},
response::IntoResponse,
};
use axum_macros::FromRequestParts; use axum_macros::FromRequestParts;
use hyper::StatusCode; use hyper::StatusCode;
use mas_axum_utils::record_error; use mas_axum_utils::record_error;
@@ -21,7 +18,7 @@ use crate::{
admin::{ admin::{
call_context::CallContext, call_context::CallContext,
model::{CompatSession, Resource}, model::{CompatSession, Resource},
params::Pagination, params::{IncludeCount, Pagination},
response::{ErrorResponse, PaginatedResponse}, response::{ErrorResponse, PaginatedResponse},
}, },
impl_from_error_for_route, impl_from_error_for_route,
@@ -137,16 +134,22 @@ Use the `filter[status]` parameter to filter the sessions by their status and `p
let sessions = CompatSession::samples(); let sessions = CompatSession::samples();
let pagination = mas_storage::Pagination::first(sessions.len()); let pagination = mas_storage::Pagination::first(sessions.len());
let page = Page { let page = Page {
edges: sessions.into(), edges: sessions
.into_iter()
.map(|node| mas_storage::pagination::Edge {
cursor: node.id(),
node,
})
.collect(),
has_next_page: true, has_next_page: true,
has_previous_page: false, has_previous_page: false,
}; };
t.description("Paginated response of compatibility sessions") t.description("Paginated response of compatibility sessions")
.example(PaginatedResponse::new( .example(PaginatedResponse::for_page(
page, page,
pagination, pagination,
42, Some(42),
CompatSession::PATH, CompatSession::PATH,
)) ))
}) })
@@ -159,10 +162,11 @@ Use the `filter[status]` parameter to filter the sessions by their status and `p
#[tracing::instrument(name = "handler.admin.v1.compat_sessions.list", skip_all)] #[tracing::instrument(name = "handler.admin.v1.compat_sessions.list", skip_all)]
pub async fn handler( pub async fn handler(
CallContext { mut repo, .. }: CallContext, CallContext { mut repo, .. }: CallContext,
Pagination(pagination): Pagination, Pagination(pagination, include_count): Pagination,
params: FilterParams, params: FilterParams,
) -> Result<Json<PaginatedResponse<CompatSession>>, RouteError> { ) -> Result<Json<PaginatedResponse<CompatSession>>, RouteError> {
let base = format!("{path}{params}", path = CompatSession::PATH); let base = format!("{path}{params}", path = CompatSession::PATH);
let base = include_count.add_to_base(&base);
let filter = CompatSessionFilter::default(); let filter = CompatSessionFilter::default();
// Load the user from the filter // Load the user from the filter
@@ -206,15 +210,31 @@ pub async fn handler(
None => filter, None => filter,
}; };
let page = repo.compat_session().list(filter, pagination).await?; let response = match include_count {
let count = repo.compat_session().count(filter).await?; IncludeCount::True => {
let page = repo
.compat_session()
.list(filter, pagination)
.await?
.map(CompatSession::from);
let count = repo.compat_session().count(filter).await?;
PaginatedResponse::for_page(page, pagination, Some(count), &base)
}
IncludeCount::False => {
let page = repo
.compat_session()
.list(filter, pagination)
.await?
.map(CompatSession::from);
PaginatedResponse::for_page(page, pagination, None, &base)
}
IncludeCount::Only => {
let count = repo.compat_session().count(filter).await?;
PaginatedResponse::for_count_only(count, &base)
}
};
Ok(Json(PaginatedResponse::new( Ok(Json(response))
page.map(CompatSession::from),
pagination,
count,
&base,
)))
} }
#[cfg(test)] #[cfg(test)]
@@ -299,6 +319,11 @@ mod tests {
}, },
"links": { "links": {
"self": "/api/admin/v1/compat-sessions/01FSHNB530AAPR7PEV8KNBZD5Y" "self": "/api/admin/v1/compat-sessions/01FSHNB530AAPR7PEV8KNBZD5Y"
},
"meta": {
"page": {
"cursor": "01FSHNB530AAPR7PEV8KNBZD5Y"
}
} }
}, },
{ {
@@ -318,6 +343,11 @@ mod tests {
}, },
"links": { "links": {
"self": "/api/admin/v1/compat-sessions/01FSHNCZP0PPF7X0EVMJNECPZW" "self": "/api/admin/v1/compat-sessions/01FSHNCZP0PPF7X0EVMJNECPZW"
},
"meta": {
"page": {
"cursor": "01FSHNCZP0PPF7X0EVMJNECPZW"
}
} }
} }
], ],
@@ -362,6 +392,11 @@ mod tests {
}, },
"links": { "links": {
"self": "/api/admin/v1/compat-sessions/01FSHNB530AAPR7PEV8KNBZD5Y" "self": "/api/admin/v1/compat-sessions/01FSHNB530AAPR7PEV8KNBZD5Y"
},
"meta": {
"page": {
"cursor": "01FSHNB530AAPR7PEV8KNBZD5Y"
}
} }
} }
], ],
@@ -403,6 +438,11 @@ mod tests {
}, },
"links": { "links": {
"self": "/api/admin/v1/compat-sessions/01FSHNB530AAPR7PEV8KNBZD5Y" "self": "/api/admin/v1/compat-sessions/01FSHNB530AAPR7PEV8KNBZD5Y"
},
"meta": {
"page": {
"cursor": "01FSHNB530AAPR7PEV8KNBZD5Y"
}
} }
} }
], ],
@@ -444,6 +484,11 @@ mod tests {
}, },
"links": { "links": {
"self": "/api/admin/v1/compat-sessions/01FSHNCZP0PPF7X0EVMJNECPZW" "self": "/api/admin/v1/compat-sessions/01FSHNCZP0PPF7X0EVMJNECPZW"
},
"meta": {
"page": {
"cursor": "01FSHNCZP0PPF7X0EVMJNECPZW"
}
} }
} }
], ],
@@ -454,5 +499,155 @@ mod tests {
} }
} }
"#); "#);
// Test count=false
let request = Request::get("/api/admin/v1/compat-sessions?count=false")
.bearer(&token)
.empty();
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let body: serde_json::Value = response.json();
assert_json_snapshot!(body, @r#"
{
"data": [
{
"type": "compat-session",
"id": "01FSHNB530AAPR7PEV8KNBZD5Y",
"attributes": {
"user_id": "01FSHN9AG0MZAA6S4AF7CTV32E",
"device_id": "LoieH5Iecx",
"user_session_id": null,
"redirect_uri": null,
"created_at": "2022-01-16T14:41:00Z",
"user_agent": null,
"last_active_at": null,
"last_active_ip": null,
"finished_at": null,
"human_name": null
},
"links": {
"self": "/api/admin/v1/compat-sessions/01FSHNB530AAPR7PEV8KNBZD5Y"
},
"meta": {
"page": {
"cursor": "01FSHNB530AAPR7PEV8KNBZD5Y"
}
}
},
{
"type": "compat-session",
"id": "01FSHNCZP0PPF7X0EVMJNECPZW",
"attributes": {
"user_id": "01FSHNB530AJ6AC5HQ9X6H4RP4",
"device_id": "ZXyvelQWW9",
"user_session_id": null,
"redirect_uri": null,
"created_at": "2022-01-16T14:42:00Z",
"user_agent": null,
"last_active_at": null,
"last_active_ip": null,
"finished_at": "2022-01-16T14:43:00Z",
"human_name": null
},
"links": {
"self": "/api/admin/v1/compat-sessions/01FSHNCZP0PPF7X0EVMJNECPZW"
},
"meta": {
"page": {
"cursor": "01FSHNCZP0PPF7X0EVMJNECPZW"
}
}
}
],
"links": {
"self": "/api/admin/v1/compat-sessions?count=false&page[first]=10",
"first": "/api/admin/v1/compat-sessions?count=false&page[first]=10",
"last": "/api/admin/v1/compat-sessions?count=false&page[last]=10"
}
}
"#);
// Test count=only
let request = Request::get("/api/admin/v1/compat-sessions?count=only")
.bearer(&token)
.empty();
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let body: serde_json::Value = response.json();
assert_json_snapshot!(body, @r#"
{
"meta": {
"count": 2
},
"links": {
"self": "/api/admin/v1/compat-sessions?count=only"
}
}
"#);
// Test count=false with filtering
let request = Request::get(format!(
"/api/admin/v1/compat-sessions?count=false&filter[user]={}",
alice.id
))
.bearer(&token)
.empty();
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let body: serde_json::Value = response.json();
assert_json_snapshot!(body, @r#"
{
"data": [
{
"type": "compat-session",
"id": "01FSHNB530AAPR7PEV8KNBZD5Y",
"attributes": {
"user_id": "01FSHN9AG0MZAA6S4AF7CTV32E",
"device_id": "LoieH5Iecx",
"user_session_id": null,
"redirect_uri": null,
"created_at": "2022-01-16T14:41:00Z",
"user_agent": null,
"last_active_at": null,
"last_active_ip": null,
"finished_at": null,
"human_name": null
},
"links": {
"self": "/api/admin/v1/compat-sessions/01FSHNB530AAPR7PEV8KNBZD5Y"
},
"meta": {
"page": {
"cursor": "01FSHNB530AAPR7PEV8KNBZD5Y"
}
}
}
],
"links": {
"self": "/api/admin/v1/compat-sessions?filter[user]=01FSHN9AG0MZAA6S4AF7CTV32E&count=false&page[first]=10",
"first": "/api/admin/v1/compat-sessions?filter[user]=01FSHN9AG0MZAA6S4AF7CTV32E&count=false&page[first]=10",
"last": "/api/admin/v1/compat-sessions?filter[user]=01FSHN9AG0MZAA6S4AF7CTV32E&count=false&page[last]=10"
}
}
"#);
// Test count=only with filtering
let request =
Request::get("/api/admin/v1/compat-sessions?count=only&filter[status]=active")
.bearer(&token)
.empty();
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let body: serde_json::Value = response.json();
assert_json_snapshot!(body, @r#"
{
"meta": {
"count": 1
},
"links": {
"self": "/api/admin/v1/compat-sessions?filter[status]=active&count=only"
}
}
"#);
} }
} }

View File

@@ -3,10 +3,12 @@
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial // SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details. // Please see LICENSE files in the repository root for full details.
mod finish;
mod get; mod get;
mod list; mod list;
pub use self::{ pub use self::{
finish::{doc as finish_doc, handler as finish},
get::{doc as get_doc, handler as get}, get::{doc as get_doc, handler as get},
list::{doc as list_doc, handler as list}, list::{doc as list_doc, handler as list},
}; };

View File

@@ -11,7 +11,7 @@ use aide::axum::{
routing::{get_with, post_with}, routing::{get_with, post_with},
}; };
use axum::extract::{FromRef, FromRequestParts}; use axum::extract::{FromRef, FromRequestParts};
use mas_data_model::BoxRng; use mas_data_model::{AppVersion, BoxRng, SiteConfig};
use mas_matrix::HomeserverConnection; use mas_matrix::HomeserverConnection;
use mas_policy::PolicyFactory; use mas_policy::PolicyFactory;
@@ -20,23 +20,37 @@ use crate::passwords::PasswordManager;
mod compat_sessions; mod compat_sessions;
mod oauth2_sessions; mod oauth2_sessions;
mod personal_sessions;
mod policy_data; mod policy_data;
mod site_config;
mod upstream_oauth_links; mod upstream_oauth_links;
mod upstream_oauth_providers;
mod user_emails; mod user_emails;
mod user_registration_tokens; mod user_registration_tokens;
mod user_sessions; mod user_sessions;
mod users; mod users;
mod version;
pub fn router<S>() -> ApiRouter<S> pub fn router<S>() -> ApiRouter<S>
where where
S: Clone + Send + Sync + 'static, S: Clone + Send + Sync + 'static,
Arc<dyn HomeserverConnection>: FromRef<S>, Arc<dyn HomeserverConnection>: FromRef<S>,
PasswordManager: FromRef<S>, PasswordManager: FromRef<S>,
SiteConfig: FromRef<S>,
AppVersion: FromRef<S>,
Arc<PolicyFactory>: FromRef<S>, Arc<PolicyFactory>: FromRef<S>,
BoxRng: FromRequestParts<S>, BoxRng: FromRequestParts<S>,
CallContext: FromRequestParts<S>, CallContext: FromRequestParts<S>,
{ {
ApiRouter::<S>::new() ApiRouter::<S>::new()
.api_route(
"/site-config",
get_with(self::site_config::handler, self::site_config::doc),
)
.api_route(
"/version",
get_with(self::version::handler, self::version::doc),
)
.api_route( .api_route(
"/compat-sessions", "/compat-sessions",
get_with(self::compat_sessions::list, self::compat_sessions::list_doc), get_with(self::compat_sessions::list, self::compat_sessions::list_doc),
@@ -45,6 +59,13 @@ where
"/compat-sessions/{id}", "/compat-sessions/{id}",
get_with(self::compat_sessions::get, self::compat_sessions::get_doc), get_with(self::compat_sessions::get, self::compat_sessions::get_doc),
) )
.api_route(
"/compat-sessions/{id}/finish",
post_with(
self::compat_sessions::finish,
self::compat_sessions::finish_doc,
),
)
.api_route( .api_route(
"/oauth2-sessions", "/oauth2-sessions",
get_with(self::oauth2_sessions::list, self::oauth2_sessions::list_doc), get_with(self::oauth2_sessions::list, self::oauth2_sessions::list_doc),
@@ -53,6 +74,45 @@ where
"/oauth2-sessions/{id}", "/oauth2-sessions/{id}",
get_with(self::oauth2_sessions::get, self::oauth2_sessions::get_doc), get_with(self::oauth2_sessions::get, self::oauth2_sessions::get_doc),
) )
.api_route(
"/oauth2-sessions/{id}/finish",
post_with(
self::oauth2_sessions::finish,
self::oauth2_sessions::finish_doc,
),
)
.api_route(
"/personal-sessions",
get_with(
self::personal_sessions::list,
self::personal_sessions::list_doc,
)
.post_with(
self::personal_sessions::add,
self::personal_sessions::add_doc,
),
)
.api_route(
"/personal-sessions/{id}",
get_with(
self::personal_sessions::get,
self::personal_sessions::get_doc,
),
)
.api_route(
"/personal-sessions/{id}/revoke",
post_with(
self::personal_sessions::revoke,
self::personal_sessions::revoke_doc,
),
)
.api_route(
"/personal-sessions/{id}/regenerate",
post_with(
self::personal_sessions::regenerate,
self::personal_sessions::regenerate_doc,
),
)
.api_route( .api_route(
"/policy-data", "/policy-data",
post_with(self::policy_data::set, self::policy_data::set_doc), post_with(self::policy_data::set, self::policy_data::set_doc),
@@ -123,6 +183,10 @@ where
"/user-sessions/{id}", "/user-sessions/{id}",
get_with(self::user_sessions::get, self::user_sessions::get_doc), get_with(self::user_sessions::get, self::user_sessions::get_doc),
) )
.api_route(
"/user-sessions/{id}/finish",
post_with(self::user_sessions::finish, self::user_sessions::finish_doc),
)
.api_route( .api_route(
"/user-registration-tokens", "/user-registration-tokens",
get_with( get_with(
@@ -181,4 +245,18 @@ where
self::upstream_oauth_links::delete_doc, self::upstream_oauth_links::delete_doc,
), ),
) )
.api_route(
"/upstream-oauth-providers",
get_with(
self::upstream_oauth_providers::list,
self::upstream_oauth_providers::list_doc,
),
)
.api_route(
"/upstream-oauth-providers/{id}",
get_with(
self::upstream_oauth_providers::get,
self::upstream_oauth_providers::get_doc,
),
)
} }

View File

@@ -0,0 +1,234 @@
// Copyright 2025 New Vector Ltd.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
use aide::{NoApi, OperationIo, transform::TransformOperation};
use axum::{Json, response::IntoResponse};
use hyper::StatusCode;
use mas_axum_utils::record_error;
use mas_data_model::BoxRng;
use mas_storage::queue::{QueueJobRepositoryExt as _, SyncDevicesJob};
use ulid::Ulid;
use crate::{
admin::{
call_context::CallContext,
model::{OAuth2Session, Resource},
params::UlidPathParam,
response::{ErrorResponse, SingleResponse},
},
impl_from_error_for_route,
};
#[derive(Debug, thiserror::Error, OperationIo)]
#[aide(output_with = "Json<ErrorResponse>")]
pub enum RouteError {
#[error(transparent)]
Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("OAuth 2.0 session with ID {0} not found")]
NotFound(Ulid),
#[error("OAuth 2.0 session with ID {0} is already finished")]
AlreadyFinished(Ulid),
}
impl_from_error_for_route!(mas_storage::RepositoryError);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
let error = ErrorResponse::from_error(&self);
let sentry_event_id = record_error!(self, Self::Internal(_));
let status = match self {
Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
Self::NotFound(_) => StatusCode::NOT_FOUND,
Self::AlreadyFinished(_) => StatusCode::BAD_REQUEST,
};
(status, sentry_event_id, Json(error)).into_response()
}
}
pub fn doc(operation: TransformOperation) -> TransformOperation {
operation
.id("finishOAuth2Session")
.summary("Finish an OAuth 2.0 session")
.description(
"Calling this endpoint will finish the OAuth 2.0 session, preventing any further use. If the session has a user associated with it, a job will be scheduled to sync the user's devices with the homeserver.",
)
.tag("oauth2-session")
.response_with::<200, Json<SingleResponse<OAuth2Session>>, _>(|t| {
// Get the finished session sample
let [_, _, finished_session] = OAuth2Session::samples();
let id = finished_session.id();
let response = SingleResponse::new(
finished_session,
format!("/api/admin/v1/oauth2-sessions/{id}/finish"),
);
t.description("OAuth 2.0 session was finished").example(response)
})
.response_with::<400, RouteError, _>(|t| {
let response = ErrorResponse::from_error(&RouteError::AlreadyFinished(Ulid::nil()));
t.description("Session is already finished")
.example(response)
})
.response_with::<404, RouteError, _>(|t| {
let response = ErrorResponse::from_error(&RouteError::NotFound(Ulid::nil()));
t.description("OAuth 2.0 session was not found")
.example(response)
})
}
#[tracing::instrument(name = "handler.admin.v1.oauth2_sessions.finish", skip_all)]
pub async fn handler(
CallContext {
mut repo, clock, ..
}: CallContext,
NoApi(mut rng): NoApi<BoxRng>,
id: UlidPathParam,
) -> Result<Json<SingleResponse<OAuth2Session>>, RouteError> {
let id = *id;
let session = repo
.oauth2_session()
.lookup(id)
.await?
.ok_or(RouteError::NotFound(id))?;
// Check if the session is already finished
if session.finished_at().is_some() {
return Err(RouteError::AlreadyFinished(id));
}
// If the session has a user associated with it, schedule a job to sync devices
if let Some(user_id) = session.user_id {
tracing::info!(user.id = %user_id, "Scheduling device sync job for user");
let job = SyncDevicesJob::new_for_id(user_id);
repo.queue_job().schedule_job(&mut rng, &clock, job).await?;
}
// Finish the session
let session = repo.oauth2_session().finish(&clock, session).await?;
repo.save().await?;
Ok(Json(SingleResponse::new(
OAuth2Session::from(session),
format!("/api/admin/v1/oauth2-sessions/{id}/finish"),
)))
}
#[cfg(test)]
mod tests {
use chrono::Duration;
use hyper::{Request, StatusCode};
use mas_data_model::{AccessToken, Clock as _};
use sqlx::PgPool;
use crate::test_utils::{RequestBuilderExt, ResponseExt, TestState, setup};
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_finish_session(pool: PgPool) {
setup();
let mut state = TestState::from_pool(pool).await.unwrap();
let token = state.token_with_scope("urn:mas:admin").await;
// Get the session ID from the token we just created
let mut repo = state.repository().await.unwrap();
let AccessToken { session_id, .. } = repo
.oauth2_access_token()
.find_by_token(&token)
.await
.unwrap()
.unwrap();
repo.save().await.unwrap();
let request = Request::post(format!("/api/admin/v1/oauth2-sessions/{session_id}/finish"))
.bearer(&token)
.empty();
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let body: serde_json::Value = response.json();
// The finished_at timestamp should be the same as the current time
assert_eq!(
body["data"]["attributes"]["finished_at"],
serde_json::json!(state.clock.now())
);
}
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_finish_already_finished_session(pool: PgPool) {
setup();
let mut state = TestState::from_pool(pool).await.unwrap();
// Create first admin token for the API call
let admin_token = state.token_with_scope("urn:mas:admin").await;
// Create a second admin session that we'll finish
let second_admin_token = state.token_with_scope("urn:mas:admin").await;
// Get the second session and finish it first
let mut repo = state.repository().await.unwrap();
let AccessToken { session_id, .. } = repo
.oauth2_access_token()
.find_by_token(&second_admin_token)
.await
.unwrap()
.unwrap();
let session = repo
.oauth2_session()
.lookup(session_id)
.await
.unwrap()
.unwrap();
// Finish the session first
let session = repo
.oauth2_session()
.finish(&state.clock, session)
.await
.unwrap();
repo.save().await.unwrap();
// Move the clock forward
state.clock.advance(Duration::try_minutes(1).unwrap());
let request = Request::post(format!(
"/api/admin/v1/oauth2-sessions/{}/finish",
session.id
))
.bearer(&admin_token)
.empty();
let response = state.request(request).await;
response.assert_status(StatusCode::BAD_REQUEST);
let body: serde_json::Value = response.json();
assert_eq!(
body["errors"][0]["title"],
format!(
"OAuth 2.0 session with ID {} is already finished",
session.id
)
);
}
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_finish_unknown_session(pool: PgPool) {
setup();
let mut state = TestState::from_pool(pool).await.unwrap();
let token = state.token_with_scope("urn:mas:admin").await;
let request =
Request::post("/api/admin/v1/oauth2-sessions/01040G2081040G2081040G2081/finish")
.bearer(&token)
.empty();
let response = state.request(request).await;
response.assert_status(StatusCode::NOT_FOUND);
let body: serde_json::Value = response.json();
assert_eq!(
body["errors"][0]["title"],
"OAuth 2.0 session with ID 01040G2081040G2081040G2081 not found"
);
}
}

View File

@@ -7,11 +7,8 @@
use std::str::FromStr; use std::str::FromStr;
use aide::{OperationIo, transform::TransformOperation}; use aide::{OperationIo, transform::TransformOperation};
use axum::{ use axum::{Json, response::IntoResponse};
Json, use axum_extra::extract::{Query, QueryRejection};
extract::{Query, rejection::QueryRejection},
response::IntoResponse,
};
use axum_macros::FromRequestParts; use axum_macros::FromRequestParts;
use hyper::StatusCode; use hyper::StatusCode;
use mas_axum_utils::record_error; use mas_axum_utils::record_error;
@@ -25,7 +22,7 @@ use crate::{
admin::{ admin::{
call_context::CallContext, call_context::CallContext,
model::{OAuth2Session, Resource}, model::{OAuth2Session, Resource},
params::Pagination, params::{IncludeCount, Pagination},
response::{ErrorResponse, PaginatedResponse}, response::{ErrorResponse, PaginatedResponse},
}, },
impl_from_error_for_route, impl_from_error_for_route,
@@ -192,16 +189,22 @@ Use the `filter[status]` parameter to filter the sessions by their status and `p
let sessions = OAuth2Session::samples(); let sessions = OAuth2Session::samples();
let pagination = mas_storage::Pagination::first(sessions.len()); let pagination = mas_storage::Pagination::first(sessions.len());
let page = Page { let page = Page {
edges: sessions.into(), edges: sessions
.into_iter()
.map(|node| mas_storage::pagination::Edge {
cursor: node.id(),
node,
})
.collect(),
has_next_page: true, has_next_page: true,
has_previous_page: false, has_previous_page: false,
}; };
t.description("Paginated response of OAuth 2.0 sessions") t.description("Paginated response of OAuth 2.0 sessions")
.example(PaginatedResponse::new( .example(PaginatedResponse::for_page(
page, page,
pagination, pagination,
42, Some(42),
OAuth2Session::PATH, OAuth2Session::PATH,
)) ))
}) })
@@ -218,10 +221,11 @@ Use the `filter[status]` parameter to filter the sessions by their status and `p
#[tracing::instrument(name = "handler.admin.v1.oauth2_sessions.list", skip_all)] #[tracing::instrument(name = "handler.admin.v1.oauth2_sessions.list", skip_all)]
pub async fn handler( pub async fn handler(
CallContext { mut repo, .. }: CallContext, CallContext { mut repo, .. }: CallContext,
Pagination(pagination): Pagination, Pagination(pagination, include_count): Pagination,
params: FilterParams, params: FilterParams,
) -> Result<Json<PaginatedResponse<OAuth2Session>>, RouteError> { ) -> Result<Json<PaginatedResponse<OAuth2Session>>, RouteError> {
let base = format!("{path}{params}", path = OAuth2Session::PATH); let base = format!("{path}{params}", path = OAuth2Session::PATH);
let base = include_count.add_to_base(&base);
let filter = OAuth2SessionFilter::default(); let filter = OAuth2SessionFilter::default();
// Load the user from the filter // Load the user from the filter
@@ -300,15 +304,31 @@ pub async fn handler(
None => filter, None => filter,
}; };
let page = repo.oauth2_session().list(filter, pagination).await?; let response = match include_count {
let count = repo.oauth2_session().count(filter).await?; IncludeCount::True => {
let page = repo
.oauth2_session()
.list(filter, pagination)
.await?
.map(OAuth2Session::from);
let count = repo.oauth2_session().count(filter).await?;
PaginatedResponse::for_page(page, pagination, Some(count), &base)
}
IncludeCount::False => {
let page = repo
.oauth2_session()
.list(filter, pagination)
.await?
.map(OAuth2Session::from);
PaginatedResponse::for_page(page, pagination, None, &base)
}
IncludeCount::Only => {
let count = repo.oauth2_session().count(filter).await?;
PaginatedResponse::for_count_only(count, &base)
}
};
Ok(Json(PaginatedResponse::new( Ok(Json(response))
page.map(OAuth2Session::from),
pagination,
count,
&base,
)))
} }
#[cfg(test)] #[cfg(test)]
@@ -354,6 +374,11 @@ mod tests {
}, },
"links": { "links": {
"self": "/api/admin/v1/oauth2-sessions/01FSHN9AG0MKGTBNZ16RDR3PVY" "self": "/api/admin/v1/oauth2-sessions/01FSHN9AG0MKGTBNZ16RDR3PVY"
},
"meta": {
"page": {
"cursor": "01FSHN9AG0MKGTBNZ16RDR3PVY"
}
} }
} }
], ],
@@ -364,5 +389,66 @@ mod tests {
} }
} }
"#); "#);
// Test count=false
let request = Request::get("/api/admin/v1/oauth2-sessions?count=false")
.bearer(&token)
.empty();
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let body: serde_json::Value = response.json();
insta::assert_json_snapshot!(body, @r#"
{
"data": [
{
"type": "oauth2-session",
"id": "01FSHN9AG0MKGTBNZ16RDR3PVY",
"attributes": {
"created_at": "2022-01-16T14:40:00Z",
"finished_at": null,
"user_id": null,
"user_session_id": null,
"client_id": "01FSHN9AG0FAQ50MT1E9FFRPZR",
"scope": "urn:mas:admin",
"user_agent": null,
"last_active_at": null,
"last_active_ip": null,
"human_name": null
},
"links": {
"self": "/api/admin/v1/oauth2-sessions/01FSHN9AG0MKGTBNZ16RDR3PVY"
},
"meta": {
"page": {
"cursor": "01FSHN9AG0MKGTBNZ16RDR3PVY"
}
}
}
],
"links": {
"self": "/api/admin/v1/oauth2-sessions?count=false&page[first]=10",
"first": "/api/admin/v1/oauth2-sessions?count=false&page[first]=10",
"last": "/api/admin/v1/oauth2-sessions?count=false&page[last]=10"
}
}
"#);
// Test count=only
let request = Request::get("/api/admin/v1/oauth2-sessions?count=only")
.bearer(&token)
.empty();
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let body: serde_json::Value = response.json();
insta::assert_json_snapshot!(body, @r#"
{
"meta": {
"count": 1
},
"links": {
"self": "/api/admin/v1/oauth2-sessions?count=only"
}
}
"#);
} }
} }

View File

@@ -4,10 +4,12 @@
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial // SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details. // Please see LICENSE files in the repository root for full details.
mod finish;
mod get; mod get;
mod list; mod list;
pub use self::{ pub use self::{
finish::{doc as finish_doc, handler as finish},
get::{doc as get_doc, handler as get}, get::{doc as get_doc, handler as get},
list::{doc as list_doc, handler as list}, list::{doc as list_doc, handler as list},
}; };

View File

@@ -0,0 +1,311 @@
// Copyright 2025 New Vector Ltd.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
use std::sync::Arc;
use aide::{NoApi, OperationIo, transform::TransformOperation};
use anyhow::Context;
use axum::{Json, extract::State, response::IntoResponse};
use chrono::Duration;
use hyper::StatusCode;
use mas_axum_utils::record_error;
use mas_data_model::{BoxRng, Device, TokenType};
use mas_matrix::HomeserverConnection;
use oauth2_types::scope::Scope;
use schemars::JsonSchema;
use serde::Deserialize;
use ulid::Ulid;
use crate::{
admin::{
call_context::CallContext,
model::{InconsistentPersonalSession, PersonalSession},
response::{ErrorResponse, SingleResponse},
v1::personal_sessions::personal_session_owner_from_caller,
},
impl_from_error_for_route,
};
#[derive(Debug, thiserror::Error, OperationIo)]
#[aide(output_with = "Json<ErrorResponse>")]
pub enum RouteError {
#[error(transparent)]
Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("User not found")]
UserNotFound,
#[error("User is not active")]
UserDeactivated,
#[error("Invalid scope")]
InvalidScope,
}
impl_from_error_for_route!(mas_storage::RepositoryError);
impl_from_error_for_route!(InconsistentPersonalSession);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
let error = ErrorResponse::from_error(&self);
let sentry_event_id = record_error!(self, Self::Internal(_));
let status = match self {
Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
Self::UserNotFound => StatusCode::NOT_FOUND,
Self::UserDeactivated => StatusCode::GONE,
Self::InvalidScope => StatusCode::BAD_REQUEST,
};
(status, sentry_event_id, Json(error)).into_response()
}
}
/// # JSON payload for the `POST /api/admin/v1/personal-sessions` endpoint
#[derive(Deserialize, JsonSchema)]
#[serde(rename = "CreatePersonalSessionRequest")]
pub struct Request {
/// The user this session will act on behalf of
#[schemars(with = "crate::admin::schema::Ulid")]
actor_user_id: Ulid,
/// Human-readable name for the session
human_name: String,
/// `OAuth2` scopes for this session
scope: String,
/// Token expiry time in seconds.
/// If not set, the token won't expire.
expires_in: Option<u32>,
}
pub fn doc(operation: TransformOperation) -> TransformOperation {
operation
.id("createPersonalSession")
.summary("Create a new personal session with personal access token")
.tag("personal-session")
.response_with::<201, Json<SingleResponse<PersonalSession>>, _>(|t| {
t.description("Personal session and personal access token were created")
})
.response_with::<400, RouteError, _>(|t| {
let response = ErrorResponse::from_error(&RouteError::InvalidScope);
t.description("Invalid scope provided").example(response)
})
.response_with::<404, RouteError, _>(|t| {
let response = ErrorResponse::from_error(&RouteError::UserNotFound);
t.description("User was not found").example(response)
})
}
#[tracing::instrument(name = "handler.admin.v1.personal_sessions.add", skip_all)]
pub async fn handler(
CallContext {
mut repo,
clock,
session,
..
}: CallContext,
NoApi(mut rng): NoApi<BoxRng>,
NoApi(State(homeserver)): NoApi<State<Arc<dyn HomeserverConnection>>>,
Json(params): Json<Request>,
) -> Result<(StatusCode, Json<SingleResponse<PersonalSession>>), RouteError> {
let owner = personal_session_owner_from_caller(&session);
let actor_user = repo
.user()
.lookup(params.actor_user_id)
.await?
.ok_or(RouteError::UserNotFound)?;
if !actor_user.is_valid_actor() {
return Err(RouteError::UserDeactivated);
}
let scope: Scope = params.scope.parse().map_err(|_| RouteError::InvalidScope)?;
// Create the personal session
let session = repo
.personal_session()
.add(
&mut rng,
&clock,
owner,
&actor_user,
params.human_name,
scope,
)
.await?;
// Create the initial token for the session
let access_token_string = TokenType::PersonalAccessToken.generate(&mut rng);
let access_token = repo
.personal_access_token()
.add(
&mut rng,
&clock,
&session,
&access_token_string,
params
.expires_in
.map(|exp_in| Duration::seconds(i64::from(exp_in))),
)
.await?;
// If the session has a device, we should add those to the homeserver now
if session.has_device() {
// Lock the user sync to make sure we don't get into a race condition
repo.user().acquire_lock_for_sync(&actor_user).await?;
for scope in &*session.scope {
if let Some(device) = Device::from_scope_token(scope) {
// NOTE: We haven't relinquished the repo at this point,
// so we are holding a transaction across the homeserver
// operation.
// This is suboptimal, but simpler.
// Given this is an administrative endpoint, this is a tolerable
// compromise for now.
homeserver
.upsert_device(&actor_user.username, device.as_str(), None)
.await
.context("Failed to provision device")
.map_err(|e| RouteError::Internal(e.into()))?;
}
}
}
repo.save().await?;
Ok((
StatusCode::CREATED,
Json(SingleResponse::new_canonical(
PersonalSession::try_from((session, Some(access_token)))?
.with_token(access_token_string),
)),
))
}
#[cfg(test)]
mod tests {
use hyper::{Request, StatusCode};
use insta::assert_json_snapshot;
use serde_json::Value;
use sqlx::PgPool;
use crate::test_utils::{RequestBuilderExt, ResponseExt, TestState, setup};
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_create_personal_session_with_token(pool: PgPool) {
setup();
let mut state = TestState::from_pool(pool).await.unwrap();
let token = state.token_with_scope("urn:mas:admin").await;
// Create a user for testing
let mut repo = state.repository().await.unwrap();
let mut rng = state.rng();
let user = repo
.user()
.add(&mut rng, &state.clock, "alice".to_owned())
.await
.unwrap();
repo.save().await.unwrap();
let request_body = serde_json::json!({
"actor_user_id": user.id,
"human_name": "Test Session",
"scope": "openid urn:mas:admin",
"expires_in": 3600
});
let request = Request::post("/api/admin/v1/personal-sessions")
.bearer(&token)
.json(&request_body);
let response = state.request(request).await;
response.assert_status(StatusCode::CREATED);
let body: Value = response.json();
assert_json_snapshot!(body, @r#"
{
"data": {
"type": "personal-session",
"id": "01FSHN9AG07HNEZXNQM2KNBNF6",
"attributes": {
"created_at": "2022-01-16T14:40:00Z",
"revoked_at": null,
"owner_user_id": null,
"owner_client_id": "01FSHN9AG0FAQ50MT1E9FFRPZR",
"actor_user_id": "01FSHN9AG0MZAA6S4AF7CTV32E",
"human_name": "Test Session",
"scope": "openid urn:mas:admin",
"last_active_at": null,
"last_active_ip": null,
"expires_at": "2022-01-16T15:40:00Z",
"access_token": "mpt_FM44zJN5qePGMLvvMXC4Ds1A3lCWc6_bJ9Wj1"
},
"links": {
"self": "/api/admin/v1/personal-sessions/01FSHN9AG07HNEZXNQM2KNBNF6"
}
},
"links": {
"self": "/api/admin/v1/personal-sessions/01FSHN9AG07HNEZXNQM2KNBNF6"
}
}
"#);
}
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_create_personal_session_invalid_user(pool: PgPool) {
setup();
let mut state = TestState::from_pool(pool).await.unwrap();
let token = state.token_with_scope("urn:mas:admin").await;
let request_body = serde_json::json!({
"actor_user_id": "01FSHN9AG0MZAA6S4AF7CTV32E",
"scope": "openid",
"human_name": "Test Session",
"expires_in": 3600
});
let request = Request::post("/api/admin/v1/personal-sessions")
.bearer(&token)
.json(&request_body);
let response = state.request(request).await;
response.assert_status(StatusCode::NOT_FOUND);
}
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_create_personal_session_invalid_scope(pool: PgPool) {
setup();
let mut state = TestState::from_pool(pool).await.unwrap();
let token = state.token_with_scope("urn:mas:admin").await;
// Create a user for testing
let mut repo = state.repository().await.unwrap();
let mut rng = state.rng();
let user = repo
.user()
.add(&mut rng, &state.clock, "alice".to_owned())
.await
.unwrap();
repo.save().await.unwrap();
let request_body = serde_json::json!({
"actor_user_id": user.id,
"human_name": "Test Session",
"scope": "invalid\nscope",
"expires_in": 3600
});
let request = Request::post("/api/admin/v1/personal-sessions")
.bearer(&token)
.json(&request_body);
let response = state.request(request).await;
response.assert_status(StatusCode::BAD_REQUEST);
}
}

View File

@@ -0,0 +1,189 @@
// Copyright 2025 New Vector Ltd.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
use aide::{OperationIo, transform::TransformOperation};
use axum::{Json, response::IntoResponse};
use hyper::StatusCode;
use mas_axum_utils::record_error;
use crate::{
admin::{
call_context::CallContext,
model::{InconsistentPersonalSession, PersonalSession},
params::UlidPathParam,
response::{ErrorResponse, SingleResponse},
},
impl_from_error_for_route,
};
#[derive(Debug, thiserror::Error, OperationIo)]
#[aide(output_with = "Json<ErrorResponse>")]
pub enum RouteError {
#[error(transparent)]
Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("Personal session not found")]
NotFound,
}
impl_from_error_for_route!(mas_storage::RepositoryError);
impl_from_error_for_route!(InconsistentPersonalSession);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
let error = ErrorResponse::from_error(&self);
let sentry_event_id = record_error!(self, Self::Internal(_));
let status = match self {
Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
Self::NotFound => StatusCode::NOT_FOUND,
};
(status, sentry_event_id, Json(error)).into_response()
}
}
pub fn doc(operation: TransformOperation) -> TransformOperation {
operation
.id("getPersonalSession")
.summary("Get a personal session")
.tag("personal-session")
.response_with::<200, Json<SingleResponse<PersonalSession>>, _>(|t| {
let [sample, ..] = PersonalSession::samples();
let response = SingleResponse::new_canonical(sample);
t.description("Personal session details").example(response)
})
.response_with::<404, RouteError, _>(|t| {
let response = ErrorResponse::from_error(&RouteError::NotFound);
t.description("Personal session not found")
.example(response)
})
}
#[tracing::instrument(
name = "handler.admin.v1.personal_sessions.get",
skip_all,
fields(personal_session.id = %*id),
)]
pub async fn handler(
CallContext { mut repo, .. }: CallContext,
id: UlidPathParam,
) -> Result<Json<SingleResponse<PersonalSession>>, RouteError> {
let session_id = *id;
let session = repo
.personal_session()
.lookup(session_id)
.await?
.ok_or(RouteError::NotFound)?;
let token = if session.is_revoked() {
None
} else {
repo.personal_access_token()
.find_active_for_session(&session)
.await?
};
Ok(Json(SingleResponse::new_canonical(
PersonalSession::try_from((session, token))?,
)))
}
#[cfg(test)]
mod tests {
use hyper::{Request, StatusCode};
use insta::assert_json_snapshot;
use mas_data_model::personal::session::PersonalSessionOwner;
use oauth2_types::scope::{OPENID, Scope};
use sqlx::PgPool;
use ulid::Ulid;
use crate::test_utils::{RequestBuilderExt, ResponseExt, TestState, setup};
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_get(pool: PgPool) {
setup();
let mut state = TestState::from_pool(pool).await.unwrap();
let token = state.token_with_scope("urn:mas:admin").await;
// Create a user and personal session for testing
let mut repo = state.repository().await.unwrap();
let mut rng = state.rng();
let user = repo
.user()
.add(&mut rng, &state.clock, "alice".to_owned())
.await
.unwrap();
let personal_session = repo
.personal_session()
.add(
&mut rng,
&state.clock,
PersonalSessionOwner::from(&user),
&user,
"Test session".to_owned(),
Scope::from_iter([OPENID]),
)
.await
.unwrap();
repo.personal_access_token()
.add(&mut rng, &state.clock, &personal_session, "mpt_hiss", None)
.await
.unwrap();
repo.save().await.unwrap();
let request = Request::get(format!(
"/api/admin/v1/personal-sessions/{}",
personal_session.id
))
.bearer(&token)
.empty();
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let body: serde_json::Value = response.json();
assert_eq!(body["data"]["id"], personal_session.id.to_string());
assert_json_snapshot!(body, @r#"
{
"data": {
"type": "personal-session",
"id": "01FSHN9AG0AJ6AC5HQ9X6H4RP4",
"attributes": {
"created_at": "2022-01-16T14:40:00Z",
"revoked_at": null,
"owner_user_id": "01FSHN9AG0MZAA6S4AF7CTV32E",
"owner_client_id": null,
"actor_user_id": "01FSHN9AG0MZAA6S4AF7CTV32E",
"human_name": "Test session",
"scope": "openid",
"last_active_at": null,
"last_active_ip": null,
"expires_at": null
},
"links": {
"self": "/api/admin/v1/personal-sessions/01FSHN9AG0AJ6AC5HQ9X6H4RP4"
}
},
"links": {
"self": "/api/admin/v1/personal-sessions/01FSHN9AG0AJ6AC5HQ9X6H4RP4"
}
}
"#);
}
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_not_found(pool: PgPool) {
setup();
let mut state = TestState::from_pool(pool).await.unwrap();
let token = state.token_with_scope("urn:mas:admin").await;
let session_id = Ulid::nil();
let request = Request::get(format!("/api/admin/v1/personal-sessions/{session_id}"))
.bearer(&token)
.empty();
let response = state.request(request).await;
response.assert_status(StatusCode::NOT_FOUND);
}
}

View File

@@ -0,0 +1,585 @@
// Copyright 2025 New Vector Ltd.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
use std::str::FromStr as _;
use aide::{OperationIo, transform::TransformOperation};
use axum::{Json, response::IntoResponse};
use axum_extra::extract::{Query, QueryRejection};
use axum_macros::FromRequestParts;
use chrono::{DateTime, Utc};
use hyper::StatusCode;
use mas_axum_utils::record_error;
use mas_storage::personal::PersonalSessionFilter;
use oauth2_types::scope::{Scope, ScopeToken};
use schemars::JsonSchema;
use serde::Deserialize;
use ulid::Ulid;
use crate::{
admin::{
call_context::CallContext,
model::{InconsistentPersonalSession, PersonalSession, Resource},
params::{IncludeCount, Pagination},
response::{ErrorResponse, PaginatedResponse},
},
impl_from_error_for_route,
};
#[derive(Deserialize, JsonSchema, Clone, Copy)]
#[serde(rename_all = "snake_case")]
enum PersonalSessionStatus {
Active,
Revoked,
}
impl std::fmt::Display for PersonalSessionStatus {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
Self::Active => write!(f, "active"),
Self::Revoked => write!(f, "revoked"),
}
}
}
#[derive(FromRequestParts, Deserialize, JsonSchema, OperationIo)]
#[serde(rename = "PersonalSessionFilter")]
#[aide(input_with = "Query<FilterParams>")]
#[from_request(via(Query), rejection(RouteError))]
pub struct FilterParams {
/// Filter by owner user ID
#[serde(rename = "filter[owner_user]")]
#[schemars(with = "Option<crate::admin::schema::Ulid>")]
owner_user: Option<Ulid>,
/// Filter by owner `OAuth2` client ID
#[serde(rename = "filter[owner_client]")]
#[schemars(with = "Option<crate::admin::schema::Ulid>")]
owner_client: Option<Ulid>,
/// Filter by actor user ID
#[serde(rename = "filter[actor_user]")]
#[schemars(with = "Option<crate::admin::schema::Ulid>")]
actor_user: Option<Ulid>,
/// Retrieve the items with the given scope
#[serde(default, rename = "filter[scope]")]
scope: Vec<String>,
/// Filter by session status
#[serde(rename = "filter[status]")]
status: Option<PersonalSessionStatus>,
/// Filter by access token expiry date
#[serde(rename = "filter[expires_before]")]
expires_before: Option<DateTime<Utc>>,
/// Filter by access token expiry date
#[serde(rename = "filter[expires_after]")]
expires_after: Option<DateTime<Utc>>,
/// Filter by whether the access token has an expiry time
#[serde(rename = "filter[expires]")]
expires: Option<bool>,
}
impl std::fmt::Display for FilterParams {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut sep = '?';
if let Some(owner_user) = self.owner_user {
write!(f, "{sep}filter[owner_user]={owner_user}")?;
sep = '&';
}
if let Some(owner_client) = self.owner_client {
write!(f, "{sep}filter[owner_client]={owner_client}")?;
sep = '&';
}
if let Some(actor_user) = self.actor_user {
write!(f, "{sep}filter[actor_user]={actor_user}")?;
sep = '&';
}
for scope in &self.scope {
write!(f, "{sep}filter[scope]={scope}")?;
sep = '&';
}
if let Some(status) = self.status {
write!(f, "{sep}filter[status]={status}")?;
sep = '&';
}
if let Some(expires_before) = self.expires_before {
write!(
f,
"{sep}filter[expires_before]={}",
expires_before.format("%Y-%m-%dT%H:%M:%SZ")
)?;
sep = '&';
}
if let Some(expires_after) = self.expires_after {
write!(
f,
"{sep}filter[expires_after]={}",
expires_after.format("%Y-%m-%dT%H:%M:%SZ")
)?;
sep = '&';
}
if let Some(expires) = self.expires {
write!(f, "{sep}filter[expires]={expires}")?;
sep = '&';
}
let _ = sep;
Ok(())
}
}
#[derive(Debug, thiserror::Error, OperationIo)]
#[aide(output_with = "Json<ErrorResponse>")]
pub enum RouteError {
#[error(transparent)]
Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("User ID {0} not found")]
UserNotFound(Ulid),
#[error("Client ID {0} not found")]
ClientNotFound(Ulid),
#[error("Invalid filter parameters")]
InvalidFilter(#[from] QueryRejection),
#[error("Invalid scope {0:?} in filter parameters")]
InvalidScope(String),
}
impl_from_error_for_route!(mas_storage::RepositoryError);
impl_from_error_for_route!(InconsistentPersonalSession);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
let error = ErrorResponse::from_error(&self);
let sentry_event_id = record_error!(self, Self::Internal(_));
let status = match self {
Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
Self::UserNotFound(_) | Self::ClientNotFound(_) => StatusCode::NOT_FOUND,
Self::InvalidScope(_) | Self::InvalidFilter(_) => StatusCode::BAD_REQUEST,
};
(status, sentry_event_id, Json(error)).into_response()
}
}
pub fn doc(operation: TransformOperation) -> TransformOperation {
operation
.id("listPersonalSessions")
.summary("List personal sessions")
.description("Retrieve a list of personal sessions.
Note that by default, all sessions, including revoked ones are returned, with the oldest first.
Use the `filter[status]` parameter to filter the sessions by their status and `page[last]` parameter to retrieve the last N sessions.")
.tag("personal-session")
.response_with::<200, Json<PaginatedResponse<PersonalSession>>, _>(|t| {
let sessions = PersonalSession::samples();
let pagination = mas_storage::Pagination::first(sessions.len());
let page = mas_storage::Page {
edges: sessions
.into_iter()
.map(|node| mas_storage::pagination::Edge {
cursor: node.id(),
node,
})
.collect(),
has_next_page: true,
has_previous_page: false,
};
t.description("Paginated response of personal sessions")
.example(PaginatedResponse::for_page(
page,
pagination,
Some(3),
PersonalSession::PATH,
))
})
.response_with::<404, RouteError, _>(|t| {
let response = ErrorResponse::from_error(&RouteError::UserNotFound(Ulid::nil()));
t.description("User was not found").example(response)
})
.response_with::<404, RouteError, _>(|t| {
let response = ErrorResponse::from_error(&RouteError::ClientNotFound(Ulid::nil()));
t.description("Client was not found").example(response)
})
}
#[tracing::instrument(name = "handler.admin.v1.personal_sessions.list", skip_all)]
pub async fn handler(
CallContext { mut repo, .. }: CallContext,
Pagination(pagination, include_count): Pagination,
params: FilterParams,
) -> Result<Json<PaginatedResponse<PersonalSession>>, RouteError> {
let base = format!("{path}{params}", path = PersonalSession::PATH);
let base = include_count.add_to_base(&base);
let filter = PersonalSessionFilter::new();
let owner_user = if let Some(owner_user_id) = params.owner_user {
let owner_user = repo
.user()
.lookup(owner_user_id)
.await?
.ok_or(RouteError::UserNotFound(owner_user_id))?;
Some(owner_user)
} else {
None
};
let filter = match &owner_user {
Some(user) => filter.for_owner_user(user),
None => filter,
};
let owner_client = if let Some(owner_client_id) = params.owner_client {
let owner_client = repo
.oauth2_client()
.lookup(owner_client_id)
.await?
.ok_or(RouteError::ClientNotFound(owner_client_id))?;
Some(owner_client)
} else {
None
};
let filter = match &owner_client {
Some(client) => filter.for_owner_oauth2_client(client),
None => filter,
};
let actor_user = if let Some(actor_user_id) = params.actor_user {
let user = repo
.user()
.lookup(actor_user_id)
.await?
.ok_or(RouteError::UserNotFound(actor_user_id))?;
Some(user)
} else {
None
};
let filter = match &actor_user {
Some(user) => filter.for_actor_user(user),
None => filter,
};
let scope: Scope = params
.scope
.into_iter()
.map(|s| ScopeToken::from_str(&s).map_err(|_| RouteError::InvalidScope(s)))
.collect::<Result<_, _>>()?;
let filter = if scope.is_empty() {
filter
} else {
filter.with_scope(&scope)
};
let filter = match params.status {
Some(PersonalSessionStatus::Active) => filter.active_only(),
Some(PersonalSessionStatus::Revoked) => filter.finished_only(),
None => filter,
};
let filter = if let Some(expires_after) = params.expires_after {
filter.with_expires_after(expires_after)
} else {
filter
};
let filter = if let Some(expires_before) = params.expires_before {
filter.with_expires_before(expires_before)
} else {
filter
};
let filter = if let Some(expires) = params.expires {
filter.with_expires(expires)
} else {
filter
};
let response = match include_count {
IncludeCount::True => {
let page = repo.personal_session().list(filter, pagination).await?;
let count = repo.personal_session().count(filter).await?;
PaginatedResponse::for_page(
page.try_map(PersonalSession::try_from)?,
pagination,
Some(count),
&base,
)
}
IncludeCount::False => {
let page = repo.personal_session().list(filter, pagination).await?;
PaginatedResponse::for_page(
page.try_map(PersonalSession::try_from)?,
pagination,
None,
&base,
)
}
IncludeCount::Only => {
let count = repo.personal_session().count(filter).await?;
PaginatedResponse::for_count_only(count, &base)
}
};
Ok(Json(response))
}
#[cfg(test)]
mod tests {
use std::collections::BTreeSet;
use chrono::Duration;
use hyper::{Request, StatusCode};
use insta::assert_json_snapshot;
use mas_data_model::personal::session::PersonalSessionOwner;
use oauth2_types::scope::{OPENID, Scope};
use sqlx::PgPool;
use crate::test_utils::{RequestBuilderExt, ResponseExt, TestState, setup};
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_list(pool: PgPool) {
setup();
let mut state = TestState::from_pool(pool).await.unwrap();
// Create a user and personal session for testing
let mut repo = state.repository().await.unwrap();
let mut rng = state.rng();
let user = repo
.user()
.add(&mut rng, &state.clock, "alice".to_owned())
.await
.unwrap();
let personal_session = repo
.personal_session()
.add(
&mut rng,
&state.clock,
PersonalSessionOwner::from(&user),
&user,
"Test session".to_owned(),
Scope::from_iter([OPENID]),
)
.await
.unwrap();
repo.personal_access_token()
.add(
&mut rng,
&state.clock,
&personal_session,
"mpt_hiss",
Some(Duration::days(42)),
)
.await
.unwrap();
state.clock.advance(Duration::days(1));
let personal_session = repo
.personal_session()
.add(
&mut rng,
&state.clock,
PersonalSessionOwner::from(&user),
&user,
"Another test session".to_owned(),
Scope::from_iter([OPENID]),
)
.await
.unwrap();
repo.personal_access_token()
.add(
&mut rng,
&state.clock,
&personal_session,
"mpt_scratch",
Some(Duration::days(21)),
)
.await
.unwrap();
repo.personal_session()
.revoke(&state.clock, personal_session)
.await
.unwrap();
state.clock.advance(Duration::days(1));
let personal_session = repo
.personal_session()
.add(
&mut rng,
&state.clock,
PersonalSessionOwner::from(&user),
&user,
"Another test session".to_owned(),
Scope::from_iter([OPENID, "urn:mas:admin".parse().unwrap()]),
)
.await
.unwrap();
repo.personal_access_token()
.add(
&mut rng,
&state.clock,
&personal_session,
"mpt_meow",
Some(Duration::days(14)),
)
.await
.unwrap();
repo.save().await.unwrap();
let token = state.token_with_scope("urn:mas:admin").await;
let request = Request::get("/api/admin/v1/personal-sessions")
.bearer(&token)
.empty();
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let body: serde_json::Value = response.json();
assert_json_snapshot!(body, @r#"
{
"meta": {
"count": 3
},
"data": [
{
"type": "personal-session",
"id": "01FSHN9AG0YQYAR04VCYTHJ8SK",
"attributes": {
"created_at": "2022-01-16T14:40:00Z",
"revoked_at": null,
"owner_user_id": "01FSHN9AG09FE39KETP6F390F8",
"owner_client_id": null,
"actor_user_id": "01FSHN9AG09FE39KETP6F390F8",
"human_name": "Test session",
"scope": "openid",
"last_active_at": null,
"last_active_ip": null,
"expires_at": "2022-02-27T14:40:00Z"
},
"links": {
"self": "/api/admin/v1/personal-sessions/01FSHN9AG0YQYAR04VCYTHJ8SK"
},
"meta": {
"page": {
"cursor": "01FSHN9AG0YQYAR04VCYTHJ8SK"
}
}
},
{
"type": "personal-session",
"id": "01FSM7P1G0VBGAMK9D9QMGQ5MY",
"attributes": {
"created_at": "2022-01-17T14:40:00Z",
"revoked_at": "2022-01-17T14:40:00Z",
"owner_user_id": "01FSHN9AG09FE39KETP6F390F8",
"owner_client_id": null,
"actor_user_id": "01FSHN9AG09FE39KETP6F390F8",
"human_name": "Another test session",
"scope": "openid",
"last_active_at": null,
"last_active_ip": null,
"expires_at": null
},
"links": {
"self": "/api/admin/v1/personal-sessions/01FSM7P1G0VBGAMK9D9QMGQ5MY"
},
"meta": {
"page": {
"cursor": "01FSM7P1G0VBGAMK9D9QMGQ5MY"
}
}
},
{
"type": "personal-session",
"id": "01FSPT2RG08Y11Y5BM4VZ4CN8K",
"attributes": {
"created_at": "2022-01-18T14:40:00Z",
"revoked_at": null,
"owner_user_id": "01FSHN9AG09FE39KETP6F390F8",
"owner_client_id": null,
"actor_user_id": "01FSHN9AG09FE39KETP6F390F8",
"human_name": "Another test session",
"scope": "openid urn:mas:admin",
"last_active_at": null,
"last_active_ip": null,
"expires_at": "2022-02-01T14:40:00Z"
},
"links": {
"self": "/api/admin/v1/personal-sessions/01FSPT2RG08Y11Y5BM4VZ4CN8K"
},
"meta": {
"page": {
"cursor": "01FSPT2RG08Y11Y5BM4VZ4CN8K"
}
}
}
],
"links": {
"self": "/api/admin/v1/personal-sessions?page[first]=10",
"first": "/api/admin/v1/personal-sessions?page[first]=10",
"last": "/api/admin/v1/personal-sessions?page[last]=10"
}
}
"#);
// Map of filters to their expected set of returned ULIDs
let filters_and_expected: &[(&str, &[&str])] = &[
(
"filter[expires_before]=2022-02-15T00:00:00Z",
&["01FSPT2RG08Y11Y5BM4VZ4CN8K"],
),
(
"filter[expires_after]=2022-02-15T00:00:00Z",
&["01FSHN9AG0YQYAR04VCYTHJ8SK"],
),
(
"filter[status]=active",
&["01FSHN9AG0YQYAR04VCYTHJ8SK", "01FSPT2RG08Y11Y5BM4VZ4CN8K"],
),
("filter[status]=revoked", &["01FSM7P1G0VBGAMK9D9QMGQ5MY"]),
(
"filter[expires]=true",
&["01FSHN9AG0YQYAR04VCYTHJ8SK", "01FSPT2RG08Y11Y5BM4VZ4CN8K"],
),
("filter[expires]=false", &["01FSM7P1G0VBGAMK9D9QMGQ5MY"]),
(
"filter[scope]=urn:mas:admin",
&["01FSPT2RG08Y11Y5BM4VZ4CN8K"],
),
];
for (filter, expected_ids) in filters_and_expected {
let request = Request::get(format!("/api/admin/v1/personal-sessions?{filter}"))
.bearer(&token)
.empty();
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let body: serde_json::Value = response.json();
let found: BTreeSet<&str> = body["data"]
.as_array()
.unwrap()
.iter()
.map(|item| item["id"].as_str().unwrap())
.collect();
let expected: BTreeSet<&str> = expected_ids.iter().copied().collect();
assert_eq!(
found, expected,
"filter {filter} did not produce expected results"
);
}
}
}

View File

@@ -0,0 +1,39 @@
// Copyright 2025 New Vector Ltd.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
mod add;
mod get;
mod list;
mod regenerate;
mod revoke;
use mas_data_model::personal::session::PersonalSessionOwner;
pub use self::{
add::{doc as add_doc, handler as add},
get::{doc as get_doc, handler as get},
list::{doc as list_doc, handler as list},
regenerate::{doc as regenerate_doc, handler as regenerate},
revoke::{doc as revoke_doc, handler as revoke},
};
use crate::admin::call_context::CallerSession;
/// Given the [`CallerSession`] of a caller of the Admin API,
/// return the [`PersonalSessionOwner`] that should own created personal
/// sessions.
fn personal_session_owner_from_caller(caller: &CallerSession) -> PersonalSessionOwner {
match caller {
CallerSession::OAuth2Session(session) => {
if let Some(user_id) = session.user_id {
PersonalSessionOwner::User(user_id)
} else {
PersonalSessionOwner::OAuth2Client(session.client_id)
}
}
CallerSession::PersonalSession(session) => {
PersonalSessionOwner::User(session.actor_user_id)
}
}
}

View File

@@ -0,0 +1,246 @@
// Copyright 2025 New Vector Ltd.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
use aide::{NoApi, OperationIo, transform::TransformOperation};
use axum::{Json, response::IntoResponse};
use chrono::Duration;
use hyper::StatusCode;
use mas_axum_utils::record_error;
use mas_data_model::{BoxRng, TokenType};
use schemars::JsonSchema;
use serde::Deserialize;
use tracing::error;
use crate::{
admin::{
call_context::CallContext,
model::{InconsistentPersonalSession, PersonalSession},
params::UlidPathParam,
response::{ErrorResponse, SingleResponse},
v1::personal_sessions::personal_session_owner_from_caller,
},
impl_from_error_for_route,
};
#[derive(Debug, thiserror::Error, OperationIo)]
#[aide(output_with = "Json<ErrorResponse>")]
pub enum RouteError {
#[error(transparent)]
Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("User not found")]
UserNotFound,
#[error("Session not found")]
SessionNotFound,
#[error("Session not valid")]
SessionNotValid,
#[error("Session does not belong to you")]
SessionNotYours,
}
impl_from_error_for_route!(mas_storage::RepositoryError);
impl_from_error_for_route!(InconsistentPersonalSession);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
let error = ErrorResponse::from_error(&self);
let sentry_event_id = record_error!(self, Self::Internal(_));
let status = match self {
Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
Self::UserNotFound | Self::SessionNotFound => StatusCode::NOT_FOUND,
Self::SessionNotValid => StatusCode::UNPROCESSABLE_ENTITY,
Self::SessionNotYours => StatusCode::FORBIDDEN,
};
(status, sentry_event_id, Json(error)).into_response()
}
}
/// # JSON payload for the `POST /api/admin/v1/personal-sessions/{id}/regenerate` endpoint
#[derive(Deserialize, JsonSchema)]
#[serde(rename = "RegeneratePersonalSessionRequest")]
pub struct Request {
/// Token expiry time in seconds.
/// If not set, the token won't expire.
expires_in: Option<u32>,
}
pub fn doc(operation: TransformOperation) -> TransformOperation {
operation
.id("regeneratePersonalSession")
.summary("Regenerate a personal session by replacing its personal access token")
.tag("personal-session")
.response_with::<201, Json<SingleResponse<PersonalSession>>, _>(|t| {
t.description(
"Personal session was regenerated and a personal access token was created",
)
})
.response_with::<404, RouteError, _>(|t| {
let response = ErrorResponse::from_error(&RouteError::UserNotFound);
t.description("User was not found").example(response)
})
}
#[tracing::instrument(name = "handler.admin.v1.personal_sessions.add", skip_all)]
pub async fn handler(
CallContext {
mut repo,
clock,
session: caller_session,
..
}: CallContext,
NoApi(mut rng): NoApi<BoxRng>,
id: UlidPathParam,
Json(params): Json<Request>,
) -> Result<(StatusCode, Json<SingleResponse<PersonalSession>>), RouteError> {
let session_id = *id;
let session = repo
.personal_session()
.lookup(session_id)
.await?
.ok_or(RouteError::SessionNotFound)?;
if !session.is_valid() {
// We don't revive revoked sessions through regeneration
return Err(RouteError::SessionNotValid);
}
// If the owner is not the current caller, then currently we reject the
// regeneration.
let caller = personal_session_owner_from_caller(&caller_session);
if session.owner != caller {
return Err(RouteError::SessionNotYours);
}
// Revoke the existing active token for the session.
let old_token_opt = repo
.personal_access_token()
.find_active_for_session(&session)
.await?;
let Some(old_token) = old_token_opt else {
// This shouldn't happen
error!("session is supposedly valid but had no access token");
return Err(RouteError::SessionNotValid);
};
repo.personal_access_token()
.revoke(&clock, old_token)
.await?;
// Create the regenerated token for the session
let access_token_string = TokenType::PersonalAccessToken.generate(&mut rng);
let access_token = repo
.personal_access_token()
.add(
&mut rng,
&clock,
&session,
&access_token_string,
params
.expires_in
.map(|exp_in| Duration::seconds(i64::from(exp_in))),
)
.await?;
repo.save().await?;
Ok((
StatusCode::CREATED,
Json(SingleResponse::new_canonical(
PersonalSession::try_from((session, Some(access_token)))?
.with_token(access_token_string),
)),
))
}
#[cfg(test)]
mod tests {
use chrono::Duration;
use hyper::{Request, StatusCode};
use insta::assert_json_snapshot;
use serde_json::{Value, json};
use sqlx::PgPool;
use crate::test_utils::{RequestBuilderExt, ResponseExt, TestState, setup};
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_regenerate_personal_session(pool: PgPool) {
setup();
let mut state = TestState::from_pool(pool).await.unwrap();
let token = state.token_with_scope("urn:mas:admin").await;
// Create a user for testing
let mut repo = state.repository().await.unwrap();
let mut rng = state.rng();
let user = repo
.user()
.add(&mut rng, &state.clock, "alice".to_owned())
.await
.unwrap();
repo.save().await.unwrap();
let request = Request::post("/api/admin/v1/personal-sessions")
.bearer(&token)
.json(json!({
"actor_user_id": user.id,
"human_name": "SuperDuperAdminCLITool Token",
"scope": "openid urn:mas:admin",
"expires_in": 3600
}));
let response = state.request(request).await;
response.assert_status(StatusCode::CREATED);
let created: Value = response.json();
let session_id = created["data"]["id"].as_str().unwrap();
state.clock.advance(Duration::minutes(3));
let request = Request::post(format!(
"/api/admin/v1/personal-sessions/{session_id}/regenerate"
))
.bearer(&token)
.json(json!({
"expires_in": 86400
}));
let response = state.request(request).await;
response.assert_status(StatusCode::CREATED);
let body: Value = response.json();
assert_json_snapshot!(body, @r#"
{
"data": {
"type": "personal-session",
"id": "01FSHN9AG07HNEZXNQM2KNBNF6",
"attributes": {
"created_at": "2022-01-16T14:40:00Z",
"revoked_at": null,
"owner_user_id": null,
"owner_client_id": "01FSHN9AG0FAQ50MT1E9FFRPZR",
"actor_user_id": "01FSHN9AG0MZAA6S4AF7CTV32E",
"human_name": "SuperDuperAdminCLITool Token",
"scope": "openid urn:mas:admin",
"last_active_at": null,
"last_active_ip": null,
"expires_at": "2022-01-17T14:43:00Z",
"access_token": "mpt_6cq7FqNSYoosbXl3bbpfh9yNy9NzuR_0vOV2O"
},
"links": {
"self": "/api/admin/v1/personal-sessions/01FSHN9AG07HNEZXNQM2KNBNF6"
}
},
"links": {
"self": "/api/admin/v1/personal-sessions/01FSHN9AG07HNEZXNQM2KNBNF6"
}
}
"#);
}
}

View File

@@ -0,0 +1,250 @@
// Copyright 2025 New Vector Ltd.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
use aide::{NoApi, OperationIo, transform::TransformOperation};
use axum::{Json, response::IntoResponse};
use hyper::StatusCode;
use mas_axum_utils::record_error;
use mas_data_model::BoxRng;
use mas_storage::queue::{QueueJobRepositoryExt as _, SyncDevicesJob};
use ulid::Ulid;
use crate::{
admin::{
call_context::CallContext,
model::{InconsistentPersonalSession, PersonalSession},
params::UlidPathParam,
response::{ErrorResponse, SingleResponse},
},
impl_from_error_for_route,
};
#[derive(Debug, thiserror::Error, OperationIo)]
#[aide(output_with = "Json<ErrorResponse>")]
pub enum RouteError {
#[error(transparent)]
Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("Personal session with ID {0} not found")]
NotFound(Ulid),
#[error("Personal session with ID {0} is already revoked")]
AlreadyRevoked(Ulid),
}
impl_from_error_for_route!(mas_storage::RepositoryError);
impl_from_error_for_route!(InconsistentPersonalSession);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
let error = ErrorResponse::from_error(&self);
let sentry_event_id = record_error!(self, Self::Internal(_));
let status = match self {
Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
Self::NotFound(_) => StatusCode::NOT_FOUND,
Self::AlreadyRevoked(_) => StatusCode::CONFLICT,
};
(status, sentry_event_id, Json(error)).into_response()
}
}
pub fn doc(operation: TransformOperation) -> TransformOperation {
operation
.id("revokePersonalSession")
.summary("Revoke a personal session")
.tag("personal-session")
.response_with::<200, Json<SingleResponse<PersonalSession>>, _>(|t| {
let [sample, ..] = PersonalSession::samples();
let response = SingleResponse::new_canonical(sample);
t.description("Personal session was revoked")
.example(response)
})
.response_with::<404, RouteError, _>(|t| {
let response = ErrorResponse::from_error(&RouteError::NotFound(Ulid::nil()));
t.description("Personal session not found")
.example(response)
})
.response_with::<409, RouteError, _>(|t| {
let response = ErrorResponse::from_error(&RouteError::AlreadyRevoked(Ulid::nil()));
t.description("Personal session already revoked")
.example(response)
})
}
#[tracing::instrument(
name = "handler.admin.v1.personal_sessions.revoke",
skip_all,
fields(personal_session.id = %*session_id),
)]
pub async fn handler(
CallContext {
mut repo, clock, ..
}: CallContext,
NoApi(mut rng): NoApi<BoxRng>,
session_id: UlidPathParam,
) -> Result<Json<SingleResponse<PersonalSession>>, RouteError> {
let session_id = *session_id;
let session = repo
.personal_session()
.lookup(session_id)
.await?
.ok_or(RouteError::NotFound(session_id))?;
if session.is_revoked() {
return Err(RouteError::AlreadyRevoked(session_id));
}
let session = repo.personal_session().revoke(&clock, session).await?;
if session.has_device() {
// If the session has a device, then we are now
// deleting a device and should schedule a device sync to clean up.
repo.queue_job()
.schedule_job(
&mut rng,
&clock,
SyncDevicesJob::new_for_id(session.actor_user_id),
)
.await?;
}
repo.save().await?;
Ok(Json(SingleResponse::new_canonical(
PersonalSession::try_from((session, None))?,
)))
}
#[cfg(test)]
mod tests {
use chrono::Duration;
use hyper::{Request, StatusCode};
use mas_data_model::{Clock, personal::session::PersonalSessionOwner};
use oauth2_types::scope::Scope;
use sqlx::PgPool;
use crate::test_utils::{RequestBuilderExt, ResponseExt, TestState, setup};
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_revoke_session(pool: PgPool) {
setup();
let mut state = TestState::from_pool(pool).await.unwrap();
let token = state.token_with_scope("urn:mas:admin").await;
// Create a user and personal session for testing
let mut repo = state.repository().await.unwrap();
let mut rng = state.rng();
let user = repo
.user()
.add(&mut rng, &state.clock, "alice".to_owned())
.await
.unwrap();
let personal_session = repo
.personal_session()
.add(
&mut rng,
&state.clock,
PersonalSessionOwner::from(&user),
&user,
"Test session".to_owned(),
Scope::from_iter([]),
)
.await
.unwrap();
repo.save().await.unwrap();
let request = Request::post(format!(
"/api/admin/v1/personal-sessions/{}/revoke",
personal_session.id
))
.bearer(&token)
.empty();
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let body: serde_json::Value = response.json();
// The revoked_at timestamp should be the same as the current time
assert_eq!(
body["data"]["attributes"]["revoked_at"],
serde_json::json!(Clock::now(&state.clock))
);
}
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_revoke_already_revoked_session(pool: PgPool) {
setup();
let mut state = TestState::from_pool(pool).await.unwrap();
let token = state.token_with_scope("urn:mas:admin").await;
// Create a user and personal session for testing
let mut repo = state.repository().await.unwrap();
let mut rng = state.rng();
let user = repo
.user()
.add(&mut rng, &state.clock, "alice".to_owned())
.await
.unwrap();
let personal_session = repo
.personal_session()
.add(
&mut rng,
&state.clock,
PersonalSessionOwner::from(&user),
&user,
"Test session".to_owned(),
Scope::from_iter([]),
)
.await
.unwrap();
// Revoke the session first
let session = repo
.personal_session()
.revoke(&state.clock, personal_session)
.await
.unwrap();
repo.save().await.unwrap();
// Move the clock forward
state.clock.advance(Duration::try_minutes(1).unwrap());
let request = Request::post(format!(
"/api/admin/v1/personal-sessions/{}/revoke",
session.id
))
.bearer(&token)
.empty();
let response = state.request(request).await;
response.assert_status(StatusCode::CONFLICT);
let body: serde_json::Value = response.json();
assert_eq!(
body["errors"][0]["title"],
format!("Personal session with ID {} is already revoked", session.id)
);
}
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_revoke_unknown_session(pool: PgPool) {
setup();
let mut state = TestState::from_pool(pool).await.unwrap();
let token = state.token_with_scope("urn:mas:admin").await;
let request =
Request::post("/api/admin/v1/personal-sessions/01040G2081040G2081040G2081/revoke")
.bearer(&token)
.empty();
let response = state.request(request).await;
response.assert_status(StatusCode::NOT_FOUND);
let body: serde_json::Value = response.json();
assert_eq!(
body["errors"][0]["title"],
"Personal session with ID 01040G2081040G2081040G2081 not found"
);
}
}

View File

@@ -59,7 +59,7 @@ fn data_example() -> serde_json::Value {
#[derive(Deserialize, JsonSchema)] #[derive(Deserialize, JsonSchema)]
#[serde(rename = "SetPolicyDataRequest")] #[serde(rename = "SetPolicyDataRequest")]
pub struct SetPolicyDataRequest { pub struct SetPolicyDataRequest {
#[schemars(example = "data_example")] #[schemars(example = data_example())]
pub data: serde_json::Value, pub data: serde_json::Value,
} }

View File

@@ -0,0 +1,97 @@
// Copyright 2025 New Vector Ltd.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
use aide::transform::TransformOperation;
use axum::{Json, extract::State};
use schemars::JsonSchema;
use serde::Serialize;
use crate::admin::call_context::CallContext;
#[allow(clippy::struct_excessive_bools)]
#[derive(Serialize, JsonSchema)]
pub struct SiteConfig {
/// The Matrix server name for which this instance is configured
server_name: String,
/// Whether password login is enabled.
pub password_login_enabled: bool,
/// Whether password registration is enabled.
pub password_registration_enabled: bool,
/// Whether a valid email address is required for password registrations.
pub password_registration_email_required: bool,
/// Whether registration tokens are required for password registrations.
pub registration_token_required: bool,
/// Whether users can change their email.
pub email_change_allowed: bool,
/// Whether users can change their display name.
pub displayname_change_allowed: bool,
/// Whether users can change their password.
pub password_change_allowed: bool,
/// Whether users can recover their account via email.
pub account_recovery_allowed: bool,
/// Whether users can delete their own account.
pub account_deactivation_allowed: bool,
/// Whether CAPTCHA during registration is enabled.
pub captcha_enabled: bool,
/// Minimum password complexity, between 0 and 4.
/// This is a score from zxcvbn.
#[schemars(range(min = 0, max = 4))]
pub minimum_password_complexity: u8,
}
pub fn doc(operation: TransformOperation) -> TransformOperation {
operation
.id("siteConfig")
.tag("server")
.summary("Get informations about the configuration of this MAS instance")
.response_with::<200, Json<SiteConfig>, _>(|t| {
t.example(SiteConfig {
server_name: "example.com".to_owned(),
password_login_enabled: true,
password_registration_enabled: true,
password_registration_email_required: true,
registration_token_required: true,
email_change_allowed: true,
displayname_change_allowed: true,
password_change_allowed: true,
account_recovery_allowed: true,
account_deactivation_allowed: true,
captcha_enabled: true,
minimum_password_complexity: 3,
})
})
}
#[tracing::instrument(name = "handler.admin.v1.site_config", skip_all)]
pub async fn handler(
_: CallContext,
State(site_config): State<mas_data_model::SiteConfig>,
) -> Json<SiteConfig> {
Json(SiteConfig {
server_name: site_config.server_name,
password_login_enabled: site_config.password_login_enabled,
password_registration_enabled: site_config.password_registration_enabled,
password_registration_email_required: site_config.password_registration_email_required,
registration_token_required: site_config.registration_token_required,
email_change_allowed: site_config.email_change_allowed,
displayname_change_allowed: site_config.displayname_change_allowed,
password_change_allowed: site_config.password_change_allowed,
account_recovery_allowed: site_config.account_recovery_allowed,
account_deactivation_allowed: site_config.account_deactivation_allowed,
captcha_enabled: site_config.captcha.is_some(),
minimum_password_complexity: site_config.minimum_password_complexity,
})
}

View File

@@ -4,11 +4,8 @@
// Please see LICENSE files in the repository root for full details. // Please see LICENSE files in the repository root for full details.
use aide::{OperationIo, transform::TransformOperation}; use aide::{OperationIo, transform::TransformOperation};
use axum::{ use axum::{Json, response::IntoResponse};
Json, use axum_extra::extract::{Query, QueryRejection};
extract::{Query, rejection::QueryRejection},
response::IntoResponse,
};
use axum_macros::FromRequestParts; use axum_macros::FromRequestParts;
use hyper::StatusCode; use hyper::StatusCode;
use mas_axum_utils::record_error; use mas_axum_utils::record_error;
@@ -21,7 +18,7 @@ use crate::{
admin::{ admin::{
call_context::CallContext, call_context::CallContext,
model::{Resource, UpstreamOAuthLink}, model::{Resource, UpstreamOAuthLink},
params::Pagination, params::{IncludeCount, Pagination},
response::{ErrorResponse, PaginatedResponse}, response::{ErrorResponse, PaginatedResponse},
}, },
impl_from_error_for_route, impl_from_error_for_route,
@@ -112,16 +109,22 @@ pub fn doc(operation: TransformOperation) -> TransformOperation {
let links = UpstreamOAuthLink::samples(); let links = UpstreamOAuthLink::samples();
let pagination = mas_storage::Pagination::first(links.len()); let pagination = mas_storage::Pagination::first(links.len());
let page = Page { let page = Page {
edges: links.into(), edges: links
.into_iter()
.map(|node| mas_storage::pagination::Edge {
cursor: node.id(),
node,
})
.collect(),
has_next_page: true, has_next_page: true,
has_previous_page: false, has_previous_page: false,
}; };
t.description("Paginated response of upstream OAuth 2.0 links") t.description("Paginated response of upstream OAuth 2.0 links")
.example(PaginatedResponse::new( .example(PaginatedResponse::for_page(
page, page,
pagination, pagination,
42, Some(42),
UpstreamOAuthLink::PATH, UpstreamOAuthLink::PATH,
)) ))
}) })
@@ -135,10 +138,11 @@ pub fn doc(operation: TransformOperation) -> TransformOperation {
#[tracing::instrument(name = "handler.admin.v1.upstream_oauth_links.list", skip_all)] #[tracing::instrument(name = "handler.admin.v1.upstream_oauth_links.list", skip_all)]
pub async fn handler( pub async fn handler(
CallContext { mut repo, .. }: CallContext, CallContext { mut repo, .. }: CallContext,
Pagination(pagination): Pagination, Pagination(pagination, include_count): Pagination,
params: FilterParams, params: FilterParams,
) -> Result<Json<PaginatedResponse<UpstreamOAuthLink>>, RouteError> { ) -> Result<Json<PaginatedResponse<UpstreamOAuthLink>>, RouteError> {
let base = format!("{path}{params}", path = UpstreamOAuthLink::PATH); let base = format!("{path}{params}", path = UpstreamOAuthLink::PATH);
let base = include_count.add_to_base(&base);
let filter = UpstreamOAuthLinkFilter::default(); let filter = UpstreamOAuthLinkFilter::default();
// Load the user from the filter // Load the user from the filter
@@ -183,15 +187,31 @@ pub async fn handler(
filter filter
}; };
let page = repo.upstream_oauth_link().list(filter, pagination).await?; let response = match include_count {
let count = repo.upstream_oauth_link().count(filter).await?; IncludeCount::True => {
let page = repo
.upstream_oauth_link()
.list(filter, pagination)
.await?
.map(UpstreamOAuthLink::from);
let count = repo.upstream_oauth_link().count(filter).await?;
PaginatedResponse::for_page(page, pagination, Some(count), &base)
}
IncludeCount::False => {
let page = repo
.upstream_oauth_link()
.list(filter, pagination)
.await?
.map(UpstreamOAuthLink::from);
PaginatedResponse::for_page(page, pagination, None, &base)
}
IncludeCount::Only => {
let count = repo.upstream_oauth_link().count(filter).await?;
PaginatedResponse::for_count_only(count, &base)
}
};
Ok(Json(PaginatedResponse::new( Ok(Json(response))
page.map(UpstreamOAuthLink::from),
pagination,
count,
&base,
)))
} }
#[cfg(test)] #[cfg(test)]
@@ -296,7 +316,7 @@ mod tests {
let response = state.request(request).await; let response = state.request(request).await;
response.assert_status(StatusCode::OK); response.assert_status(StatusCode::OK);
let body: serde_json::Value = response.json(); let body: serde_json::Value = response.json();
assert_json_snapshot!(body, @r###" assert_json_snapshot!(body, @r#"
{ {
"meta": { "meta": {
"count": 3 "count": 3
@@ -314,6 +334,11 @@ mod tests {
}, },
"links": { "links": {
"self": "/api/admin/v1/upstream-oauth-links/01FSHN9AG0AQZQP8DX40GD59PW" "self": "/api/admin/v1/upstream-oauth-links/01FSHN9AG0AQZQP8DX40GD59PW"
},
"meta": {
"page": {
"cursor": "01FSHN9AG0AQZQP8DX40GD59PW"
}
} }
}, },
{ {
@@ -328,6 +353,11 @@ mod tests {
}, },
"links": { "links": {
"self": "/api/admin/v1/upstream-oauth-links/01FSHN9AG0PJZ6DZNTAA1XKPT4" "self": "/api/admin/v1/upstream-oauth-links/01FSHN9AG0PJZ6DZNTAA1XKPT4"
},
"meta": {
"page": {
"cursor": "01FSHN9AG0PJZ6DZNTAA1XKPT4"
}
} }
}, },
{ {
@@ -342,6 +372,11 @@ mod tests {
}, },
"links": { "links": {
"self": "/api/admin/v1/upstream-oauth-links/01FSHN9AG0QHEHKX2JNQ2A2D07" "self": "/api/admin/v1/upstream-oauth-links/01FSHN9AG0QHEHKX2JNQ2A2D07"
},
"meta": {
"page": {
"cursor": "01FSHN9AG0QHEHKX2JNQ2A2D07"
}
} }
} }
], ],
@@ -351,7 +386,7 @@ mod tests {
"last": "/api/admin/v1/upstream-oauth-links?page[last]=10" "last": "/api/admin/v1/upstream-oauth-links?page[last]=10"
} }
} }
"###); "#);
// Filter by user ID // Filter by user ID
let request = Request::get(format!( let request = Request::get(format!(
@@ -364,7 +399,7 @@ mod tests {
let response = state.request(request).await; let response = state.request(request).await;
response.assert_status(StatusCode::OK); response.assert_status(StatusCode::OK);
let body: serde_json::Value = response.json(); let body: serde_json::Value = response.json();
assert_json_snapshot!(body, @r###" assert_json_snapshot!(body, @r#"
{ {
"meta": { "meta": {
"count": 2 "count": 2
@@ -382,6 +417,11 @@ mod tests {
}, },
"links": { "links": {
"self": "/api/admin/v1/upstream-oauth-links/01FSHN9AG0AQZQP8DX40GD59PW" "self": "/api/admin/v1/upstream-oauth-links/01FSHN9AG0AQZQP8DX40GD59PW"
},
"meta": {
"page": {
"cursor": "01FSHN9AG0AQZQP8DX40GD59PW"
}
} }
}, },
{ {
@@ -396,6 +436,11 @@ mod tests {
}, },
"links": { "links": {
"self": "/api/admin/v1/upstream-oauth-links/01FSHN9AG0QHEHKX2JNQ2A2D07" "self": "/api/admin/v1/upstream-oauth-links/01FSHN9AG0QHEHKX2JNQ2A2D07"
},
"meta": {
"page": {
"cursor": "01FSHN9AG0QHEHKX2JNQ2A2D07"
}
} }
} }
], ],
@@ -405,7 +450,7 @@ mod tests {
"last": "/api/admin/v1/upstream-oauth-links?filter[user]=01FSHN9AG0MZAA6S4AF7CTV32E&page[last]=10" "last": "/api/admin/v1/upstream-oauth-links?filter[user]=01FSHN9AG0MZAA6S4AF7CTV32E&page[last]=10"
} }
} }
"###); "#);
// Filter by provider // Filter by provider
let request = Request::get(format!( let request = Request::get(format!(
@@ -418,7 +463,7 @@ mod tests {
let response = state.request(request).await; let response = state.request(request).await;
response.assert_status(StatusCode::OK); response.assert_status(StatusCode::OK);
let body: serde_json::Value = response.json(); let body: serde_json::Value = response.json();
assert_json_snapshot!(body, @r###" assert_json_snapshot!(body, @r#"
{ {
"meta": { "meta": {
"count": 2 "count": 2
@@ -436,6 +481,11 @@ mod tests {
}, },
"links": { "links": {
"self": "/api/admin/v1/upstream-oauth-links/01FSHN9AG0AQZQP8DX40GD59PW" "self": "/api/admin/v1/upstream-oauth-links/01FSHN9AG0AQZQP8DX40GD59PW"
},
"meta": {
"page": {
"cursor": "01FSHN9AG0AQZQP8DX40GD59PW"
}
} }
}, },
{ {
@@ -450,6 +500,11 @@ mod tests {
}, },
"links": { "links": {
"self": "/api/admin/v1/upstream-oauth-links/01FSHN9AG0PJZ6DZNTAA1XKPT4" "self": "/api/admin/v1/upstream-oauth-links/01FSHN9AG0PJZ6DZNTAA1XKPT4"
},
"meta": {
"page": {
"cursor": "01FSHN9AG0PJZ6DZNTAA1XKPT4"
}
} }
} }
], ],
@@ -459,7 +514,7 @@ mod tests {
"last": "/api/admin/v1/upstream-oauth-links?filter[provider]=01FSHN9AG09NMZYX8MFYH578R9&page[last]=10" "last": "/api/admin/v1/upstream-oauth-links?filter[provider]=01FSHN9AG09NMZYX8MFYH578R9&page[last]=10"
} }
} }
"###); "#);
// Filter by subject // Filter by subject
let request = Request::get(format!( let request = Request::get(format!(
@@ -472,7 +527,7 @@ mod tests {
let response = state.request(request).await; let response = state.request(request).await;
response.assert_status(StatusCode::OK); response.assert_status(StatusCode::OK);
let body: serde_json::Value = response.json(); let body: serde_json::Value = response.json();
assert_json_snapshot!(body, @r###" assert_json_snapshot!(body, @r#"
{ {
"meta": { "meta": {
"count": 1 "count": 1
@@ -490,6 +545,11 @@ mod tests {
}, },
"links": { "links": {
"self": "/api/admin/v1/upstream-oauth-links/01FSHN9AG0AQZQP8DX40GD59PW" "self": "/api/admin/v1/upstream-oauth-links/01FSHN9AG0AQZQP8DX40GD59PW"
},
"meta": {
"page": {
"cursor": "01FSHN9AG0AQZQP8DX40GD59PW"
}
} }
} }
], ],
@@ -499,6 +559,181 @@ mod tests {
"last": "/api/admin/v1/upstream-oauth-links?filter[subject]=subject1&page[last]=10" "last": "/api/admin/v1/upstream-oauth-links?filter[subject]=subject1&page[last]=10"
} }
} }
"#);
// Test count=false
let request = Request::get("/api/admin/v1/upstream-oauth-links?count=false")
.bearer(&token)
.empty();
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let body: serde_json::Value = response.json();
assert_json_snapshot!(body, @r#"
{
"data": [
{
"type": "upstream-oauth-link",
"id": "01FSHN9AG0AQZQP8DX40GD59PW",
"attributes": {
"created_at": "2022-01-16T14:40:00Z",
"provider_id": "01FSHN9AG09NMZYX8MFYH578R9",
"subject": "subject1",
"user_id": "01FSHN9AG0MZAA6S4AF7CTV32E",
"human_account_name": "alice@acme"
},
"links": {
"self": "/api/admin/v1/upstream-oauth-links/01FSHN9AG0AQZQP8DX40GD59PW"
},
"meta": {
"page": {
"cursor": "01FSHN9AG0AQZQP8DX40GD59PW"
}
}
},
{
"type": "upstream-oauth-link",
"id": "01FSHN9AG0PJZ6DZNTAA1XKPT4",
"attributes": {
"created_at": "2022-01-16T14:40:00Z",
"provider_id": "01FSHN9AG09NMZYX8MFYH578R9",
"subject": "subject3",
"user_id": "01FSHN9AG0AJ6AC5HQ9X6H4RP4",
"human_account_name": "bob@acme"
},
"links": {
"self": "/api/admin/v1/upstream-oauth-links/01FSHN9AG0PJZ6DZNTAA1XKPT4"
},
"meta": {
"page": {
"cursor": "01FSHN9AG0PJZ6DZNTAA1XKPT4"
}
}
},
{
"type": "upstream-oauth-link",
"id": "01FSHN9AG0QHEHKX2JNQ2A2D07",
"attributes": {
"created_at": "2022-01-16T14:40:00Z",
"provider_id": "01FSHN9AG0KEPHYQQXW9XPTX6Z",
"subject": "subject2",
"user_id": "01FSHN9AG0MZAA6S4AF7CTV32E",
"human_account_name": "alice@example"
},
"links": {
"self": "/api/admin/v1/upstream-oauth-links/01FSHN9AG0QHEHKX2JNQ2A2D07"
},
"meta": {
"page": {
"cursor": "01FSHN9AG0QHEHKX2JNQ2A2D07"
}
}
}
],
"links": {
"self": "/api/admin/v1/upstream-oauth-links?count=false&page[first]=10",
"first": "/api/admin/v1/upstream-oauth-links?count=false&page[first]=10",
"last": "/api/admin/v1/upstream-oauth-links?count=false&page[last]=10"
}
}
"#);
// Test count=only
let request = Request::get("/api/admin/v1/upstream-oauth-links?count=only")
.bearer(&token)
.empty();
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let body: serde_json::Value = response.json();
assert_json_snapshot!(body, @r###"
{
"meta": {
"count": 3
},
"links": {
"self": "/api/admin/v1/upstream-oauth-links?count=only"
}
}
"###); "###);
// Test count=false with filtering
let request = Request::get(format!(
"/api/admin/v1/upstream-oauth-links?count=false&filter[user]={}",
alice.id
))
.bearer(&token)
.empty();
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let body: serde_json::Value = response.json();
assert_json_snapshot!(body, @r#"
{
"data": [
{
"type": "upstream-oauth-link",
"id": "01FSHN9AG0AQZQP8DX40GD59PW",
"attributes": {
"created_at": "2022-01-16T14:40:00Z",
"provider_id": "01FSHN9AG09NMZYX8MFYH578R9",
"subject": "subject1",
"user_id": "01FSHN9AG0MZAA6S4AF7CTV32E",
"human_account_name": "alice@acme"
},
"links": {
"self": "/api/admin/v1/upstream-oauth-links/01FSHN9AG0AQZQP8DX40GD59PW"
},
"meta": {
"page": {
"cursor": "01FSHN9AG0AQZQP8DX40GD59PW"
}
}
},
{
"type": "upstream-oauth-link",
"id": "01FSHN9AG0QHEHKX2JNQ2A2D07",
"attributes": {
"created_at": "2022-01-16T14:40:00Z",
"provider_id": "01FSHN9AG0KEPHYQQXW9XPTX6Z",
"subject": "subject2",
"user_id": "01FSHN9AG0MZAA6S4AF7CTV32E",
"human_account_name": "alice@example"
},
"links": {
"self": "/api/admin/v1/upstream-oauth-links/01FSHN9AG0QHEHKX2JNQ2A2D07"
},
"meta": {
"page": {
"cursor": "01FSHN9AG0QHEHKX2JNQ2A2D07"
}
}
}
],
"links": {
"self": "/api/admin/v1/upstream-oauth-links?filter[user]=01FSHN9AG0MZAA6S4AF7CTV32E&count=false&page[first]=10",
"first": "/api/admin/v1/upstream-oauth-links?filter[user]=01FSHN9AG0MZAA6S4AF7CTV32E&count=false&page[first]=10",
"last": "/api/admin/v1/upstream-oauth-links?filter[user]=01FSHN9AG0MZAA6S4AF7CTV32E&count=false&page[last]=10"
}
}
"#);
// Test count=only with filtering
let request = Request::get(format!(
"/api/admin/v1/upstream-oauth-links?count=only&filter[provider]={}",
provider1.id
))
.bearer(&token)
.empty();
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let body: serde_json::Value = response.json();
assert_json_snapshot!(body, @r#"
{
"meta": {
"count": 2
},
"links": {
"self": "/api/admin/v1/upstream-oauth-links?filter[provider]=01FSHN9AG09NMZYX8MFYH578R9&count=only"
}
}
"#);
} }
} }

View File

@@ -0,0 +1,196 @@
// Copyright 2025 New Vector Ltd.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
use aide::{OperationIo, transform::TransformOperation};
use axum::{Json, response::IntoResponse};
use hyper::StatusCode;
use mas_axum_utils::record_error;
use mas_storage::{RepositoryAccess, upstream_oauth2::UpstreamOAuthProviderRepository};
use crate::{
admin::{
call_context::CallContext,
model::UpstreamOAuthProvider,
params::UlidPathParam,
response::{ErrorResponse, SingleResponse},
},
impl_from_error_for_route,
};
#[derive(Debug, thiserror::Error, OperationIo)]
#[aide(output_with = "Json<ErrorResponse>")]
pub enum RouteError {
#[error(transparent)]
Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("Provider not found")]
NotFound,
}
impl_from_error_for_route!(mas_storage::RepositoryError);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
let error = ErrorResponse::from_error(&self);
let sentry_event_id = record_error!(self, Self::Internal(_));
let status = match self {
Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
Self::NotFound => StatusCode::NOT_FOUND,
};
(status, sentry_event_id, Json(error)).into_response()
}
}
pub fn doc(operation: TransformOperation) -> TransformOperation {
operation
.id("getUpstreamOAuthProvider")
.summary("Get upstream OAuth provider")
.tag("upstream-oauth-provider")
.response_with::<200, Json<SingleResponse<UpstreamOAuthProvider>>, _>(|t| {
let [sample, ..] = UpstreamOAuthProvider::samples();
t.description("The upstream OAuth provider")
.example(SingleResponse::new_canonical(sample))
})
.response_with::<404, Json<ErrorResponse>, _>(|t| t.description("Provider not found"))
}
#[tracing::instrument(name = "handler.admin.v1.upstream_oauth_providers.get", skip_all)]
pub async fn handler(
CallContext { mut repo, .. }: CallContext,
id: UlidPathParam,
) -> Result<Json<SingleResponse<UpstreamOAuthProvider>>, RouteError> {
let provider = repo
.upstream_oauth_provider()
.lookup(*id)
.await?
.ok_or(RouteError::NotFound)?;
Ok(Json(SingleResponse::new_canonical(
UpstreamOAuthProvider::from(provider),
)))
}
#[cfg(test)]
mod tests {
use hyper::{Request, StatusCode};
use mas_data_model::{
UpstreamOAuthProvider, UpstreamOAuthProviderClaimsImports,
UpstreamOAuthProviderDiscoveryMode, UpstreamOAuthProviderOnBackchannelLogout,
UpstreamOAuthProviderPkceMode, UpstreamOAuthProviderTokenAuthMethod,
};
use mas_iana::jose::JsonWebSignatureAlg;
use mas_storage::{
RepositoryAccess,
upstream_oauth2::{UpstreamOAuthProviderParams, UpstreamOAuthProviderRepository},
};
use oauth2_types::scope::{OPENID, Scope};
use sqlx::PgPool;
use ulid::Ulid;
use crate::test_utils::{RequestBuilderExt, ResponseExt, TestState, setup};
async fn create_test_provider(state: &mut TestState) -> UpstreamOAuthProvider {
let mut repo = state.repository().await.unwrap();
let params = UpstreamOAuthProviderParams {
issuer: Some("https://accounts.google.com".to_owned()),
human_name: Some("Google".to_owned()),
brand_name: Some("google".to_owned()),
discovery_mode: UpstreamOAuthProviderDiscoveryMode::Oidc,
pkce_mode: UpstreamOAuthProviderPkceMode::Auto,
jwks_uri_override: None,
authorization_endpoint_override: None,
token_endpoint_override: None,
userinfo_endpoint_override: None,
fetch_userinfo: true,
userinfo_signed_response_alg: None,
client_id: "google-client-id".to_owned(),
encrypted_client_secret: Some("encrypted-secret".to_owned()),
token_endpoint_signing_alg: None,
token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::ClientSecretPost,
id_token_signed_response_alg: JsonWebSignatureAlg::Rs256,
response_mode: None,
scope: Scope::from_iter([OPENID]),
claims_imports: UpstreamOAuthProviderClaimsImports::default(),
additional_authorization_parameters: vec![],
forward_login_hint: false,
on_backchannel_logout: UpstreamOAuthProviderOnBackchannelLogout::DoNothing,
ui_order: 0,
};
let provider = repo
.upstream_oauth_provider()
.add(&mut state.rng(), &state.clock, params)
.await
.unwrap();
Box::new(repo).save().await.unwrap();
provider
}
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_get_provider(pool: PgPool) {
setup();
let mut state = TestState::from_pool(pool).await.unwrap();
let admin_token = state.token_with_scope("urn:mas:admin").await;
let provider = create_test_provider(&mut state).await;
let request = Request::get(format!(
"/api/admin/v1/upstream-oauth-providers/{}",
provider.id
))
.bearer(&admin_token)
.empty();
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let body: serde_json::Value = response.json::<serde_json::Value>();
assert_eq!(body["data"]["type"], "upstream-oauth-provider");
assert_eq!(body["data"]["id"], provider.id.to_string());
assert_eq!(body["data"]["attributes"]["human_name"], "Google");
insta::assert_json_snapshot!(body, @r###"
{
"data": {
"type": "upstream-oauth-provider",
"id": "01FSHN9AG0MZAA6S4AF7CTV32E",
"attributes": {
"issuer": "https://accounts.google.com",
"human_name": "Google",
"brand_name": "google",
"created_at": "2022-01-16T14:40:00Z",
"disabled_at": null
},
"links": {
"self": "/api/admin/v1/upstream-oauth-providers/01FSHN9AG0MZAA6S4AF7CTV32E"
}
},
"links": {
"self": "/api/admin/v1/upstream-oauth-providers/01FSHN9AG0MZAA6S4AF7CTV32E"
}
}
"###);
}
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_not_found(pool: PgPool) {
setup();
let mut state = TestState::from_pool(pool).await.unwrap();
let admin_token = state.token_with_scope("urn:mas:admin").await;
let provider_id = Ulid::nil();
let request = Request::get(format!(
"/api/admin/v1/upstream-oauth-providers/{provider_id}"
))
.bearer(&admin_token)
.empty();
let response = state.request(request).await;
response.assert_status(StatusCode::NOT_FOUND);
}
}

View File

@@ -0,0 +1,799 @@
// Copyright 2025 New Vector Ltd.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
use aide::{OperationIo, transform::TransformOperation};
use axum::{Json, response::IntoResponse};
use axum_extra::extract::{Query, QueryRejection};
use axum_macros::FromRequestParts;
use hyper::StatusCode;
use mas_axum_utils::record_error;
use mas_storage::{Page, upstream_oauth2::UpstreamOAuthProviderFilter};
use schemars::JsonSchema;
use serde::Deserialize;
use crate::{
admin::{
call_context::CallContext,
model::{Resource, UpstreamOAuthProvider},
params::{IncludeCount, Pagination},
response::{ErrorResponse, PaginatedResponse},
},
impl_from_error_for_route,
};
#[derive(FromRequestParts, Deserialize, JsonSchema, OperationIo)]
#[serde(rename = "UpstreamOAuthProviderFilter")]
#[aide(input_with = "Query<FilterParams>")]
#[from_request(via(Query), rejection(RouteError))]
pub struct FilterParams {
/// Retrieve providers that are (or are not) enabled
#[serde(rename = "filter[enabled]")]
enabled: Option<bool>,
}
impl std::fmt::Display for FilterParams {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let mut sep = '?';
if let Some(enabled) = self.enabled {
write!(f, "{sep}filter[enabled]={enabled}")?;
sep = '&';
}
let _ = sep;
Ok(())
}
}
#[derive(Debug, thiserror::Error, OperationIo)]
#[aide(output_with = "Json<ErrorResponse>")]
pub enum RouteError {
#[error(transparent)]
Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("Invalid filter parameters")]
InvalidFilter(#[from] QueryRejection),
}
impl_from_error_for_route!(mas_storage::RepositoryError);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
let error = ErrorResponse::from_error(&self);
let sentry_event_id = record_error!(self, Self::Internal(_));
let status = match self {
Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
Self::InvalidFilter(_) => StatusCode::BAD_REQUEST,
};
(status, sentry_event_id, Json(error)).into_response()
}
}
pub fn doc(operation: TransformOperation) -> TransformOperation {
operation
.id("listUpstreamOAuthProviders")
.summary("List upstream OAuth 2.0 providers")
.tag("upstream-oauth-provider")
.response_with::<200, Json<PaginatedResponse<UpstreamOAuthProvider>>, _>(|t| {
let providers = UpstreamOAuthProvider::samples();
let pagination = mas_storage::Pagination::first(providers.len());
let page = Page {
edges: providers
.into_iter()
.map(|node| mas_storage::pagination::Edge {
cursor: node.id(),
node,
})
.collect(),
has_next_page: true,
has_previous_page: false,
};
t.description("Paginated response of upstream OAuth 2.0 providers")
.example(PaginatedResponse::for_page(
page,
pagination,
Some(42),
UpstreamOAuthProvider::PATH,
))
})
}
#[tracing::instrument(name = "handler.admin.v1.upstream_oauth_providers.list", skip_all)]
pub async fn handler(
CallContext { mut repo, .. }: CallContext,
Pagination(pagination, include_count): Pagination,
params: FilterParams,
) -> Result<Json<PaginatedResponse<UpstreamOAuthProvider>>, RouteError> {
let base = format!("{path}{params}", path = UpstreamOAuthProvider::PATH);
let base = include_count.add_to_base(&base);
let filter = UpstreamOAuthProviderFilter::new();
let filter = match params.enabled {
Some(true) => filter.enabled_only(),
Some(false) => filter.disabled_only(),
None => filter,
};
let response = match include_count {
IncludeCount::True => {
let page = repo
.upstream_oauth_provider()
.list(filter, pagination)
.await?
.map(UpstreamOAuthProvider::from);
let count = repo.upstream_oauth_provider().count(filter).await?;
PaginatedResponse::for_page(page, pagination, Some(count), &base)
}
IncludeCount::False => {
let page = repo
.upstream_oauth_provider()
.list(filter, pagination)
.await?
.map(UpstreamOAuthProvider::from);
PaginatedResponse::for_page(page, pagination, None, &base)
}
IncludeCount::Only => {
let count = repo.upstream_oauth_provider().count(filter).await?;
PaginatedResponse::for_count_only(count, &base)
}
};
Ok(Json(response))
}
#[cfg(test)]
mod tests {
use hyper::{Request, StatusCode};
use mas_data_model::{
UpstreamOAuthProviderClaimsImports, UpstreamOAuthProviderDiscoveryMode,
UpstreamOAuthProviderOnBackchannelLogout, UpstreamOAuthProviderPkceMode,
UpstreamOAuthProviderTokenAuthMethod,
};
use mas_iana::jose::JsonWebSignatureAlg;
use mas_storage::{
RepositoryAccess,
upstream_oauth2::{UpstreamOAuthProviderParams, UpstreamOAuthProviderRepository},
};
use oauth2_types::scope::{OPENID, Scope};
use sqlx::PgPool;
use crate::test_utils::{RequestBuilderExt, ResponseExt, TestState, setup};
async fn create_test_providers(state: &mut TestState) {
let mut repo = state.repository().await.unwrap();
// Create an enabled provider
let enabled_params = UpstreamOAuthProviderParams {
issuer: Some("https://accounts.google.com".to_owned()),
human_name: Some("Google".to_owned()),
brand_name: Some("google".to_owned()),
discovery_mode: UpstreamOAuthProviderDiscoveryMode::Oidc,
pkce_mode: UpstreamOAuthProviderPkceMode::Auto,
jwks_uri_override: None,
authorization_endpoint_override: None,
token_endpoint_override: None,
userinfo_endpoint_override: None,
fetch_userinfo: true,
userinfo_signed_response_alg: None,
client_id: "google-client-id".to_owned(),
encrypted_client_secret: Some("encrypted-secret".to_owned()),
token_endpoint_signing_alg: None,
token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::ClientSecretPost,
id_token_signed_response_alg: JsonWebSignatureAlg::Rs256,
response_mode: None,
scope: Scope::from_iter([OPENID]),
claims_imports: UpstreamOAuthProviderClaimsImports::default(),
additional_authorization_parameters: vec![],
forward_login_hint: false,
on_backchannel_logout: UpstreamOAuthProviderOnBackchannelLogout::DoNothing,
ui_order: 0,
};
repo.upstream_oauth_provider()
.add(&mut state.rng(), &state.clock, enabled_params)
.await
.unwrap();
// Create a disabled provider
let disabled_params = UpstreamOAuthProviderParams {
issuer: Some("https://appleid.apple.com".to_owned()),
human_name: Some("Apple ID".to_owned()),
brand_name: Some("apple".to_owned()),
discovery_mode: UpstreamOAuthProviderDiscoveryMode::Oidc,
pkce_mode: UpstreamOAuthProviderPkceMode::S256,
jwks_uri_override: None,
authorization_endpoint_override: None,
token_endpoint_override: None,
userinfo_endpoint_override: None,
fetch_userinfo: true,
userinfo_signed_response_alg: None,
client_id: "apple-client-id".to_owned(),
encrypted_client_secret: Some("encrypted-secret".to_owned()),
token_endpoint_signing_alg: None,
token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::ClientSecretPost,
id_token_signed_response_alg: JsonWebSignatureAlg::Rs256,
response_mode: None,
scope: Scope::from_iter([OPENID]),
claims_imports: UpstreamOAuthProviderClaimsImports::default(),
additional_authorization_parameters: vec![],
forward_login_hint: false,
on_backchannel_logout: UpstreamOAuthProviderOnBackchannelLogout::DoNothing,
ui_order: 1,
};
let disabled_provider = repo
.upstream_oauth_provider()
.add(&mut state.rng(), &state.clock, disabled_params)
.await
.unwrap();
// Disable the provider
repo.upstream_oauth_provider()
.disable(&state.clock, disabled_provider)
.await
.unwrap();
// Create another enabled provider
let another_enabled_params = UpstreamOAuthProviderParams {
issuer: Some("https://login.microsoftonline.com/common/v2.0".to_owned()),
human_name: Some("Microsoft".to_owned()),
brand_name: Some("microsoft".to_owned()),
discovery_mode: UpstreamOAuthProviderDiscoveryMode::Oidc,
pkce_mode: UpstreamOAuthProviderPkceMode::Auto,
jwks_uri_override: None,
authorization_endpoint_override: None,
token_endpoint_override: None,
userinfo_endpoint_override: None,
fetch_userinfo: true,
userinfo_signed_response_alg: None,
client_id: "microsoft-client-id".to_owned(),
encrypted_client_secret: Some("encrypted-secret".to_owned()),
token_endpoint_signing_alg: None,
token_endpoint_auth_method: UpstreamOAuthProviderTokenAuthMethod::ClientSecretPost,
id_token_signed_response_alg: JsonWebSignatureAlg::Rs256,
response_mode: None,
scope: Scope::from_iter([OPENID]),
claims_imports: UpstreamOAuthProviderClaimsImports::default(),
additional_authorization_parameters: vec![],
forward_login_hint: false,
on_backchannel_logout: UpstreamOAuthProviderOnBackchannelLogout::DoNothing,
ui_order: 2,
};
repo.upstream_oauth_provider()
.add(&mut state.rng(), &state.clock, another_enabled_params)
.await
.unwrap();
Box::new(repo).save().await.unwrap();
}
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_list_all_providers(pool: PgPool) {
setup();
let mut state = TestState::from_pool(pool).await.unwrap();
let admin_token = state.token_with_scope("urn:mas:admin").await;
create_test_providers(&mut state).await;
let request = Request::get("/api/admin/v1/upstream-oauth-providers")
.bearer(&admin_token)
.empty();
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let body: serde_json::Value = response.json::<serde_json::Value>();
// Should return all providers
assert_eq!(body["data"].as_array().unwrap().len(), 3);
insta::assert_json_snapshot!(body, @r#"
{
"meta": {
"count": 3
},
"data": [
{
"type": "upstream-oauth-provider",
"id": "01FSHN9AG07HNEZXNQM2KNBNF6",
"attributes": {
"issuer": "https://appleid.apple.com",
"human_name": "Apple ID",
"brand_name": "apple",
"created_at": "2022-01-16T14:40:00Z",
"disabled_at": "2022-01-16T14:40:00Z"
},
"links": {
"self": "/api/admin/v1/upstream-oauth-providers/01FSHN9AG07HNEZXNQM2KNBNF6"
},
"meta": {
"page": {
"cursor": "01FSHN9AG07HNEZXNQM2KNBNF6"
}
}
},
{
"type": "upstream-oauth-provider",
"id": "01FSHN9AG09AVTNSQFMSR34AJC",
"attributes": {
"issuer": "https://login.microsoftonline.com/common/v2.0",
"human_name": "Microsoft",
"brand_name": "microsoft",
"created_at": "2022-01-16T14:40:00Z",
"disabled_at": null
},
"links": {
"self": "/api/admin/v1/upstream-oauth-providers/01FSHN9AG09AVTNSQFMSR34AJC"
},
"meta": {
"page": {
"cursor": "01FSHN9AG09AVTNSQFMSR34AJC"
}
}
},
{
"type": "upstream-oauth-provider",
"id": "01FSHN9AG0MZAA6S4AF7CTV32E",
"attributes": {
"issuer": "https://accounts.google.com",
"human_name": "Google",
"brand_name": "google",
"created_at": "2022-01-16T14:40:00Z",
"disabled_at": null
},
"links": {
"self": "/api/admin/v1/upstream-oauth-providers/01FSHN9AG0MZAA6S4AF7CTV32E"
},
"meta": {
"page": {
"cursor": "01FSHN9AG0MZAA6S4AF7CTV32E"
}
}
}
],
"links": {
"self": "/api/admin/v1/upstream-oauth-providers?page[first]=10",
"first": "/api/admin/v1/upstream-oauth-providers?page[first]=10",
"last": "/api/admin/v1/upstream-oauth-providers?page[last]=10"
}
}
"#);
}
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_filter_by_enabled_true(pool: PgPool) {
setup();
let mut state = TestState::from_pool(pool).await.unwrap();
let admin_token = state.token_with_scope("urn:mas:admin").await;
create_test_providers(&mut state).await;
let request = Request::get("/api/admin/v1/upstream-oauth-providers?filter[enabled]=true")
.bearer(&admin_token)
.empty();
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let body: serde_json::Value = response.json::<serde_json::Value>();
insta::assert_json_snapshot!(body, @r#"
{
"meta": {
"count": 2
},
"data": [
{
"type": "upstream-oauth-provider",
"id": "01FSHN9AG09AVTNSQFMSR34AJC",
"attributes": {
"issuer": "https://login.microsoftonline.com/common/v2.0",
"human_name": "Microsoft",
"brand_name": "microsoft",
"created_at": "2022-01-16T14:40:00Z",
"disabled_at": null
},
"links": {
"self": "/api/admin/v1/upstream-oauth-providers/01FSHN9AG09AVTNSQFMSR34AJC"
},
"meta": {
"page": {
"cursor": "01FSHN9AG09AVTNSQFMSR34AJC"
}
}
},
{
"type": "upstream-oauth-provider",
"id": "01FSHN9AG0MZAA6S4AF7CTV32E",
"attributes": {
"issuer": "https://accounts.google.com",
"human_name": "Google",
"brand_name": "google",
"created_at": "2022-01-16T14:40:00Z",
"disabled_at": null
},
"links": {
"self": "/api/admin/v1/upstream-oauth-providers/01FSHN9AG0MZAA6S4AF7CTV32E"
},
"meta": {
"page": {
"cursor": "01FSHN9AG0MZAA6S4AF7CTV32E"
}
}
}
],
"links": {
"self": "/api/admin/v1/upstream-oauth-providers?filter[enabled]=true&page[first]=10",
"first": "/api/admin/v1/upstream-oauth-providers?filter[enabled]=true&page[first]=10",
"last": "/api/admin/v1/upstream-oauth-providers?filter[enabled]=true&page[last]=10"
}
}
"#);
}
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_filter_by_enabled_false(pool: PgPool) {
setup();
let mut state = TestState::from_pool(pool).await.unwrap();
let admin_token = state.token_with_scope("urn:mas:admin").await;
create_test_providers(&mut state).await;
let request = Request::get("/api/admin/v1/upstream-oauth-providers?filter[enabled]=false")
.bearer(&admin_token)
.empty();
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let body: serde_json::Value = response.json::<serde_json::Value>();
insta::assert_json_snapshot!(body, @r#"
{
"meta": {
"count": 1
},
"data": [
{
"type": "upstream-oauth-provider",
"id": "01FSHN9AG07HNEZXNQM2KNBNF6",
"attributes": {
"issuer": "https://appleid.apple.com",
"human_name": "Apple ID",
"brand_name": "apple",
"created_at": "2022-01-16T14:40:00Z",
"disabled_at": "2022-01-16T14:40:00Z"
},
"links": {
"self": "/api/admin/v1/upstream-oauth-providers/01FSHN9AG07HNEZXNQM2KNBNF6"
},
"meta": {
"page": {
"cursor": "01FSHN9AG07HNEZXNQM2KNBNF6"
}
}
}
],
"links": {
"self": "/api/admin/v1/upstream-oauth-providers?filter[enabled]=false&page[first]=10",
"first": "/api/admin/v1/upstream-oauth-providers?filter[enabled]=false&page[first]=10",
"last": "/api/admin/v1/upstream-oauth-providers?filter[enabled]=false&page[last]=10"
}
}
"#);
}
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_pagination(pool: PgPool) {
setup();
let mut state = TestState::from_pool(pool).await.unwrap();
let admin_token = state.token_with_scope("urn:mas:admin").await;
create_test_providers(&mut state).await;
// Test first page with limit of 2
let request = Request::get("/api/admin/v1/upstream-oauth-providers?page[first]=2")
.bearer(&admin_token)
.empty();
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let body: serde_json::Value = response.json::<serde_json::Value>();
insta::assert_json_snapshot!(body, @r#"
{
"meta": {
"count": 3
},
"data": [
{
"type": "upstream-oauth-provider",
"id": "01FSHN9AG07HNEZXNQM2KNBNF6",
"attributes": {
"issuer": "https://appleid.apple.com",
"human_name": "Apple ID",
"brand_name": "apple",
"created_at": "2022-01-16T14:40:00Z",
"disabled_at": "2022-01-16T14:40:00Z"
},
"links": {
"self": "/api/admin/v1/upstream-oauth-providers/01FSHN9AG07HNEZXNQM2KNBNF6"
},
"meta": {
"page": {
"cursor": "01FSHN9AG07HNEZXNQM2KNBNF6"
}
}
},
{
"type": "upstream-oauth-provider",
"id": "01FSHN9AG09AVTNSQFMSR34AJC",
"attributes": {
"issuer": "https://login.microsoftonline.com/common/v2.0",
"human_name": "Microsoft",
"brand_name": "microsoft",
"created_at": "2022-01-16T14:40:00Z",
"disabled_at": null
},
"links": {
"self": "/api/admin/v1/upstream-oauth-providers/01FSHN9AG09AVTNSQFMSR34AJC"
},
"meta": {
"page": {
"cursor": "01FSHN9AG09AVTNSQFMSR34AJC"
}
}
}
],
"links": {
"self": "/api/admin/v1/upstream-oauth-providers?page[first]=2",
"first": "/api/admin/v1/upstream-oauth-providers?page[first]=2",
"last": "/api/admin/v1/upstream-oauth-providers?page[last]=2",
"next": "/api/admin/v1/upstream-oauth-providers?page[after]=01FSHN9AG09AVTNSQFMSR34AJC&page[first]=2"
}
}
"#);
// Extract the ID of the last item for pagination
let last_item_id = body["data"][1]["id"].as_str().unwrap();
let request = Request::get(format!(
"/api/admin/v1/upstream-oauth-providers?page[first]=2&page[after]={last_item_id}",
))
.bearer(&admin_token)
.empty();
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let body: serde_json::Value = response.json::<serde_json::Value>();
insta::assert_json_snapshot!(body, @r#"
{
"meta": {
"count": 3
},
"data": [
{
"type": "upstream-oauth-provider",
"id": "01FSHN9AG0MZAA6S4AF7CTV32E",
"attributes": {
"issuer": "https://accounts.google.com",
"human_name": "Google",
"brand_name": "google",
"created_at": "2022-01-16T14:40:00Z",
"disabled_at": null
},
"links": {
"self": "/api/admin/v1/upstream-oauth-providers/01FSHN9AG0MZAA6S4AF7CTV32E"
},
"meta": {
"page": {
"cursor": "01FSHN9AG0MZAA6S4AF7CTV32E"
}
}
}
],
"links": {
"self": "/api/admin/v1/upstream-oauth-providers?page[after]=01FSHN9AG09AVTNSQFMSR34AJC&page[first]=2",
"first": "/api/admin/v1/upstream-oauth-providers?page[first]=2",
"last": "/api/admin/v1/upstream-oauth-providers?page[last]=2"
}
}
"#);
}
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_invalid_filter(pool: PgPool) {
setup();
let mut state = TestState::from_pool(pool).await.unwrap();
let admin_token = state.token_with_scope("urn:mas:admin").await;
let request =
Request::get("/api/admin/v1/upstream-oauth-providers?filter[enabled]=invalid")
.bearer(&admin_token)
.empty();
let response = state.request(request).await;
response.assert_status(StatusCode::BAD_REQUEST);
}
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_count_parameter(pool: PgPool) {
setup();
let mut state = TestState::from_pool(pool).await.unwrap();
let admin_token = state.token_with_scope("urn:mas:admin").await;
create_test_providers(&mut state).await;
// Test count=false
let request = Request::get("/api/admin/v1/upstream-oauth-providers?count=false")
.bearer(&admin_token)
.empty();
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let body: serde_json::Value = response.json::<serde_json::Value>();
insta::assert_json_snapshot!(body, @r#"
{
"data": [
{
"type": "upstream-oauth-provider",
"id": "01FSHN9AG07HNEZXNQM2KNBNF6",
"attributes": {
"issuer": "https://appleid.apple.com",
"human_name": "Apple ID",
"brand_name": "apple",
"created_at": "2022-01-16T14:40:00Z",
"disabled_at": "2022-01-16T14:40:00Z"
},
"links": {
"self": "/api/admin/v1/upstream-oauth-providers/01FSHN9AG07HNEZXNQM2KNBNF6"
},
"meta": {
"page": {
"cursor": "01FSHN9AG07HNEZXNQM2KNBNF6"
}
}
},
{
"type": "upstream-oauth-provider",
"id": "01FSHN9AG09AVTNSQFMSR34AJC",
"attributes": {
"issuer": "https://login.microsoftonline.com/common/v2.0",
"human_name": "Microsoft",
"brand_name": "microsoft",
"created_at": "2022-01-16T14:40:00Z",
"disabled_at": null
},
"links": {
"self": "/api/admin/v1/upstream-oauth-providers/01FSHN9AG09AVTNSQFMSR34AJC"
},
"meta": {
"page": {
"cursor": "01FSHN9AG09AVTNSQFMSR34AJC"
}
}
},
{
"type": "upstream-oauth-provider",
"id": "01FSHN9AG0MZAA6S4AF7CTV32E",
"attributes": {
"issuer": "https://accounts.google.com",
"human_name": "Google",
"brand_name": "google",
"created_at": "2022-01-16T14:40:00Z",
"disabled_at": null
},
"links": {
"self": "/api/admin/v1/upstream-oauth-providers/01FSHN9AG0MZAA6S4AF7CTV32E"
},
"meta": {
"page": {
"cursor": "01FSHN9AG0MZAA6S4AF7CTV32E"
}
}
}
],
"links": {
"self": "/api/admin/v1/upstream-oauth-providers?count=false&page[first]=10",
"first": "/api/admin/v1/upstream-oauth-providers?count=false&page[first]=10",
"last": "/api/admin/v1/upstream-oauth-providers?count=false&page[last]=10"
}
}
"#);
// Test count=only
let request = Request::get("/api/admin/v1/upstream-oauth-providers?count=only")
.bearer(&admin_token)
.empty();
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let body: serde_json::Value = response.json::<serde_json::Value>();
insta::assert_json_snapshot!(body, @r#"
{
"meta": {
"count": 3
},
"links": {
"self": "/api/admin/v1/upstream-oauth-providers?count=only"
}
}
"#);
// Test count=false with filtering
let request =
Request::get("/api/admin/v1/upstream-oauth-providers?count=false&filter[enabled]=true")
.bearer(&admin_token)
.empty();
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let body: serde_json::Value = response.json::<serde_json::Value>();
insta::assert_json_snapshot!(body, @r#"
{
"data": [
{
"type": "upstream-oauth-provider",
"id": "01FSHN9AG09AVTNSQFMSR34AJC",
"attributes": {
"issuer": "https://login.microsoftonline.com/common/v2.0",
"human_name": "Microsoft",
"brand_name": "microsoft",
"created_at": "2022-01-16T14:40:00Z",
"disabled_at": null
},
"links": {
"self": "/api/admin/v1/upstream-oauth-providers/01FSHN9AG09AVTNSQFMSR34AJC"
},
"meta": {
"page": {
"cursor": "01FSHN9AG09AVTNSQFMSR34AJC"
}
}
},
{
"type": "upstream-oauth-provider",
"id": "01FSHN9AG0MZAA6S4AF7CTV32E",
"attributes": {
"issuer": "https://accounts.google.com",
"human_name": "Google",
"brand_name": "google",
"created_at": "2022-01-16T14:40:00Z",
"disabled_at": null
},
"links": {
"self": "/api/admin/v1/upstream-oauth-providers/01FSHN9AG0MZAA6S4AF7CTV32E"
},
"meta": {
"page": {
"cursor": "01FSHN9AG0MZAA6S4AF7CTV32E"
}
}
}
],
"links": {
"self": "/api/admin/v1/upstream-oauth-providers?filter[enabled]=true&count=false&page[first]=10",
"first": "/api/admin/v1/upstream-oauth-providers?filter[enabled]=true&count=false&page[first]=10",
"last": "/api/admin/v1/upstream-oauth-providers?filter[enabled]=true&count=false&page[last]=10"
}
}
"#);
// Test count=only with filtering
let request =
Request::get("/api/admin/v1/upstream-oauth-providers?count=only&filter[enabled]=false")
.bearer(&admin_token)
.empty();
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let body: serde_json::Value = response.json::<serde_json::Value>();
insta::assert_json_snapshot!(body, @r#"
{
"meta": {
"count": 1
},
"links": {
"self": "/api/admin/v1/upstream-oauth-providers?filter[enabled]=false&count=only"
}
}
"#);
}
}

View File

@@ -0,0 +1,12 @@
// Copyright 2025 New Vector Ltd.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
mod get;
mod list;
pub use self::{
get::{doc as get_doc, handler as get},
list::{doc as list_doc, handler as list},
};

View File

@@ -4,11 +4,8 @@
// Please see LICENSE files in the repository root for full details. // Please see LICENSE files in the repository root for full details.
use aide::{OperationIo, transform::TransformOperation}; use aide::{OperationIo, transform::TransformOperation};
use axum::{ use axum::{Json, response::IntoResponse};
Json, use axum_extra::extract::{Query, QueryRejection};
extract::{Query, rejection::QueryRejection},
response::IntoResponse,
};
use axum_macros::FromRequestParts; use axum_macros::FromRequestParts;
use hyper::StatusCode; use hyper::StatusCode;
use mas_axum_utils::record_error; use mas_axum_utils::record_error;
@@ -21,7 +18,7 @@ use crate::{
admin::{ admin::{
call_context::CallContext, call_context::CallContext,
model::{Resource, UserEmail}, model::{Resource, UserEmail},
params::Pagination, params::{IncludeCount, Pagination},
response::{ErrorResponse, PaginatedResponse}, response::{ErrorResponse, PaginatedResponse},
}, },
impl_from_error_for_route, impl_from_error_for_route,
@@ -99,16 +96,22 @@ pub fn doc(operation: TransformOperation) -> TransformOperation {
let emails = UserEmail::samples(); let emails = UserEmail::samples();
let pagination = mas_storage::Pagination::first(emails.len()); let pagination = mas_storage::Pagination::first(emails.len());
let page = Page { let page = Page {
edges: emails.into(), edges: emails
.into_iter()
.map(|node| mas_storage::pagination::Edge {
cursor: node.id(),
node,
})
.collect(),
has_next_page: true, has_next_page: true,
has_previous_page: false, has_previous_page: false,
}; };
t.description("Paginated response of user emails") t.description("Paginated response of user emails")
.example(PaginatedResponse::new( .example(PaginatedResponse::for_page(
page, page,
pagination, pagination,
42, Some(42),
UserEmail::PATH, UserEmail::PATH,
)) ))
}) })
@@ -121,10 +124,11 @@ pub fn doc(operation: TransformOperation) -> TransformOperation {
#[tracing::instrument(name = "handler.admin.v1.user_emails.list", skip_all)] #[tracing::instrument(name = "handler.admin.v1.user_emails.list", skip_all)]
pub async fn handler( pub async fn handler(
CallContext { mut repo, .. }: CallContext, CallContext { mut repo, .. }: CallContext,
Pagination(pagination): Pagination, Pagination(pagination, include_count): Pagination,
params: FilterParams, params: FilterParams,
) -> Result<Json<PaginatedResponse<UserEmail>>, RouteError> { ) -> Result<Json<PaginatedResponse<UserEmail>>, RouteError> {
let base = format!("{path}{params}", path = UserEmail::PATH); let base = format!("{path}{params}", path = UserEmail::PATH);
let base = include_count.add_to_base(&base);
let filter = UserEmailFilter::default(); let filter = UserEmailFilter::default();
// Load the user from the filter // Load the user from the filter
@@ -150,15 +154,31 @@ pub async fn handler(
None => filter, None => filter,
}; };
let page = repo.user_email().list(filter, pagination).await?; let response = match include_count {
let count = repo.user_email().count(filter).await?; IncludeCount::True => {
let page = repo
.user_email()
.list(filter, pagination)
.await?
.map(UserEmail::from);
let count = repo.user_email().count(filter).await?;
PaginatedResponse::for_page(page, pagination, Some(count), &base)
}
IncludeCount::False => {
let page = repo
.user_email()
.list(filter, pagination)
.await?
.map(UserEmail::from);
PaginatedResponse::for_page(page, pagination, None, &base)
}
IncludeCount::Only => {
let count = repo.user_email().count(filter).await?;
PaginatedResponse::for_count_only(count, &base)
}
};
Ok(Json(PaginatedResponse::new( Ok(Json(response))
page.map(UserEmail::from),
pagination,
count,
&base,
)))
} }
#[cfg(test)] #[cfg(test)]
@@ -209,7 +229,7 @@ mod tests {
let response = state.request(request).await; let response = state.request(request).await;
response.assert_status(StatusCode::OK); response.assert_status(StatusCode::OK);
let body: serde_json::Value = response.json(); let body: serde_json::Value = response.json();
insta::assert_json_snapshot!(body, @r###" insta::assert_json_snapshot!(body, @r#"
{ {
"meta": { "meta": {
"count": 2 "count": 2
@@ -225,6 +245,11 @@ mod tests {
}, },
"links": { "links": {
"self": "/api/admin/v1/user-emails/01FSHN9AG09NMZYX8MFYH578R9" "self": "/api/admin/v1/user-emails/01FSHN9AG09NMZYX8MFYH578R9"
},
"meta": {
"page": {
"cursor": "01FSHN9AG09NMZYX8MFYH578R9"
}
} }
}, },
{ {
@@ -237,6 +262,11 @@ mod tests {
}, },
"links": { "links": {
"self": "/api/admin/v1/user-emails/01FSHN9AG0KEPHYQQXW9XPTX6Z" "self": "/api/admin/v1/user-emails/01FSHN9AG0KEPHYQQXW9XPTX6Z"
},
"meta": {
"page": {
"cursor": "01FSHN9AG0KEPHYQQXW9XPTX6Z"
}
} }
} }
], ],
@@ -246,7 +276,7 @@ mod tests {
"last": "/api/admin/v1/user-emails?page[last]=10" "last": "/api/admin/v1/user-emails?page[last]=10"
} }
} }
"###); "#);
// Filter by user // Filter by user
let request = Request::get(format!( let request = Request::get(format!(
@@ -258,7 +288,7 @@ mod tests {
let response = state.request(request).await; let response = state.request(request).await;
response.assert_status(StatusCode::OK); response.assert_status(StatusCode::OK);
let body: serde_json::Value = response.json(); let body: serde_json::Value = response.json();
insta::assert_json_snapshot!(body, @r###" insta::assert_json_snapshot!(body, @r#"
{ {
"meta": { "meta": {
"count": 1 "count": 1
@@ -274,6 +304,11 @@ mod tests {
}, },
"links": { "links": {
"self": "/api/admin/v1/user-emails/01FSHN9AG09NMZYX8MFYH578R9" "self": "/api/admin/v1/user-emails/01FSHN9AG09NMZYX8MFYH578R9"
},
"meta": {
"page": {
"cursor": "01FSHN9AG09NMZYX8MFYH578R9"
}
} }
} }
], ],
@@ -283,7 +318,7 @@ mod tests {
"last": "/api/admin/v1/user-emails?filter[user]=01FSHN9AG0MZAA6S4AF7CTV32E&page[last]=10" "last": "/api/admin/v1/user-emails?filter[user]=01FSHN9AG0MZAA6S4AF7CTV32E&page[last]=10"
} }
} }
"###); "#);
// Filter by email // Filter by email
let request = Request::get("/api/admin/v1/user-emails?filter[email]=alice@example.com") let request = Request::get("/api/admin/v1/user-emails?filter[email]=alice@example.com")
@@ -292,7 +327,7 @@ mod tests {
let response = state.request(request).await; let response = state.request(request).await;
response.assert_status(StatusCode::OK); response.assert_status(StatusCode::OK);
let body: serde_json::Value = response.json(); let body: serde_json::Value = response.json();
insta::assert_json_snapshot!(body, @r###" insta::assert_json_snapshot!(body, @r#"
{ {
"meta": { "meta": {
"count": 1 "count": 1
@@ -308,6 +343,11 @@ mod tests {
}, },
"links": { "links": {
"self": "/api/admin/v1/user-emails/01FSHN9AG09NMZYX8MFYH578R9" "self": "/api/admin/v1/user-emails/01FSHN9AG09NMZYX8MFYH578R9"
},
"meta": {
"page": {
"cursor": "01FSHN9AG09NMZYX8MFYH578R9"
}
} }
} }
], ],
@@ -317,6 +357,137 @@ mod tests {
"last": "/api/admin/v1/user-emails?filter[email]=alice@example.com&page[last]=10" "last": "/api/admin/v1/user-emails?filter[email]=alice@example.com&page[last]=10"
} }
} }
"#);
// Test count=false
let request = Request::get("/api/admin/v1/user-emails?count=false")
.bearer(&token)
.empty();
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let body: serde_json::Value = response.json();
insta::assert_json_snapshot!(body, @r#"
{
"data": [
{
"type": "user-email",
"id": "01FSHN9AG09NMZYX8MFYH578R9",
"attributes": {
"created_at": "2022-01-16T14:40:00Z",
"user_id": "01FSHN9AG0MZAA6S4AF7CTV32E",
"email": "alice@example.com"
},
"links": {
"self": "/api/admin/v1/user-emails/01FSHN9AG09NMZYX8MFYH578R9"
},
"meta": {
"page": {
"cursor": "01FSHN9AG09NMZYX8MFYH578R9"
}
}
},
{
"type": "user-email",
"id": "01FSHN9AG0KEPHYQQXW9XPTX6Z",
"attributes": {
"created_at": "2022-01-16T14:40:00Z",
"user_id": "01FSHN9AG0AJ6AC5HQ9X6H4RP4",
"email": "bob@example.com"
},
"links": {
"self": "/api/admin/v1/user-emails/01FSHN9AG0KEPHYQQXW9XPTX6Z"
},
"meta": {
"page": {
"cursor": "01FSHN9AG0KEPHYQQXW9XPTX6Z"
}
}
}
],
"links": {
"self": "/api/admin/v1/user-emails?count=false&page[first]=10",
"first": "/api/admin/v1/user-emails?count=false&page[first]=10",
"last": "/api/admin/v1/user-emails?count=false&page[last]=10"
}
}
"#);
// Test count=only
let request = Request::get("/api/admin/v1/user-emails?count=only")
.bearer(&token)
.empty();
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let body: serde_json::Value = response.json();
insta::assert_json_snapshot!(body, @r###"
{
"meta": {
"count": 2
},
"links": {
"self": "/api/admin/v1/user-emails?count=only"
}
}
"###); "###);
// Test count=false with filtering
let request = Request::get(format!(
"/api/admin/v1/user-emails?count=false&filter[user]={}",
alice.id
))
.bearer(&token)
.empty();
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let body: serde_json::Value = response.json();
insta::assert_json_snapshot!(body, @r#"
{
"data": [
{
"type": "user-email",
"id": "01FSHN9AG09NMZYX8MFYH578R9",
"attributes": {
"created_at": "2022-01-16T14:40:00Z",
"user_id": "01FSHN9AG0MZAA6S4AF7CTV32E",
"email": "alice@example.com"
},
"links": {
"self": "/api/admin/v1/user-emails/01FSHN9AG09NMZYX8MFYH578R9"
},
"meta": {
"page": {
"cursor": "01FSHN9AG09NMZYX8MFYH578R9"
}
}
}
],
"links": {
"self": "/api/admin/v1/user-emails?filter[user]=01FSHN9AG0MZAA6S4AF7CTV32E&count=false&page[first]=10",
"first": "/api/admin/v1/user-emails?filter[user]=01FSHN9AG0MZAA6S4AF7CTV32E&count=false&page[first]=10",
"last": "/api/admin/v1/user-emails?filter[user]=01FSHN9AG0MZAA6S4AF7CTV32E&count=false&page[last]=10"
}
}
"#);
// Test count=only with filtering
let request = Request::get(format!(
"/api/admin/v1/user-emails?count=only&filter[user]={}",
alice.id
))
.bearer(&token)
.empty();
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let body: serde_json::Value = response.json();
insta::assert_json_snapshot!(body, @r#"
{
"meta": {
"count": 1
},
"links": {
"self": "/api/admin/v1/user-emails?filter[user]=01FSHN9AG0MZAA6S4AF7CTV32E&count=only"
}
}
"#);
} }
} }

View File

@@ -5,11 +5,8 @@
// Please see LICENSE files in the repository root for full details. // Please see LICENSE files in the repository root for full details.
use aide::{OperationIo, transform::TransformOperation}; use aide::{OperationIo, transform::TransformOperation};
use axum::{ use axum::{Json, response::IntoResponse};
Json, use axum_extra::extract::{Query, QueryRejection};
extract::{Query, rejection::QueryRejection},
response::IntoResponse,
};
use axum_macros::FromRequestParts; use axum_macros::FromRequestParts;
use hyper::StatusCode; use hyper::StatusCode;
use mas_axum_utils::record_error; use mas_axum_utils::record_error;
@@ -21,7 +18,7 @@ use crate::{
admin::{ admin::{
call_context::CallContext, call_context::CallContext,
model::{Resource, UserRegistrationToken}, model::{Resource, UserRegistrationToken},
params::Pagination, params::{IncludeCount, Pagination},
response::{ErrorResponse, PaginatedResponse}, response::{ErrorResponse, PaginatedResponse},
}, },
impl_from_error_for_route, impl_from_error_for_route,
@@ -112,16 +109,22 @@ pub fn doc(operation: TransformOperation) -> TransformOperation {
let tokens = UserRegistrationToken::samples(); let tokens = UserRegistrationToken::samples();
let pagination = mas_storage::Pagination::first(tokens.len()); let pagination = mas_storage::Pagination::first(tokens.len());
let page = Page { let page = Page {
edges: tokens.into(), edges: tokens
.into_iter()
.map(|node| mas_storage::pagination::Edge {
cursor: node.id(),
node,
})
.collect(),
has_next_page: true, has_next_page: true,
has_previous_page: false, has_previous_page: false,
}; };
t.description("Paginated response of registration tokens") t.description("Paginated response of registration tokens")
.example(PaginatedResponse::new( .example(PaginatedResponse::for_page(
page, page,
pagination, pagination,
42, Some(42),
UserRegistrationToken::PATH, UserRegistrationToken::PATH,
)) ))
}) })
@@ -132,10 +135,11 @@ pub async fn handler(
CallContext { CallContext {
mut repo, clock, .. mut repo, clock, ..
}: CallContext, }: CallContext,
Pagination(pagination): Pagination, Pagination(pagination, include_count): Pagination,
params: FilterParams, params: FilterParams,
) -> Result<Json<PaginatedResponse<UserRegistrationToken>>, RouteError> { ) -> Result<Json<PaginatedResponse<UserRegistrationToken>>, RouteError> {
let base = format!("{path}{params}", path = UserRegistrationToken::PATH); let base = format!("{path}{params}", path = UserRegistrationToken::PATH);
let base = include_count.add_to_base(&base);
let now = clock.now(); let now = clock.now();
let mut filter = UserRegistrationTokenFilter::new(now); let mut filter = UserRegistrationTokenFilter::new(now);
@@ -155,18 +159,31 @@ pub async fn handler(
filter = filter.with_valid(valid); filter = filter.with_valid(valid);
} }
let page = repo let response = match include_count {
.user_registration_token() IncludeCount::True => {
.list(filter, pagination) let page = repo
.await?; .user_registration_token()
let count = repo.user_registration_token().count(filter).await?; .list(filter, pagination)
.await?
.map(|token| UserRegistrationToken::new(token, now));
let count = repo.user_registration_token().count(filter).await?;
PaginatedResponse::for_page(page, pagination, Some(count), &base)
}
IncludeCount::False => {
let page = repo
.user_registration_token()
.list(filter, pagination)
.await?
.map(|token| UserRegistrationToken::new(token, now));
PaginatedResponse::for_page(page, pagination, None, &base)
}
IncludeCount::Only => {
let count = repo.user_registration_token().count(filter).await?;
PaginatedResponse::for_count_only(count, &base)
}
};
Ok(Json(PaginatedResponse::new( Ok(Json(response))
page.map(|token| UserRegistrationToken::new(token, now)),
pagination,
count,
&base,
)))
} }
#[cfg(test)] #[cfg(test)]
@@ -300,6 +317,11 @@ mod tests {
}, },
"links": { "links": {
"self": "/api/admin/v1/user-registration-tokens/01FSHN9AG064K8BYZXSY5G511Z" "self": "/api/admin/v1/user-registration-tokens/01FSHN9AG064K8BYZXSY5G511Z"
},
"meta": {
"page": {
"cursor": "01FSHN9AG064K8BYZXSY5G511Z"
}
} }
}, },
{ {
@@ -317,6 +339,11 @@ mod tests {
}, },
"links": { "links": {
"self": "/api/admin/v1/user-registration-tokens/01FSHN9AG07HNEZXNQM2KNBNF6" "self": "/api/admin/v1/user-registration-tokens/01FSHN9AG07HNEZXNQM2KNBNF6"
},
"meta": {
"page": {
"cursor": "01FSHN9AG07HNEZXNQM2KNBNF6"
}
} }
}, },
{ {
@@ -334,6 +361,11 @@ mod tests {
}, },
"links": { "links": {
"self": "/api/admin/v1/user-registration-tokens/01FSHN9AG09AVTNSQFMSR34AJC" "self": "/api/admin/v1/user-registration-tokens/01FSHN9AG09AVTNSQFMSR34AJC"
},
"meta": {
"page": {
"cursor": "01FSHN9AG09AVTNSQFMSR34AJC"
}
} }
}, },
{ {
@@ -351,6 +383,11 @@ mod tests {
}, },
"links": { "links": {
"self": "/api/admin/v1/user-registration-tokens/01FSHN9AG0MZAA6S4AF7CTV32E" "self": "/api/admin/v1/user-registration-tokens/01FSHN9AG0MZAA6S4AF7CTV32E"
},
"meta": {
"page": {
"cursor": "01FSHN9AG0MZAA6S4AF7CTV32E"
}
} }
}, },
{ {
@@ -368,6 +405,11 @@ mod tests {
}, },
"links": { "links": {
"self": "/api/admin/v1/user-registration-tokens/01FSHN9AG0S3ZJD8CXQ7F11KXN" "self": "/api/admin/v1/user-registration-tokens/01FSHN9AG0S3ZJD8CXQ7F11KXN"
},
"meta": {
"page": {
"cursor": "01FSHN9AG0S3ZJD8CXQ7F11KXN"
}
} }
} }
], ],
@@ -416,6 +458,11 @@ mod tests {
}, },
"links": { "links": {
"self": "/api/admin/v1/user-registration-tokens/01FSHN9AG07HNEZXNQM2KNBNF6" "self": "/api/admin/v1/user-registration-tokens/01FSHN9AG07HNEZXNQM2KNBNF6"
},
"meta": {
"page": {
"cursor": "01FSHN9AG07HNEZXNQM2KNBNF6"
}
} }
}, },
{ {
@@ -433,6 +480,11 @@ mod tests {
}, },
"links": { "links": {
"self": "/api/admin/v1/user-registration-tokens/01FSHN9AG0S3ZJD8CXQ7F11KXN" "self": "/api/admin/v1/user-registration-tokens/01FSHN9AG0S3ZJD8CXQ7F11KXN"
},
"meta": {
"page": {
"cursor": "01FSHN9AG0S3ZJD8CXQ7F11KXN"
}
} }
} }
], ],
@@ -473,6 +525,11 @@ mod tests {
}, },
"links": { "links": {
"self": "/api/admin/v1/user-registration-tokens/01FSHN9AG064K8BYZXSY5G511Z" "self": "/api/admin/v1/user-registration-tokens/01FSHN9AG064K8BYZXSY5G511Z"
},
"meta": {
"page": {
"cursor": "01FSHN9AG064K8BYZXSY5G511Z"
}
} }
}, },
{ {
@@ -490,6 +547,11 @@ mod tests {
}, },
"links": { "links": {
"self": "/api/admin/v1/user-registration-tokens/01FSHN9AG09AVTNSQFMSR34AJC" "self": "/api/admin/v1/user-registration-tokens/01FSHN9AG09AVTNSQFMSR34AJC"
},
"meta": {
"page": {
"cursor": "01FSHN9AG09AVTNSQFMSR34AJC"
}
} }
}, },
{ {
@@ -507,6 +569,11 @@ mod tests {
}, },
"links": { "links": {
"self": "/api/admin/v1/user-registration-tokens/01FSHN9AG0MZAA6S4AF7CTV32E" "self": "/api/admin/v1/user-registration-tokens/01FSHN9AG0MZAA6S4AF7CTV32E"
},
"meta": {
"page": {
"cursor": "01FSHN9AG0MZAA6S4AF7CTV32E"
}
} }
} }
], ],
@@ -555,6 +622,11 @@ mod tests {
}, },
"links": { "links": {
"self": "/api/admin/v1/user-registration-tokens/01FSHN9AG09AVTNSQFMSR34AJC" "self": "/api/admin/v1/user-registration-tokens/01FSHN9AG09AVTNSQFMSR34AJC"
},
"meta": {
"page": {
"cursor": "01FSHN9AG09AVTNSQFMSR34AJC"
}
} }
}, },
{ {
@@ -572,6 +644,11 @@ mod tests {
}, },
"links": { "links": {
"self": "/api/admin/v1/user-registration-tokens/01FSHN9AG0S3ZJD8CXQ7F11KXN" "self": "/api/admin/v1/user-registration-tokens/01FSHN9AG0S3ZJD8CXQ7F11KXN"
},
"meta": {
"page": {
"cursor": "01FSHN9AG0S3ZJD8CXQ7F11KXN"
}
} }
} }
], ],
@@ -612,6 +689,11 @@ mod tests {
}, },
"links": { "links": {
"self": "/api/admin/v1/user-registration-tokens/01FSHN9AG064K8BYZXSY5G511Z" "self": "/api/admin/v1/user-registration-tokens/01FSHN9AG064K8BYZXSY5G511Z"
},
"meta": {
"page": {
"cursor": "01FSHN9AG064K8BYZXSY5G511Z"
}
} }
}, },
{ {
@@ -629,6 +711,11 @@ mod tests {
}, },
"links": { "links": {
"self": "/api/admin/v1/user-registration-tokens/01FSHN9AG07HNEZXNQM2KNBNF6" "self": "/api/admin/v1/user-registration-tokens/01FSHN9AG07HNEZXNQM2KNBNF6"
},
"meta": {
"page": {
"cursor": "01FSHN9AG07HNEZXNQM2KNBNF6"
}
} }
}, },
{ {
@@ -646,6 +733,11 @@ mod tests {
}, },
"links": { "links": {
"self": "/api/admin/v1/user-registration-tokens/01FSHN9AG0MZAA6S4AF7CTV32E" "self": "/api/admin/v1/user-registration-tokens/01FSHN9AG0MZAA6S4AF7CTV32E"
},
"meta": {
"page": {
"cursor": "01FSHN9AG0MZAA6S4AF7CTV32E"
}
} }
} }
], ],
@@ -694,6 +786,11 @@ mod tests {
}, },
"links": { "links": {
"self": "/api/admin/v1/user-registration-tokens/01FSHN9AG064K8BYZXSY5G511Z" "self": "/api/admin/v1/user-registration-tokens/01FSHN9AG064K8BYZXSY5G511Z"
},
"meta": {
"page": {
"cursor": "01FSHN9AG064K8BYZXSY5G511Z"
}
} }
} }
], ],
@@ -734,6 +831,11 @@ mod tests {
}, },
"links": { "links": {
"self": "/api/admin/v1/user-registration-tokens/01FSHN9AG07HNEZXNQM2KNBNF6" "self": "/api/admin/v1/user-registration-tokens/01FSHN9AG07HNEZXNQM2KNBNF6"
},
"meta": {
"page": {
"cursor": "01FSHN9AG07HNEZXNQM2KNBNF6"
}
} }
}, },
{ {
@@ -751,6 +853,11 @@ mod tests {
}, },
"links": { "links": {
"self": "/api/admin/v1/user-registration-tokens/01FSHN9AG09AVTNSQFMSR34AJC" "self": "/api/admin/v1/user-registration-tokens/01FSHN9AG09AVTNSQFMSR34AJC"
},
"meta": {
"page": {
"cursor": "01FSHN9AG09AVTNSQFMSR34AJC"
}
} }
}, },
{ {
@@ -768,6 +875,11 @@ mod tests {
}, },
"links": { "links": {
"self": "/api/admin/v1/user-registration-tokens/01FSHN9AG0MZAA6S4AF7CTV32E" "self": "/api/admin/v1/user-registration-tokens/01FSHN9AG0MZAA6S4AF7CTV32E"
},
"meta": {
"page": {
"cursor": "01FSHN9AG0MZAA6S4AF7CTV32E"
}
} }
}, },
{ {
@@ -785,6 +897,11 @@ mod tests {
}, },
"links": { "links": {
"self": "/api/admin/v1/user-registration-tokens/01FSHN9AG0S3ZJD8CXQ7F11KXN" "self": "/api/admin/v1/user-registration-tokens/01FSHN9AG0S3ZJD8CXQ7F11KXN"
},
"meta": {
"page": {
"cursor": "01FSHN9AG0S3ZJD8CXQ7F11KXN"
}
} }
} }
], ],
@@ -833,6 +950,11 @@ mod tests {
}, },
"links": { "links": {
"self": "/api/admin/v1/user-registration-tokens/01FSHN9AG07HNEZXNQM2KNBNF6" "self": "/api/admin/v1/user-registration-tokens/01FSHN9AG07HNEZXNQM2KNBNF6"
},
"meta": {
"page": {
"cursor": "01FSHN9AG07HNEZXNQM2KNBNF6"
}
} }
}, },
{ {
@@ -850,6 +972,11 @@ mod tests {
}, },
"links": { "links": {
"self": "/api/admin/v1/user-registration-tokens/01FSHN9AG0MZAA6S4AF7CTV32E" "self": "/api/admin/v1/user-registration-tokens/01FSHN9AG0MZAA6S4AF7CTV32E"
},
"meta": {
"page": {
"cursor": "01FSHN9AG0MZAA6S4AF7CTV32E"
}
} }
} }
], ],
@@ -890,6 +1017,11 @@ mod tests {
}, },
"links": { "links": {
"self": "/api/admin/v1/user-registration-tokens/01FSHN9AG064K8BYZXSY5G511Z" "self": "/api/admin/v1/user-registration-tokens/01FSHN9AG064K8BYZXSY5G511Z"
},
"meta": {
"page": {
"cursor": "01FSHN9AG064K8BYZXSY5G511Z"
}
} }
}, },
{ {
@@ -907,6 +1039,11 @@ mod tests {
}, },
"links": { "links": {
"self": "/api/admin/v1/user-registration-tokens/01FSHN9AG09AVTNSQFMSR34AJC" "self": "/api/admin/v1/user-registration-tokens/01FSHN9AG09AVTNSQFMSR34AJC"
},
"meta": {
"page": {
"cursor": "01FSHN9AG09AVTNSQFMSR34AJC"
}
} }
}, },
{ {
@@ -924,6 +1061,11 @@ mod tests {
}, },
"links": { "links": {
"self": "/api/admin/v1/user-registration-tokens/01FSHN9AG0S3ZJD8CXQ7F11KXN" "self": "/api/admin/v1/user-registration-tokens/01FSHN9AG0S3ZJD8CXQ7F11KXN"
},
"meta": {
"page": {
"cursor": "01FSHN9AG0S3ZJD8CXQ7F11KXN"
}
} }
} }
], ],
@@ -974,6 +1116,11 @@ mod tests {
}, },
"links": { "links": {
"self": "/api/admin/v1/user-registration-tokens/01FSHN9AG0S3ZJD8CXQ7F11KXN" "self": "/api/admin/v1/user-registration-tokens/01FSHN9AG0S3ZJD8CXQ7F11KXN"
},
"meta": {
"page": {
"cursor": "01FSHN9AG0S3ZJD8CXQ7F11KXN"
}
} }
} }
], ],
@@ -1022,6 +1169,11 @@ mod tests {
}, },
"links": { "links": {
"self": "/api/admin/v1/user-registration-tokens/01FSHN9AG064K8BYZXSY5G511Z" "self": "/api/admin/v1/user-registration-tokens/01FSHN9AG064K8BYZXSY5G511Z"
},
"meta": {
"page": {
"cursor": "01FSHN9AG064K8BYZXSY5G511Z"
}
} }
}, },
{ {
@@ -1039,6 +1191,11 @@ mod tests {
}, },
"links": { "links": {
"self": "/api/admin/v1/user-registration-tokens/01FSHN9AG07HNEZXNQM2KNBNF6" "self": "/api/admin/v1/user-registration-tokens/01FSHN9AG07HNEZXNQM2KNBNF6"
},
"meta": {
"page": {
"cursor": "01FSHN9AG07HNEZXNQM2KNBNF6"
}
} }
} }
], ],
@@ -1080,6 +1237,11 @@ mod tests {
}, },
"links": { "links": {
"self": "/api/admin/v1/user-registration-tokens/01FSHN9AG09AVTNSQFMSR34AJC" "self": "/api/admin/v1/user-registration-tokens/01FSHN9AG09AVTNSQFMSR34AJC"
},
"meta": {
"page": {
"cursor": "01FSHN9AG09AVTNSQFMSR34AJC"
}
} }
}, },
{ {
@@ -1097,6 +1259,11 @@ mod tests {
}, },
"links": { "links": {
"self": "/api/admin/v1/user-registration-tokens/01FSHN9AG0MZAA6S4AF7CTV32E" "self": "/api/admin/v1/user-registration-tokens/01FSHN9AG0MZAA6S4AF7CTV32E"
},
"meta": {
"page": {
"cursor": "01FSHN9AG0MZAA6S4AF7CTV32E"
}
} }
} }
], ],
@@ -1138,6 +1305,11 @@ mod tests {
}, },
"links": { "links": {
"self": "/api/admin/v1/user-registration-tokens/01FSHN9AG0S3ZJD8CXQ7F11KXN" "self": "/api/admin/v1/user-registration-tokens/01FSHN9AG0S3ZJD8CXQ7F11KXN"
},
"meta": {
"page": {
"cursor": "01FSHN9AG0S3ZJD8CXQ7F11KXN"
}
} }
} }
], ],
@@ -1172,4 +1344,242 @@ mod tests {
.contains("Invalid filter parameters") .contains("Invalid filter parameters")
); );
} }
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_count_parameter(pool: PgPool) {
setup();
let mut state = TestState::from_pool(pool).await.unwrap();
let admin_token = state.token_with_scope("urn:mas:admin").await;
create_test_tokens(&mut state).await;
// Test count=false
let request = Request::get("/api/admin/v1/user-registration-tokens?count=false")
.bearer(&admin_token)
.empty();
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let body: serde_json::Value = response.json();
insta::assert_json_snapshot!(body, @r#"
{
"data": [
{
"type": "user-registration_token",
"id": "01FSHN9AG064K8BYZXSY5G511Z",
"attributes": {
"token": "token_expired",
"valid": false,
"usage_limit": 5,
"times_used": 0,
"created_at": "2022-01-16T14:40:00Z",
"last_used_at": null,
"expires_at": "2022-01-15T14:40:00Z",
"revoked_at": null
},
"links": {
"self": "/api/admin/v1/user-registration-tokens/01FSHN9AG064K8BYZXSY5G511Z"
},
"meta": {
"page": {
"cursor": "01FSHN9AG064K8BYZXSY5G511Z"
}
}
},
{
"type": "user-registration_token",
"id": "01FSHN9AG07HNEZXNQM2KNBNF6",
"attributes": {
"token": "token_used",
"valid": true,
"usage_limit": 10,
"times_used": 1,
"created_at": "2022-01-16T14:40:00Z",
"last_used_at": "2022-01-16T14:40:00Z",
"expires_at": null,
"revoked_at": null
},
"links": {
"self": "/api/admin/v1/user-registration-tokens/01FSHN9AG07HNEZXNQM2KNBNF6"
},
"meta": {
"page": {
"cursor": "01FSHN9AG07HNEZXNQM2KNBNF6"
}
}
},
{
"type": "user-registration_token",
"id": "01FSHN9AG09AVTNSQFMSR34AJC",
"attributes": {
"token": "token_revoked",
"valid": false,
"usage_limit": 10,
"times_used": 0,
"created_at": "2022-01-16T14:40:00Z",
"last_used_at": null,
"expires_at": null,
"revoked_at": "2022-01-16T14:40:00Z"
},
"links": {
"self": "/api/admin/v1/user-registration-tokens/01FSHN9AG09AVTNSQFMSR34AJC"
},
"meta": {
"page": {
"cursor": "01FSHN9AG09AVTNSQFMSR34AJC"
}
}
},
{
"type": "user-registration_token",
"id": "01FSHN9AG0MZAA6S4AF7CTV32E",
"attributes": {
"token": "token_unused",
"valid": true,
"usage_limit": 10,
"times_used": 0,
"created_at": "2022-01-16T14:40:00Z",
"last_used_at": null,
"expires_at": null,
"revoked_at": null
},
"links": {
"self": "/api/admin/v1/user-registration-tokens/01FSHN9AG0MZAA6S4AF7CTV32E"
},
"meta": {
"page": {
"cursor": "01FSHN9AG0MZAA6S4AF7CTV32E"
}
}
},
{
"type": "user-registration_token",
"id": "01FSHN9AG0S3ZJD8CXQ7F11KXN",
"attributes": {
"token": "token_used_revoked",
"valid": false,
"usage_limit": 10,
"times_used": 1,
"created_at": "2022-01-16T14:40:00Z",
"last_used_at": "2022-01-16T14:40:00Z",
"expires_at": null,
"revoked_at": "2022-01-16T14:40:00Z"
},
"links": {
"self": "/api/admin/v1/user-registration-tokens/01FSHN9AG0S3ZJD8CXQ7F11KXN"
},
"meta": {
"page": {
"cursor": "01FSHN9AG0S3ZJD8CXQ7F11KXN"
}
}
}
],
"links": {
"self": "/api/admin/v1/user-registration-tokens?count=false&page[first]=10",
"first": "/api/admin/v1/user-registration-tokens?count=false&page[first]=10",
"last": "/api/admin/v1/user-registration-tokens?count=false&page[last]=10"
}
}
"#);
// Test count=only
let request = Request::get("/api/admin/v1/user-registration-tokens?count=only")
.bearer(&admin_token)
.empty();
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let body: serde_json::Value = response.json();
insta::assert_json_snapshot!(body, @r#"
{
"meta": {
"count": 5
},
"links": {
"self": "/api/admin/v1/user-registration-tokens?count=only"
}
}
"#);
// Test count=false with filtering
let request =
Request::get("/api/admin/v1/user-registration-tokens?count=false&filter[valid]=true")
.bearer(&admin_token)
.empty();
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let body: serde_json::Value = response.json();
insta::assert_json_snapshot!(body, @r#"
{
"data": [
{
"type": "user-registration_token",
"id": "01FSHN9AG07HNEZXNQM2KNBNF6",
"attributes": {
"token": "token_used",
"valid": true,
"usage_limit": 10,
"times_used": 1,
"created_at": "2022-01-16T14:40:00Z",
"last_used_at": "2022-01-16T14:40:00Z",
"expires_at": null,
"revoked_at": null
},
"links": {
"self": "/api/admin/v1/user-registration-tokens/01FSHN9AG07HNEZXNQM2KNBNF6"
},
"meta": {
"page": {
"cursor": "01FSHN9AG07HNEZXNQM2KNBNF6"
}
}
},
{
"type": "user-registration_token",
"id": "01FSHN9AG0MZAA6S4AF7CTV32E",
"attributes": {
"token": "token_unused",
"valid": true,
"usage_limit": 10,
"times_used": 0,
"created_at": "2022-01-16T14:40:00Z",
"last_used_at": null,
"expires_at": null,
"revoked_at": null
},
"links": {
"self": "/api/admin/v1/user-registration-tokens/01FSHN9AG0MZAA6S4AF7CTV32E"
},
"meta": {
"page": {
"cursor": "01FSHN9AG0MZAA6S4AF7CTV32E"
}
}
}
],
"links": {
"self": "/api/admin/v1/user-registration-tokens?filter[valid]=true&count=false&page[first]=10",
"first": "/api/admin/v1/user-registration-tokens?filter[valid]=true&count=false&page[first]=10",
"last": "/api/admin/v1/user-registration-tokens?filter[valid]=true&count=false&page[last]=10"
}
}
"#);
// Test count=only with filtering
let request =
Request::get("/api/admin/v1/user-registration-tokens?count=only&filter[revoked]=true")
.bearer(&admin_token)
.empty();
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let body: serde_json::Value = response.json();
insta::assert_json_snapshot!(body, @r#"
{
"meta": {
"count": 2
},
"links": {
"self": "/api/admin/v1/user-registration-tokens?filter[revoked]=true&count=only"
}
}
"#);
}
} }

View File

@@ -0,0 +1,216 @@
// Copyright 2025 New Vector Ltd.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
use aide::{OperationIo, transform::TransformOperation};
use axum::{Json, response::IntoResponse};
use hyper::StatusCode;
use mas_axum_utils::record_error;
use ulid::Ulid;
use crate::{
admin::{
call_context::CallContext,
model::{Resource, UserSession},
params::UlidPathParam,
response::{ErrorResponse, SingleResponse},
},
impl_from_error_for_route,
};
#[derive(Debug, thiserror::Error, OperationIo)]
#[aide(output_with = "Json<ErrorResponse>")]
pub enum RouteError {
#[error(transparent)]
Internal(Box<dyn std::error::Error + Send + Sync + 'static>),
#[error("User session with ID {0} not found")]
NotFound(Ulid),
#[error("User session with ID {0} is already finished")]
AlreadyFinished(Ulid),
}
impl_from_error_for_route!(mas_storage::RepositoryError);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
let error = ErrorResponse::from_error(&self);
let sentry_event_id = record_error!(self, Self::Internal(_));
let status = match self {
Self::Internal(_) => StatusCode::INTERNAL_SERVER_ERROR,
Self::NotFound(_) => StatusCode::NOT_FOUND,
Self::AlreadyFinished(_) => StatusCode::BAD_REQUEST,
};
(status, sentry_event_id, Json(error)).into_response()
}
}
pub fn doc(operation: TransformOperation) -> TransformOperation {
operation
.id("finishUserSession")
.summary("Finish a user session")
.description(
"Calling this endpoint will finish the user session, preventing any further use.",
)
.tag("user-session")
.response_with::<200, Json<SingleResponse<UserSession>>, _>(|t| {
// Get the finished session sample
let [_, _, finished_session] = UserSession::samples();
let id = finished_session.id();
let response = SingleResponse::new(
finished_session,
format!("/api/admin/v1/user-sessions/{id}/finish"),
);
t.description("User session was finished").example(response)
})
.response_with::<400, RouteError, _>(|t| {
let response = ErrorResponse::from_error(&RouteError::AlreadyFinished(Ulid::nil()));
t.description("Session is already finished")
.example(response)
})
.response_with::<404, RouteError, _>(|t| {
let response = ErrorResponse::from_error(&RouteError::NotFound(Ulid::nil()));
t.description("User session was not found")
.example(response)
})
}
#[tracing::instrument(name = "handler.admin.v1.user_sessions.finish", skip_all)]
pub async fn handler(
CallContext {
mut repo, clock, ..
}: CallContext,
id: UlidPathParam,
) -> Result<Json<SingleResponse<UserSession>>, RouteError> {
let id = *id;
let session = repo
.browser_session()
.lookup(id)
.await?
.ok_or(RouteError::NotFound(id))?;
// Check if the session is already finished
if session.finished_at.is_some() {
return Err(RouteError::AlreadyFinished(id));
}
// Finish the session
let session = repo.browser_session().finish(&clock, session).await?;
repo.save().await?;
Ok(Json(SingleResponse::new(
UserSession::from(session),
format!("/api/admin/v1/user-sessions/{id}/finish"),
)))
}
#[cfg(test)]
mod tests {
use chrono::Duration;
use hyper::{Request, StatusCode};
use mas_data_model::Clock as _;
use sqlx::PgPool;
use crate::test_utils::{RequestBuilderExt, ResponseExt, TestState, setup};
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_finish_session(pool: PgPool) {
setup();
let mut state = TestState::from_pool(pool).await.unwrap();
let token = state.token_with_scope("urn:mas:admin").await;
let mut rng = state.rng();
// Provision a user and a user session
let mut repo = state.repository().await.unwrap();
let user = repo
.user()
.add(&mut rng, &state.clock, "alice".to_owned())
.await
.unwrap();
let session = repo
.browser_session()
.add(&mut rng, &state.clock, &user, None)
.await
.unwrap();
repo.save().await.unwrap();
let request = Request::post(format!("/api/admin/v1/user-sessions/{}/finish", session.id))
.bearer(&token)
.empty();
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let body: serde_json::Value = response.json();
// The finished_at timestamp should be the same as the current time
assert_eq!(
body["data"]["attributes"]["finished_at"],
serde_json::json!(state.clock.now())
);
}
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_finish_already_finished_session(pool: PgPool) {
setup();
let mut state = TestState::from_pool(pool).await.unwrap();
let token = state.token_with_scope("urn:mas:admin").await;
let mut rng = state.rng();
// Provision a user and a user session
let mut repo = state.repository().await.unwrap();
let user = repo
.user()
.add(&mut rng, &state.clock, "alice".to_owned())
.await
.unwrap();
let session = repo
.browser_session()
.add(&mut rng, &state.clock, &user, None)
.await
.unwrap();
// Finish the session first
let session = repo
.browser_session()
.finish(&state.clock, session)
.await
.unwrap();
repo.save().await.unwrap();
// Move the clock forward
state.clock.advance(Duration::try_minutes(1).unwrap());
let request = Request::post(format!("/api/admin/v1/user-sessions/{}/finish", session.id))
.bearer(&token)
.empty();
let response = state.request(request).await;
response.assert_status(StatusCode::BAD_REQUEST);
let body: serde_json::Value = response.json();
assert_eq!(
body["errors"][0]["title"],
format!("User session with ID {} is already finished", session.id)
);
}
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_finish_unknown_session(pool: PgPool) {
setup();
let mut state = TestState::from_pool(pool).await.unwrap();
let token = state.token_with_scope("urn:mas:admin").await;
let request =
Request::post("/api/admin/v1/user-sessions/01040G2081040G2081040G2081/finish")
.bearer(&token)
.empty();
let response = state.request(request).await;
response.assert_status(StatusCode::NOT_FOUND);
let body: serde_json::Value = response.json();
assert_eq!(
body["errors"][0]["title"],
"User session with ID 01040G2081040G2081040G2081 not found"
);
}
}

View File

@@ -4,11 +4,8 @@
// Please see LICENSE files in the repository root for full details. // Please see LICENSE files in the repository root for full details.
use aide::{OperationIo, transform::TransformOperation}; use aide::{OperationIo, transform::TransformOperation};
use axum::{ use axum::{Json, response::IntoResponse};
Json, use axum_extra::extract::{Query, QueryRejection};
extract::{Query, rejection::QueryRejection},
response::IntoResponse,
};
use axum_macros::FromRequestParts; use axum_macros::FromRequestParts;
use hyper::StatusCode; use hyper::StatusCode;
use mas_axum_utils::record_error; use mas_axum_utils::record_error;
@@ -21,7 +18,7 @@ use crate::{
admin::{ admin::{
call_context::CallContext, call_context::CallContext,
model::{Resource, UserSession}, model::{Resource, UserSession},
params::Pagination, params::{IncludeCount, Pagination},
response::{ErrorResponse, PaginatedResponse}, response::{ErrorResponse, PaginatedResponse},
}, },
impl_from_error_for_route, impl_from_error_for_route,
@@ -123,16 +120,22 @@ Use the `filter[status]` parameter to filter the sessions by their status and `p
let sessions = UserSession::samples(); let sessions = UserSession::samples();
let pagination = mas_storage::Pagination::first(sessions.len()); let pagination = mas_storage::Pagination::first(sessions.len());
let page = Page { let page = Page {
edges: sessions.into(), edges: sessions
.into_iter()
.map(|node| mas_storage::pagination::Edge {
cursor: node.id(),
node,
})
.collect(),
has_next_page: true, has_next_page: true,
has_previous_page: false, has_previous_page: false,
}; };
t.description("Paginated response of user sessions") t.description("Paginated response of user sessions")
.example(PaginatedResponse::new( .example(PaginatedResponse::for_page(
page, page,
pagination, pagination,
42, Some(42),
UserSession::PATH, UserSession::PATH,
)) ))
}) })
@@ -145,10 +148,11 @@ Use the `filter[status]` parameter to filter the sessions by their status and `p
#[tracing::instrument(name = "handler.admin.v1.user_sessions.list", skip_all)] #[tracing::instrument(name = "handler.admin.v1.user_sessions.list", skip_all)]
pub async fn handler( pub async fn handler(
CallContext { mut repo, .. }: CallContext, CallContext { mut repo, .. }: CallContext,
Pagination(pagination): Pagination, Pagination(pagination, include_count): Pagination,
params: FilterParams, params: FilterParams,
) -> Result<Json<PaginatedResponse<UserSession>>, RouteError> { ) -> Result<Json<PaginatedResponse<UserSession>>, RouteError> {
let base = format!("{path}{params}", path = UserSession::PATH); let base = format!("{path}{params}", path = UserSession::PATH);
let base = include_count.add_to_base(&base);
let filter = BrowserSessionFilter::default(); let filter = BrowserSessionFilter::default();
// Load the user from the filter // Load the user from the filter
@@ -175,15 +179,31 @@ pub async fn handler(
None => filter, None => filter,
}; };
let page = repo.browser_session().list(filter, pagination).await?; let response = match include_count {
let count = repo.browser_session().count(filter).await?; IncludeCount::True => {
let page = repo
.browser_session()
.list(filter, pagination)
.await?
.map(UserSession::from);
let count = repo.browser_session().count(filter).await?;
PaginatedResponse::for_page(page, pagination, Some(count), &base)
}
IncludeCount::False => {
let page = repo
.browser_session()
.list(filter, pagination)
.await?
.map(UserSession::from);
PaginatedResponse::for_page(page, pagination, None, &base)
}
IncludeCount::Only => {
let count = repo.browser_session().count(filter).await?;
PaginatedResponse::for_count_only(count, &base)
}
};
Ok(Json(PaginatedResponse::new( Ok(Json(response))
page.map(UserSession::from),
pagination,
count,
&base,
)))
} }
#[cfg(test)] #[cfg(test)]
@@ -241,7 +261,7 @@ mod tests {
let response = state.request(request).await; let response = state.request(request).await;
response.assert_status(StatusCode::OK); response.assert_status(StatusCode::OK);
let body: serde_json::Value = response.json(); let body: serde_json::Value = response.json();
assert_json_snapshot!(body, @r###" assert_json_snapshot!(body, @r#"
{ {
"meta": { "meta": {
"count": 2 "count": 2
@@ -260,6 +280,11 @@ mod tests {
}, },
"links": { "links": {
"self": "/api/admin/v1/user-sessions/01FSHNB5309NMZYX8MFYH578R9" "self": "/api/admin/v1/user-sessions/01FSHNB5309NMZYX8MFYH578R9"
},
"meta": {
"page": {
"cursor": "01FSHN9AG0MZAA6S4AF7CTV32E"
}
} }
}, },
{ {
@@ -275,6 +300,11 @@ mod tests {
}, },
"links": { "links": {
"self": "/api/admin/v1/user-sessions/01FSHNB530KEPHYQQXW9XPTX6Z" "self": "/api/admin/v1/user-sessions/01FSHNB530KEPHYQQXW9XPTX6Z"
},
"meta": {
"page": {
"cursor": "01FSHNB530AJ6AC5HQ9X6H4RP4"
}
} }
} }
], ],
@@ -284,7 +314,7 @@ mod tests {
"last": "/api/admin/v1/user-sessions?page[last]=10" "last": "/api/admin/v1/user-sessions?page[last]=10"
} }
} }
"###); "#);
// Filter by user // Filter by user
let request = Request::get(format!( let request = Request::get(format!(
@@ -296,7 +326,7 @@ mod tests {
let response = state.request(request).await; let response = state.request(request).await;
response.assert_status(StatusCode::OK); response.assert_status(StatusCode::OK);
let body: serde_json::Value = response.json(); let body: serde_json::Value = response.json();
assert_json_snapshot!(body, @r###" assert_json_snapshot!(body, @r#"
{ {
"meta": { "meta": {
"count": 1 "count": 1
@@ -315,6 +345,11 @@ mod tests {
}, },
"links": { "links": {
"self": "/api/admin/v1/user-sessions/01FSHNB5309NMZYX8MFYH578R9" "self": "/api/admin/v1/user-sessions/01FSHNB5309NMZYX8MFYH578R9"
},
"meta": {
"page": {
"cursor": "01FSHN9AG0MZAA6S4AF7CTV32E"
}
} }
} }
], ],
@@ -324,7 +359,7 @@ mod tests {
"last": "/api/admin/v1/user-sessions?filter[user]=01FSHN9AG0MZAA6S4AF7CTV32E&page[last]=10" "last": "/api/admin/v1/user-sessions?filter[user]=01FSHN9AG0MZAA6S4AF7CTV32E&page[last]=10"
} }
} }
"###); "#);
// Filter by status (active) // Filter by status (active)
let request = Request::get("/api/admin/v1/user-sessions?filter[status]=active") let request = Request::get("/api/admin/v1/user-sessions?filter[status]=active")
@@ -333,7 +368,7 @@ mod tests {
let response = state.request(request).await; let response = state.request(request).await;
response.assert_status(StatusCode::OK); response.assert_status(StatusCode::OK);
let body: serde_json::Value = response.json(); let body: serde_json::Value = response.json();
assert_json_snapshot!(body, @r###" assert_json_snapshot!(body, @r#"
{ {
"meta": { "meta": {
"count": 1 "count": 1
@@ -352,6 +387,11 @@ mod tests {
}, },
"links": { "links": {
"self": "/api/admin/v1/user-sessions/01FSHNB5309NMZYX8MFYH578R9" "self": "/api/admin/v1/user-sessions/01FSHNB5309NMZYX8MFYH578R9"
},
"meta": {
"page": {
"cursor": "01FSHN9AG0MZAA6S4AF7CTV32E"
}
} }
} }
], ],
@@ -361,7 +401,7 @@ mod tests {
"last": "/api/admin/v1/user-sessions?filter[status]=active&page[last]=10" "last": "/api/admin/v1/user-sessions?filter[status]=active&page[last]=10"
} }
} }
"###); "#);
// Filter by status (finished) // Filter by status (finished)
let request = Request::get("/api/admin/v1/user-sessions?filter[status]=finished") let request = Request::get("/api/admin/v1/user-sessions?filter[status]=finished")
@@ -370,7 +410,7 @@ mod tests {
let response = state.request(request).await; let response = state.request(request).await;
response.assert_status(StatusCode::OK); response.assert_status(StatusCode::OK);
let body: serde_json::Value = response.json(); let body: serde_json::Value = response.json();
assert_json_snapshot!(body, @r###" assert_json_snapshot!(body, @r#"
{ {
"meta": { "meta": {
"count": 1 "count": 1
@@ -389,6 +429,11 @@ mod tests {
}, },
"links": { "links": {
"self": "/api/admin/v1/user-sessions/01FSHNB530KEPHYQQXW9XPTX6Z" "self": "/api/admin/v1/user-sessions/01FSHNB530KEPHYQQXW9XPTX6Z"
},
"meta": {
"page": {
"cursor": "01FSHNB530AJ6AC5HQ9X6H4RP4"
}
} }
} }
], ],
@@ -398,6 +443,143 @@ mod tests {
"last": "/api/admin/v1/user-sessions?filter[status]=finished&page[last]=10" "last": "/api/admin/v1/user-sessions?filter[status]=finished&page[last]=10"
} }
} }
"#);
// Test count=false
let request = Request::get("/api/admin/v1/user-sessions?count=false")
.bearer(&token)
.empty();
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let body: serde_json::Value = response.json();
assert_json_snapshot!(body, @r#"
{
"data": [
{
"type": "user-session",
"id": "01FSHNB5309NMZYX8MFYH578R9",
"attributes": {
"created_at": "2022-01-16T14:41:00Z",
"finished_at": null,
"user_id": "01FSHN9AG0MZAA6S4AF7CTV32E",
"user_agent": null,
"last_active_at": null,
"last_active_ip": null
},
"links": {
"self": "/api/admin/v1/user-sessions/01FSHNB5309NMZYX8MFYH578R9"
},
"meta": {
"page": {
"cursor": "01FSHN9AG0MZAA6S4AF7CTV32E"
}
}
},
{
"type": "user-session",
"id": "01FSHNB530KEPHYQQXW9XPTX6Z",
"attributes": {
"created_at": "2022-01-16T14:41:00Z",
"finished_at": "2022-01-16T14:42:00Z",
"user_id": "01FSHNB530AJ6AC5HQ9X6H4RP4",
"user_agent": null,
"last_active_at": null,
"last_active_ip": null
},
"links": {
"self": "/api/admin/v1/user-sessions/01FSHNB530KEPHYQQXW9XPTX6Z"
},
"meta": {
"page": {
"cursor": "01FSHNB530AJ6AC5HQ9X6H4RP4"
}
}
}
],
"links": {
"self": "/api/admin/v1/user-sessions?count=false&page[first]=10",
"first": "/api/admin/v1/user-sessions?count=false&page[first]=10",
"last": "/api/admin/v1/user-sessions?count=false&page[last]=10"
}
}
"#);
// Test count=only
let request = Request::get("/api/admin/v1/user-sessions?count=only")
.bearer(&token)
.empty();
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let body: serde_json::Value = response.json();
assert_json_snapshot!(body, @r###"
{
"meta": {
"count": 2
},
"links": {
"self": "/api/admin/v1/user-sessions?count=only"
}
}
"###); "###);
// Test count=false with filtering
let request = Request::get(format!(
"/api/admin/v1/user-sessions?count=false&filter[user]={}",
alice.id
))
.bearer(&token)
.empty();
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let body: serde_json::Value = response.json();
assert_json_snapshot!(body, @r#"
{
"data": [
{
"type": "user-session",
"id": "01FSHNB5309NMZYX8MFYH578R9",
"attributes": {
"created_at": "2022-01-16T14:41:00Z",
"finished_at": null,
"user_id": "01FSHN9AG0MZAA6S4AF7CTV32E",
"user_agent": null,
"last_active_at": null,
"last_active_ip": null
},
"links": {
"self": "/api/admin/v1/user-sessions/01FSHNB5309NMZYX8MFYH578R9"
},
"meta": {
"page": {
"cursor": "01FSHN9AG0MZAA6S4AF7CTV32E"
}
}
}
],
"links": {
"self": "/api/admin/v1/user-sessions?filter[user]=01FSHN9AG0MZAA6S4AF7CTV32E&count=false&page[first]=10",
"first": "/api/admin/v1/user-sessions?filter[user]=01FSHN9AG0MZAA6S4AF7CTV32E&count=false&page[first]=10",
"last": "/api/admin/v1/user-sessions?filter[user]=01FSHN9AG0MZAA6S4AF7CTV32E&count=false&page[last]=10"
}
}
"#);
// Test count=only with filtering
let request = Request::get("/api/admin/v1/user-sessions?count=only&filter[status]=active")
.bearer(&token)
.empty();
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let body: serde_json::Value = response.json();
assert_json_snapshot!(body, @r#"
{
"meta": {
"count": 1
},
"links": {
"self": "/api/admin/v1/user-sessions?filter[status]=active&count=only"
}
}
"#);
} }
} }

View File

@@ -3,10 +3,12 @@
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial // SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details. // Please see LICENSE files in the repository root for full details.
mod finish;
mod get; mod get;
mod list; mod list;
pub use self::{ pub use self::{
finish::{doc as finish_doc, handler as finish},
get::{doc as get_doc, handler as get}, get::{doc as get_doc, handler as get},
list::{doc as list_doc, handler as list}, list::{doc as list_doc, handler as list},
}; };

View File

@@ -209,7 +209,8 @@ mod tests {
"created_at": "2022-01-16T14:40:00Z", "created_at": "2022-01-16T14:40:00Z",
"locked_at": null, "locked_at": null,
"deactivated_at": "2022-01-16T14:40:00Z", "deactivated_at": "2022-01-16T14:40:00Z",
"admin": false "admin": false,
"legacy_guest": false
}, },
"links": { "links": {
"self": "/api/admin/v1/users/01FSHN9AG0MZAA6S4AF7CTV32E" "self": "/api/admin/v1/users/01FSHN9AG0MZAA6S4AF7CTV32E"
@@ -289,7 +290,8 @@ mod tests {
"created_at": "2022-01-16T14:40:00Z", "created_at": "2022-01-16T14:40:00Z",
"locked_at": "2022-01-16T14:40:00Z", "locked_at": "2022-01-16T14:40:00Z",
"deactivated_at": "2022-01-16T14:41:00Z", "deactivated_at": "2022-01-16T14:41:00Z",
"admin": false "admin": false,
"legacy_guest": false
}, },
"links": { "links": {
"self": "/api/admin/v1/users/01FSHN9AG0MZAA6S4AF7CTV32E" "self": "/api/admin/v1/users/01FSHN9AG0MZAA6S4AF7CTV32E"

View File

@@ -5,11 +5,8 @@
// Please see LICENSE files in the repository root for full details. // Please see LICENSE files in the repository root for full details.
use aide::{OperationIo, transform::TransformOperation}; use aide::{OperationIo, transform::TransformOperation};
use axum::{ use axum::{Json, response::IntoResponse};
Json, use axum_extra::extract::{Query, QueryRejection};
extract::{Query, rejection::QueryRejection},
response::IntoResponse,
};
use axum_macros::FromRequestParts; use axum_macros::FromRequestParts;
use hyper::StatusCode; use hyper::StatusCode;
use mas_axum_utils::record_error; use mas_axum_utils::record_error;
@@ -21,7 +18,7 @@ use crate::{
admin::{ admin::{
call_context::CallContext, call_context::CallContext,
model::{Resource, User}, model::{Resource, User},
params::Pagination, params::{IncludeCount, Pagination},
response::{ErrorResponse, PaginatedResponse}, response::{ErrorResponse, PaginatedResponse},
}, },
impl_from_error_for_route, impl_from_error_for_route,
@@ -54,6 +51,17 @@ pub struct FilterParams {
#[serde(rename = "filter[admin]")] #[serde(rename = "filter[admin]")]
admin: Option<bool>, admin: Option<bool>,
/// Retrieve users with (or without) the `legacy_guest` flag set
#[serde(rename = "filter[legacy-guest]")]
legacy_guest: Option<bool>,
/// Retrieve users where the username matches contains the given string
///
/// Note that this doesn't change the ordering of the result, which are
/// still ordered by ID.
#[serde(rename = "filter[search]")]
search: Option<String>,
/// Retrieve the items with the given status /// Retrieve the items with the given status
/// ///
/// Defaults to retrieve all users, including locked ones. /// Defaults to retrieve all users, including locked ones.
@@ -75,6 +83,14 @@ impl std::fmt::Display for FilterParams {
write!(f, "{sep}filter[admin]={admin}")?; write!(f, "{sep}filter[admin]={admin}")?;
sep = '&'; sep = '&';
} }
if let Some(legacy_guest) = self.legacy_guest {
write!(f, "{sep}filter[legacy-guest]={legacy_guest}")?;
sep = '&';
}
if let Some(search) = &self.search {
write!(f, "{sep}filter[search]={search}")?;
sep = '&';
}
if let Some(status) = self.status { if let Some(status) = self.status {
write!(f, "{sep}filter[status]={status}")?; write!(f, "{sep}filter[status]={status}")?;
sep = '&'; sep = '&';
@@ -118,23 +134,35 @@ pub fn doc(operation: TransformOperation) -> TransformOperation {
let users = User::samples(); let users = User::samples();
let pagination = mas_storage::Pagination::first(users.len()); let pagination = mas_storage::Pagination::first(users.len());
let page = Page { let page = Page {
edges: users.into(), edges: users
.into_iter()
.map(|node| mas_storage::pagination::Edge {
cursor: node.id(),
node,
})
.collect(),
has_next_page: true, has_next_page: true,
has_previous_page: false, has_previous_page: false,
}; };
t.description("Paginated response of users") t.description("Paginated response of users")
.example(PaginatedResponse::new(page, pagination, 42, User::PATH)) .example(PaginatedResponse::for_page(
page,
pagination,
Some(42),
User::PATH,
))
}) })
} }
#[tracing::instrument(name = "handler.admin.v1.users.list", skip_all)] #[tracing::instrument(name = "handler.admin.v1.users.list", skip_all)]
pub async fn handler( pub async fn handler(
CallContext { mut repo, .. }: CallContext, CallContext { mut repo, .. }: CallContext,
Pagination(pagination): Pagination, Pagination(pagination, include_count): Pagination,
params: FilterParams, params: FilterParams,
) -> Result<Json<PaginatedResponse<User>>, RouteError> { ) -> Result<Json<PaginatedResponse<User>>, RouteError> {
let base = format!("{path}{params}", path = User::PATH); let base = format!("{path}{params}", path = User::PATH);
let base = include_count.add_to_base(&base);
let filter = UserFilter::default(); let filter = UserFilter::default();
let filter = match params.admin { let filter = match params.admin {
@@ -143,6 +171,17 @@ pub async fn handler(
None => filter, None => filter,
}; };
let filter = match params.legacy_guest {
Some(true) => filter.guest_only(),
Some(false) => filter.non_guest_only(),
None => filter,
};
let filter = match params.search.as_deref() {
Some(search) => filter.matching_search(search),
None => filter,
};
let filter = match params.status { let filter = match params.status {
Some(UserStatus::Active) => filter.active_only(), Some(UserStatus::Active) => filter.active_only(),
Some(UserStatus::Locked) => filter.locked_only(), Some(UserStatus::Locked) => filter.locked_only(),
@@ -150,13 +189,243 @@ pub async fn handler(
None => filter, None => filter,
}; };
let page = repo.user().list(filter, pagination).await?; let response = match include_count {
let count = repo.user().count(filter).await?; IncludeCount::True => {
let page = repo.user().list(filter, pagination).await?;
let count = repo.user().count(filter).await?;
PaginatedResponse::for_page(page.map(User::from), pagination, Some(count), &base)
}
IncludeCount::False => {
let page = repo.user().list(filter, pagination).await?;
PaginatedResponse::for_page(page.map(User::from), pagination, None, &base)
}
IncludeCount::Only => {
let count = repo.user().count(filter).await?;
PaginatedResponse::for_count_only(count, &base)
}
};
Ok(Json(PaginatedResponse::new( Ok(Json(response))
page.map(User::from), }
pagination,
count, #[cfg(test)]
&base, mod tests {
))) use hyper::{Request, StatusCode};
use sqlx::PgPool;
use crate::test_utils::{RequestBuilderExt, ResponseExt, TestState, setup};
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_list_users(pool: PgPool) {
setup();
let mut state = TestState::from_pool(pool).await.unwrap();
let token = state.token_with_scope("urn:mas:admin").await;
let mut rng = state.rng();
// Provision two users
let mut repo = state.repository().await.unwrap();
repo.user()
.add(&mut rng, &state.clock, "alice".to_owned())
.await
.unwrap();
repo.user()
.add(&mut rng, &state.clock, "bob".to_owned())
.await
.unwrap();
repo.save().await.unwrap();
// Test default behavior (count=true)
let request = Request::get("/api/admin/v1/users").bearer(&token).empty();
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let body: serde_json::Value = response.json();
insta::assert_json_snapshot!(body, @r#"
{
"meta": {
"count": 2
},
"data": [
{
"type": "user",
"id": "01FSHN9AG0AJ6AC5HQ9X6H4RP4",
"attributes": {
"username": "bob",
"created_at": "2022-01-16T14:40:00Z",
"locked_at": null,
"deactivated_at": null,
"admin": false,
"legacy_guest": false
},
"links": {
"self": "/api/admin/v1/users/01FSHN9AG0AJ6AC5HQ9X6H4RP4"
},
"meta": {
"page": {
"cursor": "01FSHN9AG0AJ6AC5HQ9X6H4RP4"
}
}
},
{
"type": "user",
"id": "01FSHN9AG0MZAA6S4AF7CTV32E",
"attributes": {
"username": "alice",
"created_at": "2022-01-16T14:40:00Z",
"locked_at": null,
"deactivated_at": null,
"admin": false,
"legacy_guest": false
},
"links": {
"self": "/api/admin/v1/users/01FSHN9AG0MZAA6S4AF7CTV32E"
},
"meta": {
"page": {
"cursor": "01FSHN9AG0MZAA6S4AF7CTV32E"
}
}
}
],
"links": {
"self": "/api/admin/v1/users?page[first]=10",
"first": "/api/admin/v1/users?page[first]=10",
"last": "/api/admin/v1/users?page[last]=10"
}
}
"#);
// Test count=false
let request = Request::get("/api/admin/v1/users?count=false")
.bearer(&token)
.empty();
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let body: serde_json::Value = response.json();
insta::assert_json_snapshot!(body, @r#"
{
"data": [
{
"type": "user",
"id": "01FSHN9AG0AJ6AC5HQ9X6H4RP4",
"attributes": {
"username": "bob",
"created_at": "2022-01-16T14:40:00Z",
"locked_at": null,
"deactivated_at": null,
"admin": false,
"legacy_guest": false
},
"links": {
"self": "/api/admin/v1/users/01FSHN9AG0AJ6AC5HQ9X6H4RP4"
},
"meta": {
"page": {
"cursor": "01FSHN9AG0AJ6AC5HQ9X6H4RP4"
}
}
},
{
"type": "user",
"id": "01FSHN9AG0MZAA6S4AF7CTV32E",
"attributes": {
"username": "alice",
"created_at": "2022-01-16T14:40:00Z",
"locked_at": null,
"deactivated_at": null,
"admin": false,
"legacy_guest": false
},
"links": {
"self": "/api/admin/v1/users/01FSHN9AG0MZAA6S4AF7CTV32E"
},
"meta": {
"page": {
"cursor": "01FSHN9AG0MZAA6S4AF7CTV32E"
}
}
}
],
"links": {
"self": "/api/admin/v1/users?count=false&page[first]=10",
"first": "/api/admin/v1/users?count=false&page[first]=10",
"last": "/api/admin/v1/users?count=false&page[last]=10"
}
}
"#);
// Test count=only
let request = Request::get("/api/admin/v1/users?count=only")
.bearer(&token)
.empty();
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let body: serde_json::Value = response.json();
insta::assert_json_snapshot!(body, @r###"
{
"meta": {
"count": 2
},
"links": {
"self": "/api/admin/v1/users?count=only"
}
}
"###);
// Test count=false with filtering
let request = Request::get("/api/admin/v1/users?count=false&filter[search]=alice")
.bearer(&token)
.empty();
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let body: serde_json::Value = response.json();
insta::assert_json_snapshot!(body, @r#"
{
"data": [
{
"type": "user",
"id": "01FSHN9AG0MZAA6S4AF7CTV32E",
"attributes": {
"username": "alice",
"created_at": "2022-01-16T14:40:00Z",
"locked_at": null,
"deactivated_at": null,
"admin": false,
"legacy_guest": false
},
"links": {
"self": "/api/admin/v1/users/01FSHN9AG0MZAA6S4AF7CTV32E"
},
"meta": {
"page": {
"cursor": "01FSHN9AG0MZAA6S4AF7CTV32E"
}
}
}
],
"links": {
"self": "/api/admin/v1/users?filter[search]=alice&count=false&page[first]=10",
"first": "/api/admin/v1/users?filter[search]=alice&count=false&page[first]=10",
"last": "/api/admin/v1/users?filter[search]=alice&count=false&page[last]=10"
}
}
"#);
// Test count=only with filtering
let request = Request::get("/api/admin/v1/users?count=only&filter[search]=alice")
.bearer(&token)
.empty();
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let body: serde_json::Value = response.json();
insta::assert_json_snapshot!(body, @r#"
{
"meta": {
"count": 1
},
"links": {
"self": "/api/admin/v1/users?filter[search]=alice&count=only"
}
}
"#);
}
} }

View File

@@ -55,16 +55,12 @@ impl IntoResponse for RouteError {
} }
} }
fn password_example() -> String {
"hunter2".to_owned()
}
/// # JSON payload for the `POST /api/admin/v1/users/:id/set-password` endpoint /// # JSON payload for the `POST /api/admin/v1/users/:id/set-password` endpoint
#[derive(Deserialize, JsonSchema)] #[derive(Deserialize, JsonSchema)]
#[schemars(rename = "SetUserPasswordRequest")] #[schemars(rename = "SetUserPasswordRequest")]
pub struct Request { pub struct Request {
/// The password to set for the user /// The password to set for the user
#[schemars(example = "password_example")] #[schemars(example = &"hunter2")]
password: String, password: String,
/// Skip the password complexity check /// Skip the password complexity check

View File

@@ -0,0 +1,62 @@
// Copyright 2025 New Vector Ltd.
//
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details.
use aide::transform::TransformOperation;
use axum::{Json, extract::State};
use mas_data_model::AppVersion;
use schemars::JsonSchema;
use serde::Serialize;
use crate::admin::call_context::CallContext;
#[derive(Serialize, JsonSchema)]
pub struct Version {
/// The semver version of the app
pub version: &'static str,
}
pub fn doc(operation: TransformOperation) -> TransformOperation {
operation
.id("version")
.tag("server")
.summary("Get the version currently running")
.response_with::<200, Json<Version>, _>(|t| t.example(Version { version: "v1.0.0" }))
}
#[tracing::instrument(name = "handler.admin.v1.version", skip_all)]
pub async fn handler(
_: CallContext,
State(AppVersion(version)): State<mas_data_model::AppVersion>,
) -> Json<Version> {
Json(Version { version })
}
#[cfg(test)]
mod tests {
use hyper::{Request, StatusCode};
use insta::assert_json_snapshot;
use sqlx::PgPool;
use crate::test_utils::{RequestBuilderExt, ResponseExt, TestState, setup};
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_add_user(pool: PgPool) {
setup();
let mut state = TestState::from_pool(pool).await.unwrap();
let token = state.token_with_scope("urn:mas:admin").await;
let request = Request::get("/api/admin/v1/version").bearer(&token).empty();
let response = state.request(request).await;
assert_eq!(response.status(), StatusCode::OK);
let body: serde_json::Value = response.json();
assert_json_snapshot!(body, @r#"
{
"version": "v0.0.0-test"
}
"#);
}
}

View File

@@ -59,6 +59,8 @@ impl_from_ref!(Arc<dyn mas_matrix::HomeserverConnection>);
impl_from_ref!(mas_keystore::Keystore); impl_from_ref!(mas_keystore::Keystore);
impl_from_ref!(mas_handlers::passwords::PasswordManager); impl_from_ref!(mas_handlers::passwords::PasswordManager);
impl_from_ref!(Arc<mas_policy::PolicyFactory>); impl_from_ref!(Arc<mas_policy::PolicyFactory>);
impl_from_ref!(mas_data_model::SiteConfig);
impl_from_ref!(mas_data_model::AppVersion);
fn main() -> Result<(), Box<dyn std::error::Error>> { fn main() -> Result<(), Box<dyn std::error::Error>> {
let (mut api, _) = mas_handlers::admin_api_router::<DummyState>(); let (mut api, _) = mas_handlers::admin_api_router::<DummyState>();

View File

@@ -8,9 +8,10 @@ use std::collections::HashMap;
use anyhow::Context; use anyhow::Context;
use axum::{ use axum::{
extract::{Form, Path, Query, State}, extract::{Form, Path, State},
response::{Html, IntoResponse, Redirect, Response}, response::{Html, IntoResponse, Redirect, Response},
}; };
use axum_extra::extract::Query;
use chrono::Duration; use chrono::Duration;
use mas_axum_utils::{ use mas_axum_utils::{
InternalError, InternalError,

View File

@@ -4,10 +4,8 @@
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial // SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details. // Please see LICENSE files in the repository root for full details.
use axum::{ use axum::{extract::State, response::IntoResponse};
extract::{Query, State}, use axum_extra::extract::Query;
response::IntoResponse,
};
use hyper::StatusCode; use hyper::StatusCode;
use mas_axum_utils::{GenericError, InternalError}; use mas_axum_utils::{GenericError, InternalError};
use mas_data_model::{BoxClock, BoxRng}; use mas_data_model::{BoxClock, BoxRng};

View File

@@ -172,7 +172,7 @@ impl BrowserSession {
connection connection
.edges .edges
.extend(page.edges.into_iter().map(|s| match s { .extend(page.edges.into_iter().map(|edge| match edge.node {
mas_storage::app_session::AppSession::Compat(session) => Edge::new( mas_storage::app_session::AppSession::Compat(session) => Edge::new(
OpaqueCursor(NodeCursor(NodeType::CompatSession, session.id)), OpaqueCursor(NodeCursor(NodeType::CompatSession, session.id)),
AppSession::CompatSession(Box::new(CompatSession::new(*session))), AppSession::CompatSession(Box::new(CompatSession::new(*session))),

View File

@@ -125,10 +125,10 @@ impl User {
page.has_next_page, page.has_next_page,
PreloadedTotalCount(count), PreloadedTotalCount(count),
); );
connection.edges.extend(page.edges.into_iter().map(|u| { connection.edges.extend(page.edges.into_iter().map(|edge| {
Edge::new( Edge::new(
OpaqueCursor(NodeCursor(NodeType::CompatSsoLogin, u.id)), OpaqueCursor(NodeCursor(NodeType::CompatSsoLogin, edge.cursor)),
CompatSsoLogin(u), CompatSsoLogin(edge.node),
) )
})); }));
@@ -219,14 +219,13 @@ impl User {
page.has_next_page, page.has_next_page,
PreloadedTotalCount(count), PreloadedTotalCount(count),
); );
connection connection.edges.extend(page.edges.into_iter().map(|edge| {
.edges let (session, sso_login) = edge.node;
.extend(page.edges.into_iter().map(|(session, sso_login)| { Edge::new(
Edge::new( OpaqueCursor(NodeCursor(NodeType::CompatSession, session.id)),
OpaqueCursor(NodeCursor(NodeType::CompatSession, session.id)), CompatSession::new(session).with_loaded_sso_login(sso_login),
CompatSession::new(session).with_loaded_sso_login(sso_login), )
) }));
}));
Ok::<_, async_graphql::Error>(connection) Ok::<_, async_graphql::Error>(connection)
}, },
@@ -305,10 +304,10 @@ impl User {
page.has_next_page, page.has_next_page,
PreloadedTotalCount(count), PreloadedTotalCount(count),
); );
connection.edges.extend(page.edges.into_iter().map(|u| { connection.edges.extend(page.edges.into_iter().map(|edge| {
Edge::new( Edge::new(
OpaqueCursor(NodeCursor(NodeType::BrowserSession, u.id)), OpaqueCursor(NodeCursor(NodeType::BrowserSession, edge.cursor)),
BrowserSession(u), BrowserSession(edge.node),
) )
})); }));
@@ -373,10 +372,10 @@ impl User {
page.has_next_page, page.has_next_page,
PreloadedTotalCount(count), PreloadedTotalCount(count),
); );
connection.edges.extend(page.edges.into_iter().map(|u| { connection.edges.extend(page.edges.into_iter().map(|edge| {
Edge::new( Edge::new(
OpaqueCursor(NodeCursor(NodeType::UserEmail, u.id)), OpaqueCursor(NodeCursor(NodeType::UserEmail, edge.cursor)),
UserEmail(u), UserEmail(edge.node),
) )
})); }));
@@ -480,10 +479,10 @@ impl User {
PreloadedTotalCount(count), PreloadedTotalCount(count),
); );
connection.edges.extend(page.edges.into_iter().map(|s| { connection.edges.extend(page.edges.into_iter().map(|edge| {
Edge::new( Edge::new(
OpaqueCursor(NodeCursor(NodeType::OAuth2Session, s.id)), OpaqueCursor(NodeCursor(NodeType::OAuth2Session, edge.cursor)),
OAuth2Session(s), OAuth2Session(edge.node),
) )
})); }));
@@ -547,10 +546,10 @@ impl User {
page.has_next_page, page.has_next_page,
PreloadedTotalCount(count), PreloadedTotalCount(count),
); );
connection.edges.extend(page.edges.into_iter().map(|s| { connection.edges.extend(page.edges.into_iter().map(|edge| {
Edge::new( Edge::new(
OpaqueCursor(NodeCursor(NodeType::UpstreamOAuth2Link, s.id)), OpaqueCursor(NodeCursor(NodeType::UpstreamOAuth2Link, edge.cursor)),
UpstreamOAuth2Link::new(s), UpstreamOAuth2Link::new(edge.node),
) )
})); }));
@@ -689,13 +688,13 @@ impl User {
connection connection
.edges .edges
.extend(page.edges.into_iter().map(|s| match s { .extend(page.edges.into_iter().map(|edge| match edge.node {
mas_storage::app_session::AppSession::Compat(session) => Edge::new( mas_storage::app_session::AppSession::Compat(session) => Edge::new(
OpaqueCursor(NodeCursor(NodeType::CompatSession, session.id)), OpaqueCursor(NodeCursor(NodeType::CompatSession, edge.cursor)),
AppSession::CompatSession(Box::new(CompatSession::new(*session))), AppSession::CompatSession(Box::new(CompatSession::new(*session))),
), ),
mas_storage::app_session::AppSession::OAuth2(session) => Edge::new( mas_storage::app_session::AppSession::OAuth2(session) => Edge::new(
OpaqueCursor(NodeCursor(NodeType::OAuth2Session, session.id)), OpaqueCursor(NodeCursor(NodeType::OAuth2Session, edge.cursor)),
AppSession::OAuth2Session(Box::new(OAuth2Session(*session))), AppSession::OAuth2Session(Box::new(OAuth2Session(*session))),
), ),
})); }));

View File

@@ -84,7 +84,7 @@ async fn verify_password_if_needed(
password, password,
user_password.hashed_password, user_password.hashed_password,
) )
.await; .await?;
Ok(res.is_ok()) Ok(res.is_success())
} }

View File

@@ -737,13 +737,14 @@ impl UserMutations {
)); ));
}; };
if let Err(_err) = password_manager if !password_manager
.verify( .verify(
active_password.version, active_password.version,
Zeroizing::new(current_password_attempt), Zeroizing::new(current_password_attempt),
active_password.hashed_password, active_password.hashed_password,
) )
.await .await?
.is_success()
{ {
return Ok(SetPasswordPayload { return Ok(SetPasswordPayload {
status: SetPasswordStatus::WrongPassword, status: SetPasswordStatus::WrongPassword,

View File

@@ -68,7 +68,8 @@ impl SessionQuery {
); );
} }
if let Some((compat_session, sso_login)) = compat_sessions.edges.into_iter().next() { if let Some(edge) = compat_sessions.edges.into_iter().next() {
let (compat_session, sso_login) = edge.node;
repo.cancel().await?; repo.cancel().await?;
return Ok(Some(Session::CompatSession(Box::new( return Ok(Some(Session::CompatSession(Box::new(
@@ -92,10 +93,10 @@ impl SessionQuery {
); );
} }
if let Some(session) = sessions.edges.into_iter().next() { if let Some(edge) = sessions.edges.into_iter().next() {
repo.cancel().await?; repo.cancel().await?;
return Ok(Some(Session::OAuth2Session(Box::new(OAuth2Session( return Ok(Some(Session::OAuth2Session(Box::new(OAuth2Session(
session, edge.node,
))))); )))));
} }
repo.cancel().await?; repo.cancel().await?;

View File

@@ -130,10 +130,10 @@ impl UpstreamOAuthQuery {
page.has_next_page, page.has_next_page,
PreloadedTotalCount(count), PreloadedTotalCount(count),
); );
connection.edges.extend(page.edges.into_iter().map(|p| { connection.edges.extend(page.edges.into_iter().map(|edge| {
Edge::new( Edge::new(
OpaqueCursor(NodeCursor(NodeType::UpstreamOAuth2Provider, p.id)), OpaqueCursor(NodeCursor(NodeType::UpstreamOAuth2Provider, edge.cursor)),
UpstreamOAuth2Provider::new(p), UpstreamOAuth2Provider::new(edge.node),
) )
})); }));

View File

@@ -143,11 +143,12 @@ impl UserQuery {
page.has_next_page, page.has_next_page,
PreloadedTotalCount(count), PreloadedTotalCount(count),
); );
connection.edges.extend( connection.edges.extend(page.edges.into_iter().map(|edge| {
page.edges.into_iter().map(|p| { Edge::new(
Edge::new(OpaqueCursor(NodeCursor(NodeType::User, p.id)), User(p)) OpaqueCursor(NodeCursor(NodeType::User, edge.cursor)),
}), User(edge.node),
); )
}));
Ok::<_, async_graphql::Error>(connection) Ok::<_, async_graphql::Error>(connection)
}, },

View File

@@ -6,6 +6,7 @@
use axum::http::Request; use axum::http::Request;
use hyper::StatusCode; use hyper::StatusCode;
use mas_axum_utils::SessionInfoExt;
use mas_data_model::{AccessToken, Client, TokenType, User}; use mas_data_model::{AccessToken, Client, TokenType, User};
use mas_matrix::{HomeserverConnection, ProvisionRequest}; use mas_matrix::{HomeserverConnection, ProvisionRequest};
use mas_router::SimpleRoute; use mas_router::SimpleRoute;
@@ -19,11 +20,9 @@ use oauth2_types::{
scope::{OPENID, Scope, ScopeToken}, scope::{OPENID, Scope, ScopeToken},
}; };
use sqlx::PgPool; use sqlx::PgPool;
use zeroize::Zeroizing;
use crate::{ use crate::test_utils::{self, CookieHelper, RequestBuilderExt, ResponseExt, TestState, setup};
test_utils,
test_utils::{RequestBuilderExt, ResponseExt, TestState, setup},
};
async fn create_test_client(state: &TestState) -> Client { async fn create_test_client(state: &TestState) -> Client {
let mut repo = state.repository().await.unwrap(); let mut repo = state.repository().await.unwrap();
@@ -781,3 +780,301 @@ async fn test_add_user(pool: PgPool) {
}) })
); );
} }
/// Test the setPassword mutation where the current password provided is
/// wrong.
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_set_password_rejected_wrong_password(pool: PgPool) {
setup();
let state = TestState::from_pool(pool).await.unwrap();
let mut rng = state.rng();
let mut repo = state.repository().await.unwrap();
let user = repo
.user()
.add(&mut rng, &state.clock, "alice".to_owned())
.await
.unwrap();
let password = Zeroizing::new("current.password.123".to_owned());
let (version, hashed_password) = state
.password_manager
.hash(&mut rng, password)
.await
.unwrap();
repo.user_password()
.add(
&mut rng,
&state.clock,
&user,
version,
hashed_password,
None,
)
.await
.unwrap();
let browser_session = repo
.browser_session()
.add(&mut rng, &state.clock, &user, None)
.await
.unwrap();
repo.save().await.unwrap();
let cookie_jar = state.cookie_jar();
let cookie_jar = cookie_jar.set_session(&browser_session);
let user_id = user.id;
let request = Request::post("/graphql").json(serde_json::json!({
"query": format!(r#"
mutation {{
setPassword(input: {{
userId: "user:{user_id}",
currentPassword: "wrong.password.123",
newPassword: "new.password.123"
}}) {{
status
}}
}}
"#),
}));
let cookies = CookieHelper::new();
cookies.import(cookie_jar);
let request = cookies.with_cookies(request);
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let response: GraphQLResponse = response.json();
assert!(response.errors.is_empty(), "{:?}", response.errors);
assert_eq!(
response.data["setPassword"]["status"].as_str(),
Some("WRONG_PASSWORD"),
"{:?}",
response.data
);
}
/// Test the startEmailAuthentication mutation where the current password
/// provided is invalid.
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_start_email_authentication_rejected_wrong_password(pool: PgPool) {
setup();
let state = TestState::from_pool(pool).await.unwrap();
let mut rng = state.rng();
let mut repo = state.repository().await.unwrap();
let user = repo
.user()
.add(&mut rng, &state.clock, "alice".to_owned())
.await
.unwrap();
let password = Zeroizing::new("current.password.123".to_owned());
let (version, hashed_password) = state
.password_manager
.hash(&mut rng, password)
.await
.unwrap();
repo.user_password()
.add(
&mut rng,
&state.clock,
&user,
version,
hashed_password,
None,
)
.await
.unwrap();
let browser_session = repo
.browser_session()
.add(&mut rng, &state.clock, &user, None)
.await
.unwrap();
repo.save().await.unwrap();
let cookie_jar = state.cookie_jar();
let cookie_jar = cookie_jar.set_session(&browser_session);
let request = Request::post("/graphql").json(serde_json::json!({
"query": r#"
mutation {
startEmailAuthentication(input: {
email: "alice@example.org",
password: "wrong.password.123"
}) {
status
}
}
"#,
}));
let cookies = CookieHelper::new();
cookies.import(cookie_jar);
let request = cookies.with_cookies(request);
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let response: GraphQLResponse = response.json();
assert!(response.errors.is_empty(), "{:?}", response.errors);
assert_eq!(
response.data["startEmailAuthentication"]["status"].as_str(),
Some("INCORRECT_PASSWORD"),
"{:?}",
response.data
);
}
/// Test the removeEmail mutation where the current password
/// provided is invalid.
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_remove_email_rejected_wrong_password(pool: PgPool) {
setup();
let state = TestState::from_pool(pool).await.unwrap();
let mut rng = state.rng();
let mut repo = state.repository().await.unwrap();
let user = repo
.user()
.add(&mut rng, &state.clock, "alice".to_owned())
.await
.unwrap();
let password = Zeroizing::new("current.password.123".to_owned());
let (version, hashed_password) = state
.password_manager
.hash(&mut rng, password)
.await
.unwrap();
repo.user_password()
.add(
&mut rng,
&state.clock,
&user,
version,
hashed_password,
None,
)
.await
.unwrap();
let user_email_id = repo
.user_email()
.add(
&mut rng,
&state.clock,
&user,
"alice@example.org".to_owned(),
)
.await
.unwrap()
.id;
let browser_session = repo
.browser_session()
.add(&mut rng, &state.clock, &user, None)
.await
.unwrap();
repo.save().await.unwrap();
let cookie_jar = state.cookie_jar();
let cookie_jar = cookie_jar.set_session(&browser_session);
let request = Request::post("/graphql").json(serde_json::json!({
"query": format!(r#"
mutation {{
removeEmail(input: {{
userEmailId: "user_email:{user_email_id}",
password: "wrong.password.123"
}}) {{
status
}}
}}
"#),
}));
let cookies = CookieHelper::new();
cookies.import(cookie_jar);
let request = cookies.with_cookies(request);
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let response: GraphQLResponse = response.json();
assert!(response.errors.is_empty(), "{:?}", response.errors);
assert_eq!(
response.data["removeEmail"]["status"].as_str(),
Some("INCORRECT_PASSWORD"),
"{:?}",
response.data
);
}
/// Test the deactivateUser mutation where the current password
/// provided is invalid.
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_deactivate_user_rejected_wrong_password(pool: PgPool) {
setup();
let state = TestState::from_pool(pool).await.unwrap();
let mut rng = state.rng();
let mut repo = state.repository().await.unwrap();
let user = repo
.user()
.add(&mut rng, &state.clock, "alice".to_owned())
.await
.unwrap();
let password = Zeroizing::new("current.password.123".to_owned());
let (version, hashed_password) = state
.password_manager
.hash(&mut rng, password)
.await
.unwrap();
repo.user_password()
.add(
&mut rng,
&state.clock,
&user,
version,
hashed_password,
None,
)
.await
.unwrap();
let browser_session = repo
.browser_session()
.add(&mut rng, &state.clock, &user, None)
.await
.unwrap();
repo.save().await.unwrap();
let cookie_jar = state.cookie_jar();
let cookie_jar = cookie_jar.set_session(&browser_session);
let request = Request::post("/graphql").json(serde_json::json!({
"query": r#"
mutation {
deactivateUser(input: {
hsErase: true,
password: "wrong.password.123"
}) {
status
}
}
"#,
}));
let cookies = CookieHelper::new();
cookies.import(cookie_jar);
let request = cookies.with_cookies(request);
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let response: GraphQLResponse = response.json();
assert!(response.errors.is_empty(), "{:?}", response.errors);
assert_eq!(
response.data["deactivateUser"]["status"].as_str(),
Some("INCORRECT_PASSWORD"),
"{:?}",
response.data
);
}

View File

@@ -5,9 +5,10 @@
// Please see LICENSE files in the repository root for full details. // Please see LICENSE files in the repository root for full details.
use axum::{ use axum::{
extract::{Query, State}, extract::State,
response::{Html, IntoResponse}, response::{Html, IntoResponse},
}; };
use axum_extra::extract::Query;
use mas_axum_utils::{InternalError, cookies::CookieJar}; use mas_axum_utils::{InternalError, cookies::CookieJar};
use mas_data_model::BoxClock; use mas_data_model::BoxClock;
use mas_router::UrlBuilder; use mas_router::UrlBuilder;

View File

@@ -15,7 +15,9 @@ use mas_axum_utils::{
client_authorization::{ClientAuthorization, CredentialsVerificationError}, client_authorization::{ClientAuthorization, CredentialsVerificationError},
record_error, record_error,
}; };
use mas_data_model::{BoxClock, Clock, Device, TokenFormatError, TokenType}; use mas_data_model::{
BoxClock, Clock, Device, TokenFormatError, TokenType, personal::session::PersonalSessionOwner,
};
use mas_iana::oauth::{OAuthClientAuthenticationMethod, OAuthTokenTypeHint}; use mas_iana::oauth::{OAuthClientAuthenticationMethod, OAuthTokenTypeHint};
use mas_keystore::Encrypter; use mas_keystore::Encrypter;
use mas_matrix::HomeserverConnection; use mas_matrix::HomeserverConnection;
@@ -93,6 +95,14 @@ pub enum RouteError {
#[error("unknown compat session {0}")] #[error("unknown compat session {0}")]
CantLoadCompatSession(Ulid), CantLoadCompatSession(Ulid),
/// The personal access token session is not valid.
#[error("invalid personal access token session {0}")]
InvalidPersonalSession(Ulid),
/// The personal access token session could not be found in the database.
#[error("unknown personal access token session {0}")]
CantLoadPersonalSession(Ulid),
/// The Device ID in the compat session can't be encoded as a scope /// The Device ID in the compat session can't be encoded as a scope
#[error("device ID contains characters that are not allowed in a scope")] #[error("device ID contains characters that are not allowed in a scope")]
CantEncodeDeviceID(#[from] mas_data_model::ToScopeTokenError), CantEncodeDeviceID(#[from] mas_data_model::ToScopeTokenError),
@@ -103,6 +113,9 @@ pub enum RouteError {
#[error("unknown user {0}")] #[error("unknown user {0}")]
CantLoadUser(Ulid), CantLoadUser(Ulid),
#[error("unknown OAuth2 client {0}")]
CantLoadOAuth2Client(Ulid),
#[error("bad request")] #[error("bad request")]
BadRequest, BadRequest,
@@ -131,7 +144,9 @@ impl IntoResponse for RouteError {
e @ (Self::Internal(_) e @ (Self::Internal(_)
| Self::CantLoadCompatSession(_) | Self::CantLoadCompatSession(_)
| Self::CantLoadOAuthSession(_) | Self::CantLoadOAuthSession(_)
| Self::CantLoadPersonalSession(_)
| Self::CantLoadUser(_) | Self::CantLoadUser(_)
| Self::CantLoadOAuth2Client(_)
| Self::FailedToVerifyToken(_)) => ( | Self::FailedToVerifyToken(_)) => (
StatusCode::INTERNAL_SERVER_ERROR, StatusCode::INTERNAL_SERVER_ERROR,
Json( Json(
@@ -167,6 +182,7 @@ impl IntoResponse for RouteError {
| Self::InvalidUser(_) | Self::InvalidUser(_)
| Self::InvalidCompatSession(_) | Self::InvalidCompatSession(_)
| Self::InvalidOAuthSession(_) | Self::InvalidOAuthSession(_)
| Self::InvalidPersonalSession(_)
| Self::InvalidTokenFormat(_) | Self::InvalidTokenFormat(_)
| Self::CantEncodeDeviceID(_) => { | Self::CantEncodeDeviceID(_) => {
INTROSPECTION_COUNTER.add(1, &[KeyValue::new(ACTIVE.clone(), false)]); INTROSPECTION_COUNTER.add(1, &[KeyValue::new(ACTIVE.clone(), false)]);
@@ -625,6 +641,97 @@ pub(crate) async fn post(
device_id: session.device.map(Device::into), device_id: session.device.map(Device::into),
} }
} }
TokenType::PersonalAccessToken => {
let access_token = repo
.personal_access_token()
.find_by_token(token)
.await?
.ok_or(RouteError::UnknownToken(TokenType::AccessToken))?;
if !access_token.is_valid(clock.now()) {
return Err(RouteError::InvalidToken(TokenType::AccessToken));
}
let session = repo
.personal_session()
.lookup(access_token.session_id)
.await?
.ok_or(RouteError::CantLoadPersonalSession(access_token.session_id))?;
if !session.is_valid() {
return Err(RouteError::InvalidPersonalSession(session.id));
}
let actor_user = repo
.user()
.lookup(session.actor_user_id)
.await?
.ok_or(RouteError::CantLoadUser(session.actor_user_id))?;
if !actor_user.is_valid() {
return Err(RouteError::InvalidUser(actor_user.id));
}
let client_id = match session.owner {
PersonalSessionOwner::User(owner_user_id) => {
let owner_user = repo
.user()
.lookup(owner_user_id)
.await?
.ok_or(RouteError::CantLoadUser(owner_user_id))?;
if !owner_user.is_valid() {
return Err(RouteError::InvalidUser(owner_user.id));
}
None
}
PersonalSessionOwner::OAuth2Client(owner_client_id) => {
let owner_client = repo
.oauth2_client()
.lookup(owner_client_id)
.await?
.ok_or(RouteError::CantLoadOAuth2Client(owner_client_id))?;
// OAuth2 clients are always valid if they're in the database
Some(owner_client.client_id.clone())
}
};
activity_tracker
.record_personal_session(&clock, &session, ip)
.await;
INTROSPECTION_COUNTER.add(
1,
&[
KeyValue::new(KIND, "personal_access_token"),
KeyValue::new(ACTIVE, true),
],
);
let scope = normalize_scope(session.scope);
IntrospectionResponse {
active: true,
scope: Some(scope),
client_id,
username: Some(actor_user.username),
token_type: Some(OAuthTokenTypeHint::AccessToken),
exp: access_token.expires_at,
expires_in: access_token
.expires_at
.map(|expires_at| expires_at.signed_duration_since(clock.now())),
iat: Some(access_token.created_at),
nbf: Some(access_token.created_at),
sub: Some(actor_user.sub),
aud: None,
iss: None,
jti: None,
device_id: None,
}
}
}; };
repo.save().await?; repo.save().await?;
@@ -636,7 +743,9 @@ pub(crate) async fn post(
mod tests { mod tests {
use chrono::Duration; use chrono::Duration;
use hyper::{Request, StatusCode}; use hyper::{Request, StatusCode};
use mas_data_model::{AccessToken, Clock, RefreshToken}; use mas_data_model::{
AccessToken, Clock, RefreshToken, TokenType, personal::session::PersonalSessionOwner,
};
use mas_iana::oauth::OAuthTokenTypeHint; use mas_iana::oauth::OAuthTokenTypeHint;
use mas_matrix::{HomeserverConnection, MockHomeserverConnection, ProvisionRequest}; use mas_matrix::{HomeserverConnection, MockHomeserverConnection, ProvisionRequest};
use mas_router::{OAuth2Introspection, OAuth2RegistrationEndpoint, SimpleRoute}; use mas_router::{OAuth2Introspection, OAuth2RegistrationEndpoint, SimpleRoute};
@@ -1069,4 +1178,125 @@ mod tests {
let response: ClientError = response.json(); let response: ClientError = response.json();
assert_eq!(response.error, ClientErrorCode::AccessDenied); assert_eq!(response.error, ClientErrorCode::AccessDenied);
} }
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_introspect_personal_access_tokens(pool: PgPool) {
setup();
let state = TestState::from_pool(pool).await.unwrap();
// Provision a client which will be used to do introspection requests
let request = Request::post(OAuth2RegistrationEndpoint::PATH).json(json!({
"client_uri": "https://introspecting.com/",
"grant_types": [],
"token_endpoint_auth_method": "client_secret_basic",
}));
let response = state.request(request).await;
response.assert_status(StatusCode::CREATED);
let client: ClientRegistrationResponse = response.json();
let introspecting_client_id = client.client_id;
let introspecting_client_secret = client.client_secret.unwrap();
let mut repo = state.repository().await.unwrap();
// Provision an owner user (who provisions the personal session)
let owner_user = repo
.user()
.add(&mut state.rng(), &state.clock, "admin".to_owned())
.await
.unwrap();
// Provision an actor user (which the token represents)
let actor_user = repo
.user()
.add(&mut state.rng(), &state.clock, "bruce".to_owned())
.await
.unwrap();
// admin creates a personal session to control bruce's account
let personal_session = repo
.personal_session()
.add(
&mut state.rng(),
&state.clock,
PersonalSessionOwner::User(owner_user.id),
&actor_user,
"Test Personal Access Token".to_owned(),
Scope::from_iter([OPENID]),
)
.await
.unwrap();
// Generate a personal access token with proper token format
let token_string = TokenType::PersonalAccessToken.generate(&mut state.rng());
let _personal_access_token = repo
.personal_access_token()
.add(
&mut state.rng(),
&state.clock,
&personal_session,
&token_string,
Some(Duration::try_hours(1).unwrap()),
)
.await
.unwrap();
repo.save().await.unwrap();
// Now that we have a personal access token, we can introspect it
let request = Request::post(OAuth2Introspection::PATH)
.basic_auth(&introspecting_client_id, &introspecting_client_secret)
.form(json!({ "token": token_string }));
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let response: IntrospectionResponse = response.json();
assert!(response.active);
// Actor user
assert_eq!(response.username, Some("bruce".to_owned()));
// Not owned by a client
assert_eq!(response.client_id, None);
assert_eq!(response.token_type, Some(OAuthTokenTypeHint::AccessToken));
assert_eq!(response.scope, Some(Scope::from_iter([OPENID])));
// Do the same request, but with a token_type_hint
let last_active = state.clock.now();
let request = Request::post(OAuth2Introspection::PATH)
.basic_auth(&introspecting_client_id, &introspecting_client_secret)
.form(json!({"token": token_string, "token_type_hint": "access_token"}));
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let response: IntrospectionResponse = response.json();
assert!(response.active);
// Do the same request, but with the wrong token_type_hint
let request = Request::post(OAuth2Introspection::PATH)
.basic_auth(&introspecting_client_id, &introspecting_client_secret)
.form(json!({"token": token_string, "token_type_hint": "refresh_token"}));
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let response: IntrospectionResponse = response.json();
assert!(!response.active); // It shouldn't be active with wrong hint
// Advance the clock to invalidate the access token
state.clock.advance(Duration::try_hours(2).unwrap());
let request = Request::post(OAuth2Introspection::PATH)
.basic_auth(&introspecting_client_id, &introspecting_client_secret)
.form(json!({ "token": token_string }));
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let response: IntrospectionResponse = response.json();
assert!(!response.active); // It shouldn't be active anymore
state.activity_tracker.flush().await;
let mut repo = state.repository().await.unwrap();
let session = repo
.personal_session()
.lookup(personal_session.id)
.await
.unwrap()
.unwrap();
assert_eq!(session.last_active_at, Some(last_active));
repo.save().await.unwrap();
}
} }

View File

@@ -4,12 +4,8 @@
// SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial // SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
// Please see LICENSE files in the repository root for full details. // Please see LICENSE files in the repository root for full details.
use axum::{ use axum::{Json, extract::State, response::IntoResponse};
Json, use axum_extra::{extract::Query, typed_header::TypedHeader};
extract::{Query, State},
response::IntoResponse,
};
use axum_extra::typed_header::TypedHeader;
use headers::ContentType; use headers::ContentType;
use mas_router::UrlBuilder; use mas_router::UrlBuilder;
use oauth2_types::webfinger::WebFingerResponse; use oauth2_types::webfinger::WebFingerResponse;

View File

@@ -49,6 +49,11 @@ impl<T> PasswordVerificationResult<T> {
Self::Failure => PasswordVerificationResult::Failure, Self::Failure => PasswordVerificationResult::Failure,
} }
} }
#[must_use]
pub fn is_success(&self) -> bool {
matches!(self, Self::Success(_))
}
} }
impl From<bool> for PasswordVerificationResult<()> { impl From<bool> for PasswordVerificationResult<()> {

Some files were not shown because too many files have changed in this diff Show More