Merge branch 'main' into quenting/dynamic-policy-data

This commit is contained in:
Quentin Gliech
2025-03-14 10:16:16 +01:00
committed by GitHub
142 changed files with 4705 additions and 1869 deletions

View File

@@ -94,6 +94,7 @@ updates:
tanstack-router:
patterns:
- "@tanstack/react-router"
- "@tanstack/react-router-*"
- "@tanstack/router-*"
tanstack-query:
patterns:

View File

@@ -117,7 +117,7 @@ jobs:
${{ matrix.target }}
- name: Setup sccache
uses: mozilla-actions/sccache-action@v0.0.7
uses: mozilla-actions/sccache-action@v0.0.8
- name: Install zig
uses: goto-bus-stop/setup-zig@v2

View File

@@ -153,7 +153,7 @@ jobs:
uses: actions/checkout@v4.2.2
- name: Run `cargo-deny`
uses: EmbarkStudios/cargo-deny-action@v2.0.6
uses: EmbarkStudios/cargo-deny-action@v2.0.11
with:
rust-version: stable
@@ -174,7 +174,7 @@ jobs:
rustup default stable
- name: Setup sccache
uses: mozilla-actions/sccache-action@v0.0.7
uses: mozilla-actions/sccache-action@v0.0.8
- uses: ./.github/actions/build-frontend
@@ -217,7 +217,7 @@ jobs:
- uses: ./.github/actions/build-policies
- name: Setup sccache
uses: mozilla-actions/sccache-action@v0.0.7
uses: mozilla-actions/sccache-action@v0.0.8
- name: Run clippy
run: |
@@ -243,7 +243,7 @@ jobs:
tool: cargo-nextest
- name: Setup sccache
uses: mozilla-actions/sccache-action@v0.0.7
uses: mozilla-actions/sccache-action@v0.0.8
- name: Build and archive tests
run: cargo nextest archive --workspace --archive-file nextest-archive.tar.zst

View File

@@ -102,7 +102,7 @@ jobs:
components: llvm-tools-preview
- name: Setup sccache
uses: mozilla-actions/sccache-action@v0.0.7
uses: mozilla-actions/sccache-action@v0.0.8
- name: Install grcov
uses: taiki-e/install-action@v2

View File

@@ -26,7 +26,7 @@ jobs:
uses: dtolnay/rust-toolchain@stable
- name: Setup sccache
uses: mozilla-actions/sccache-action@v0.0.7
uses: mozilla-actions/sccache-action@v0.0.8
- name: Install mdbook
uses: taiki-e/install-action@v2

View File

@@ -37,7 +37,7 @@ jobs:
- name: Create Pull Request
id: cpr
uses: peter-evans/create-pull-request@v7.0.7
uses: peter-evans/create-pull-request@v7.0.8
with:
sign-commits: true
token: ${{ secrets.BOT_GITHUB_TOKEN }}

319
Cargo.lock generated
View File

@@ -95,11 +95,11 @@ dependencies = [
"bytes",
"cfg-if",
"http",
"indexmap 2.7.1",
"indexmap 2.8.0",
"schemars",
"serde",
"serde_json",
"thiserror 2.0.11",
"thiserror 2.0.12",
"tower-layer",
"tower-service",
"tracing",
@@ -317,7 +317,7 @@ dependencies = [
"futures-timer",
"futures-util",
"http",
"indexmap 2.7.1",
"indexmap 2.8.0",
"mime",
"multer",
"num-traits",
@@ -369,7 +369,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "741110dda927420a28fbc1c310543d3416f789a6ba96859c2c265843a0a96887"
dependencies = [
"bytes",
"indexmap 2.7.1",
"indexmap 2.8.0",
"serde",
"serde_json",
]
@@ -742,9 +742,9 @@ checksum = "349f9b6a179ed607305526ca489b34ad0a41aed5f7980fa90eb03160b69598fb"
[[package]]
name = "bitflags"
version = "2.8.0"
version = "2.9.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8f68f53c83ab957f72c32642f3868eec03eb974d1fb82e453128456482613d36"
checksum = "5c8214115b7bf84099f1309324e63141d4c5d7cc26862f97a0a857dbefe165bd"
dependencies = [
"serde",
]
@@ -822,9 +822,9 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
[[package]]
name = "bytes"
version = "1.10.0"
version = "1.10.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f61dac84819c6588b558454b194026eb1f09c293b9036ae9b159e74e73ab6cf9"
checksum = "d71b6127be86fdcfddb610f7182ac57211d4b18a3e9c82eb2d17662f2227ad6a"
dependencies = [
"serde",
]
@@ -930,9 +930,9 @@ dependencies = [
[[package]]
name = "chrono"
version = "0.4.39"
version = "0.4.40"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7e36cc9d416881d2e24f9a963be5fb1cd90966419ac844274161d10488b3e825"
checksum = "1a7964611d71df112cb1730f2ee67324fcf4d0fc6606acbbe9bfe06df124637c"
dependencies = [
"android-tzdata",
"iana-time-zone",
@@ -940,7 +940,7 @@ dependencies = [
"num-traits",
"serde",
"wasm-bindgen",
"windows-targets 0.52.6",
"windows-link",
]
[[package]]
@@ -1101,9 +1101,9 @@ dependencies = [
[[package]]
name = "console"
version = "0.15.10"
version = "0.15.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ea3c6ecd8059b57859df5c69830340ed3c41d30e3da0c1cbed90a96ac853041b"
checksum = "054ccb5b10f9f2cbf51eb355ca1d05c2d279ce1804688d0db74b4733a5aeafd8"
dependencies = [
"encode_unicode",
"libc",
@@ -1120,9 +1120,9 @@ checksum = "c2459377285ad874054d797f3ccebf984978aa39129f6eafde5cdc8315b612f8"
[[package]]
name = "convert_case"
version = "0.7.1"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bb402b8d4c85569410425650ce3eddc7d698ed96d39a73f941b08fb63082f1e7"
checksum = "baaaa0ecca5b51987b9423ccdc971514dd8b0bb7b4060b983d3664dad3f1f89f"
dependencies = [
"unicode-segmentation",
]
@@ -1663,7 +1663,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "64ad6b66883f70e2f38f1ee99e3797b9d7e7b7fb051ed2e23e027c81753056c8"
dependencies = [
"rust_decimal",
"thiserror 2.0.11",
"thiserror 2.0.12",
"winnow",
]
@@ -1719,9 +1719,9 @@ dependencies = [
[[package]]
name = "email-encoding"
version = "0.3.1"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ea3d894bbbab314476b265f9b2d46bf24b123a36dd0e96b06a1b49545b9d9dcc"
checksum = "20b9cde6a71f9f758440470f3de16db6c09a02c443ce66850d87f5410548fb8e"
dependencies = [
"base64 0.22.1",
"memchr",
@@ -2081,7 +2081,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "07e28edb80900c19c28f1072f2e8aeca7fa06b23cd4169cefe1af5aa3260783f"
dependencies = [
"fallible-iterator",
"indexmap 2.7.1",
"indexmap 2.8.0",
"stable_deref_trait",
]
@@ -2147,7 +2147,7 @@ dependencies = [
"futures-core",
"futures-sink",
"http",
"indexmap 2.7.1",
"indexmap 2.8.0",
"slab",
"tokio",
"tokio-util",
@@ -2285,9 +2285,9 @@ dependencies = [
[[package]]
name = "http"
version = "1.2.0"
version = "1.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f16ca2af56261c99fba8bac40a10251ce8188205a4c448fbb745a2e4daa76fea"
checksum = "f4a85d31aea989eead29a3aaf9e1115a180df8282431156e533de47660892565"
dependencies = [
"bytes",
"fnv",
@@ -2306,12 +2306,12 @@ dependencies = [
[[package]]
name = "http-body-util"
version = "0.1.2"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "793429d76616a256bcb62c2a2ec2bed781c8307e797e2598c50010f2bee2544f"
checksum = "b021d93e26becf5dc7e1b75b1bed1fd93124b374ceb73f43d4d4eafec896a64a"
dependencies = [
"bytes",
"futures-util",
"futures-core",
"http",
"http-body",
"pin-project-lite",
@@ -2754,9 +2754,9 @@ dependencies = [
[[package]]
name = "indexmap"
version = "2.7.1"
version = "2.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8c9c992b02b5b4c94ea26e32fe5bccb7aa7d9f390ab5c1221ff895bc7ea8b652"
checksum = "3954d50fe15b02142bf25d3b8bdadb634ec3948f103d04ffe3031bc8fe9d7058"
dependencies = [
"equivalent",
"hashbrown 0.15.2",
@@ -2992,9 +2992,9 @@ checksum = "884e2677b40cc8c339eaefcb701c32ef1fd2493d71118dc0ca4b6a736c93bd67"
[[package]]
name = "lettre"
version = "0.11.14"
version = "0.11.15"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5d476fe7a4a798f392ce34947aa7d53d981127e37523c5251da3c927f7fa901f"
checksum = "759bc2b8eabb6a30b235d6f716f7f36479f4b38cbe65b8747aefee51f89e8437"
dependencies = [
"async-std",
"async-trait",
@@ -3142,7 +3142,7 @@ dependencies = [
"serde",
"serde_json",
"serde_with",
"thiserror 2.0.11",
"thiserror 2.0.12",
"tokio",
"tracing",
"ulid",
@@ -3265,7 +3265,7 @@ dependencies = [
"ruma-common",
"serde",
"serde_json",
"thiserror 2.0.11",
"thiserror 2.0.12",
"ulid",
"url",
"woothee",
@@ -3278,7 +3278,7 @@ dependencies = [
"async-trait",
"lettre",
"mas-templates",
"thiserror 2.0.11",
"thiserror 2.0.12",
"tracing",
]
@@ -3304,7 +3304,7 @@ dependencies = [
"governor",
"headers",
"hyper",
"indexmap 2.7.1",
"indexmap 2.8.0",
"insta",
"lettre",
"mas-axum-utils",
@@ -3342,7 +3342,7 @@ dependencies = [
"serde_urlencoded",
"serde_with",
"sqlx",
"thiserror 2.0.11",
"thiserror 2.0.12",
"time",
"tokio",
"tokio-util",
@@ -3396,7 +3396,7 @@ dependencies = [
"pest_derive",
"serde",
"serde_json",
"thiserror 2.0.11",
"thiserror 2.0.12",
"writeable",
]
@@ -3464,7 +3464,7 @@ dependencies = [
"serde_with",
"sha2",
"signature",
"thiserror 2.0.11",
"thiserror 2.0.12",
"url",
]
@@ -3493,7 +3493,7 @@ dependencies = [
"rsa",
"sec1",
"spki",
"thiserror 2.0.11",
"thiserror 2.0.12",
]
[[package]]
@@ -3509,7 +3509,7 @@ dependencies = [
"pin-project-lite",
"rustls-pemfile",
"socket2",
"thiserror 2.0.11",
"thiserror 2.0.12",
"tokio",
"tokio-rustls",
"tokio-test",
@@ -3541,7 +3541,7 @@ dependencies = [
"mas-matrix",
"reqwest",
"serde",
"thiserror 2.0.11",
"thiserror 2.0.12",
"tracing",
"url",
"urlencoding",
@@ -3576,7 +3576,7 @@ dependencies = [
"serde",
"serde_json",
"serde_urlencoded",
"thiserror 2.0.11",
"thiserror 2.0.12",
"tokio",
"tracing",
"url",
@@ -3595,7 +3595,7 @@ dependencies = [
"schemars",
"serde",
"serde_json",
"thiserror 2.0.11",
"thiserror 2.0.12",
"tokio",
"tracing",
]
@@ -3617,7 +3617,7 @@ version = "0.14.1"
dependencies = [
"camino",
"serde",
"thiserror 2.0.11",
"thiserror 2.0.12",
]
[[package]]
@@ -3635,7 +3635,7 @@ dependencies = [
"rand_core",
"serde",
"serde_json",
"thiserror 2.0.11",
"thiserror 2.0.12",
"tracing",
"tracing-opentelemetry",
"ulid",
@@ -3661,7 +3661,7 @@ dependencies = [
"sea-query-binder",
"serde_json",
"sqlx",
"thiserror 2.0.11",
"thiserror 2.0.12",
"tracing",
"ulid",
"url",
@@ -3691,7 +3691,7 @@ dependencies = [
"serde",
"serde_json",
"sqlx",
"thiserror 2.0.11",
"thiserror 2.0.12",
"tokio",
"tokio-util",
"tracing",
@@ -3720,7 +3720,7 @@ dependencies = [
"serde",
"serde_json",
"serde_urlencoded",
"thiserror 2.0.11",
"thiserror 2.0.12",
"tokio",
"tracing",
"ulid",
@@ -3807,9 +3807,9 @@ dependencies = [
[[package]]
name = "minijinja"
version = "2.7.0"
version = "2.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cff7b8df5e85e30b87c2b0b3f58ba3a87b68e133738bf512a7713769326dbca9"
checksum = "6e36f1329330bb1614c94b78632b9ce45dd7d761f3304a1bed07b2990a7c5097"
dependencies = [
"memo-map",
"self_cell",
@@ -3820,9 +3820,9 @@ dependencies = [
[[package]]
name = "minijinja-contrib"
version = "2.7.0"
version = "2.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7ac3e47a9006ed0500425a092c9f8b2e56d10f8aeec8ce870c5e8a7c6ef2d7c3"
checksum = "8e807b6b15e36a4c808e92f78c2ac1f6776519a50d9cf6649819c759a8e7133c"
dependencies = [
"minijinja",
"serde",
@@ -4019,7 +4019,7 @@ dependencies = [
"serde_json",
"serde_with",
"sha2",
"thiserror 2.0.11",
"thiserror 2.0.12",
"url",
]
@@ -4031,7 +4031,7 @@ checksum = "62948e14d923ea95ea2c7c86c71013138b66525b86bdc08d2dcc262bdb497b87"
dependencies = [
"crc32fast",
"hashbrown 0.15.2",
"indexmap 2.7.1",
"indexmap 2.8.0",
"memchr",
]
@@ -4069,7 +4069,7 @@ dependencies = [
"sha1",
"sha2",
"sprintf",
"thiserror 2.0.11",
"thiserror 1.0.69",
"tokio",
"tracing",
"urlencoding",
@@ -4099,7 +4099,7 @@ dependencies = [
"futures-sink",
"js-sys",
"pin-project-lite",
"thiserror 2.0.11",
"thiserror 2.0.12",
"tracing",
]
@@ -4141,7 +4141,7 @@ dependencies = [
"opentelemetry-proto",
"opentelemetry_sdk",
"prost",
"thiserror 2.0.11",
"thiserror 2.0.12",
]
[[package]]
@@ -4199,7 +4199,7 @@ dependencies = [
"opentelemetry",
"opentelemetry_sdk",
"serde",
"thiserror 2.0.11",
"thiserror 2.0.12",
]
[[package]]
@@ -4217,7 +4217,7 @@ dependencies = [
"percent-encoding",
"rand",
"serde_json",
"thiserror 2.0.11",
"thiserror 2.0.12",
"tokio",
"tokio-stream",
"tracing",
@@ -4392,7 +4392,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b7cafe60d6cf8e62e1b9b2ea516a089c008945bb5a275416789e7db0bc199dc"
dependencies = [
"memchr",
"thiserror 2.0.11",
"thiserror 2.0.12",
"ucd-trie",
]
@@ -4722,9 +4722,9 @@ checksum = "106dd99e98437432fed6519dedecfade6a06a73bb7b2a1e019fdd2bee5778d94"
[[package]]
name = "psl"
version = "2.1.86"
version = "2.1.93"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "138e02ed846877ce4044391085ca68b470b0d379cd18a9be0666161764d35448"
checksum = "2e3f76f79643b799e2ddc530120ef7046a05e943297ddc46f91c3b659b6e8456"
dependencies = [
"psl-types",
]
@@ -4783,7 +4783,7 @@ dependencies = [
"rustc-hash 2.1.1",
"rustls",
"socket2",
"thiserror 2.0.11",
"thiserror 2.0.12",
"tokio",
"tracing",
]
@@ -4802,7 +4802,7 @@ dependencies = [
"rustls",
"rustls-pki-types",
"slab",
"thiserror 2.0.11",
"thiserror 2.0.12",
"tinyvec",
"tracing",
"web-time",
@@ -4965,9 +4965,9 @@ checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c"
[[package]]
name = "reqwest"
version = "0.12.12"
version = "0.12.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "43e734407157c3c2034e0258f5e4473ddb361b1e85f95a66690d67264d7cd1da"
checksum = "989e327e510263980e231de548a33e63d34962d29ae61b467389a1a09627a254"
dependencies = [
"base64 0.22.1",
"bytes",
@@ -5021,9 +5021,9 @@ dependencies = [
[[package]]
name = "ring"
version = "0.17.11"
version = "0.17.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "da5349ae27d3887ca812fb375b45a4fbb36d8d12d2df394968cd86e35683fe73"
checksum = "70ac5d832aa16abd7d1def883a8545280c20a60f523a370aa3a9617c2b8550ee"
dependencies = [
"cc",
"cfg-if",
@@ -5063,7 +5063,7 @@ dependencies = [
"base64 0.22.1",
"bytes",
"form_urlencoded",
"indexmap 2.7.1",
"indexmap 2.8.0",
"js_int",
"percent-encoding",
"regex",
@@ -5072,7 +5072,7 @@ dependencies = [
"serde",
"serde_html_form",
"serde_json",
"thiserror 2.0.11",
"thiserror 2.0.12",
"time",
"tracing",
"url",
@@ -5087,7 +5087,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6ad674b5e5368c53a2c90fde7dac7e30747004aaf7b1827b72874a25fc06d4d8"
dependencies = [
"js_int",
"thiserror 2.0.11",
"thiserror 2.0.12",
]
[[package]]
@@ -5289,7 +5289,7 @@ dependencies = [
"chrono",
"dyn-clone",
"indexmap 1.9.3",
"indexmap 2.7.1",
"indexmap 2.8.0",
"schemars_derive",
"serde",
"serde_json",
@@ -5532,18 +5532,18 @@ dependencies = [
[[package]]
name = "serde"
version = "1.0.218"
version = "1.0.219"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e8dfc9d19bdbf6d17e22319da49161d5d0108e4188e8b680aef6299eed22df60"
checksum = "5f0e2c6ed6606019b4e29e69dbaba95b11854410e5347d525002456dbbb786b6"
dependencies = [
"serde_derive",
]
[[package]]
name = "serde_derive"
version = "1.0.218"
version = "1.0.219"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f09503e191f4e797cb8aac08e9a4a4695c5edf6a2e70e376d961ddd5c969f82b"
checksum = "5b0276cf7f2c73365f7157c8123c21cd9a50fbbd844757af28ca1f5925fc2a00"
dependencies = [
"proc-macro2",
"quote",
@@ -5568,7 +5568,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9d2de91cf02bbc07cde38891769ccd5d4f073d22a40683aa4bc7a95781aaa2c4"
dependencies = [
"form_urlencoded",
"indexmap 2.7.1",
"indexmap 2.8.0",
"itoa",
"ryu",
"serde",
@@ -5580,7 +5580,7 @@ version = "1.0.140"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "20068b6e96dc6c9bd23e01df8827e6c7e1f2fddd43c21810382803c136b99373"
dependencies = [
"indexmap 2.7.1",
"indexmap 2.8.0",
"itoa",
"memchr",
"ryu",
@@ -5628,7 +5628,7 @@ dependencies = [
"chrono",
"hex",
"indexmap 1.9.3",
"indexmap 2.7.1",
"indexmap 2.8.0",
"serde",
"serde_derive",
"serde_json",
@@ -5654,7 +5654,7 @@ version = "0.9.34+deprecated"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6a8b1a1a2ebf674015cc02edccce75287f1a0130d394307b36743c2f5d504b47"
dependencies = [
"indexmap 2.7.1",
"indexmap 2.8.0",
"itoa",
"ryu",
"serde",
@@ -5837,7 +5837,7 @@ dependencies = [
"futures-util",
"hashbrown 0.15.2",
"hashlink",
"indexmap 2.7.1",
"indexmap 2.8.0",
"ipnetwork",
"log",
"memchr",
@@ -5849,7 +5849,7 @@ dependencies = [
"serde_json",
"sha2",
"smallvec",
"thiserror 2.0.11",
"thiserror 2.0.12",
"tokio",
"tokio-stream",
"tracing",
@@ -5935,7 +5935,7 @@ dependencies = [
"smallvec",
"sqlx-core",
"stringprep",
"thiserror 2.0.11",
"thiserror 2.0.12",
"tracing",
"uuid",
"whoami",
@@ -5975,7 +5975,7 @@ dependencies = [
"smallvec",
"sqlx-core",
"stringprep",
"thiserror 2.0.11",
"thiserror 2.0.12",
"tracing",
"uuid",
"whoami",
@@ -6098,6 +6098,7 @@ name = "syn2mas"
version = "0.14.1"
dependencies = [
"anyhow",
"arc-swap",
"bitflags",
"camino",
"chrono",
@@ -6108,12 +6109,17 @@ dependencies = [
"mas-config",
"mas-storage",
"mas-storage-pg",
"opentelemetry",
"opentelemetry-semantic-conventions",
"rand",
"rand_chacha",
"rustc-hash 2.1.1",
"serde",
"sqlx",
"thiserror 2.0.11",
"thiserror 2.0.12",
"thiserror-ext",
"tokio",
"tokio-util",
"tracing",
"ulid",
"uuid",
@@ -6179,11 +6185,11 @@ dependencies = [
[[package]]
name = "thiserror"
version = "2.0.11"
version = "2.0.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d452f284b73e6d76dd36758a0c8684b1d5be31f92b89d07fd5822175732206fc"
checksum = "567b8a2dae586314f7be2a752ec7474332959c6460e02bde30d702a66d488708"
dependencies = [
"thiserror-impl 2.0.11",
"thiserror-impl 2.0.12",
]
[[package]]
@@ -6221,9 +6227,9 @@ dependencies = [
[[package]]
name = "thiserror-impl"
version = "2.0.11"
version = "2.0.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "26afc1baea8a989337eeb52b6e72a039780ce45c3edfcc9c5b9d112feeb173c2"
checksum = "7f7cf42b4507d8ea322120659672cf1b9dbb93f8f2d4ecfd6e51350ff5b17a1d"
dependencies = [
"proc-macro2",
"quote",
@@ -6242,9 +6248,9 @@ dependencies = [
[[package]]
name = "time"
version = "0.3.37"
version = "0.3.39"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "35e7868883861bd0e56d9ac6efcaaca0d6d5d82a2a7ec8209ff492c07cf37b21"
checksum = "dad298b01a40a23aac4580b67e3dbedb7cc8402f3592d7f49469de2ea4aecdd8"
dependencies = [
"deranged",
"itoa",
@@ -6259,15 +6265,15 @@ dependencies = [
[[package]]
name = "time-core"
version = "0.1.2"
version = "0.1.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3"
checksum = "765c97a5b985b7c11d7bc27fa927dc4fe6af3a6dfb021d28deb60d3bf51e76ef"
[[package]]
name = "time-macros"
version = "0.2.19"
version = "0.2.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2834e6017e3e5e4b9834939793b282bc03b37a3336245fa820e35e233e2a85de"
checksum = "e8093bc3e81c3bc5f7879de09619d06c9a5a5e45ca44dfeeb7225bae38005c5c"
dependencies = [
"num-conv",
"time-core",
@@ -6300,9 +6306,9 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20"
[[package]]
name = "tokio"
version = "1.43.0"
version = "1.44.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3d61fa4ffa3de412bfea335c6ecff681de2b609ba3c77ef3e00e521813a9ed9e"
checksum = "9975ea0f48b5aa3972bf2d888c238182458437cc2a19374b81b25cdf1023fb3a"
dependencies = [
"backtrace",
"bytes",
@@ -6329,9 +6335,9 @@ dependencies = [
[[package]]
name = "tokio-rustls"
version = "0.26.1"
version = "0.26.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5f6d0975eaace0cf0fcadee4e4aaa5da15b5c079146f2cffb67c113be122bf37"
checksum = "8e727b36a1a0e8b74c376ac2211e40c2c8af09fb4013c60d910495810f008e9b"
dependencies = [
"rustls",
"tokio",
@@ -6375,9 +6381,9 @@ dependencies = [
[[package]]
name = "tokio-util"
version = "0.7.13"
version = "0.7.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d7fcaa8d55a2bdd6b83ace262b016eca0d79ee02818c5c1bcdf0305114081078"
checksum = "6b9590b93e6fcc1739458317cccd391ad3955e2bde8913edf6f95f9e65a8f034"
dependencies = [
"bytes",
"futures-core",
@@ -6415,7 +6421,7 @@ version = "0.22.22"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4ae48d6208a266e853d946088ed816055e556cc6028c5e8e2b84d9fa5dd7c7f5"
dependencies = [
"indexmap 2.7.1",
"indexmap 2.8.0",
"serde",
"serde_spanned",
"toml_datetime",
@@ -6774,9 +6780,9 @@ checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821"
[[package]]
name = "uuid"
version = "1.14.0"
version = "1.15.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "93d59ca99a559661b96bf898d8fce28ed87935fd2bea9f05983c1464dd6c71b1"
checksum = "e0f540e3240398cce6128b64ba83fdbdd86129c16a3aa1a3a252efd66eb3d587"
dependencies = [
"serde",
]
@@ -6969,7 +6975,7 @@ checksum = "04f17a5917c2ddd3819e84c661fae0d6ba29d7b9c1f0e96c708c65a9c4188e11"
dependencies = [
"bitflags",
"hashbrown 0.15.2",
"indexmap 2.7.1",
"indexmap 2.8.0",
"semver",
"serde",
]
@@ -6999,7 +7005,7 @@ dependencies = [
"cc",
"cfg-if",
"hashbrown 0.15.2",
"indexmap 2.7.1",
"indexmap 2.8.0",
"libc",
"log",
"mach2",
@@ -7097,7 +7103,7 @@ dependencies = [
"cranelift-bitset",
"cranelift-entity",
"gimli",
"indexmap 2.7.1",
"indexmap 2.8.0",
"log",
"object",
"postcard",
@@ -7171,7 +7177,7 @@ checksum = "c8a658273786102da083263eaf2deb76ef7176349b47098bfff15a3dd5776ff2"
dependencies = [
"anyhow",
"heck 0.5.0",
"indexmap 2.7.1",
"indexmap 2.8.0",
"wit-parser",
]
@@ -7292,33 +7298,38 @@ dependencies = [
]
[[package]]
name = "windows-registry"
version = "0.2.0"
name = "windows-link"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "e400001bb720a623c1c69032f8e3e4cf09984deec740f007dd2b03ec864804b0"
checksum = "6dccfd733ce2b1753b03b6d3c65edf020262ea35e20ccdf3e288043e6dd620e3"
[[package]]
name = "windows-registry"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4286ad90ddb45071efd1a66dfa43eb02dd0dfbae1545ad6cc3c51cf34d7e8ba3"
dependencies = [
"windows-result",
"windows-strings",
"windows-targets 0.52.6",
"windows-targets 0.53.0",
]
[[package]]
name = "windows-result"
version = "0.2.0"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1d1043d8214f791817bab27572aaa8af63732e11bf84aa21a45a78d6c317ae0e"
checksum = "06374efe858fab7e4f881500e6e86ec8bc28f9462c47e5a9941a0142ad86b189"
dependencies = [
"windows-targets 0.52.6",
"windows-link",
]
[[package]]
name = "windows-strings"
version = "0.1.0"
version = "0.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4cd9b125c486025df0eabcb585e62173c6c9eddcec5d117d3b6e8c30e2ee4d10"
checksum = "87fa48cc5d406560701792be122a10132491cff9d0aeb23583cc2dcafc847319"
dependencies = [
"windows-result",
"windows-targets 0.52.6",
"windows-link",
]
[[package]]
@@ -7396,13 +7407,29 @@ dependencies = [
"windows_aarch64_gnullvm 0.52.6",
"windows_aarch64_msvc 0.52.6",
"windows_i686_gnu 0.52.6",
"windows_i686_gnullvm",
"windows_i686_gnullvm 0.52.6",
"windows_i686_msvc 0.52.6",
"windows_x86_64_gnu 0.52.6",
"windows_x86_64_gnullvm 0.52.6",
"windows_x86_64_msvc 0.52.6",
]
[[package]]
name = "windows-targets"
version = "0.53.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b1e4c7e8ceaaf9cb7d7507c974735728ab453b67ef8f18febdd7c11fe59dca8b"
dependencies = [
"windows_aarch64_gnullvm 0.53.0",
"windows_aarch64_msvc 0.53.0",
"windows_i686_gnu 0.53.0",
"windows_i686_gnullvm 0.53.0",
"windows_i686_msvc 0.53.0",
"windows_x86_64_gnu 0.53.0",
"windows_x86_64_gnullvm 0.53.0",
"windows_x86_64_msvc 0.53.0",
]
[[package]]
name = "windows_aarch64_gnullvm"
version = "0.42.2"
@@ -7421,6 +7448,12 @@ version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3"
[[package]]
name = "windows_aarch64_gnullvm"
version = "0.53.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "86b8d5f90ddd19cb4a147a5fa63ca848db3df085e25fee3cc10b39b6eebae764"
[[package]]
name = "windows_aarch64_msvc"
version = "0.42.2"
@@ -7439,6 +7472,12 @@ version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469"
[[package]]
name = "windows_aarch64_msvc"
version = "0.53.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c7651a1f62a11b8cbd5e0d42526e55f2c99886c77e007179efff86c2b137e66c"
[[package]]
name = "windows_i686_gnu"
version = "0.42.2"
@@ -7457,12 +7496,24 @@ version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b"
[[package]]
name = "windows_i686_gnu"
version = "0.53.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c1dc67659d35f387f5f6c479dc4e28f1d4bb90ddd1a5d3da2e5d97b42d6272c3"
[[package]]
name = "windows_i686_gnullvm"
version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66"
[[package]]
name = "windows_i686_gnullvm"
version = "0.53.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9ce6ccbdedbf6d6354471319e781c0dfef054c81fbc7cf83f338a4296c0cae11"
[[package]]
name = "windows_i686_msvc"
version = "0.42.2"
@@ -7481,6 +7532,12 @@ version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66"
[[package]]
name = "windows_i686_msvc"
version = "0.53.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "581fee95406bb13382d2f65cd4a908ca7b1e4c2f1917f143ba16efe98a589b5d"
[[package]]
name = "windows_x86_64_gnu"
version = "0.42.2"
@@ -7499,6 +7556,12 @@ version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78"
[[package]]
name = "windows_x86_64_gnu"
version = "0.53.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2e55b5ac9ea33f2fc1716d1742db15574fd6fc8dadc51caab1c16a3d3b4190ba"
[[package]]
name = "windows_x86_64_gnullvm"
version = "0.42.2"
@@ -7517,6 +7580,12 @@ version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d"
[[package]]
name = "windows_x86_64_gnullvm"
version = "0.53.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0a6e035dd0599267ce1ee132e51c27dd29437f63325753051e71dd9e42406c57"
[[package]]
name = "windows_x86_64_msvc"
version = "0.42.2"
@@ -7535,6 +7604,12 @@ version = "0.52.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec"
[[package]]
name = "windows_x86_64_msvc"
version = "0.53.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "271414315aff87387382ec3d271b52d7ae78726f5d44ac98b4f4030c91880486"
[[package]]
name = "winnow"
version = "0.6.26"
@@ -7546,9 +7621,9 @@ dependencies = [
[[package]]
name = "wiremock"
version = "0.6.2"
version = "0.6.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7fff469918e7ca034884c7fd8f93fe27bacb7fcb599fd879df6c7b429a29b646"
checksum = "101681b74cd87b5899e87bcf5a64e83334dd313fcd3053ea72e6dba18928e301"
dependencies = [
"assert-json-diff",
"async-trait",
@@ -7576,7 +7651,7 @@ checksum = "e3477d8d0acb530d76beaa8becbdb1e3face08929db275f39934963eb4f716f8"
dependencies = [
"anyhow",
"id-arena",
"indexmap 2.7.1",
"indexmap 2.8.0",
"log",
"semver",
"serde",

View File

@@ -97,11 +97,11 @@ version = "1.6.0"
# Packed bitfields
[workspace.dependencies.bitflags]
version = "2.8.0"
version = "2.9.0"
# Bytes
[workspace.dependencies.bytes]
version = "1.10.0"
version = "1.10.1"
# UTF-8 paths
[workspace.dependencies.camino]
@@ -113,7 +113,7 @@ version = "0.8.1"
# Time utilities
[workspace.dependencies.chrono]
version = "0.4.39"
version = "0.4.40"
default-features = false
features = ["serde", "clock"]
@@ -150,7 +150,7 @@ version = "0.4.0"
# HTTP request/response
[workspace.dependencies.http]
version = "1.2.0"
version = "1.3.1"
# HTTP body trait
[workspace.dependencies.http-body]
@@ -158,7 +158,7 @@ version = "1.0.1"
# http-body utilities
[workspace.dependencies.http-body-util]
version = "0.1.2"
version = "0.1.3"
# HTTP client and server
[workspace.dependencies.hyper]
@@ -191,7 +191,7 @@ features = ["yaml", "json"]
# Email sending
[workspace.dependencies.lettre]
version = "0.11.14"
version = "0.11.15"
default-features = false
features = [
"tokio1-rustls-tls",
@@ -205,12 +205,12 @@ features = [
# Templates
[workspace.dependencies.minijinja]
version = "2.7.0"
version = "2.8.0"
features = ["loader", "json", "speedups", "unstable_machinery"]
# Additional filters for minijinja
[workspace.dependencies.minijinja-contrib]
version = "2.7.0"
version = "2.8.0"
features = ["pycompat"]
# Utilities to deal with non-zero values
@@ -257,7 +257,7 @@ version = "0.6.4"
# High-level HTTP client
[workspace.dependencies.reqwest]
version = "0.12.12"
version = "0.12.14"
default-features = false
features = ["http2", "rustls-tls-manual-roots", "charset", "json", "socks"]
@@ -311,7 +311,7 @@ version = "0.36.0"
# Serialization and deserialization
[workspace.dependencies.serde]
version = "1.0.218"
version = "1.0.219"
features = ["derive"] # Most of the time, if we need serde, we need derive
# JSON serialization and deserialization
@@ -335,14 +335,14 @@ features = [
# Custom error types
[workspace.dependencies.thiserror]
version = "2.0.11"
version = "2.0.12"
[workspace.dependencies.thiserror-ext]
version = "0.2.1"
# Async runtime
[workspace.dependencies.tokio]
version = "1.43.0"
version = "1.44.0"
features = ["full"]
[workspace.dependencies.tokio-stream]
@@ -350,7 +350,7 @@ version = "0.1.17"
# Useful async utilities
[workspace.dependencies.tokio-util]
version = "0.7.13"
version = "0.7.14"
features = ["rt"]
# Tower services
@@ -427,7 +427,7 @@ features = ["serde"]
# HTTP mock server
[workspace.dependencies.wiremock]
version = "0.6.2"
version = "0.6.3"
[profile.release]
codegen-units = 1 # Reduce the number of codegen units to increase optimizations

View File

@@ -5,7 +5,7 @@
// Please see LICENSE in the repository root for full details.
use mas_data_model::BrowserSession;
use mas_storage::{RepositoryAccess, user::BrowserSessionRepository};
use mas_storage::RepositoryAccess;
use serde::{Deserialize, Serialize};
use ulid::Ulid;
@@ -33,13 +33,12 @@ impl SessionInfo {
self
}
/// Load the [`BrowserSession`] from database
/// Load the active [`BrowserSession`] from database
///
/// # Errors
///
/// Returns an error if the session is not found or if the session is not
/// active anymore
pub async fn load_session<E>(
/// Returns an error if the underlying repository fails to load the session.
pub async fn load_active_session<E>(
&self,
repo: &mut impl RepositoryAccess<Error = E>,
) -> Result<Option<BrowserSession>, E> {
@@ -56,6 +55,12 @@ impl SessionInfo {
Ok(maybe_session)
}
/// Get the current session ID, if any
#[must_use]
pub fn current_session_id(&self) -> Option<Ulid> {
self.current
}
}
pub trait SessionInfoExt {

View File

@@ -19,7 +19,7 @@ axum.workspace = true
bytes.workspace = true
camino.workspace = true
clap.workspace = true
console = "0.15.10"
console = "0.15.11"
dialoguer = { version = "0.11.0", default-features = false, features = [
"fuzzy-select",
"password",

View File

@@ -292,7 +292,7 @@ impl Options {
.context("User not found")?;
let device = if let Some(device_id) = device_id {
device_id.try_into()?
device_id.into()
} else {
Device::generate(&mut rng)
};

View File

@@ -1,4 +1,4 @@
use std::{collections::HashMap, process::ExitCode};
use std::{collections::HashMap, process::ExitCode, sync::atomic::Ordering, time::Duration};
use anyhow::Context;
use camino::Utf8PathBuf;
@@ -12,10 +12,12 @@ use mas_storage::SystemClock;
use mas_storage_pg::MIGRATOR;
use rand::thread_rng;
use sqlx::{Connection, Either, PgConnection, postgres::PgConnectOptions, types::Uuid};
use syn2mas::{LockedMasDatabase, MasWriter, SynapseReader, synapse_config};
use tracing::{Instrument, error, info_span, warn};
use syn2mas::{
LockedMasDatabase, MasWriter, Progress, ProgressStage, SynapseReader, synapse_config,
};
use tracing::{Instrument, error, info, info_span, warn};
use crate::util::database_connection_from_config;
use crate::util::{DatabaseConnectOptions, database_connection_from_config_with_options};
/// The exit code used by `syn2mas check` and `syn2mas migrate` when there are
/// errors preventing migration.
@@ -80,6 +82,7 @@ enum Subcommand {
const NUM_WRITER_CONNECTIONS: usize = 8;
impl Options {
#[tracing::instrument("cli.syn2mas.run", skip_all)]
#[allow(clippy::too_many_lines)]
pub async fn run(self, figment: &Figment) -> anyhow::Result<ExitCode> {
warn!(
@@ -113,7 +116,13 @@ impl Options {
let config = DatabaseConfig::extract_or_default(figment)?;
let mut mas_connection = database_connection_from_config(&config).await?;
let mut mas_connection = database_connection_from_config_with_options(
&config,
&DatabaseConnectOptions {
log_slow_statements: false,
},
)
.await?;
MIGRATOR
.run(&mut mas_connection)
@@ -173,14 +182,14 @@ impl Options {
// Display errors and warnings
if !check_errors.is_empty() {
eprintln!("===== Errors =====");
eprintln!("\n\n===== Errors =====");
eprintln!("These issues prevent migrating from Synapse to MAS right now:\n");
for error in &check_errors {
eprintln!("{error}\n");
}
}
if !check_warnings.is_empty() {
eprintln!("===== Warnings =====");
eprintln!("\n\n===== Warnings =====");
eprintln!(
"These potential issues should be considered before migrating from Synapse to MAS right now:\n"
);
@@ -220,10 +229,19 @@ impl Options {
// TODO how should we handle warnings at this stage?
// TODO this dry-run flag should be set to false in real circumstances !!!
let reader = SynapseReader::new(&mut syn_conn, true).await?;
let mut writer_mas_connections = Vec::with_capacity(NUM_WRITER_CONNECTIONS);
for _ in 0..NUM_WRITER_CONNECTIONS {
writer_mas_connections.push(database_connection_from_config(&config).await?);
writer_mas_connections.push(
database_connection_from_config_with_options(
&config,
&DatabaseConnectOptions {
log_slow_statements: false,
},
)
.await?,
);
}
let writer = MasWriter::new(mas_connection, writer_mas_connections).await?;
@@ -232,8 +250,13 @@ impl Options {
#[allow(clippy::disallowed_methods)]
let mut rng = thread_rng();
// TODO progress reporting
let progress = Progress::default();
let occasional_progress_logger_task =
tokio::spawn(occasional_progress_logger(progress.clone()));
let mas_matrix = MatrixConfig::extract(figment)?;
eprintln!("\n\n");
syn2mas::migrate(
reader,
writer,
@@ -241,11 +264,45 @@ impl Options {
&clock,
&mut rng,
provider_id_mappings,
&progress,
)
.await?;
occasional_progress_logger_task.abort();
Ok(ExitCode::SUCCESS)
}
}
}
}
/// Logs progress every 30 seconds, as a lightweight alternative to a progress
/// bar. For most deployments, the migration will not take 30 seconds so this
/// will not be relevant. In other cases, this will give the operator an idea of
/// what's going on.
async fn occasional_progress_logger(progress: Progress) {
loop {
tokio::time::sleep(Duration::from_secs(30)).await;
match &**progress.get_current_stage() {
ProgressStage::SettingUp => {
info!(name: "progress", "still setting up");
}
ProgressStage::MigratingData {
entity,
migrated,
approx_count,
} => {
let migrated = migrated.load(Ordering::Relaxed);
#[allow(clippy::cast_precision_loss)]
let percent = (f64::from(migrated) / *approx_count as f64) * 100.0;
info!(name: "progress", "migrating {entity}: {migrated}/~{approx_count} (~{percent:.1}%)");
}
ProgressStage::RebuildIndex { index_name } => {
info!(name: "progress", "still waiting for rebuild of index {index_name}");
}
ProgressStage::RebuildConstraint { constraint_name } => {
info!(name: "progress", "still waiting for rebuild of constraint {constraint_name}");
}
}
}
}

View File

@@ -165,11 +165,14 @@ pub async fn config_sync(
}
}
for provider in upstream_oauth2_config.providers {
for (index, provider) in upstream_oauth2_config.providers.into_iter().enumerate() {
if !provider.enabled {
continue;
}
// Use the position in the config of the provider as position in the UI
let ui_order = index.try_into().unwrap_or(i32::MAX);
let _span = info_span!("provider", %provider.id).entered();
if existing_enabled_ids.contains(&provider.id) {
info!("Updating provider");
@@ -293,6 +296,7 @@ pub async fn config_sync(
.additional_authorization_parameters
.into_iter()
.collect(),
ui_order,
},
)
.await?;

View File

@@ -210,6 +210,7 @@ pub fn site_config_from_config(
&& account_config.password_change_allowed,
account_recovery_allowed: password_config.enabled()
&& account_config.password_recovery_enabled,
account_deactivation_allowed: account_config.account_deactivation_allowed,
captcha,
minimum_password_complexity: password_config.minimum_complexity(),
session_expiration,
@@ -234,6 +235,7 @@ pub async fn templates_from_config(
fn database_connect_options_from_config(
config: &DatabaseConfig,
opts: &DatabaseConnectOptions,
) -> Result<PgConnectOptions, anyhow::Error> {
let options = if let Some(uri) = config.uri.as_deref() {
uri.parse()
@@ -318,9 +320,11 @@ fn database_connect_options_from_config(
None => options,
};
let options = options
.log_statements(LevelFilter::Debug)
.log_slow_statements(LevelFilter::Warn, Duration::from_millis(100));
let mut options = options.log_statements(LevelFilter::Debug);
if opts.log_slow_statements {
options = options.log_slow_statements(LevelFilter::Warn, Duration::from_millis(100));
}
Ok(options)
}
@@ -328,7 +332,7 @@ fn database_connect_options_from_config(
/// Create a database connection pool from the configuration
#[tracing::instrument(name = "db.connect", skip_all, err(Debug))]
pub async fn database_pool_from_config(config: &DatabaseConfig) -> Result<PgPool, anyhow::Error> {
let options = database_connect_options_from_config(config)?;
let options = database_connect_options_from_config(config, &DatabaseConnectOptions::default())?;
PgPoolOptions::new()
.max_connections(config.max_connections.into())
.min_connections(config.min_connections)
@@ -340,12 +344,37 @@ pub async fn database_pool_from_config(config: &DatabaseConfig) -> Result<PgPool
.context("could not connect to the database")
}
pub struct DatabaseConnectOptions {
pub log_slow_statements: bool,
}
impl Default for DatabaseConnectOptions {
fn default() -> Self {
Self {
log_slow_statements: true,
}
}
}
/// Create a single database connection from the configuration
#[tracing::instrument(name = "db.connect", skip_all, err(Debug))]
pub async fn database_connection_from_config(
config: &DatabaseConfig,
) -> Result<PgConnection, anyhow::Error> {
database_connect_options_from_config(config)?
database_connect_options_from_config(config, &DatabaseConnectOptions::default())?
.connect()
.await
.context("could not connect to the database")
}
/// Create a single database connection from the configuration,
/// with specific options.
#[tracing::instrument(name = "db.connect", skip_all, err(Debug))]
pub async fn database_connection_from_config_with_options(
config: &DatabaseConfig,
options: &DatabaseConnectOptions,
) -> Result<PgConnection, anyhow::Error> {
database_connect_options_from_config(config, options)?
.connect()
.await
.context("could not connect to the database")

View File

@@ -61,6 +61,11 @@ pub struct AccountConfig {
/// This has no effect if password login is disabled.
#[serde(default = "default_false", skip_serializing_if = "is_default_false")]
pub password_recovery_enabled: bool,
/// Whether users are allowed to delete their own account. Defaults to
/// `true`.
#[serde(default = "default_true", skip_serializing_if = "is_default_true")]
pub account_deactivation_allowed: bool,
}
impl Default for AccountConfig {
@@ -71,6 +76,7 @@ impl Default for AccountConfig {
password_registration_enabled: default_false(),
password_change_allowed: default_true(),
password_recovery_enabled: default_false(),
account_deactivation_allowed: default_true(),
}
}
}
@@ -83,6 +89,7 @@ impl AccountConfig {
&& is_default_true(&self.displayname_change_allowed)
&& is_default_true(&self.password_change_allowed)
&& is_default_false(&self.password_recovery_enabled)
&& is_default_true(&self.account_deactivation_allowed)
}
}

View File

@@ -22,21 +22,22 @@ pub struct Device {
}
#[derive(Debug, Error)]
pub enum InvalidDeviceID {
#[error("Device ID contains invalid characters")]
pub enum ToScopeTokenError {
#[error("Device ID contains characters that can't be encoded in a scope")]
InvalidCharacters,
}
impl Device {
/// Get the corresponding [`ScopeToken`] for that device
#[must_use]
pub fn to_scope_token(&self) -> ScopeToken {
// SAFETY: the inner id should only have valid scope characters
let Ok(scope_token) = format!("{DEVICE_SCOPE_PREFIX}{}", self.id).parse() else {
unreachable!()
};
scope_token
///
/// # Errors
///
/// Returns an error if the device ID contains characters that can't be
/// encoded in a scope
pub fn to_scope_token(&self) -> Result<ScopeToken, ToScopeTokenError> {
format!("{DEVICE_SCOPE_PREFIX}{}", self.id)
.parse()
.map_err(|_| ToScopeTokenError::InvalidCharacters)
}
/// Get the corresponding [`Device`] from a [`ScopeToken`]
@@ -45,8 +46,7 @@ impl Device {
#[must_use]
pub fn from_scope_token(token: &ScopeToken) -> Option<Self> {
let id = token.as_str().strip_prefix(DEVICE_SCOPE_PREFIX)?;
// XXX: we might be silently ignoring errors here, but it's probably fine?
Device::try_from(id.to_owned()).ok()
Some(Device::from(id.to_owned()))
}
/// Generate a random device ID
@@ -62,39 +62,15 @@ impl Device {
}
}
const fn valid_device_chars(c: char) -> bool {
// This matches the regex in the policy
c.is_ascii_alphanumeric()
|| c == '.'
|| c == '_'
|| c == '~'
|| c == '!'
|| c == '$'
|| c == '&'
|| c == '\''
|| c == '('
|| c == ')'
|| c == '*'
|| c == '+'
|| c == ','
|| c == ';'
|| c == '='
|| c == ':'
|| c == '@'
|| c == '/'
|| c == '-'
impl From<String> for Device {
fn from(id: String) -> Self {
Self { id }
}
}
impl TryFrom<String> for Device {
type Error = InvalidDeviceID;
/// Create a [`Device`] out of an ID, validating the ID has the right shape
fn try_from(id: String) -> Result<Self, Self::Error> {
if !id.chars().all(valid_device_chars) {
return Err(InvalidDeviceID::InvalidCharacters);
}
Ok(Self { id })
impl From<Device> for String {
fn from(device: Device) -> Self {
device.id
}
}
@@ -112,8 +88,8 @@ mod test {
#[test]
fn test_device_id_to_from_scope_token() {
let device = Device::try_from("AABBCCDDEE".to_owned()).unwrap();
let scope_token = device.to_scope_token();
let device = Device::from("AABBCCDDEE".to_owned());
let scope_token = device.to_scope_token().unwrap();
assert_eq!(
scope_token.as_str(),
"urn:matrix:org.matrix.msc2967.client:device:AABBCCDDEE"

View File

@@ -12,7 +12,7 @@ mod session;
mod sso_login;
pub use self::{
device::Device,
device::{Device, ToScopeTokenError},
session::{CompatSession, CompatSessionState},
sso_login::{CompatSsoLogin, CompatSsoLoginState},
};

View File

@@ -27,7 +27,7 @@ pub use ulid::Ulid;
pub use self::{
compat::{
CompatAccessToken, CompatRefreshToken, CompatRefreshTokenState, CompatSession,
CompatSessionState, CompatSsoLogin, CompatSsoLoginState, Device,
CompatSessionState, CompatSsoLogin, CompatSsoLoginState, Device, ToScopeTokenError,
},
oauth2::{
AuthorizationCode, AuthorizationGrant, AuthorizationGrantStage, Client, DeviceCodeGrant,

View File

@@ -76,6 +76,9 @@ pub struct SiteConfig {
/// Whether users can recover their account via email.
pub account_recovery_allowed: bool,
/// Whether users can delete their own account.
pub account_deactivation_allowed: bool,
/// Captcha configuration
pub captcha: Option<CaptchaConfig>,

View File

@@ -21,14 +21,15 @@ pub struct User {
pub sub: String,
pub created_at: DateTime<Utc>,
pub locked_at: Option<DateTime<Utc>>,
pub deactivated_at: Option<DateTime<Utc>>,
pub can_request_admin: bool,
}
impl User {
/// Returns `true` unless the user is locked.
/// Returns `true` unless the user is locked or deactivated.
#[must_use]
pub fn is_valid(&self) -> bool {
self.locked_at.is_none()
self.locked_at.is_none() && self.deactivated_at.is_none()
}
}
@@ -42,6 +43,7 @@ impl User {
sub: "123-456".to_owned(),
created_at: now,
locked_at: None,
deactivated_at: None,
can_request_admin: false,
}]
}

View File

@@ -73,10 +73,10 @@ camino.workspace = true
chrono.workspace = true
elliptic-curve.workspace = true
governor.workspace = true
indexmap = "2.7.1"
indexmap = "2.8.0"
pkcs8.workspace = true
psl = "2.1.86"
time = "0.3.37"
psl = "2.1.93"
time = "0.3.39"
url.workspace = true
mime = "0.3.17"
minijinja.workspace = true

View File

@@ -229,7 +229,7 @@ impl CompatSession {
Self {
id: Ulid::from_bytes([0x01; 16]),
user_id: Ulid::from_bytes([0x01; 16]),
device_id: Some("AABBCCDDEE".to_owned().try_into().unwrap()),
device_id: Some("AABBCCDDEE".to_owned().into()),
user_session_id: Some(Ulid::from_bytes([0x11; 16])),
redirect_uri: Some("https://example.com/redirect".parse().unwrap()),
created_at: DateTime::default(),
@@ -241,7 +241,7 @@ impl CompatSession {
Self {
id: Ulid::from_bytes([0x02; 16]),
user_id: Ulid::from_bytes([0x01; 16]),
device_id: Some("FFGGHHIIJJ".to_owned().try_into().unwrap()),
device_id: Some("FFGGHHIIJJ".to_owned().into()),
user_session_id: Some(Ulid::from_bytes([0x12; 16])),
redirect_uri: None,
created_at: DateTime::default(),

View File

@@ -43,6 +43,7 @@ mod test_utils {
userinfo_endpoint_override: None,
jwks_uri_override: None,
additional_authorization_parameters: Vec::new(),
ui_order: 0,
}
}
}

View File

@@ -13,7 +13,7 @@ use axum::{
};
use chrono::Duration;
use mas_axum_utils::{
FancyError, SessionInfoExt,
FancyError,
cookies::CookieJar,
csrf::{CsrfExt, ProtectedForm},
};
@@ -28,7 +28,10 @@ use mas_templates::{CompatSsoContext, ErrorContext, TemplateContext, Templates};
use serde::{Deserialize, Serialize};
use ulid::Ulid;
use crate::PreferredLanguage;
use crate::{
PreferredLanguage,
session::{SessionOrFallback, load_session_or_fallback},
};
#[derive(Serialize)]
struct AllParams<'s> {
@@ -61,10 +64,20 @@ pub async fn get(
Path(id): Path<Ulid>,
Query(params): Query<Params>,
) -> Result<Response, FancyError> {
let (session_info, cookie_jar) = cookie_jar.session_info();
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
let (cookie_jar, maybe_session) = match load_session_or_fallback(
cookie_jar, &clock, &mut rng, &templates, &locale, &mut repo,
)
.await?
{
SessionOrFallback::MaybeSession {
cookie_jar,
maybe_session,
..
} => (cookie_jar, maybe_session),
SessionOrFallback::Fallback { response } => return Ok(response),
};
let maybe_session = session_info.load_session(&mut repo).await?;
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
let Some(session) = maybe_session else {
// If there is no session, redirect to the login or register screen
@@ -126,10 +139,20 @@ pub async fn post(
Query(params): Query<Params>,
Form(form): Form<ProtectedForm<()>>,
) -> Result<Response, FancyError> {
let (session_info, cookie_jar) = cookie_jar.session_info();
cookie_jar.verify_form(&clock, form)?;
let (cookie_jar, maybe_session) = match load_session_or_fallback(
cookie_jar, &clock, &mut rng, &templates, &locale, &mut repo,
)
.await?
{
SessionOrFallback::MaybeSession {
cookie_jar,
maybe_session,
..
} => (cookie_jar, maybe_session),
SessionOrFallback::Fallback { response } => return Ok(response),
};
let maybe_session = session_info.load_session(&mut repo).await?;
cookie_jar.verify_form(&clock, form)?;
let Some(session) = maybe_session else {
// If there is no session, redirect to the login or register screen

View File

@@ -288,7 +288,7 @@ async fn get_requester(
RequestingEntity::OAuth2Session(Box::new((session, user)))
} else {
let maybe_session = session_info.load_session(&mut repo).await?;
let maybe_session = session_info.load_active_session(&mut repo).await?;
if let Some(session) = maybe_session.as_ref() {
activity_tracker

View File

@@ -46,6 +46,9 @@ pub struct SiteConfig {
/// Whether passwords are enabled and users can register using a password.
password_registration_enabled: bool,
/// Whether users can delete their own account.
account_deactivation_allowed: bool,
/// Minimum password complexity, from 0 to 4, in terms of a zxcvbn score.
/// The exact scorer (including dictionaries and other data tables)
/// in use is <https://crates.io/crates/zxcvbn>.
@@ -93,6 +96,7 @@ impl SiteConfig {
password_login_enabled: data_model.password_login_enabled,
password_change_allowed: data_model.password_change_allowed,
password_registration_enabled: data_model.password_registration_enabled,
account_deactivation_allowed: data_model.account_deactivation_allowed,
minimum_password_complexity: data_model.minimum_password_complexity,
}
}

View File

@@ -11,7 +11,14 @@ mod oauth2_session;
mod user;
mod user_email;
use anyhow::Context as _;
use async_graphql::MergedObject;
use mas_data_model::SiteConfig;
use mas_storage::BoxRepository;
use zeroize::Zeroizing;
use super::Requester;
use crate::passwords::PasswordManager;
/// The mutations root of the GraphQL interface.
#[derive(Default, MergedObject)]
@@ -30,3 +37,54 @@ impl Mutation {
Self::default()
}
}
/// Check the password if neeed
///
/// Returns true if password verification is not needed, or if the password is
/// correct. Returns false if the password is incorrect or missing.
async fn verify_password_if_needed(
requester: &Requester,
config: &SiteConfig,
password_manager: &PasswordManager,
password: Option<String>,
user: &mas_data_model::User,
repo: &mut BoxRepository,
) -> Result<bool, async_graphql::Error> {
// If the requester is admin, they don't need to provide a password
if requester.is_admin() {
return Ok(true);
}
// If password login is disabled, assume we don't want the user to reauth
if !config.password_login_enabled {
return Ok(true);
}
// Else we need to check if the user has a password
let Some(user_password) = repo
.user_password()
.active(user)
.await
.context("Failed to load user password")?
else {
// User has no password, so we don't need to verify the password
return Ok(true);
};
let Some(password) = password else {
// There is a password on the user, but not provided in the input
return Ok(false);
};
let password = Zeroizing::new(password.into_bytes());
let res = password_manager
.verify(
user_password.version,
password,
user_password.hashed_password,
)
.await;
Ok(res.is_ok())
}

View File

@@ -18,6 +18,7 @@ use ulid::Ulid;
use url::Url;
use zeroize::Zeroizing;
use super::verify_password_if_needed;
use crate::graphql::{
UserId,
model::{NodeType, User},
@@ -383,6 +384,61 @@ impl ResendRecoveryEmailPayload {
}
}
/// The input for the `deactivateUser` mutation.
#[derive(InputObject)]
pub struct DeactivateUserInput {
/// Whether to ask the homeserver to GDPR-erase the user
///
/// This is equivalent to the `erase` parameter on the
/// `/_matrix/client/v3/account/deactivate` C-S API, which is
/// implementation-specific.
///
/// What Synapse does is documented here:
/// <https://element-hq.github.io/synapse/latest/admin_api/user_admin_api.html#deactivate-account>
hs_erase: bool,
/// The password of the user to deactivate.
password: Option<String>,
}
/// The payload for the `deactivateUser` mutation.
#[derive(Description)]
pub enum DeactivateUserPayload {
/// The user was deactivated.
Deactivated(mas_data_model::User),
/// The password was wrong or missing.
IncorrectPassword,
}
/// The status of the `deactivateUser` mutation.
#[derive(Enum, Copy, Clone, Eq, PartialEq)]
pub enum DeactivateUserStatus {
/// The user was deactivated.
Deactivated,
/// The password was wrong.
IncorrectPassword,
}
#[Object(use_type_description)]
impl DeactivateUserPayload {
/// Status of the operation
async fn status(&self) -> DeactivateUserStatus {
match self {
Self::Deactivated(_) => DeactivateUserStatus::Deactivated,
Self::IncorrectPassword => DeactivateUserStatus::IncorrectPassword,
}
}
async fn user(&self) -> Option<User> {
match self {
Self::Deactivated(user) => Some(User(user.clone())),
Self::IncorrectPassword => None,
}
}
}
fn valid_username_character(c: char) -> bool {
c.is_ascii_lowercase()
|| c.is_ascii_digit()
@@ -868,4 +924,64 @@ impl UserMutations {
recovery_session_id: recovery_session.id,
})
}
/// Deactivate the current user account
///
/// If the user has a password, it *must* be supplied in the `password`
/// field.
async fn deactivate_user(
&self,
ctx: &Context<'_>,
input: DeactivateUserInput,
) -> Result<DeactivateUserPayload, async_graphql::Error> {
let state = ctx.state();
let mut rng = state.rng();
let clock = state.clock();
let requester = ctx.requester();
let site_config = state.site_config();
// Only allow calling this if the requester is a browser session
let Some(browser_session) = requester.browser_session() else {
return Err(async_graphql::Error::new("Unauthorized"));
};
if !site_config.account_deactivation_allowed {
return Err(async_graphql::Error::new(
"Account deactivation is not allowed on this server",
));
}
let mut repo = state.repository().await?;
if !verify_password_if_needed(
requester,
site_config,
&state.password_manager(),
input.password,
&browser_session.user,
&mut repo,
)
.await?
{
return Ok(DeactivateUserPayload::IncorrectPassword);
}
// Deactivate the user right away
let user = repo
.user()
.deactivate(&state.clock(), browser_session.user.clone())
.await?;
// and then schedule a job to deactivate it fully
repo.queue_job()
.schedule_job(
&mut rng,
&clock,
DeactivateUserJob::new(&user, input.hs_erase),
)
.await?;
repo.save().await?;
Ok(DeactivateUserPayload::Deactivated(user))
}
}

View File

@@ -13,6 +13,7 @@ use mas_storage::{
user::{UserEmailFilter, UserEmailRepository, UserRepository},
};
use super::verify_password_if_needed;
use crate::graphql::{
model::{NodeType, User, UserEmail, UserEmailAuthentication},
state::ContextExt,
@@ -120,6 +121,10 @@ impl AddEmailPayload {
struct RemoveEmailInput {
/// The ID of the email address to remove
user_email_id: ID,
/// The user's current password. This is required if the user is not an
/// admin and it has a password on its account.
password: Option<String>,
}
/// The status of the `removeEmail` mutation
@@ -130,6 +135,9 @@ enum RemoveEmailStatus {
/// The email address was not found
NotFound,
/// The password provided is incorrect
IncorrectPassword,
}
/// The payload of the `removeEmail` mutation
@@ -137,6 +145,7 @@ enum RemoveEmailStatus {
enum RemoveEmailPayload {
Removed(mas_data_model::UserEmail),
NotFound,
IncorrectPassword,
}
#[Object(use_type_description)]
@@ -146,6 +155,7 @@ impl RemoveEmailPayload {
match self {
RemoveEmailPayload::Removed(_) => RemoveEmailStatus::Removed,
RemoveEmailPayload::NotFound => RemoveEmailStatus::NotFound,
RemoveEmailPayload::IncorrectPassword => RemoveEmailStatus::IncorrectPassword,
}
}
@@ -153,20 +163,23 @@ impl RemoveEmailPayload {
async fn email(&self) -> Option<UserEmail> {
match self {
RemoveEmailPayload::Removed(email) => Some(UserEmail(email.clone())),
RemoveEmailPayload::NotFound => None,
RemoveEmailPayload::NotFound | RemoveEmailPayload::IncorrectPassword => None,
}
}
/// The user to whom the email address belonged
async fn user(&self, ctx: &Context<'_>) -> Result<Option<User>, async_graphql::Error> {
let state = ctx.state();
let mut repo = state.repository().await?;
let user_id = match self {
RemoveEmailPayload::Removed(email) => email.user_id,
RemoveEmailPayload::NotFound => return Ok(None),
RemoveEmailPayload::NotFound | RemoveEmailPayload::IncorrectPassword => {
return Ok(None);
}
};
let mut repo = state.repository().await?;
let user = repo
.user()
.lookup(user_id)
@@ -226,6 +239,10 @@ struct StartEmailAuthenticationInput {
/// The email address to add to the account
email: String,
/// The user's current password. This is required if the user has a password
/// on its account.
password: Option<String>,
/// The language to use for the email
#[graphql(default = "en")]
language: String,
@@ -244,6 +261,8 @@ enum StartEmailAuthenticationStatus {
Denied,
/// The email address is already in use on this account
InUse,
/// The password provided is incorrect
IncorrectPassword,
}
/// The payload of the `startEmailAuthentication` mutation
@@ -256,6 +275,7 @@ enum StartEmailAuthenticationPayload {
violations: Vec<mas_policy::Violation>,
},
InUse,
IncorrectPassword,
}
#[Object(use_type_description)]
@@ -268,6 +288,7 @@ impl StartEmailAuthenticationPayload {
Self::RateLimited => StartEmailAuthenticationStatus::RateLimited,
Self::Denied { .. } => StartEmailAuthenticationStatus::Denied,
Self::InUse => StartEmailAuthenticationStatus::InUse,
Self::IncorrectPassword => StartEmailAuthenticationStatus::IncorrectPassword,
}
}
@@ -275,9 +296,11 @@ impl StartEmailAuthenticationPayload {
async fn authentication(&self) -> Option<&UserEmailAuthentication> {
match self {
Self::Started(authentication) => Some(authentication),
Self::InvalidEmailAddress | Self::RateLimited | Self::Denied { .. } | Self::InUse => {
None
}
Self::InvalidEmailAddress
| Self::RateLimited
| Self::Denied { .. }
| Self::InUse
| Self::IncorrectPassword => None,
}
}
@@ -494,6 +517,20 @@ impl UserEmailMutations {
.await?
.context("Failed to load user")?;
// Validate the password input if needed
if !verify_password_if_needed(
requester,
state.site_config(),
&state.password_manager(),
input.password,
&user,
&mut repo,
)
.await?
{
return Ok(RemoveEmailPayload::IncorrectPassword);
}
// TODO: don't allow removing the last email address
repo.user_email().remove(user_email.clone()).await?;
@@ -627,6 +664,20 @@ impl UserEmailMutations {
});
}
// Validate the password input if needed
if !verify_password_if_needed(
requester,
state.site_config(),
&state.password_manager(),
input.password,
&browser_session.user,
&mut repo,
)
.await?
{
return Ok(StartEmailAuthenticationPayload::IncorrectPassword);
}
// Create a new authentication session
let authentication = repo
.user_email()

View File

@@ -44,9 +44,7 @@ impl SessionQuery {
return Ok(None);
}
let Ok(device) = Device::try_from(device_id) else {
return Ok(None);
};
let device = Device::from(device_id);
let state = ctx.state();
let mut repo = state.repository().await?;
@@ -81,7 +79,14 @@ impl SessionQuery {
// Then, try to find an OAuth 2.0 session. Because we don't have any dedicated
// device column, we're looking up using the device scope.
let scope = Scope::from_iter([device.to_scope_token()]);
// All device IDs can't necessarily be encoded as a scope. If it's not the case,
// we'll skip looking for OAuth 2.0 sessions.
let Ok(scope_token) = device.to_scope_token() else {
repo.cancel().await?;
return Ok(None);
};
let scope = Scope::from_iter([scope_token]);
let filter = OAuth2SessionFilter::new()
.for_user(&user)
.active_only()

View File

@@ -64,6 +64,7 @@ mod activity_tracker;
mod captcha;
mod preferred_language;
mod rate_limit;
mod session;
#[cfg(test)]
mod test_utils;

View File

@@ -97,7 +97,7 @@ pub(crate) async fn get(
) -> Result<Response, RouteError> {
let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut repo).await?;
let maybe_session = session_info.load_active_session(&mut repo).await?;
let user_agent = user_agent.map(|TypedHeader(ua)| ua.to_string());

View File

@@ -176,7 +176,7 @@ pub(crate) async fn get(
let callback_destination = callback_destination.clone();
let locale = locale.clone();
async move {
let maybe_session = session_info.load_session(&mut repo).await?;
let maybe_session = session_info.load_active_session(&mut repo).await?;
let prompt = params.auth.prompt.as_deref().unwrap_or_default();
// Check if the request/request_uri/registration params are used. If so, reply

View File

@@ -1,4 +1,4 @@
// Copyright 2024 New Vector Ltd.
// Copyright 2024, 2025 New Vector Ltd.
// Copyright 2022-2024 The Matrix.org Foundation C.I.C.
//
// SPDX-License-Identifier: AGPL-3.0-only
@@ -11,7 +11,6 @@ use axum::{
use axum_extra::TypedHeader;
use hyper::StatusCode;
use mas_axum_utils::{
SessionInfoExt,
cookies::CookieJar,
csrf::{CsrfExt, ProtectedForm},
sentry::SentryEventID,
@@ -27,7 +26,10 @@ use mas_templates::{ConsentContext, PolicyViolationContext, TemplateContext, Tem
use thiserror::Error;
use ulid::Ulid;
use crate::{BoundActivityTracker, PreferredLanguage, impl_from_error_for_route};
use crate::{
BoundActivityTracker, PreferredLanguage, impl_from_error_for_route,
session::{SessionOrFallback, load_session_or_fallback},
};
#[derive(Debug, Error)]
pub enum RouteError {
@@ -54,6 +56,7 @@ impl_from_error_for_route!(mas_templates::TemplateError);
impl_from_error_for_route!(mas_storage::RepositoryError);
impl_from_error_for_route!(mas_policy::LoadError);
impl_from_error_for_route!(mas_policy::EvaluationError);
impl_from_error_for_route!(crate::session::SessionLoadError);
impl IntoResponse for RouteError {
fn into_response(self) -> axum::response::Response {
@@ -85,9 +88,18 @@ pub(crate) async fn get(
cookie_jar: CookieJar,
Path(grant_id): Path<Ulid>,
) -> Result<Response, RouteError> {
let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut repo).await?;
let (cookie_jar, maybe_session) = match load_session_or_fallback(
cookie_jar, &clock, &mut rng, &templates, &locale, &mut repo,
)
.await?
{
SessionOrFallback::MaybeSession {
cookie_jar,
maybe_session,
..
} => (cookie_jar, maybe_session),
SessionOrFallback::Fallback { response } => return Ok(response),
};
let user_agent = user_agent.map(|ua| ua.to_string());
@@ -107,48 +119,48 @@ pub(crate) async fn get(
return Err(RouteError::GrantNotPending);
}
if let Some(session) = maybe_session {
activity_tracker
.record_browser_session(&clock, &session)
.await;
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
let res = policy
.evaluate_authorization_grant(mas_policy::AuthorizationGrantInput {
user: Some(&session.user),
client: &client,
scope: &grant.scope,
grant_type: mas_policy::GrantType::AuthorizationCode,
requester: mas_policy::Requester {
ip_address: activity_tracker.ip(),
user_agent,
},
})
.await?;
if res.valid() {
let ctx = ConsentContext::new(grant, client)
.with_session(session)
.with_csrf(csrf_token.form_value())
.with_language(locale);
let content = templates.render_consent(&ctx)?;
Ok((cookie_jar, Html(content)).into_response())
} else {
let ctx = PolicyViolationContext::for_authorization_grant(grant, client)
.with_session(session)
.with_csrf(csrf_token.form_value())
.with_language(locale);
let content = templates.render_policy_violation(&ctx)?;
Ok((cookie_jar, Html(content)).into_response())
}
} else {
let Some(session) = maybe_session else {
let login = mas_router::Login::and_continue_grant(grant_id);
Ok((cookie_jar, url_builder.redirect(&login)).into_response())
return Ok((cookie_jar, url_builder.redirect(&login)).into_response());
};
activity_tracker
.record_browser_session(&clock, &session)
.await;
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
let res = policy
.evaluate_authorization_grant(mas_policy::AuthorizationGrantInput {
user: Some(&session.user),
client: &client,
scope: &grant.scope,
grant_type: mas_policy::GrantType::AuthorizationCode,
requester: mas_policy::Requester {
ip_address: activity_tracker.ip(),
user_agent,
},
})
.await?;
if res.valid() {
let ctx = ConsentContext::new(grant, client)
.with_session(session)
.with_csrf(csrf_token.form_value())
.with_language(locale);
let content = templates.render_consent(&ctx)?;
Ok((cookie_jar, Html(content)).into_response())
} else {
let ctx = PolicyViolationContext::for_authorization_grant(grant, client)
.with_session(session)
.with_csrf(csrf_token.form_value())
.with_language(locale);
let content = templates.render_policy_violation(&ctx)?;
Ok((cookie_jar, Html(content)).into_response())
}
}
@@ -161,6 +173,8 @@ pub(crate) async fn get(
pub(crate) async fn post(
mut rng: BoxRng,
clock: BoxClock,
PreferredLanguage(locale): PreferredLanguage,
State(templates): State<Templates>,
mut policy: Policy,
mut repo: BoxRepository,
activity_tracker: BoundActivityTracker,
@@ -172,9 +186,18 @@ pub(crate) async fn post(
) -> Result<Response, RouteError> {
cookie_jar.verify_form(&clock, form)?;
let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut repo).await?;
let (cookie_jar, maybe_session) = match load_session_or_fallback(
cookie_jar, &clock, &mut rng, &templates, &locale, &mut repo,
)
.await?
{
SessionOrFallback::MaybeSession {
cookie_jar,
maybe_session,
..
} => (cookie_jar, maybe_session),
SessionOrFallback::Fallback { response } => return Ok(response),
};
let user_agent = user_agent.map(|ua| ua.to_string());

View File

@@ -12,7 +12,7 @@ use axum::{
};
use axum_extra::TypedHeader;
use mas_axum_utils::{
FancyError, SessionInfoExt,
FancyError,
cookies::CookieJar,
csrf::{CsrfExt, ProtectedForm},
};
@@ -24,7 +24,10 @@ use serde::Deserialize;
use tracing::warn;
use ulid::Ulid;
use crate::{BoundActivityTracker, PreferredLanguage};
use crate::{
BoundActivityTracker, PreferredLanguage,
session::{SessionOrFallback, load_session_or_fallback},
};
#[derive(Deserialize, Debug)]
#[serde(rename_all = "lowercase")]
@@ -51,10 +54,20 @@ pub(crate) async fn get(
cookie_jar: CookieJar,
Path(grant_id): Path<Ulid>,
) -> Result<Response, FancyError> {
let (session_info, cookie_jar) = cookie_jar.session_info();
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
let (cookie_jar, maybe_session) = match load_session_or_fallback(
cookie_jar, &clock, &mut rng, &templates, &locale, &mut repo,
)
.await?
{
SessionOrFallback::MaybeSession {
cookie_jar,
maybe_session,
..
} => (cookie_jar, maybe_session),
SessionOrFallback::Fallback { response } => return Ok(response),
};
let maybe_session = session_info.load_session(&mut repo).await?;
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
let user_agent = user_agent.map(|ua| ua.to_string());
@@ -137,12 +150,21 @@ pub(crate) async fn post(
Path(grant_id): Path<Ulid>,
Form(form): Form<ProtectedForm<ConsentForm>>,
) -> Result<Response, FancyError> {
let (session_info, cookie_jar) = cookie_jar.session_info();
let form = cookie_jar.verify_form(&clock, form)?;
let (cookie_jar, maybe_session) = match load_session_or_fallback(
cookie_jar, &clock, &mut rng, &templates, &locale, &mut repo,
)
.await?
{
SessionOrFallback::MaybeSession {
cookie_jar,
maybe_session,
..
} => (cookie_jar, maybe_session),
SessionOrFallback::Fallback { response } => return Ok(response),
};
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
let maybe_session = session_info.load_session(&mut repo).await?;
let user_agent = user_agent.map(|TypedHeader(ua)| ua.to_string());
let Some(session) = maybe_session else {

View File

@@ -4,8 +4,8 @@
// SPDX-License-Identifier: AGPL-3.0-only
// Please see LICENSE in the repository root for full details.
use axum::{Json, extract::State, response::IntoResponse};
use hyper::StatusCode;
use axum::{Json, extract::State, http::HeaderValue, response::IntoResponse};
use hyper::{HeaderMap, StatusCode};
use mas_axum_utils::{
client_authorization::{ClientAuthorization, CredentialsVerificationError},
sentry::SentryEventID,
@@ -74,6 +74,10 @@ pub enum RouteError {
#[error("unknown compat session")]
CantLoadCompatSession,
/// The Device ID in the compat session can't be encoded as a scope
#[error("device ID contains characters that are not allowed in a scope")]
CantEncodeDeviceID(#[from] mas_data_model::ToScopeTokenError),
#[error("invalid user")]
InvalidUser,
@@ -120,7 +124,8 @@ impl IntoResponse for RouteError {
| Self::InvalidUser
| Self::InvalidCompatSession
| Self::InvalidOAuthSession
| Self::InvalidTokenFormat(_) => Json(INACTIVE).into_response(),
| Self::InvalidTokenFormat(_)
| Self::CantEncodeDeviceID(_) => Json(INACTIVE).into_response(),
Self::NotAllowed => (
StatusCode::UNAUTHORIZED,
Json(ClientError::from(ClientErrorCode::AccessDenied)),
@@ -152,6 +157,7 @@ const INACTIVE: IntrospectionResponse = IntrospectionResponse {
aud: None,
iss: None,
jti: None,
device_id: None,
};
const API_SCOPE: ScopeToken = ScopeToken::from_static("urn:matrix:org.matrix.msc2967.client:api:*");
@@ -170,6 +176,7 @@ pub(crate) async fn post(
mut repo: BoxRepository,
activity_tracker: ActivityTracker,
State(encrypter): State<Encrypter>,
headers: HeaderMap,
client_authorization: ClientAuthorization<IntrospectionRequest>,
) -> Result<impl IntoResponse, RouteError> {
let client = client_authorization
@@ -202,6 +209,16 @@ pub(crate) async fn post(
}
}
// Not all device IDs can be encoded as scope. On OAuth 2.0 sessions, we
// don't have this problem, as the device ID *is* already encoded as a scope.
// But on compatibility sessions, it's possible to have device IDs with
// spaces in them, or other weird characters.
// In those cases, we prefer explicitly giving out the device ID as a separate
// field. The client introspecting tells us whether it supports having the
// device ID as a separate field through this header.
let supports_explicit_device_id =
headers.get("X-MAS-Supports-Device-Id") == Some(&HeaderValue::from_static("1"));
// XXX: we should get the IP from the client introspecting the token
let ip = None;
@@ -270,6 +287,7 @@ pub(crate) async fn post(
aud: None,
iss: None,
jti: Some(access_token.jti()),
device_id: None,
}
}
@@ -329,6 +347,7 @@ pub(crate) async fn post(
aud: None,
iss: None,
jti: Some(refresh_token.jti()),
device_id: None,
}
}
@@ -365,7 +384,19 @@ pub(crate) async fn post(
// Grant the synapse admin scope if the session has the admin flag set.
let synapse_admin_scope_opt = session.is_synapse_admin.then_some(SYNAPSE_ADMIN_SCOPE);
let device_scope_opt = session.device.as_ref().map(Device::to_scope_token);
// If the client supports explicitly giving the device ID in the response, skip
// encoding it in the scope
let device_scope_opt = if supports_explicit_device_id {
None
} else {
session
.device
.as_ref()
.map(Device::to_scope_token)
.transpose()?
};
let scope = [API_SCOPE]
.into_iter()
.chain(device_scope_opt)
@@ -389,6 +420,7 @@ pub(crate) async fn post(
aud: None,
iss: None,
jti: None,
device_id: session.device.map(Device::into),
}
}
@@ -425,7 +457,19 @@ pub(crate) async fn post(
// Grant the synapse admin scope if the session has the admin flag set.
let synapse_admin_scope_opt = session.is_synapse_admin.then_some(SYNAPSE_ADMIN_SCOPE);
let device_scope_opt = session.device.as_ref().map(Device::to_scope_token);
// If the client supports explicitly giving the device ID in the response, skip
// encoding it in the scope
let device_scope_opt = if supports_explicit_device_id {
None
} else {
session
.device
.as_ref()
.map(Device::to_scope_token)
.transpose()?
};
let scope = [API_SCOPE]
.into_iter()
.chain(device_scope_opt)
@@ -449,6 +493,7 @@ pub(crate) async fn post(
aud: None,
iss: None,
jti: None,
device_id: session.device.map(Device::into),
}
}
};
@@ -777,10 +822,30 @@ mod tests {
response.assert_status(StatusCode::OK);
let response: IntrospectionResponse = response.json();
assert!(response.active);
assert_eq!(response.username, Some("alice".to_owned()));
assert_eq!(response.client_id, Some("legacy".to_owned()));
assert_eq!(response.username.as_deref(), Some("alice"));
assert_eq!(response.client_id.as_deref(), Some("legacy"));
assert_eq!(response.token_type, Some(OAuthTokenTypeHint::AccessToken));
assert_eq!(response.scope, Some(expected_scope.clone()));
assert_eq!(response.scope.as_ref(), Some(&expected_scope));
assert_eq!(response.device_id.as_deref(), Some(device_id));
// Check that requesting with X-MAS-Supports-Device-Id removes the device ID
// from the scope but not from the explicit device_id field
let request = Request::post(OAuth2Introspection::PATH)
.basic_auth(&introspecting_client_id, &introspecting_client_secret)
.header("X-MAS-Supports-Device-Id", "1")
.form(json!({ "token": access_token }));
let response = state.request(request).await;
response.assert_status(StatusCode::OK);
let response: IntrospectionResponse = response.json();
assert!(response.active);
assert_eq!(response.username.as_deref(), Some("alice"));
assert_eq!(response.client_id.as_deref(), Some("legacy"));
assert_eq!(response.token_type, Some(OAuthTokenTypeHint::AccessToken));
assert_eq!(
response.scope.map(|s| s.to_string()),
Some("urn:matrix:org.matrix.msc2967.client:api:*".to_owned())
);
assert_eq!(response.device_id.as_deref(), Some(device_id));
// Do the same request, but with a token_type_hint
let request = Request::post(OAuth2Introspection::PATH)
@@ -808,10 +873,11 @@ mod tests {
response.assert_status(StatusCode::OK);
let response: IntrospectionResponse = response.json();
assert!(response.active);
assert_eq!(response.username, Some("alice".to_owned()));
assert_eq!(response.client_id, Some("legacy".to_owned()));
assert_eq!(response.username.as_deref(), Some("alice"));
assert_eq!(response.client_id.as_deref(), Some("legacy"));
assert_eq!(response.token_type, Some(OAuthTokenTypeHint::RefreshToken));
assert_eq!(response.scope, Some(expected_scope.clone()));
assert_eq!(response.scope.as_ref(), Some(&expected_scope));
assert_eq!(response.device_id.as_deref(), Some(device_id));
// Do the same request, but with a token_type_hint
let request = Request::post(OAuth2Introspection::PATH)

View File

@@ -327,6 +327,7 @@ mod tests {
sub: "123-456".to_owned(),
created_at: now,
locked_at: None,
deactivated_at: None,
can_request_admin: false,
};
@@ -336,6 +337,7 @@ mod tests {
sub: "123-456".to_owned(),
created_at: now,
locked_at: None,
deactivated_at: None,
can_request_admin: false,
};

View File

@@ -0,0 +1,104 @@
// Copyright 2025 New Vector Ltd.
//
// SPDX-License-Identifier: AGPL-3.0-only
// Please see LICENSE in the repository root for full details.
//! Utilities for showing proposer HTML fallbacks when the user is logged out,
//! locked or deactivated
use axum::response::{Html, IntoResponse as _, Response};
use mas_axum_utils::{SessionInfoExt, cookies::CookieJar, csrf::CsrfExt};
use mas_data_model::BrowserSession;
use mas_i18n::DataLocale;
use mas_storage::{BoxRepository, Clock, RepositoryError};
use mas_templates::{AccountInactiveContext, TemplateContext, Templates};
use rand::RngCore;
use thiserror::Error;
#[derive(Debug, Error)]
#[error(transparent)]
pub enum SessionLoadError {
Template(#[from] mas_templates::TemplateError),
Repository(#[from] RepositoryError),
}
#[allow(clippy::large_enum_variant)]
pub enum SessionOrFallback {
MaybeSession {
cookie_jar: CookieJar,
maybe_session: Option<BrowserSession>,
},
Fallback {
response: Response,
},
}
/// Load a session from the cookie jar, or fall back to an HTML error page if
/// the account is locked, deactivated or logged out
pub async fn load_session_or_fallback(
cookie_jar: CookieJar,
clock: &impl Clock,
rng: impl RngCore,
templates: &Templates,
locale: &DataLocale,
repo: &mut BoxRepository,
) -> Result<SessionOrFallback, SessionLoadError> {
let (session_info, cookie_jar) = cookie_jar.session_info();
let Some(session_id) = session_info.current_session_id() else {
return Ok(SessionOrFallback::MaybeSession {
cookie_jar,
maybe_session: None,
});
};
let Some(session) = repo.browser_session().lookup(session_id).await? else {
// We looked up the session, but it was not found. Still update the cookie
let session_info = session_info.mark_session_ended();
let cookie_jar = cookie_jar.update_session_info(&session_info);
return Ok(SessionOrFallback::MaybeSession {
cookie_jar,
maybe_session: None,
});
};
if session.user.deactivated_at.is_some() {
// The account is deactivated, show the 'account deactivated' fallback
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock, rng);
let ctx = AccountInactiveContext::new(session.user)
.with_csrf(csrf_token.form_value())
.with_language(locale.clone());
let fallback = templates.render_account_deactivated(&ctx)?;
let response = (cookie_jar, Html(fallback)).into_response();
return Ok(SessionOrFallback::Fallback { response });
}
if session.user.locked_at.is_some() {
// The account is locked, show the 'account locked' fallback
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock, rng);
let ctx = AccountInactiveContext::new(session.user)
.with_csrf(csrf_token.form_value())
.with_language(locale.clone());
let fallback = templates.render_account_locked(&ctx)?;
let response = (cookie_jar, Html(fallback)).into_response();
return Ok(SessionOrFallback::Fallback { response });
}
if session.finished_at.is_some() {
// The session has finished, but the browser still has the cookie. This is
// likely a 'remote' logout, triggered either by an admin or from the
// user-management UI. In this case, we show the 'account logged out'
// fallback.
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock, rng);
let ctx = AccountInactiveContext::new(session.user)
.with_csrf(csrf_token.form_value())
.with_language(locale.clone());
let fallback = templates.render_account_logged_out(&ctx)?;
let response = (cookie_jar, Html(fallback)).into_response();
return Ok(SessionOrFallback::Fallback { response });
}
Ok(SessionOrFallback::MaybeSession {
cookie_jar,
maybe_session: Some(session),
})
}

View File

@@ -137,6 +137,7 @@ pub fn test_site_config() -> SiteConfig {
displayname_change_allowed: true,
password_change_allowed: true,
account_recovery_allowed: true,
account_deactivation_allowed: true,
captcha: None,
minimum_password_complexity: 1,
session_expiration: None,

View File

@@ -19,7 +19,7 @@ use mas_axum_utils::{
csrf::{CsrfExt, ProtectedForm},
sentry::SentryEventID,
};
use mas_data_model::{User, UserAgent};
use mas_data_model::UserAgent;
use mas_jose::jwt::Jwt;
use mas_matrix::HomeserverConnection;
use mas_policy::Policy;
@@ -31,8 +31,8 @@ use mas_storage::{
user::{BrowserSessionRepository, UserEmailRepository, UserRepository},
};
use mas_templates::{
ErrorContext, FieldError, FormError, TemplateContext, Templates, ToFormState,
UpstreamExistingLinkContext, UpstreamRegister, UpstreamSuggestLink,
AccountInactiveContext, ErrorContext, FieldError, FormError, TemplateContext, Templates,
ToFormState, UpstreamExistingLinkContext, UpstreamRegister, UpstreamSuggestLink,
};
use minijinja::Environment;
use serde::{Deserialize, Serialize};
@@ -242,7 +242,7 @@ pub(crate) async fn get(
let (user_session_info, cookie_jar) = cookie_jar.session_info();
let (csrf_token, mut cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
let maybe_user_session = user_session_info.load_session(&mut repo).await?;
let maybe_user_session = user_session_info.load_active_session(&mut repo).await?;
let response = match (maybe_user_session, link.user_id) {
(Some(session), Some(user_id)) if session.user.id == user_id => {
@@ -272,8 +272,6 @@ pub(crate) async fn get(
.user()
.lookup(user_id)
.await?
// XXX: is that right?
.filter(User::is_valid)
.ok_or(RouteError::UserNotFound)?;
let ctx = UpstreamExistingLinkContext::new(user)
@@ -300,9 +298,27 @@ pub(crate) async fn get(
.user()
.lookup(user_id)
.await?
.filter(mas_data_model::User::is_valid)
.ok_or(RouteError::UserNotFound)?;
// Check that the user is not locked or deactivated
if user.deactivated_at.is_some() {
// The account is deactivated, show the 'account deactivated' fallback
let ctx = AccountInactiveContext::new(user)
.with_csrf(csrf_token.form_value())
.with_language(locale);
let fallback = templates.render_account_deactivated(&ctx)?;
return Ok((cookie_jar, Html(fallback).into_response()));
}
if user.locked_at.is_some() {
// The account is locked, show the 'account locked' fallback
let ctx = AccountInactiveContext::new(user)
.with_csrf(csrf_token.form_value())
.with_language(locale);
let fallback = templates.render_account_locked(&ctx)?;
return Ok((cookie_jar, Html(fallback).into_response()));
}
let session = repo
.browser_session()
.add(&mut rng, &clock, &user, user_agent)
@@ -556,7 +572,7 @@ pub(crate) async fn post(
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
let (user_session_info, cookie_jar) = cookie_jar.session_info();
let maybe_user_session = user_session_info.load_session(&mut repo).await?;
let maybe_user_session = user_session_info.load_active_session(&mut repo).await?;
let form_state = form.to_form_state();
let session = match (maybe_user_session, link.user_id, form) {
@@ -672,7 +688,7 @@ pub(crate) async fn post(
ctx
};
let forced_username = if provider.claims_imports.localpart.is_forced() {
let username = if provider.claims_imports.localpart.is_forced() {
let template = provider
.claims_imports
.localpart
@@ -680,128 +696,108 @@ pub(crate) async fn post(
.as_deref()
.unwrap_or(DEFAULT_LOCALPART_TEMPLATE);
render_attribute_template(
&env,
template,
&context,
provider.claims_imports.email.is_required(),
)?
render_attribute_template(&env, template, &context, true)?
} else {
None
};
// If there is no forced username, we can use the one the user entered
let username = forced_username
.or(username)
.filter(|username| !username.is_empty());
let Some(username) = username else {
// We're missing a username, let's re-render the form with an error
let form_state = form_state.with_error_on_field(
mas_templates::UpstreamRegisterFormField::Username,
FieldError::Required,
);
let ctx = ctx
.with_form_state(form_state)
.with_csrf(csrf_token.form_value())
.with_language(locale);
return Ok((
cookie_jar,
Html(templates.render_upstream_oauth2_do_register(&ctx)?),
)
.into_response());
};
// If there is no forced username, we can use the one the user entered
username
}
.unwrap_or_default();
let ctx = ctx.with_localpart(
username.clone(),
provider.claims_imports.localpart.is_forced(),
);
// Check if there is an existing user
let existing_user = repo.user().find_by_username(&username).await?;
// Validate the form
let form_state = {
let mut form_state = form_state;
let mut homeserver_denied_username = false;
if username.is_empty() {
form_state.add_error_on_field(
mas_templates::UpstreamRegisterFormField::Username,
FieldError::Required,
);
} else if repo.user().exists(&username).await? {
form_state.add_error_on_field(
mas_templates::UpstreamRegisterFormField::Username,
FieldError::Exists,
);
} else if !homeserver
.is_localpart_available(&username)
.await
.map_err(RouteError::HomeserverConnection)?
{
// The user already exists on the homeserver
tracing::warn!(
%username,
"Homeserver denied username provided by user"
);
// Ask the homeserver to make sure the username is valid
let is_available = homeserver
.is_localpart_available(&username)
.await
.map_err(RouteError::HomeserverConnection)?;
// We defer adding the error on the field, until we know whether we had another
// error from the policy, to avoid showing both
homeserver_denied_username = true;
}
if existing_user.is_some() || !is_available {
// If there is an existing user, we can't create a new one
// with the same username, show an error
// If we have a TOS in the config, make sure the user has accepted it
if site_config.tos_uri.is_some() && !accept_terms {
form_state.add_error_on_field(
mas_templates::UpstreamRegisterFormField::AcceptTerms,
FieldError::Required,
);
}
let form_state = form_state.with_error_on_field(
mas_templates::UpstreamRegisterFormField::Username,
FieldError::Exists,
);
// Policy check
let res = policy
.evaluate_register(mas_policy::RegisterInput {
registration_method: mas_policy::RegistrationMethod::UpstreamOAuth2,
username: &username,
email: email.as_deref(),
requester: mas_policy::Requester {
ip_address: activity_tracker.ip(),
user_agent: user_agent.clone().map(|ua| ua.raw),
},
})
.await?;
let ctx = ctx
.with_form_state(form_state)
.with_csrf(csrf_token.form_value())
.with_language(locale);
return Ok((
cookie_jar,
Html(templates.render_upstream_oauth2_do_register(&ctx)?),
)
.into_response());
}
// If we need have a TOS in the config, make sure the user has accepted it
if site_config.tos_uri.is_some() && !accept_terms {
let form_state = form_state.with_error_on_field(
mas_templates::UpstreamRegisterFormField::AcceptTerms,
FieldError::Required,
);
let ctx = ctx
.with_form_state(form_state)
.with_csrf(csrf_token.form_value())
.with_language(locale);
return Ok((
cookie_jar,
Html(templates.render_upstream_oauth2_do_register(&ctx)?),
)
.into_response());
}
// Policy check
let res = policy
.evaluate_register(mas_policy::RegisterInput {
registration_method: mas_policy::RegistrationMethod::UpstreamOAuth2,
username: &username,
email: email.as_deref(),
requester: mas_policy::Requester {
ip_address: activity_tracker.ip(),
user_agent: user_agent.clone().map(|ua| ua.raw),
},
})
.await?;
if !res.valid() {
let form_state =
res.violations
.into_iter()
.fold(form_state, |form_state, violation| {
match violation.field.as_deref() {
Some("username") => form_state.with_error_on_field(
mas_templates::UpstreamRegisterFormField::Username,
FieldError::Policy {
code: violation.code.map(|c| c.as_str()),
message: violation.msg,
},
),
_ => form_state.with_error_on_form(FormError::Policy {
for violation in res.violations {
match violation.field.as_deref() {
Some("username") => {
// If the homeserver denied the username, but we also had an error on
// the policy side, we don't want to show
// both, so we reset the state here
homeserver_denied_username = false;
form_state.add_error_on_field(
mas_templates::UpstreamRegisterFormField::Username,
FieldError::Policy {
code: violation.code.map(|c| c.as_str()),
message: violation.msg,
}),
}
});
},
);
}
_ => form_state.add_error_on_form(FormError::Policy {
code: violation.code.map(|c| c.as_str()),
message: violation.msg,
}),
}
}
if homeserver_denied_username {
// XXX: we may want to return different errors like "this username is reserved"
form_state.add_error_on_field(
mas_templates::UpstreamRegisterFormField::Username,
FieldError::Exists,
);
}
form_state
};
if !form_state.is_valid() {
let ctx = ctx
.with_form_state(form_state)
.with_csrf(csrf_token.form_value())
.with_language(locale);
return Ok((
cookie_jar,
Html(templates.render_upstream_oauth2_do_register(&ctx)?),
@@ -953,6 +949,7 @@ mod tests {
pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
response_mode: None,
additional_authorization_parameters: Vec::new(),
ui_order: 0,
},
)
.await

View File

@@ -8,13 +8,16 @@ use axum::{
extract::{Query, State},
response::{Html, IntoResponse},
};
use mas_axum_utils::{FancyError, SessionInfoExt, cookies::CookieJar};
use mas_axum_utils::{FancyError, cookies::CookieJar};
use mas_router::{PostAuthAction, UrlBuilder};
use mas_storage::{BoxClock, BoxRepository};
use mas_storage::{BoxClock, BoxRepository, BoxRng};
use mas_templates::{AppContext, TemplateContext, Templates};
use serde::Deserialize;
use crate::{BoundActivityTracker, PreferredLanguage};
use crate::{
BoundActivityTracker, PreferredLanguage,
session::{SessionOrFallback, load_session_or_fallback},
};
#[derive(Deserialize)]
pub struct Params {
@@ -31,13 +34,24 @@ pub async fn get(
Query(Params { action }): Query<Params>,
mut repo: BoxRepository,
clock: BoxClock,
mut rng: BoxRng,
cookie_jar: CookieJar,
) -> Result<impl IntoResponse, FancyError> {
let (session_info, cookie_jar) = cookie_jar.session_info();
let session = session_info.load_session(&mut repo).await?;
let (cookie_jar, maybe_session) = match load_session_or_fallback(
cookie_jar, &clock, &mut rng, &templates, &locale, &mut repo,
)
.await?
{
SessionOrFallback::MaybeSession {
cookie_jar,
maybe_session,
..
} => (cookie_jar, maybe_session),
SessionOrFallback::Fallback { response } => return Ok(response),
};
// TODO: keep the full path, not just the action
let Some(session) = session else {
let Some(session) = maybe_session else {
return Ok((
cookie_jar,
url_builder.redirect(&mas_router::Login::and_then(

View File

@@ -6,14 +6,18 @@
use axum::{
extract::State,
response::{Html, IntoResponse},
response::{Html, IntoResponse, Response},
};
use mas_axum_utils::{FancyError, SessionInfoExt, cookies::CookieJar, csrf::CsrfExt};
use mas_axum_utils::{FancyError, cookies::CookieJar, csrf::CsrfExt};
use mas_router::UrlBuilder;
use mas_storage::{BoxClock, BoxRepository, BoxRng};
use mas_templates::{IndexContext, TemplateContext, Templates};
use crate::{BoundActivityTracker, preferred_language::PreferredLanguage};
use crate::{
BoundActivityTracker,
preferred_language::PreferredLanguage,
session::{SessionOrFallback, load_session_or_fallback},
};
#[tracing::instrument(name = "handlers.views.index.get", skip_all, err)]
pub async fn get(
@@ -25,23 +29,34 @@ pub async fn get(
mut repo: BoxRepository,
cookie_jar: CookieJar,
PreferredLanguage(locale): PreferredLanguage,
) -> Result<impl IntoResponse, FancyError> {
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
let (session_info, cookie_jar) = cookie_jar.session_info();
let session = session_info.load_session(&mut repo).await?;
) -> Result<Response, FancyError> {
let (cookie_jar, maybe_session) = match load_session_or_fallback(
cookie_jar, &clock, &mut rng, &templates, &locale, &mut repo,
)
.await?
{
SessionOrFallback::MaybeSession {
cookie_jar,
maybe_session,
..
} => (cookie_jar, maybe_session),
SessionOrFallback::Fallback { response } => return Ok(response),
};
if let Some(session) = session.as_ref() {
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
if let Some(session) = maybe_session.as_ref() {
activity_tracker
.record_browser_session(&clock, session)
.await;
}
let ctx = IndexContext::new(url_builder.oidc_discovery())
.maybe_with_session(session)
.maybe_with_session(maybe_session)
.with_csrf(csrf_token.form_value())
.with_language(locale);
let content = templates.render_index(&ctx)?;
Ok((cookie_jar, Html(content)))
Ok((cookie_jar, Html(content)).into_response())
}

View File

@@ -15,9 +15,9 @@ use hyper::StatusCode;
use mas_axum_utils::{
FancyError, SessionInfoExt,
cookies::CookieJar,
csrf::{CsrfExt, CsrfToken, ProtectedForm},
csrf::{CsrfExt, ProtectedForm},
};
use mas_data_model::{BrowserSession, UserAgent, oauth2::LoginHint};
use mas_data_model::{UserAgent, oauth2::LoginHint};
use mas_i18n::DataLocale;
use mas_matrix::HomeserverConnection;
use mas_router::{UpstreamOAuth2Authorize, UrlBuilder};
@@ -27,10 +27,10 @@ use mas_storage::{
user::{BrowserSessionRepository, UserPasswordRepository, UserRepository},
};
use mas_templates::{
FieldError, FormError, LoginContext, LoginFormField, PostAuthContext, PostAuthContextInner,
TemplateContext, Templates, ToFormState,
AccountInactiveContext, FieldError, FormError, FormState, LoginContext, LoginFormField,
PostAuthContext, PostAuthContextInner, TemplateContext, Templates, ToFormState,
};
use rand::{CryptoRng, Rng};
use rand::Rng;
use serde::{Deserialize, Serialize};
use zeroize::Zeroizing;
@@ -38,6 +38,7 @@ use super::shared::OptionalPostAuthAction;
use crate::{
BoundActivityTracker, Limiter, PreferredLanguage, RequesterFingerprint, SiteConfig,
passwords::PasswordManager,
session::{SessionOrFallback, load_session_or_fallback},
};
#[derive(Debug, Deserialize, Serialize)]
@@ -64,10 +65,18 @@ pub(crate) async fn get(
Query(query): Query<OptionalPostAuthAction>,
cookie_jar: CookieJar,
) -> Result<Response, FancyError> {
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut repo).await?;
let (cookie_jar, maybe_session) = match load_session_or_fallback(
cookie_jar, &clock, &mut rng, &templates, &locale, &mut repo,
)
.await?
{
SessionOrFallback::MaybeSession {
cookie_jar,
maybe_session,
..
} => (cookie_jar, maybe_session),
SessionOrFallback::Fallback { response } => return Ok(response),
};
if let Some(session) = maybe_session {
activity_tracker
@@ -94,18 +103,18 @@ pub(crate) async fn get(
return Ok((cookie_jar, url_builder.redirect(&destination)).into_response());
};
let content = render(
render(
locale,
LoginContext::default().with_upstream_providers(providers),
cookie_jar,
FormState::default(),
query,
csrf_token,
&mut repo,
&clock,
&mut rng,
&templates,
&homeserver,
)
.await?;
Ok((cookie_jar, Html(content)).into_response())
.await
}
#[tracing::instrument(name = "handlers.views.login.post", skip_all, err)]
@@ -135,39 +144,30 @@ pub(crate) async fn post(
let form = cookie_jar.verify_form(&clock, form)?;
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
// Validate the form
let state = {
let mut state = form.to_form_state();
let mut form_state = form.to_form_state();
if form.username.is_empty() {
state.add_error_on_field(LoginFormField::Username, FieldError::Required);
}
if form.username.is_empty() {
form_state.add_error_on_field(LoginFormField::Username, FieldError::Required);
}
if form.password.is_empty() {
state.add_error_on_field(LoginFormField::Password, FieldError::Required);
}
if form.password.is_empty() {
form_state.add_error_on_field(LoginFormField::Password, FieldError::Required);
}
state
};
if !state.is_valid() {
let providers = repo.upstream_oauth_provider().all_enabled().await?;
let content = render(
if !form_state.is_valid() {
return render(
locale,
LoginContext::default()
.with_form_state(state)
.with_upstream_providers(providers),
cookie_jar,
form_state,
query,
csrf_token,
&mut repo,
&clock,
&mut rng,
&templates,
&homeserver,
)
.await?;
return Ok((cookie_jar, Html(content)).into_response());
.await;
}
// Extract the localpart of the MXID, fallback to the bare username
@@ -175,89 +175,64 @@ pub(crate) async fn post(
.localpart(&form.username)
.unwrap_or(&form.username);
match login(
password_manager,
&mut repo,
rng,
&clock,
limiter,
requester,
username,
&form.password,
user_agent,
)
.await
{
Ok(session_info) => {
repo.save().await?;
activity_tracker
.record_browser_session(&clock, &session_info)
.await;
let cookie_jar = cookie_jar.set_session(&session_info);
let reply = query.go_next(&url_builder);
Ok((cookie_jar, reply).into_response())
}
Err(e) => {
let state = state.with_error_on_form(e);
let content = render(
locale,
LoginContext::default().with_form_state(state),
query,
csrf_token,
&mut repo,
&templates,
&homeserver,
)
.await?;
Ok((cookie_jar, Html(content)).into_response())
}
}
}
// TODO: move that logic elsewhere?
async fn login(
password_manager: PasswordManager,
repo: &mut impl RepositoryAccess,
mut rng: impl Rng + CryptoRng + Send,
clock: &impl Clock,
limiter: Limiter,
requester: RequesterFingerprint,
username: &str,
password: &str,
user_agent: Option<UserAgent>,
) -> Result<BrowserSession, FormError> {
// XXX: we're loosing the error context here
// First, lookup the user
let user = repo
.user()
.find_by_username(username)
.await
.map_err(|_e| FormError::Internal)?
.filter(mas_data_model::User::is_valid)
.ok_or(FormError::InvalidCredentials)?;
let Some(user) = repo.user().find_by_username(username).await? else {
let form_state = form_state.with_error_on_form(FormError::InvalidCredentials);
return render(
locale,
cookie_jar,
form_state,
query,
&mut repo,
&clock,
&mut rng,
&templates,
&homeserver,
)
.await;
};
// Check the rate limit
limiter.check_password(requester, &user).map_err(|e| {
if let Err(e) = limiter.check_password(requester, &user) {
tracing::warn!(error = &e as &dyn std::error::Error);
FormError::RateLimitExceeded
})?;
let form_state = form_state.with_error_on_form(FormError::RateLimitExceeded);
return render(
locale,
cookie_jar,
form_state,
query,
&mut repo,
&clock,
&mut rng,
&templates,
&homeserver,
)
.await;
}
// And its password
let user_password = repo
.user_password()
.active(&user)
.await
.map_err(|_e| FormError::Internal)?
.ok_or(FormError::InvalidCredentials)?;
let Some(user_password) = repo.user_password().active(&user).await? else {
// There is no password for this user, but we don't want to disclose that. Show
// a generic 'invalid credentials' error instead
let form_state = form_state.with_error_on_form(FormError::InvalidCredentials);
return render(
locale,
cookie_jar,
form_state,
query,
&mut repo,
&clock,
&mut rng,
&templates,
&homeserver,
)
.await;
};
let password = Zeroizing::new(password.as_bytes().to_vec());
let password = Zeroizing::new(form.password.as_bytes().to_vec());
// Verify the password, and upgrade it on-the-fly if needed
let new_password_hash = password_manager
let user_password = match password_manager
.verify_and_upgrade(
&mut rng,
user_password.version,
@@ -265,51 +240,94 @@ async fn login(
user_password.hashed_password.clone(),
)
.await
.map_err(|_| FormError::InvalidCredentials)?;
let user_password = if let Some((version, new_password_hash)) = new_password_hash {
// Save the upgraded password
repo.user_password()
.add(
{
Ok(Some((version, new_password_hash))) => {
// Save the upgraded password
repo.user_password()
.add(
&mut rng,
&clock,
&user,
version,
new_password_hash,
Some(&user_password),
)
.await?
}
Ok(None) => user_password,
Err(_) => {
let form_state = form_state.with_error_on_form(FormError::InvalidCredentials);
return render(
locale,
cookie_jar,
form_state,
query,
&mut repo,
&clock,
&mut rng,
clock,
&user,
version,
new_password_hash,
Some(&user_password),
&templates,
&homeserver,
)
.await
.map_err(|_| FormError::Internal)?
} else {
user_password
.await;
}
};
// Now that we have checked the user password, we now want to show an error if
// the user is locked or deactivated
if user.deactivated_at.is_some() {
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
let ctx = AccountInactiveContext::new(user)
.with_csrf(csrf_token.form_value())
.with_language(locale);
let content = templates.render_account_deactivated(&ctx)?;
return Ok((cookie_jar, Html(content)).into_response());
}
if user.locked_at.is_some() {
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
let ctx = AccountInactiveContext::new(user)
.with_csrf(csrf_token.form_value())
.with_language(locale);
let content = templates.render_account_locked(&ctx)?;
return Ok((cookie_jar, Html(content)).into_response());
}
// At this point, we should have a 'valid' user. In case we missed something, we
// want it to crash in tests/debug builds
debug_assert!(user.is_valid());
// Start a new session
let user_session = repo
.browser_session()
.add(&mut rng, clock, &user, user_agent)
.await
.map_err(|_| FormError::Internal)?;
.add(&mut rng, &clock, &user, user_agent)
.await?;
// And mark it as authenticated by the password
repo.browser_session()
.authenticate_with_password(&mut rng, clock, &user_session, &user_password)
.await
.map_err(|_| FormError::Internal)?;
.authenticate_with_password(&mut rng, &clock, &user_session, &user_password)
.await?;
Ok(user_session)
repo.save().await?;
activity_tracker
.record_browser_session(&clock, &user_session)
.await;
let cookie_jar = cookie_jar.set_session(&user_session);
let reply = query.go_next(&url_builder);
Ok((cookie_jar, reply).into_response())
}
fn handle_login_hint(
ctx: &mut LoginContext,
mut ctx: LoginContext,
next: &PostAuthContext,
homeserver: &dyn HomeserverConnection,
) {
) -> LoginContext {
let form_state = ctx.form_state_mut();
// Do not override username if coming from a failed login attempt
if form_state.has_value(LoginFormField::Username) {
return;
return ctx;
}
if let PostAuthContextInner::ContinueAuthorizationGrant { ref grant } = next.ctx {
@@ -319,21 +337,31 @@ fn handle_login_hint(
};
form_state.set_value(LoginFormField::Username, value);
}
ctx
}
async fn render(
locale: DataLocale,
mut ctx: LoginContext,
cookie_jar: CookieJar,
form_state: FormState<LoginFormField>,
action: OptionalPostAuthAction,
csrf_token: CsrfToken,
repo: &mut impl RepositoryAccess,
clock: &impl Clock,
rng: impl Rng,
templates: &Templates,
homeserver: &dyn HomeserverConnection,
) -> Result<String, FancyError> {
) -> Result<Response, FancyError> {
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(clock, rng);
let providers = repo.upstream_oauth_provider().all_enabled().await?;
let ctx = LoginContext::default()
.with_form_state(form_state)
.with_upstream_providers(providers);
let next = action.load_context(repo).await?;
let ctx = if let Some(next) = next {
handle_login_hint(&mut ctx, &next, homeserver);
let ctx = handle_login_hint(ctx, &next, homeserver);
ctx.with_post_action(next)
} else {
ctx
@@ -341,7 +369,7 @@ async fn render(
let ctx = ctx.with_csrf(csrf_token.form_value()).with_language(locale);
let content = templates.render_login(&ctx)?;
Ok(content)
Ok((cookie_jar, Html(content)).into_response())
}
#[cfg(test)]
@@ -425,6 +453,7 @@ mod test {
pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
response_mode: None,
additional_authorization_parameters: Vec::new(),
ui_order: 0,
},
)
.await
@@ -465,6 +494,7 @@ mod test {
pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
response_mode: None,
additional_authorization_parameters: Vec::new(),
ui_order: 1,
},
)
.await
@@ -490,7 +520,11 @@ mod test {
);
}
async fn user_with_password(state: &TestState, username: &str, password: &str) {
async fn user_with_password(
state: &TestState,
username: &str,
password: &str,
) -> mas_data_model::User {
let mut rng = state.rng();
let mut repo = state.repository().await.unwrap();
let user = repo
@@ -508,6 +542,7 @@ mod test {
.await
.unwrap();
repo.save().await.unwrap();
user
}
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
@@ -712,4 +747,122 @@ mod test {
assert!(!body.contains("Invalid credentials"));
assert!(body.contains("too many requests"));
}
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_password_login_locked_account(pool: PgPool) {
setup();
let state = TestState::from_pool(pool).await.unwrap();
let cookies = CookieHelper::new();
// Provision a user with a password
let user = user_with_password(&state, "john", "hunter2").await;
// Lock the user
let mut repo = state.repository().await.unwrap();
repo.user().lock(&state.clock, user).await.unwrap();
repo.save().await.unwrap();
// Render the login page to get a CSRF token
let request = Request::get("/login").empty();
let request = cookies.with_cookies(request);
let response = state.request(request).await;
cookies.save_cookies(&response);
response.assert_status(StatusCode::OK);
response.assert_header_value(CONTENT_TYPE, "text/html; charset=utf-8");
// Extract the CSRF token from the response body
let csrf_token = response
.body()
.split("name=\"csrf\" value=\"")
.nth(1)
.unwrap()
.split('\"')
.next()
.unwrap();
// Submit the login form
let request = Request::post("/login").form(serde_json::json!({
"csrf": csrf_token,
"username": "john",
"password": "hunter2",
}));
let request = cookies.with_cookies(request);
let response = state.request(request).await;
cookies.save_cookies(&response);
response.assert_status(StatusCode::OK);
response.assert_header_value(CONTENT_TYPE, "text/html; charset=utf-8");
assert!(response.body().contains("Account locked"));
// A bad password should not disclose that the account is locked
let request = Request::post("/login").form(serde_json::json!({
"csrf": csrf_token,
"username": "john",
"password": "badpassword",
}));
let request = cookies.with_cookies(request);
let response = state.request(request).await;
cookies.save_cookies(&response);
response.assert_status(StatusCode::OK);
response.assert_header_value(CONTENT_TYPE, "text/html; charset=utf-8");
assert!(!response.body().contains("Account locked"));
assert!(response.body().contains("Invalid credentials"));
}
#[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")]
async fn test_password_login_deactivated_account(pool: PgPool) {
setup();
let state = TestState::from_pool(pool).await.unwrap();
let cookies = CookieHelper::new();
// Provision a user with a password
let user = user_with_password(&state, "john", "hunter2").await;
// Deactivate the user
let mut repo = state.repository().await.unwrap();
repo.user().deactivate(&state.clock, user).await.unwrap();
repo.save().await.unwrap();
// Render the login page to get a CSRF token
let request = Request::get("/login").empty();
let request = cookies.with_cookies(request);
let response = state.request(request).await;
cookies.save_cookies(&response);
response.assert_status(StatusCode::OK);
response.assert_header_value(CONTENT_TYPE, "text/html; charset=utf-8");
// Extract the CSRF token from the response body
let csrf_token = response
.body()
.split("name=\"csrf\" value=\"")
.nth(1)
.unwrap()
.split('\"')
.next()
.unwrap();
// Submit the login form
let request = Request::post("/login").form(serde_json::json!({
"csrf": csrf_token,
"username": "john",
"password": "hunter2",
}));
let request = cookies.with_cookies(request);
let response = state.request(request).await;
cookies.save_cookies(&response);
response.assert_status(StatusCode::OK);
response.assert_header_value(CONTENT_TYPE, "text/html; charset=utf-8");
assert!(response.body().contains("Account deleted"));
// A bad password should not disclose that the account is deleted
let request = Request::post("/login").form(serde_json::json!({
"csrf": csrf_token,
"username": "john",
"password": "badpassword",
}));
let request = cookies.with_cookies(request);
let response = state.request(request).await;
cookies.save_cookies(&response);
response.assert_status(StatusCode::OK);
response.assert_header_value(CONTENT_TYPE, "text/html; charset=utf-8");
assert!(!response.body().contains("Account deleted"));
assert!(response.body().contains("Invalid credentials"));
}
}

View File

@@ -29,21 +29,27 @@ pub(crate) async fn post(
) -> Result<impl IntoResponse, FancyError> {
let form = cookie_jar.verify_form(&clock, form)?;
let (session_info, mut cookie_jar) = cookie_jar.session_info();
let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut repo).await?;
if let Some(session_id) = session_info.current_session_id() {
let maybe_session = repo.browser_session().lookup(session_id).await?;
if let Some(session) = maybe_session {
if session.finished_at.is_none() {
activity_tracker
.record_browser_session(&clock, &session)
.await;
if let Some(session) = maybe_session {
activity_tracker
.record_browser_session(&clock, &session)
.await;
repo.browser_session().finish(&clock, session).await?;
cookie_jar = cookie_jar.update_session_info(&session_info.mark_session_ended());
repo.browser_session().finish(&clock, session).await?;
}
}
}
repo.save().await?;
// We always want to clear out the session cookie, even if the session was
// invalid
let cookie_jar = cookie_jar.update_session_info(&session_info.mark_session_ended());
let destination = if let Some(action) = form {
action.go_next(&url_builder)
} else {

View File

@@ -25,7 +25,11 @@ use serde::Deserialize;
use zeroize::Zeroizing;
use super::shared::OptionalPostAuthAction;
use crate::{BoundActivityTracker, PreferredLanguage, SiteConfig, passwords::PasswordManager};
use crate::{
BoundActivityTracker, PreferredLanguage, SiteConfig,
passwords::PasswordManager,
session::{SessionOrFallback, load_session_or_fallback},
};
#[derive(Deserialize, Debug)]
pub(crate) struct ReauthForm {
@@ -52,10 +56,18 @@ pub(crate) async fn get(
.into_response());
}
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut repo).await?;
let (cookie_jar, maybe_session) = match load_session_or_fallback(
cookie_jar, &clock, &mut rng, &templates, &locale, &mut repo,
)
.await?
{
SessionOrFallback::MaybeSession {
cookie_jar,
maybe_session,
..
} => (cookie_jar, maybe_session),
SessionOrFallback::Fallback { response } => return Ok(response),
};
let Some(session) = maybe_session else {
// If there is no session, redirect to the login screen, keeping the
@@ -64,6 +76,8 @@ pub(crate) async fn get(
return Ok((cookie_jar, url_builder.redirect(&login)).into_response());
};
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
activity_tracker
.record_browser_session(&clock, &session)
.await;
@@ -89,6 +103,8 @@ pub(crate) async fn get(
pub(crate) async fn post(
mut rng: BoxRng,
clock: BoxClock,
PreferredLanguage(locale): PreferredLanguage,
State(templates): State<Templates>,
State(password_manager): State<PasswordManager>,
State(url_builder): State<UrlBuilder>,
State(site_config): State<SiteConfig>,
@@ -104,9 +120,18 @@ pub(crate) async fn post(
let form = cookie_jar.verify_form(&clock, form)?;
let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut repo).await?;
let (cookie_jar, maybe_session) = match load_session_or_fallback(
cookie_jar, &clock, &mut rng, &templates, &locale, &mut repo,
)
.await?
{
SessionOrFallback::MaybeSession {
cookie_jar,
maybe_session,
..
} => (cookie_jar, maybe_session),
SessionOrFallback::Fallback { response } => return Ok(response),
};
let Some(session) = maybe_session else {
// If there is no session, redirect to the login screen, keeping the

View File

@@ -46,7 +46,7 @@ pub(crate) async fn get(
let (session_info, cookie_jar) = cookie_jar.session_info();
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
let maybe_session = session_info.load_session(&mut repo).await?;
let maybe_session = session_info.load_active_session(&mut repo).await?;
if maybe_session.is_some() {
// TODO: redirect to continue whatever action was going on
return Ok((cookie_jar, url_builder.redirect(&mas_router::Index)).into_response());
@@ -100,7 +100,7 @@ pub(crate) async fn post(
let (session_info, cookie_jar) = cookie_jar.session_info();
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
let maybe_session = session_info.load_session(&mut repo).await?;
let maybe_session = session_info.load_active_session(&mut repo).await?;
if maybe_session.is_some() {
// TODO: redirect to continue whatever action was going on
return Ok((cookie_jar, url_builder.redirect(&mas_router::Index)).into_response());

View File

@@ -56,7 +56,7 @@ pub(crate) async fn get(
let (session_info, cookie_jar) = cookie_jar.session_info();
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
let maybe_session = session_info.load_session(&mut repo).await?;
let maybe_session = session_info.load_active_session(&mut repo).await?;
if maybe_session.is_some() {
// TODO: redirect to continue whatever action was going on
return Ok((cookie_jar, url_builder.redirect(&mas_router::Index)).into_response());
@@ -96,7 +96,7 @@ pub(crate) async fn post(
let (session_info, cookie_jar) = cookie_jar.session_info();
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
let maybe_session = session_info.load_session(&mut repo).await?;
let maybe_session = session_info.load_active_session(&mut repo).await?;
if maybe_session.is_some() {
// TODO: redirect to continue whatever action was going on
return Ok((cookie_jar, url_builder.redirect(&mas_router::Index)).into_response());

View File

@@ -36,7 +36,7 @@ pub(crate) async fn get(
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut repo).await?;
let maybe_session = session_info.load_active_session(&mut repo).await?;
if let Some(session) = maybe_session {
activity_tracker

View File

@@ -81,7 +81,7 @@ pub(crate) async fn get(
let (csrf_token, cookie_jar) = cookie_jar.csrf_token(&clock, &mut rng);
let (session_info, cookie_jar) = cookie_jar.session_info();
let maybe_session = session_info.load_session(&mut repo).await?;
let maybe_session = session_info.load_active_session(&mut repo).await?;
if maybe_session.is_some() {
let reply = query.action.go_next(&url_builder);

View File

@@ -15,7 +15,7 @@ workspace = true
anyhow.workspace = true
async-trait.workspace = true
camino.workspace = true
convert_case = "0.7.1"
convert_case = "0.8.0"
csv = "1.3.1"
reqwest.workspace = true
serde.workspace = true

View File

@@ -21,7 +21,7 @@ pin-project-lite = "0.2.16"
socket2 = "0.5.8"
thiserror.workspace = true
tokio.workspace = true
tokio-rustls = "0.26.1"
tokio-rustls = "0.26.2"
tokio-util.workspace = true
tower.workspace = true
tower-http.workspace = true

View File

@@ -786,6 +786,9 @@ pub struct IntrospectionResponse {
/// String identifier for the token.
pub jti: Option<String>,
/// MAS extension: explicit device ID
pub device_id: Option<String>,
}
/// A request to the [Revocation Endpoint].

View File

@@ -1,14 +0,0 @@
{
"db_name": "PostgreSQL",
"query": "\n DELETE FROM user_email_confirmation_codes\n WHERE user_email_id = $1\n ",
"describe": {
"columns": [],
"parameters": {
"Left": [
"Uuid"
]
},
"nullable": []
},
"hash": "1f6297fb323e9f2fbfa1c9e3225c0b3037c8c4714533a6240c62275332aa58dc"
}

View File

@@ -0,0 +1,15 @@
{
"db_name": "PostgreSQL",
"query": "\n UPDATE users\n SET deactivated_at = $2\n WHERE user_id = $1\n AND deactivated_at IS NULL\n ",
"describe": {
"columns": [],
"parameters": {
"Left": [
"Uuid",
"Timestamptz"
]
},
"nullable": []
},
"hash": "2f7aba76cd7df75d6a9a6d91d5ddebaedf37437f3bd4f796f5581fab997587d7"
}

View File

@@ -1,6 +1,6 @@
{
"db_name": "PostgreSQL",
"query": "\n SELECT user_id\n , username\n , created_at\n , locked_at\n , can_request_admin\n FROM users\n WHERE username = $1\n ",
"query": "\n SELECT user_id\n , username\n , created_at\n , locked_at\n , deactivated_at\n , can_request_admin\n FROM users\n WHERE username = $1\n ",
"describe": {
"columns": [
{
@@ -25,6 +25,11 @@
},
{
"ordinal": 4,
"name": "deactivated_at",
"type_info": "Timestamptz"
},
{
"ordinal": 5,
"name": "can_request_admin",
"type_info": "Bool"
}
@@ -39,8 +44,9 @@
false,
false,
true,
true,
false
]
},
"hash": "e1a18bd82d28fd86d8b8da8a6ac6eddf224ab32cf96e9c28706dd9aa1d09332b"
"hash": "48213d718a256a12540c0aec595ca3e436be423f2d0c868700c6397745ed0455"
}

View File

@@ -0,0 +1,44 @@
{
"db_name": "PostgreSQL",
"query": "\n INSERT INTO upstream_oauth_providers (\n upstream_oauth_provider_id,\n issuer,\n human_name,\n brand_name,\n scope,\n token_endpoint_auth_method,\n token_endpoint_signing_alg,\n id_token_signed_response_alg,\n fetch_userinfo,\n userinfo_signed_response_alg,\n client_id,\n encrypted_client_secret,\n claims_imports,\n authorization_endpoint_override,\n token_endpoint_override,\n userinfo_endpoint_override,\n jwks_uri_override,\n discovery_mode,\n pkce_mode,\n response_mode,\n additional_parameters,\n ui_order,\n created_at\n ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11,\n $12, $13, $14, $15, $16, $17, $18, $19, $20,\n $21, $22, $23)\n ON CONFLICT (upstream_oauth_provider_id)\n DO UPDATE\n SET\n issuer = EXCLUDED.issuer,\n human_name = EXCLUDED.human_name,\n brand_name = EXCLUDED.brand_name,\n scope = EXCLUDED.scope,\n token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method,\n token_endpoint_signing_alg = EXCLUDED.token_endpoint_signing_alg,\n id_token_signed_response_alg = EXCLUDED.id_token_signed_response_alg,\n fetch_userinfo = EXCLUDED.fetch_userinfo,\n userinfo_signed_response_alg = EXCLUDED.userinfo_signed_response_alg,\n disabled_at = NULL,\n client_id = EXCLUDED.client_id,\n encrypted_client_secret = EXCLUDED.encrypted_client_secret,\n claims_imports = EXCLUDED.claims_imports,\n authorization_endpoint_override = EXCLUDED.authorization_endpoint_override,\n token_endpoint_override = EXCLUDED.token_endpoint_override,\n userinfo_endpoint_override = EXCLUDED.userinfo_endpoint_override,\n jwks_uri_override = EXCLUDED.jwks_uri_override,\n discovery_mode = EXCLUDED.discovery_mode,\n pkce_mode = EXCLUDED.pkce_mode,\n response_mode = EXCLUDED.response_mode,\n additional_parameters = EXCLUDED.additional_parameters,\n ui_order = EXCLUDED.ui_order\n RETURNING created_at\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "created_at",
"type_info": "Timestamptz"
}
],
"parameters": {
"Left": [
"Uuid",
"Text",
"Text",
"Text",
"Text",
"Text",
"Text",
"Text",
"Bool",
"Text",
"Text",
"Text",
"Jsonb",
"Text",
"Text",
"Text",
"Text",
"Text",
"Text",
"Text",
"Jsonb",
"Int4",
"Timestamptz"
]
},
"nullable": [
false
]
},
"hash": "72de26d5e3c56f4b0658685a95b45b647bb6637e55b662a5a548aa3308c62a8a"
}

View File

@@ -1,6 +1,6 @@
{
"db_name": "PostgreSQL",
"query": "\n INSERT INTO user_emails (user_email_id, user_id, email, created_at, confirmed_at)\n VALUES ($1, $2, $3, $4, $4)\n ",
"query": "\n INSERT INTO user_emails (user_email_id, user_id, email, created_at)\n VALUES ($1, $2, $3, $4)\n ",
"describe": {
"columns": [],
"parameters": {
@@ -13,5 +13,5 @@
},
"nullable": []
},
"hash": "b697bbc5aaaca219602ac8f19f90097e88faf8052effa84a03cc638ae315ff69"
"hash": "90fe32cb9c88a262a682c0db700fef7d69d6ce0be1f930d9f16c50b921a8b819"
}

View File

@@ -1,43 +0,0 @@
{
"db_name": "PostgreSQL",
"query": "\n INSERT INTO upstream_oauth_providers (\n upstream_oauth_provider_id,\n issuer,\n human_name,\n brand_name,\n scope,\n token_endpoint_auth_method,\n token_endpoint_signing_alg,\n id_token_signed_response_alg,\n fetch_userinfo,\n userinfo_signed_response_alg,\n client_id,\n encrypted_client_secret,\n claims_imports,\n authorization_endpoint_override,\n token_endpoint_override,\n userinfo_endpoint_override,\n jwks_uri_override,\n discovery_mode,\n pkce_mode,\n response_mode,\n additional_parameters,\n created_at\n ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11,\n $12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22)\n ON CONFLICT (upstream_oauth_provider_id)\n DO UPDATE\n SET\n issuer = EXCLUDED.issuer,\n human_name = EXCLUDED.human_name,\n brand_name = EXCLUDED.brand_name,\n scope = EXCLUDED.scope,\n token_endpoint_auth_method = EXCLUDED.token_endpoint_auth_method,\n token_endpoint_signing_alg = EXCLUDED.token_endpoint_signing_alg,\n id_token_signed_response_alg = EXCLUDED.id_token_signed_response_alg,\n fetch_userinfo = EXCLUDED.fetch_userinfo,\n userinfo_signed_response_alg = EXCLUDED.userinfo_signed_response_alg,\n disabled_at = NULL,\n client_id = EXCLUDED.client_id,\n encrypted_client_secret = EXCLUDED.encrypted_client_secret,\n claims_imports = EXCLUDED.claims_imports,\n authorization_endpoint_override = EXCLUDED.authorization_endpoint_override,\n token_endpoint_override = EXCLUDED.token_endpoint_override,\n userinfo_endpoint_override = EXCLUDED.userinfo_endpoint_override,\n jwks_uri_override = EXCLUDED.jwks_uri_override,\n discovery_mode = EXCLUDED.discovery_mode,\n pkce_mode = EXCLUDED.pkce_mode,\n response_mode = EXCLUDED.response_mode,\n additional_parameters = EXCLUDED.additional_parameters\n RETURNING created_at\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "created_at",
"type_info": "Timestamptz"
}
],
"parameters": {
"Left": [
"Uuid",
"Text",
"Text",
"Text",
"Text",
"Text",
"Text",
"Text",
"Bool",
"Text",
"Text",
"Text",
"Jsonb",
"Text",
"Text",
"Text",
"Text",
"Text",
"Text",
"Text",
"Jsonb",
"Timestamptz"
]
},
"nullable": [
false
]
},
"hash": "99f2a0b53e08d23408dc2837d32d734c8a0e706662e72f3b2585b0c38f42c063"
}

View File

@@ -1,6 +1,6 @@
{
"db_name": "PostgreSQL",
"query": "\n SELECT\n upstream_oauth_provider_id,\n issuer,\n human_name,\n brand_name,\n scope,\n client_id,\n encrypted_client_secret,\n token_endpoint_signing_alg,\n token_endpoint_auth_method,\n id_token_signed_response_alg,\n fetch_userinfo,\n userinfo_signed_response_alg,\n created_at,\n disabled_at,\n claims_imports as \"claims_imports: Json<UpstreamOAuthProviderClaimsImports>\",\n jwks_uri_override,\n authorization_endpoint_override,\n token_endpoint_override,\n userinfo_endpoint_override,\n discovery_mode,\n pkce_mode,\n response_mode,\n additional_parameters as \"additional_parameters: Json<Vec<(String, String)>>\"\n FROM upstream_oauth_providers\n WHERE disabled_at IS NULL\n ",
"query": "\n SELECT\n upstream_oauth_provider_id,\n issuer,\n human_name,\n brand_name,\n scope,\n client_id,\n encrypted_client_secret,\n token_endpoint_signing_alg,\n token_endpoint_auth_method,\n id_token_signed_response_alg,\n fetch_userinfo,\n userinfo_signed_response_alg,\n created_at,\n disabled_at,\n claims_imports as \"claims_imports: Json<UpstreamOAuthProviderClaimsImports>\",\n jwks_uri_override,\n authorization_endpoint_override,\n token_endpoint_override,\n userinfo_endpoint_override,\n discovery_mode,\n pkce_mode,\n response_mode,\n additional_parameters as \"additional_parameters: Json<Vec<(String, String)>>\"\n FROM upstream_oauth_providers\n WHERE disabled_at IS NULL\n ORDER BY ui_order ASC, upstream_oauth_provider_id ASC\n ",
"describe": {
"columns": [
{
@@ -148,5 +148,5 @@
true
]
},
"hash": "27d6f228a9a608b5d03d30cb4074be94dc893df9107e982583aa954b5067dfd1"
"hash": "c1e55ffd09181c0d8ddd0df2843690aeae4a20329045ab23639181a0d0903178"
}

View File

@@ -1,6 +1,6 @@
{
"db_name": "PostgreSQL",
"query": "\n SELECT user_id\n , username\n , created_at\n , locked_at\n , can_request_admin\n FROM users\n WHERE user_id = $1\n ",
"query": "\n SELECT user_id\n , username\n , created_at\n , locked_at\n , deactivated_at\n , can_request_admin\n FROM users\n WHERE user_id = $1\n ",
"describe": {
"columns": [
{
@@ -25,6 +25,11 @@
},
{
"ordinal": 4,
"name": "deactivated_at",
"type_info": "Timestamptz"
},
{
"ordinal": 5,
"name": "can_request_admin",
"type_info": "Bool"
}
@@ -39,8 +44,9 @@
false,
false,
true,
true,
false
]
},
"hash": "86767be88b7594cc9a98a2f1f1c61cf66118f2fda4b4b0415de15087524f1356"
"hash": "cc332eda5965715607ffa4eeeacc1b6532cbd8fe49904ccdb1afe315804d348d"
}

View File

@@ -1,6 +1,6 @@
{
"db_name": "PostgreSQL",
"query": "\n SELECT s.user_session_id\n , s.created_at AS \"user_session_created_at\"\n , s.finished_at AS \"user_session_finished_at\"\n , s.user_agent AS \"user_session_user_agent\"\n , s.last_active_at AS \"user_session_last_active_at\"\n , s.last_active_ip AS \"user_session_last_active_ip: IpAddr\"\n , u.user_id\n , u.username AS \"user_username\"\n , u.created_at AS \"user_created_at\"\n , u.locked_at AS \"user_locked_at\"\n , u.can_request_admin AS \"user_can_request_admin\"\n FROM user_sessions s\n INNER JOIN users u\n USING (user_id)\n WHERE s.user_session_id = $1\n ",
"query": "\n SELECT s.user_session_id\n , s.created_at AS \"user_session_created_at\"\n , s.finished_at AS \"user_session_finished_at\"\n , s.user_agent AS \"user_session_user_agent\"\n , s.last_active_at AS \"user_session_last_active_at\"\n , s.last_active_ip AS \"user_session_last_active_ip: IpAddr\"\n , u.user_id\n , u.username AS \"user_username\"\n , u.created_at AS \"user_created_at\"\n , u.locked_at AS \"user_locked_at\"\n , u.deactivated_at AS \"user_deactivated_at\"\n , u.can_request_admin AS \"user_can_request_admin\"\n FROM user_sessions s\n INNER JOIN users u\n USING (user_id)\n WHERE s.user_session_id = $1\n ",
"describe": {
"columns": [
{
@@ -55,6 +55,11 @@
},
{
"ordinal": 10,
"name": "user_deactivated_at",
"type_info": "Timestamptz"
},
{
"ordinal": 11,
"name": "user_can_request_admin",
"type_info": "Bool"
}
@@ -75,8 +80,9 @@
false,
false,
true,
true,
false
]
},
"hash": "7ea1a668480cbfda1439ba80fbd6ef2d751a3bb781e30260383eee3579f3a962"
"hash": "f924db60febad26c9fff24881b05dd1e1f7ba288d7b2f2f8e30a1ea43e98b8c8"
}

View File

@@ -26,7 +26,7 @@ opentelemetry-semantic-conventions.workspace = true
rand.workspace = true
rand_chacha.workspace = true
url.workspace = true
uuid = "1.14.0"
uuid = "1.15.1"
ulid = { workspace = true, features = ["uuid"] }
oauth2-types.workspace = true

View File

@@ -0,0 +1,8 @@
-- Copyright 2025 New Vector Ltd.
--
-- SPDX-License-Identifier: AGPL-3.0-only
-- Please see LICENSE in the repository root for full details.
ALTER TABLE users
-- Track when a user was deactivated.
ADD COLUMN deactivated_at TIMESTAMP WITH TIME ZONE;

View File

@@ -0,0 +1,9 @@
-- Copyright 2025 New Vector Ltd.
--
-- SPDX-License-Identifier: AGPL-3.0-only
-- Please see LICENSE in the repository root for full details.
-- Adds a column to track the 'UI order' of the upstream OAuth2 providers, so
-- that they can be consistently displayed in the UI
ALTER TABLE upstream_oauth_providers
ADD COLUMN ui_order INTEGER NOT NULL DEFAULT 0;

View File

@@ -588,7 +588,7 @@ mod tests {
.unwrap();
let device2 = Device::generate(&mut rng);
let scope = Scope::from_iter([OPENID, device2.to_scope_token()]);
let scope = Scope::from_iter([OPENID, device2.to_scope_token().unwrap()]);
// We're moving the clock forward by 1 minute between each session to ensure
// we're getting consistent ordering in lists.

View File

@@ -59,42 +59,28 @@ struct CompatSessionLookup {
last_active_ip: Option<IpAddr>,
}
impl TryFrom<CompatSessionLookup> for CompatSession {
type Error = DatabaseInconsistencyError;
fn try_from(value: CompatSessionLookup) -> Result<Self, Self::Error> {
impl From<CompatSessionLookup> for CompatSession {
fn from(value: CompatSessionLookup) -> Self {
let id = value.compat_session_id.into();
let device = value
.device_id
.map(Device::try_from)
.transpose()
.map_err(|e| {
DatabaseInconsistencyError::on("compat_sessions")
.column("device_id")
.row(id)
.source(e)
})?;
let state = match value.finished_at {
None => CompatSessionState::Valid,
Some(finished_at) => CompatSessionState::Finished { finished_at },
};
let session = CompatSession {
CompatSession {
id,
state,
user_id: value.user_id.into(),
user_session_id: value.user_session_id.map(Ulid::from),
device,
device: value.device_id.map(Device::from),
human_name: value.human_name,
created_at: value.created_at,
is_synapse_admin: value.is_synapse_admin,
user_agent: value.user_agent.map(UserAgent::parse),
last_active_at: value.last_active_at,
last_active_ip: value.last_active_ip,
};
Ok(session)
}
}
}
@@ -125,16 +111,6 @@ impl TryFrom<CompatSessionAndSsoLoginLookup> for (CompatSession, Option<CompatSs
fn try_from(value: CompatSessionAndSsoLoginLookup) -> Result<Self, Self::Error> {
let id = value.compat_session_id.into();
let device = value
.device_id
.map(Device::try_from)
.transpose()
.map_err(|e| {
DatabaseInconsistencyError::on("compat_sessions")
.column("device_id")
.row(id)
.source(e)
})?;
let state = match value.finished_at {
None => CompatSessionState::Valid,
@@ -145,7 +121,7 @@ impl TryFrom<CompatSessionAndSsoLoginLookup> for (CompatSession, Option<CompatSs
id,
state,
user_id: value.user_id.into(),
device,
device: value.device_id.map(Device::from),
human_name: value.human_name,
user_session_id: value.user_session_id.map(Ulid::from),
created_at: value.created_at,
@@ -310,7 +286,7 @@ impl CompatSessionRepository for PgCompatSessionRepository<'_> {
let Some(res) = res else { return Ok(None) };
Ok(Some(res.try_into()?))
Ok(Some(res.into()))
}
#[tracing::instrument(

View File

@@ -25,6 +25,7 @@ pub enum Users {
Username,
CreatedAt,
LockedAt,
DeactivatedAt,
CanRequestAdmin,
}

View File

@@ -125,10 +125,15 @@ impl Filter for OAuth2SessionFilter<'_> {
}
}))
.add_option(self.device().map(|device| {
Expr::val(device.to_scope_token().to_string()).eq(PgFunc::any(Expr::col((
OAuth2Sessions::Table,
OAuth2Sessions::ScopeList,
))))
if let Ok(scope_token) = device.to_scope_token() {
Expr::val(scope_token.to_string()).eq(PgFunc::any(Expr::col((
OAuth2Sessions::Table,
OAuth2Sessions::ScopeList,
))))
} else {
// If the device ID can't be encoded as a scope token, match no rows
Expr::val(false).into()
}
}))
.add_option(self.browser_session().map(|browser_session| {
Expr::col((OAuth2Sessions::Table, OAuth2Sessions::UserSessionId))

View File

@@ -76,6 +76,7 @@ mod tests {
pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
response_mode: None,
additional_authorization_parameters: Vec::new(),
ui_order: 0,
},
)
.await
@@ -322,6 +323,7 @@ mod tests {
pkce_mode: mas_data_model::UpstreamOAuthProviderPkceMode::Auto,
response_mode: None,
additional_authorization_parameters: Vec::new(),
ui_order: 0,
},
)
.await

View File

@@ -517,9 +517,11 @@ impl UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'_> {
pkce_mode,
response_mode,
additional_parameters,
ui_order,
created_at
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11,
$12, $13, $14, $15, $16, $17, $18, $19, $20, $21, $22)
$12, $13, $14, $15, $16, $17, $18, $19, $20,
$21, $22, $23)
ON CONFLICT (upstream_oauth_provider_id)
DO UPDATE
SET
@@ -543,7 +545,8 @@ impl UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'_> {
discovery_mode = EXCLUDED.discovery_mode,
pkce_mode = EXCLUDED.pkce_mode,
response_mode = EXCLUDED.response_mode,
additional_parameters = EXCLUDED.additional_parameters
additional_parameters = EXCLUDED.additional_parameters,
ui_order = EXCLUDED.ui_order
RETURNING created_at
"#,
Uuid::from(id),
@@ -582,6 +585,7 @@ impl UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'_> {
params.pkce_mode.as_str(),
params.response_mode.as_ref().map(ToString::to_string),
Json(&params.additional_authorization_parameters) as _,
params.ui_order,
created_at,
)
.traced()
@@ -917,6 +921,7 @@ impl UpstreamOAuthProviderRepository for PgUpstreamOAuthProviderRepository<'_> {
additional_parameters as "additional_parameters: Json<Vec<(String, String)>>"
FROM upstream_oauth_providers
WHERE disabled_at IS NULL
ORDER BY ui_order ASC, upstream_oauth_provider_id ASC
"#,
)
.traced()

View File

@@ -14,12 +14,10 @@ use mas_storage::{
Clock, Page, Pagination,
user::{UserEmailFilter, UserEmailRepository},
};
use opentelemetry_semantic_conventions::attribute::DB_QUERY_TEXT;
use rand::RngCore;
use sea_query::{Expr, PostgresQueryBuilder, Query, enum_def};
use sea_query_binder::SqlxBinder;
use sqlx::PgConnection;
use tracing::{Instrument, info_span};
use ulid::Ulid;
use uuid::Uuid;
@@ -317,12 +315,10 @@ impl UserEmailRepository for PgUserEmailRepository<'_> {
let id = Ulid::from_datetime_with_source(created_at.into(), rng);
tracing::Span::current().record("user_email.id", tracing::field::display(id));
// We now always set the 'confirmed_at' field, so that older app version
// consider those emails as verified.
sqlx::query!(
r#"
INSERT INTO user_emails (user_email_id, user_id, email, created_at, confirmed_at)
VALUES ($1, $2, $3, $4, $4)
INSERT INTO user_emails (user_email_id, user_id, email, created_at)
VALUES ($1, $2, $3, $4)
"#,
Uuid::from(id),
Uuid::from(user.id),
@@ -353,22 +349,6 @@ impl UserEmailRepository for PgUserEmailRepository<'_> {
err,
)]
async fn remove(&mut self, user_email: UserEmail) -> Result<(), Self::Error> {
let span = info_span!(
"db.user_email.remove.codes",
{ DB_QUERY_TEXT } = tracing::field::Empty
);
sqlx::query!(
r#"
DELETE FROM user_email_confirmation_codes
WHERE user_email_id = $1
"#,
Uuid::from(user_email.id),
)
.record(&span)
.execute(&mut *self.conn)
.instrument(span)
.await?;
let res = sqlx::query!(
r#"
DELETE FROM user_emails
@@ -385,6 +365,28 @@ impl UserEmailRepository for PgUserEmailRepository<'_> {
Ok(())
}
#[tracing::instrument(
name = "db.user_email.remove_bulk",
skip_all,
fields(
db.query.text,
),
err,
)]
async fn remove_bulk(&mut self, filter: UserEmailFilter<'_>) -> Result<usize, Self::Error> {
let (sql, arguments) = Query::delete()
.from_table(UserEmails::Table)
.apply_filter(filter)
.build_sqlx(PostgresQueryBuilder);
let res = sqlx::query_with(&sql, arguments)
.traced()
.execute(&mut *self.conn)
.await?;
Ok(res.rows_affected().try_into().unwrap_or(usize::MAX))
}
#[tracing::instrument(
name = "db.user_email.add_authentication_for_session",
skip_all,

View File

@@ -72,6 +72,7 @@ mod priv_ {
pub(super) username: String,
pub(super) created_at: DateTime<Utc>,
pub(super) locked_at: Option<DateTime<Utc>>,
pub(super) deactivated_at: Option<DateTime<Utc>>,
pub(super) can_request_admin: bool,
}
}
@@ -87,6 +88,7 @@ impl From<UserLookup> for User {
sub: id.to_string(),
created_at: value.created_at,
locked_at: value.locked_at,
deactivated_at: value.deactivated_at,
can_request_admin: value.can_request_admin,
}
}
@@ -96,10 +98,18 @@ impl Filter for UserFilter<'_> {
fn generate_condition(&self, _has_joins: bool) -> impl sea_query::IntoCondition {
sea_query::Condition::all()
.add_option(self.state().map(|state| {
if state.is_locked() {
Expr::col((Users::Table, Users::LockedAt)).is_not_null()
} else {
Expr::col((Users::Table, Users::LockedAt)).is_null()
match state {
mas_storage::user::UserState::Deactivated => {
Expr::col((Users::Table, Users::DeactivatedAt)).is_not_null()
}
mas_storage::user::UserState::Locked => {
Expr::col((Users::Table, Users::LockedAt)).is_not_null()
}
mas_storage::user::UserState::Active => {
Expr::col((Users::Table, Users::LockedAt))
.is_null()
.and(Expr::col((Users::Table, Users::DeactivatedAt)).is_null())
}
}
}))
.add_option(self.can_request_admin().map(|can_request_admin| {
@@ -129,6 +139,7 @@ impl UserRepository for PgUserRepository<'_> {
, username
, created_at
, locked_at
, deactivated_at
, can_request_admin
FROM users
WHERE user_id = $1
@@ -161,6 +172,7 @@ impl UserRepository for PgUserRepository<'_> {
, username
, created_at
, locked_at
, deactivated_at
, can_request_admin
FROM users
WHERE username = $1
@@ -220,6 +232,7 @@ impl UserRepository for PgUserRepository<'_> {
sub: id.to_string(),
created_at,
locked_at: None,
deactivated_at: None,
can_request_admin: false,
})
}
@@ -317,6 +330,42 @@ impl UserRepository for PgUserRepository<'_> {
Ok(user)
}
#[tracing::instrument(
name = "db.user.deactivate",
skip_all,
fields(
db.query.text,
%user.id,
),
err,
)]
async fn deactivate(&mut self, clock: &dyn Clock, mut user: User) -> Result<User, Self::Error> {
if user.deactivated_at.is_some() {
return Ok(user);
}
let deactivated_at = clock.now();
let res = sqlx::query!(
r#"
UPDATE users
SET deactivated_at = $2
WHERE user_id = $1
AND deactivated_at IS NULL
"#,
Uuid::from(user.id),
deactivated_at,
)
.traced()
.execute(&mut *self.conn)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)?;
user.deactivated_at = Some(user.created_at);
Ok(user)
}
#[tracing::instrument(
name = "db.user.set_can_request_admin",
skip_all,
@@ -382,6 +431,10 @@ impl UserRepository for PgUserRepository<'_> {
Expr::col((Users::Table, Users::LockedAt)),
UserLookupIden::LockedAt,
)
.expr_as(
Expr::col((Users::Table, Users::DeactivatedAt)),
UserLookupIden::DeactivatedAt,
)
.expr_as(
Expr::col((Users::Table, Users::CanRequestAdmin)),
UserLookupIden::CanRequestAdmin,

View File

@@ -59,6 +59,7 @@ struct SessionLookup {
user_username: String,
user_created_at: DateTime<Utc>,
user_locked_at: Option<DateTime<Utc>>,
user_deactivated_at: Option<DateTime<Utc>>,
user_can_request_admin: bool,
}
@@ -73,6 +74,7 @@ impl TryFrom<SessionLookup> for BrowserSession {
sub: id.to_string(),
created_at: value.user_created_at,
locked_at: value.user_locked_at,
deactivated_at: value.user_deactivated_at,
can_request_admin: value.user_can_request_admin,
};
@@ -173,6 +175,7 @@ impl BrowserSessionRepository for PgBrowserSessionRepository<'_> {
, u.username AS "user_username"
, u.created_at AS "user_created_at"
, u.locked_at AS "user_locked_at"
, u.deactivated_at AS "user_deactivated_at"
, u.can_request_admin AS "user_can_request_admin"
FROM user_sessions s
INNER JOIN users u
@@ -356,6 +359,10 @@ impl BrowserSessionRepository for PgBrowserSessionRepository<'_> {
Expr::col((Users::Table, Users::LockedAt)),
SessionLookupIden::UserLockedAt,
)
.expr_as(
Expr::col((Users::Table, Users::DeactivatedAt)),
SessionLookupIden::UserDeactivatedAt,
)
.expr_as(
Expr::col((Users::Table, Users::CanRequestAdmin)),
SessionLookupIden::UserCanRequestAdmin,

View File

@@ -33,6 +33,7 @@ async fn test_user_repo(pool: PgPool) {
let non_admin = all.cannot_request_admin_only();
let active = all.active_only();
let locked = all.locked_only();
let deactivated = all.deactivated_only();
// Initially, the user shouldn't exist
assert!(!repo.user().exists(USERNAME).await.unwrap());
@@ -49,6 +50,7 @@ async fn test_user_repo(pool: PgPool) {
assert_eq!(repo.user().count(non_admin).await.unwrap(), 0);
assert_eq!(repo.user().count(active).await.unwrap(), 0);
assert_eq!(repo.user().count(locked).await.unwrap(), 0);
assert_eq!(repo.user().count(deactivated).await.unwrap(), 0);
// Adding the user should work
let user = repo
@@ -73,6 +75,7 @@ async fn test_user_repo(pool: PgPool) {
assert_eq!(repo.user().count(non_admin).await.unwrap(), 1);
assert_eq!(repo.user().count(active).await.unwrap(), 1);
assert_eq!(repo.user().count(locked).await.unwrap(), 0);
assert_eq!(repo.user().count(deactivated).await.unwrap(), 0);
// Adding a second time should give a conflict
// It should not poison the transaction though
@@ -93,6 +96,7 @@ async fn test_user_repo(pool: PgPool) {
assert_eq!(repo.user().count(non_admin).await.unwrap(), 1);
assert_eq!(repo.user().count(active).await.unwrap(), 0);
assert_eq!(repo.user().count(locked).await.unwrap(), 1);
assert_eq!(repo.user().count(deactivated).await.unwrap(), 0);
// Check that the property is retrieved on lookup
let user = repo.user().lookup(user.id).await.unwrap().unwrap();
@@ -123,6 +127,7 @@ async fn test_user_repo(pool: PgPool) {
assert_eq!(repo.user().count(non_admin).await.unwrap(), 0);
assert_eq!(repo.user().count(active).await.unwrap(), 1);
assert_eq!(repo.user().count(locked).await.unwrap(), 0);
assert_eq!(repo.user().count(deactivated).await.unwrap(), 0);
// Check that the property is retrieved on lookup
let user = repo.user().lookup(user.id).await.unwrap().unwrap();
@@ -145,6 +150,26 @@ async fn test_user_repo(pool: PgPool) {
assert_eq!(repo.user().count(non_admin).await.unwrap(), 1);
assert_eq!(repo.user().count(active).await.unwrap(), 1);
assert_eq!(repo.user().count(locked).await.unwrap(), 0);
assert_eq!(repo.user().count(deactivated).await.unwrap(), 0);
// Deactivating the user should work
let user = repo.user().deactivate(&clock, user).await.unwrap();
assert!(user.deactivated_at.is_some());
// Check that the property is retrieved on lookup
let user = repo.user().lookup(user.id).await.unwrap().unwrap();
assert!(user.deactivated_at.is_some());
// Deactivating a second time should not fail
let user = repo.user().deactivate(&clock, user).await.unwrap();
assert!(user.deactivated_at.is_some());
assert_eq!(repo.user().count(all).await.unwrap(), 1);
assert_eq!(repo.user().count(admin).await.unwrap(), 0);
assert_eq!(repo.user().count(non_admin).await.unwrap(), 1);
assert_eq!(repo.user().count(active).await.unwrap(), 0);
assert_eq!(repo.user().count(locked).await.unwrap(), 0);
assert_eq!(repo.user().count(deactivated).await.unwrap(), 1);
// Check the list method
let list = repo.user().list(all, Pagination::first(10)).await.unwrap();
@@ -171,8 +196,7 @@ async fn test_user_repo(pool: PgPool) {
.list(active, Pagination::first(10))
.await
.unwrap();
assert_eq!(list.edges.len(), 1);
assert_eq!(list.edges[0].id, user.id);
assert_eq!(list.edges.len(), 0);
let list = repo
.user()
@@ -181,6 +205,14 @@ async fn test_user_repo(pool: PgPool) {
.unwrap();
assert_eq!(list.edges.len(), 0);
let list = repo
.user()
.list(deactivated, Pagination::first(10))
.await
.unwrap();
assert_eq!(list.edges.len(), 1);
assert_eq!(list.edges[0].id, user.id);
repo.save().await.unwrap();
}
@@ -290,6 +322,21 @@ async fn test_user_email_repo(pool: PgPool) {
repo.user_email().remove(user_email).await.unwrap();
assert_eq!(repo.user_email().count(all).await.unwrap(), 0);
// Add a few emails
for i in 0..5 {
let email = format!("email{i}@example.com");
repo.user_email()
.add(&mut rng, &clock, &user, email)
.await
.unwrap();
}
assert_eq!(repo.user_email().count(all).await.unwrap(), 5);
// Try removing all the emails
let affected = repo.user_email().remove_bulk(all).await.unwrap();
assert_eq!(affected, 5);
assert_eq!(repo.user_email().count(all).await.unwrap(), 0);
repo.save().await.unwrap();
}

View File

@@ -95,6 +95,9 @@ pub struct UpstreamOAuthProviderParams {
/// Additional parameters to include in the authorization request
pub additional_authorization_parameters: Vec<(String, String)>,
/// The position of the provider in the UI
pub ui_order: i32,
}
/// Filter parameters for listing upstream OAuth 2.0 providers

View File

@@ -164,6 +164,19 @@ pub trait UserEmailRepository: Send + Sync {
/// Returns [`Self::Error`] if the underlying repository fails
async fn remove(&mut self, user_email: UserEmail) -> Result<(), Self::Error>;
/// Delete all [`UserEmail`] with the given filter
///
/// Returns the number of deleted [`UserEmail`]s
///
/// # Parameters
///
/// * `filter`: The filter parameters
///
/// # Errors
///
/// Returns [`Self::Error`] if the underlying repository fails
async fn remove_bulk(&mut self, filter: UserEmailFilter<'_>) -> Result<usize, Self::Error>;
/// Add a new [`UserEmailAuthentication`] for a [`BrowserSession`]
///
/// # Parameters
@@ -303,6 +316,8 @@ repository_impl!(UserEmailRepository:
) -> Result<UserEmail, Self::Error>;
async fn remove(&mut self, user_email: UserEmail) -> Result<(), Self::Error>;
async fn remove_bulk(&mut self, filter: UserEmailFilter<'_>) -> Result<usize, Self::Error>;
async fn add_authentication_for_session(
&mut self,
rng: &mut (dyn RngCore + Send),

View File

@@ -32,6 +32,9 @@ pub use self::{
/// The state of a user account
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum UserState {
/// The account is deactivated, it has the `deactivated_at` timestamp set
Deactivated,
/// The account is locked, it has the `locked_at` timestamp set
Locked,
@@ -48,6 +51,14 @@ impl UserState {
matches!(self, Self::Locked)
}
/// Returns `true` if the user state is [`Deactivated`].
///
/// [`Deactivated`]: UserState::Deactivated
#[must_use]
pub fn is_deactivated(&self) -> bool {
matches!(self, Self::Deactivated)
}
/// Returns `true` if the user state is [`Active`].
///
/// [`Active`]: UserState::Active
@@ -86,6 +97,13 @@ impl UserFilter<'_> {
self
}
/// Filter for deactivated users
#[must_use]
pub fn deactivated_only(mut self) -> Self {
self.state = Some(UserState::Deactivated);
self
}
/// Filter for users that can request admin privileges
#[must_use]
pub fn can_request_admin_only(mut self) -> Self {
@@ -210,6 +228,20 @@ pub trait UserRepository: Send + Sync {
/// Returns [`Self::Error`] if the underlying repository fails
async fn unlock(&mut self, user: User) -> Result<User, Self::Error>;
/// Deactivate a [`User`]
///
/// Returns the deactivated [`User`]
///
/// # Parameters
///
/// * `clock`: The clock used to generate timestamps
/// * `user`: The [`User`] to deactivate
///
/// # Errors
///
/// Returns [`Self::Error`] if the underlying repository fails
async fn deactivate(&mut self, clock: &dyn Clock, user: User) -> Result<User, Self::Error>;
/// Set whether a [`User`] can request admin
///
/// Returns the [`User`] with the new `can_request_admin` value
@@ -280,6 +312,7 @@ repository_impl!(UserRepository:
async fn exists(&mut self, username: &str) -> Result<bool, Self::Error>;
async fn lock(&mut self, clock: &dyn Clock, user: User) -> Result<User, Self::Error>;
async fn unlock(&mut self, user: User) -> Result<User, Self::Error>;
async fn deactivate(&mut self, clock: &dyn Clock, user: User) -> Result<User, Self::Error>;
async fn set_can_request_admin(
&mut self,
user: User,

View File

@@ -1,19 +0,0 @@
{
"db_name": "PostgreSQL",
"query": "\n INSERT INTO syn2mas__users (\n user_id, username,\n created_at, locked_at,\n can_request_admin, is_guest)\n SELECT * FROM UNNEST(\n $1::UUID[], $2::TEXT[],\n $3::TIMESTAMP WITH TIME ZONE[], $4::TIMESTAMP WITH TIME ZONE[],\n $5::BOOL[], $6::BOOL[])\n ",
"describe": {
"columns": [],
"parameters": {
"Left": [
"UuidArray",
"TextArray",
"TimestamptzArray",
"TimestamptzArray",
"BoolArray",
"BoolArray"
]
},
"nullable": []
},
"hash": "06cd6bff12000db3e64e98c344cc9e3b5de7af6a497ad84036ae104576ae0575"
}

View File

@@ -0,0 +1,20 @@
{
"db_name": "PostgreSQL",
"query": "\n INSERT INTO syn2mas__users (\n user_id, username,\n created_at, locked_at,\n deactivated_at,\n can_request_admin, is_guest)\n SELECT * FROM UNNEST(\n $1::UUID[], $2::TEXT[],\n $3::TIMESTAMP WITH TIME ZONE[], $4::TIMESTAMP WITH TIME ZONE[],\n $5::TIMESTAMP WITH TIME ZONE[],\n $6::BOOL[], $7::BOOL[])\n ",
"describe": {
"columns": [],
"parameters": {
"Left": [
"UuidArray",
"TextArray",
"TimestamptzArray",
"TimestamptzArray",
"TimestamptzArray",
"BoolArray",
"BoolArray"
]
},
"nullable": []
},
"hash": "f2820b3752cf66669551ef90a10817cb6fe71203b2d3471e838f841b53e688d1"
}

View File

@@ -11,6 +11,7 @@ repository.workspace = true
[dependencies]
anyhow.workspace = true
arc-swap.workspace = true
bitflags.workspace = true
camino.workspace = true
figment.workspace = true
@@ -18,19 +19,25 @@ serde.workspace = true
thiserror.workspace = true
thiserror-ext.workspace = true
tokio.workspace = true
tokio-util.workspace = true
sqlx.workspace = true
chrono.workspace = true
compact_str.workspace = true
tracing.workspace = true
futures-util = "0.3.31"
rustc-hash = "2.1.1"
rand.workspace = true
uuid = "1.14.0"
rand_chacha = "0.3.1"
uuid = "1.15.1"
ulid = { workspace = true, features = ["uuid"] }
mas-config.workspace = true
mas-storage.workspace = true
opentelemetry.workspace = true
opentelemetry-semantic-conventions.workspace = true
[dev-dependencies]
mas-storage-pg.workspace = true

View File

@@ -7,10 +7,16 @@ mod mas_writer;
mod synapse_reader;
mod migration;
mod progress;
mod telemetry;
type RandomState = rustc_hash::FxBuildHasher;
type HashMap<K, V> = rustc_hash::FxHashMap<K, V>;
pub use self::{
mas_writer::{MasWriter, checks::mas_pre_migration_checks, locking::LockedMasDatabase},
migration::migrate,
progress::{Progress, ProgressStage},
synapse_reader::{
SynapseReader,
checks::{

View File

@@ -10,6 +10,7 @@
use thiserror::Error;
use thiserror_ext::ContextInto;
use tracing::Instrument as _;
use super::{MAS_TABLES_AFFECTED_BY_MIGRATION, is_syn2mas_in_progress, locking::LockedMasDatabase};
@@ -46,7 +47,7 @@ pub enum Error {
/// - If any MAS tables involved in the migration are not empty.
/// - If we can't check whether syn2mas is already in progress on this database
/// or not.
#[tracing::instrument(skip_all)]
#[tracing::instrument(name = "syn2mas.mas_pre_migration_checks", skip_all)]
pub async fn mas_pre_migration_checks(mas_connection: &mut LockedMasDatabase) -> Result<(), Error> {
if is_syn2mas_in_progress(mas_connection.as_mut())
.await
@@ -60,8 +61,11 @@ pub async fn mas_pre_migration_checks(mas_connection: &mut LockedMasDatabase) ->
// empty database.
for &table in MAS_TABLES_AFFECTED_BY_MIGRATION {
let row_present = sqlx::query(&format!("SELECT 1 AS dummy FROM {table} LIMIT 1"))
let query = format!("SELECT 1 AS dummy FROM {table} LIMIT 1");
let span = tracing::info_span!("db.query", db.query.text = query);
let row_present = sqlx::query(&query)
.fetch_optional(mas_connection.as_mut())
.instrument(span)
.await
.into_maybe_not_mas(table)?
.is_some();

View File

@@ -3,8 +3,10 @@
// SPDX-License-Identifier: AGPL-3.0-only
// Please see LICENSE in the repository root for full details.
use std::time::Instant;
use sqlx::PgConnection;
use tracing::debug;
use tracing::{debug, info};
use super::{Error, IntoDatabase};
@@ -109,15 +111,20 @@ pub async fn drop_index(conn: &mut PgConnection, index: &IndexDescription) -> Re
/// Restores (recreates) a constraint.
///
/// The constraint must not exist prior to this call.
#[tracing::instrument(name = "syn2mas.restore_constraint", skip_all, fields(constraint.name = constraint.name))]
pub async fn restore_constraint(
conn: &mut PgConnection,
constraint: &ConstraintDescription,
) -> Result<(), Error> {
let start = Instant::now();
let ConstraintDescription {
name,
table_name,
definition,
} = &constraint;
info!("rebuilding constraint {name}");
sqlx::query(&format!(
"ALTER TABLE {table_name} ADD CONSTRAINT {name} {definition};"
))
@@ -127,13 +134,21 @@ pub async fn restore_constraint(
format!("failed to recreate constraint {name} on {table_name} with {definition}")
})?;
info!(
"constraint {name} rebuilt in {:.1}s",
Instant::now().duration_since(start).as_secs_f64()
);
Ok(())
}
/// Restores (recreates) a index.
///
/// The index must not exist prior to this call.
#[tracing::instrument(name = "syn2mas.restore_index", skip_all, fields(index.name = index.name))]
pub async fn restore_index(conn: &mut PgConnection, index: &IndexDescription) -> Result<(), Error> {
let start = Instant::now();
let IndexDescription {
name,
table_name,
@@ -147,5 +162,10 @@ pub async fn restore_index(conn: &mut PgConnection, index: &IndexDescription) ->
format!("failed to recreate index {name} on {table_name} with {definition}")
})?;
info!(
"index {name} rebuilt in {:.1}s",
Instant::now().duration_since(start).as_secs_f64()
);
Ok(())
}

View File

@@ -7,7 +7,14 @@
//!
//! This module is responsible for writing new records to MAS' database.
use std::{fmt::Display, net::IpAddr};
use std::{
fmt::Display,
net::IpAddr,
sync::{
Arc,
atomic::{AtomicU32, Ordering},
},
};
use chrono::{DateTime, Utc};
use futures_util::{FutureExt, TryStreamExt, future::BoxFuture};
@@ -15,7 +22,7 @@ use sqlx::{Executor, PgConnection, query, query_as};
use thiserror::Error;
use thiserror_ext::{Construct, ContextInto};
use tokio::sync::mpsc::{self, Receiver, Sender};
use tracing::{Level, error, info, warn};
use tracing::{Instrument, Level, error, info, warn};
use uuid::{NonNilUuid, Uuid};
use self::{
@@ -44,6 +51,9 @@ pub enum Error {
#[error("inconsistent database: {0}")]
Inconsistent(String),
#[error("bug in syn2mas: write buffers not finished")]
WriteBuffersNotFinished,
#[error("{0}")]
Multiple(MultipleErrors),
}
@@ -109,18 +119,21 @@ impl WriterConnectionPool {
match self.connection_rx.recv().await {
Some(Ok(mut connection)) => {
let connection_tx = self.connection_tx.clone();
tokio::task::spawn(async move {
let to_return = match task(&mut connection).await {
Ok(()) => Ok(connection),
Err(error) => {
error!("error in writer: {error}");
Err(error)
}
};
// This should always succeed in sending unless we're already shutting
// down for some other reason.
let _: Result<_, _> = connection_tx.send(to_return).await;
});
tokio::task::spawn(
async move {
let to_return = match task(&mut connection).await {
Ok(()) => Ok(connection),
Err(error) => {
error!("error in writer: {error}");
Err(error)
}
};
// This should always succeed in sending unless we're already shutting
// down for some other reason.
let _: Result<_, _> = connection_tx.send(to_return).await;
}
.instrument(tracing::debug_span!("spawn_with_connection")),
);
Ok(())
}
@@ -188,12 +201,52 @@ impl WriterConnectionPool {
}
}
/// Small utility to make sure `finish()` is called on all write buffers
/// before committing to the database.
#[derive(Default)]
struct FinishChecker {
counter: Arc<AtomicU32>,
}
struct FinishCheckerHandle {
counter: Arc<AtomicU32>,
}
impl FinishChecker {
/// Acquire a new handle, for a task that should declare when it has
/// finished.
pub fn handle(&self) -> FinishCheckerHandle {
self.counter.fetch_add(1, Ordering::SeqCst);
FinishCheckerHandle {
counter: Arc::clone(&self.counter),
}
}
/// Check that all handles have been declared as finished.
pub fn check_all_finished(self) -> Result<(), Error> {
if self.counter.load(Ordering::SeqCst) == 0 {
Ok(())
} else {
Err(Error::WriteBuffersNotFinished)
}
}
}
impl FinishCheckerHandle {
/// Declare that the task this handle represents has been finished.
pub fn declare_finished(self) {
self.counter.fetch_sub(1, Ordering::SeqCst);
}
}
pub struct MasWriter {
conn: LockedMasDatabase,
writer_pool: WriterConnectionPool,
indices_to_restore: Vec<IndexDescription>,
constraints_to_restore: Vec<ConstraintDescription>,
write_buffer_finish_checker: FinishChecker,
}
pub struct MasNewUser {
@@ -201,6 +254,7 @@ pub struct MasNewUser {
pub username: String,
pub created_at: DateTime<Utc>,
pub locked_at: Option<DateTime<Utc>>,
pub deactivated_at: Option<DateTime<Utc>>,
pub can_request_admin: bool,
/// Whether the user was a Synapse guest.
/// Although MAS doesn't support guest access, it's still useful to track
@@ -336,7 +390,7 @@ impl MasWriter {
///
/// - If the database connection experiences an error.
#[allow(clippy::missing_panics_doc)] // not real
#[tracing::instrument(skip_all)]
#[tracing::instrument(name = "syn2mas.mas_writer.new", skip_all)]
pub async fn new(
mut conn: LockedMasDatabase,
mut writer_connections: Vec<PgConnection>,
@@ -453,6 +507,7 @@ impl MasWriter {
writer_pool: WriterConnectionPool::new(writer_connections),
indices_to_restore,
constraints_to_restore,
write_buffer_finish_checker: FinishChecker::default(),
})
}
@@ -520,6 +575,8 @@ impl MasWriter {
/// - If the database connection experiences an error.
#[tracing::instrument(skip_all)]
pub async fn finish(mut self) -> Result<PgConnection, Error> {
self.write_buffer_finish_checker.check_all_finished()?;
// Commit all writer transactions to the database.
self.writer_pool
.finish()
@@ -587,6 +644,8 @@ impl MasWriter {
let mut created_ats: Vec<DateTime<Utc>> = Vec::with_capacity(users.len());
let mut locked_ats: Vec<Option<DateTime<Utc>>> =
Vec::with_capacity(users.len());
let mut deactivated_ats: Vec<Option<DateTime<Utc>>> =
Vec::with_capacity(users.len());
let mut can_request_admins: Vec<bool> = Vec::with_capacity(users.len());
let mut is_guests: Vec<bool> = Vec::with_capacity(users.len());
for MasNewUser {
@@ -594,6 +653,7 @@ impl MasWriter {
username,
created_at,
locked_at,
deactivated_at,
can_request_admin,
is_guest,
} in users
@@ -602,6 +662,7 @@ impl MasWriter {
usernames.push(username);
created_ats.push(created_at);
locked_ats.push(locked_at);
deactivated_ats.push(deactivated_at);
can_request_admins.push(can_request_admin);
is_guests.push(is_guest);
}
@@ -611,17 +672,20 @@ impl MasWriter {
INSERT INTO syn2mas__users (
user_id, username,
created_at, locked_at,
deactivated_at,
can_request_admin, is_guest)
SELECT * FROM UNNEST(
$1::UUID[], $2::TEXT[],
$3::TIMESTAMP WITH TIME ZONE[], $4::TIMESTAMP WITH TIME ZONE[],
$5::BOOL[], $6::BOOL[])
$5::TIMESTAMP WITH TIME ZONE[],
$6::BOOL[], $7::BOOL[])
"#,
&user_ids[..],
&usernames[..],
&created_ats[..],
// We need to override the typing for arrays of optionals (sqlx limitation)
&locked_ats[..] as &[Option<DateTime<Utc>>],
&deactivated_ats[..] as &[Option<DateTime<Utc>>],
&can_request_admins[..],
&is_guests[..],
)
@@ -1033,28 +1097,24 @@ type WriteBufferFlusher<T> =
/// A buffer for writing rows to the MAS database.
/// Generic over the type of rows.
///
/// # Panics
///
/// Panics if dropped before `finish()` has been called.
pub struct MasWriteBuffer<T> {
rows: Vec<T>,
flusher: WriteBufferFlusher<T>,
finished: bool,
finish_checker_handle: FinishCheckerHandle,
}
impl<T> MasWriteBuffer<T> {
pub fn new(flusher: WriteBufferFlusher<T>) -> Self {
pub fn new(writer: &MasWriter, flusher: WriteBufferFlusher<T>) -> Self {
MasWriteBuffer {
rows: Vec::with_capacity(WRITE_BUFFER_BATCH_SIZE),
flusher,
finished: false,
finish_checker_handle: writer.write_buffer_finish_checker.handle(),
}
}
pub async fn finish(mut self, writer: &mut MasWriter) -> Result<(), Error> {
self.finished = true;
self.flush(writer).await?;
self.finish_checker_handle.declare_finished();
Ok(())
}
@@ -1077,12 +1137,6 @@ impl<T> MasWriteBuffer<T> {
}
}
impl<T> Drop for MasWriteBuffer<T> {
fn drop(&mut self) {
assert!(self.finished, "MasWriteBuffer dropped but not finished!");
}
}
#[cfg(test)]
mod test {
use std::collections::{BTreeMap, BTreeSet};
@@ -1217,6 +1271,7 @@ mod test {
username: "alice".to_owned(),
created_at: DateTime::default(),
locked_at: None,
deactivated_at: None,
can_request_admin: false,
is_guest: false,
}])
@@ -1241,6 +1296,7 @@ mod test {
username: "alice".to_owned(),
created_at: DateTime::default(),
locked_at: None,
deactivated_at: None,
can_request_admin: false,
is_guest: false,
}])
@@ -1272,6 +1328,7 @@ mod test {
username: "alice".to_owned(),
created_at: DateTime::default(),
locked_at: None,
deactivated_at: None,
can_request_admin: false,
is_guest: false,
}])
@@ -1305,6 +1362,7 @@ mod test {
username: "alice".to_owned(),
created_at: DateTime::default(),
locked_at: None,
deactivated_at: None,
can_request_admin: false,
is_guest: false,
}])
@@ -1339,6 +1397,7 @@ mod test {
username: "alice".to_owned(),
created_at: DateTime::default(),
locked_at: None,
deactivated_at: None,
can_request_admin: false,
is_guest: false,
}])
@@ -1372,6 +1431,7 @@ mod test {
username: "alice".to_owned(),
created_at: DateTime::default(),
locked_at: None,
deactivated_at: None,
can_request_admin: false,
is_guest: false,
}])
@@ -1409,6 +1469,7 @@ mod test {
username: "alice".to_owned(),
created_at: DateTime::default(),
locked_at: None,
deactivated_at: None,
can_request_admin: false,
is_guest: false,
}])
@@ -1458,6 +1519,7 @@ mod test {
username: "alice".to_owned(),
created_at: DateTime::default(),
locked_at: None,
deactivated_at: None,
can_request_admin: false,
is_guest: false,
}])

View File

@@ -5,6 +5,7 @@ expression: db_snapshot
users:
- can_request_admin: "false"
created_at: "1970-01-01 00:00:00+00"
deactivated_at: ~
is_guest: "false"
locked_at: ~
primary_user_email_id: ~

View File

@@ -23,6 +23,7 @@ compat_sessions:
users:
- can_request_admin: "false"
created_at: "1970-01-01 00:00:00+00"
deactivated_at: ~
is_guest: "false"
locked_at: ~
primary_user_email_id: ~

View File

@@ -17,6 +17,7 @@ compat_sessions:
users:
- can_request_admin: "false"
created_at: "1970-01-01 00:00:00+00"
deactivated_at: ~
is_guest: "false"
locked_at: ~
primary_user_email_id: ~

View File

@@ -11,6 +11,7 @@ user_emails:
users:
- can_request_admin: "false"
created_at: "1970-01-01 00:00:00+00"
deactivated_at: ~
is_guest: "false"
locked_at: ~
primary_user_email_id: ~

View File

@@ -12,6 +12,7 @@ user_passwords:
users:
- can_request_admin: "false"
created_at: "1970-01-01 00:00:00+00"
deactivated_at: ~
is_guest: "false"
locked_at: ~
primary_user_email_id: ~

View File

@@ -30,6 +30,7 @@ compat_sessions:
users:
- can_request_admin: "false"
created_at: "1970-01-01 00:00:00+00"
deactivated_at: ~
is_guest: "false"
locked_at: ~
primary_user_email_id: ~

View File

@@ -10,6 +10,7 @@ user_unsupported_third_party_ids:
users:
- can_request_admin: "false"
created_at: "1970-01-01 00:00:00+00"
deactivated_at: ~
is_guest: "false"
locked_at: ~
primary_user_email_id: ~

View File

@@ -30,12 +30,14 @@ upstream_oauth_providers:
token_endpoint_auth_method: client_secret_basic
token_endpoint_override: ~
token_endpoint_signing_alg: ~
ui_order: "0"
upstream_oauth_provider_id: 00000000-0000-0000-0000-000000000004
userinfo_endpoint_override: ~
userinfo_signed_response_alg: ~
users:
- can_request_admin: "false"
created_at: "1970-01-01 00:00:00+00"
deactivated_at: ~
is_guest: "false"
locked_at: ~
primary_user_email_id: ~

View File

@@ -11,30 +11,45 @@
//! This module does not implement any of the safety checks that should be run
//! *before* the migration.
use std::{collections::HashMap, pin::pin};
use std::{
pin::pin,
sync::{
Arc,
atomic::{AtomicU32, Ordering},
},
time::Instant,
};
use chrono::{DateTime, Utc};
use compact_str::CompactString;
use futures_util::StreamExt as _;
use futures_util::{SinkExt, StreamExt as _, TryFutureExt, TryStreamExt as _};
use mas_storage::Clock;
use rand::RngCore;
use opentelemetry::{KeyValue, metrics::Counter};
use rand::{RngCore, SeedableRng};
use thiserror::Error;
use thiserror_ext::ContextInto;
use tracing::Level;
use tokio_util::sync::PollSender;
use tracing::{Instrument as _, Level, info};
use ulid::Ulid;
use uuid::{NonNilUuid, Uuid};
use crate::{
SynapseReader,
HashMap, RandomState, SynapseReader,
mas_writer::{
self, MasNewCompatAccessToken, MasNewCompatRefreshToken, MasNewCompatSession,
MasNewEmailThreepid, MasNewUnsupportedThreepid, MasNewUpstreamOauthLink, MasNewUser,
MasNewUserPassword, MasWriteBuffer, MasWriter,
},
progress::{Progress, ProgressStage},
synapse_reader::{
self, ExtractLocalpartError, FullUserId, SynapseAccessToken, SynapseDevice,
SynapseExternalId, SynapseRefreshableTokenPair, SynapseThreepid, SynapseUser,
},
telemetry::{
K_ENTITY, METER, V_ENTITY_DEVICES, V_ENTITY_EXTERNAL_IDS,
V_ENTITY_NONREFRESHABLE_ACCESS_TOKENS, V_ENTITY_REFRESHABLE_TOKEN_PAIRS,
V_ENTITY_THREEPIDS, V_ENTITY_USERS,
},
};
#[derive(Debug, Error, ContextInto)]
@@ -54,6 +69,15 @@ pub enum Error {
source: ExtractLocalpartError,
user: FullUserId,
},
#[error("channel closed")]
ChannelClosed,
#[error("task failed ({context}): {source}")]
Join {
source: tokio::task::JoinError,
context: String,
},
#[error("user {user} was not found for migration but a row in {table} was found for them")]
MissingUserFromDependentTable { table: String, user: FullUserId },
#[error(
@@ -114,7 +138,7 @@ struct MigrationState {
/// A mapping of Synapse external ID providers to MAS upstream OAuth 2.0
/// provider ID
provider_id_mapping: HashMap<String, Uuid>,
provider_id_mapping: std::collections::HashMap<String, Uuid>,
}
/// Performs a migration from Synapse's database to MAS' database.
@@ -129,32 +153,166 @@ struct MigrationState {
///
/// - An underlying database access error, either to MAS or to Synapse.
/// - Invalid data in the Synapse database.
#[allow(clippy::implicit_hasher)]
#[allow(clippy::implicit_hasher, clippy::too_many_lines)]
pub async fn migrate(
mut synapse: SynapseReader<'_>,
mas: MasWriter,
server_name: String,
clock: &dyn Clock,
rng: &mut impl RngCore,
provider_id_mapping: HashMap<String, Uuid>,
provider_id_mapping: std::collections::HashMap<String, Uuid>,
progress: &Progress,
) -> Result<(), Error> {
let counts = synapse.count_rows().await.into_synapse("counting users")?;
let approx_total_counter = METER
.u64_counter("syn2mas.entity.approx_total")
.with_description("Approximate number of entities of this type to be migrated")
.build();
let migrated_otel_counter = METER
.u64_counter("syn2mas.entity.migrated")
.with_description("Number of entities of this type that have been migrated so far")
.build();
approx_total_counter.add(
counts.users as u64,
&[KeyValue::new(K_ENTITY, V_ENTITY_USERS)],
);
approx_total_counter.add(
counts.devices as u64,
&[KeyValue::new(K_ENTITY, V_ENTITY_DEVICES)],
);
approx_total_counter.add(
counts.threepids as u64,
&[KeyValue::new(K_ENTITY, V_ENTITY_THREEPIDS)],
);
approx_total_counter.add(
counts.external_ids as u64,
&[KeyValue::new(K_ENTITY, V_ENTITY_EXTERNAL_IDS)],
);
// assume 1 refreshable access token per refresh token.
let approx_nonrefreshable_access_tokens = counts.access_tokens - counts.refresh_tokens;
approx_total_counter.add(
approx_nonrefreshable_access_tokens as u64,
&[KeyValue::new(
K_ENTITY,
V_ENTITY_NONREFRESHABLE_ACCESS_TOKENS,
)],
);
approx_total_counter.add(
counts.refresh_tokens as u64,
&[KeyValue::new(K_ENTITY, V_ENTITY_REFRESHABLE_TOKEN_PAIRS)],
);
let state = MigrationState {
server_name,
users: HashMap::with_capacity(counts.users),
devices_to_compat_sessions: HashMap::with_capacity(counts.devices),
// We oversize the hashmaps, as the estimates are innaccurate, and we would like to avoid
// reallocations.
users: HashMap::with_capacity_and_hasher(counts.users * 9 / 8, RandomState::default()),
devices_to_compat_sessions: HashMap::with_capacity_and_hasher(
counts.devices * 9 / 8,
RandomState::default(),
),
provider_id_mapping,
};
let (mas, state) = migrate_users(&mut synapse, mas, state, rng).await?;
let (mas, state) = migrate_threepids(&mut synapse, mas, rng, state).await?;
let (mas, state) = migrate_external_ids(&mut synapse, mas, rng, state).await?;
let (mas, state) =
migrate_unrefreshable_access_tokens(&mut synapse, mas, clock, rng, state).await?;
let (mas, state) =
migrate_refreshable_token_pairs(&mut synapse, mas, clock, rng, state).await?;
let (mas, _state) = migrate_devices(&mut synapse, mas, rng, state).await?;
let migrated_counter = Arc::new(AtomicU32::new(0));
progress.set_current_stage(ProgressStage::MigratingData {
entity: V_ENTITY_USERS,
migrated: migrated_counter.clone(),
approx_count: counts.users as u64,
});
let (mas, state) = migrate_users(
&mut synapse,
mas,
state,
rng,
migrated_counter,
migrated_otel_counter.clone(),
)
.await?;
let migrated_counter = Arc::new(AtomicU32::new(0));
progress.set_current_stage(ProgressStage::MigratingData {
entity: V_ENTITY_THREEPIDS,
migrated: migrated_counter.clone(),
approx_count: counts.threepids as u64,
});
let (mas, state) = migrate_threepids(
&mut synapse,
mas,
rng,
state,
&migrated_counter,
migrated_otel_counter.clone(),
)
.await?;
let migrated_counter = Arc::new(AtomicU32::new(0));
progress.set_current_stage(ProgressStage::MigratingData {
entity: V_ENTITY_EXTERNAL_IDS,
migrated: migrated_counter.clone(),
approx_count: counts.external_ids as u64,
});
let (mas, state) = migrate_external_ids(
&mut synapse,
mas,
rng,
state,
&migrated_counter,
migrated_otel_counter.clone(),
)
.await?;
let migrated_counter = Arc::new(AtomicU32::new(0));
progress.set_current_stage(ProgressStage::MigratingData {
entity: V_ENTITY_NONREFRESHABLE_ACCESS_TOKENS,
migrated: migrated_counter.clone(),
approx_count: (counts.access_tokens - counts.refresh_tokens) as u64,
});
let (mas, state) = migrate_unrefreshable_access_tokens(
&mut synapse,
mas,
clock,
rng,
state,
migrated_counter,
migrated_otel_counter.clone(),
)
.await?;
let migrated_counter = Arc::new(AtomicU32::new(0));
progress.set_current_stage(ProgressStage::MigratingData {
entity: V_ENTITY_REFRESHABLE_TOKEN_PAIRS,
migrated: migrated_counter.clone(),
approx_count: counts.refresh_tokens as u64,
});
let (mas, state) = migrate_refreshable_token_pairs(
&mut synapse,
mas,
clock,
rng,
state,
&migrated_counter,
migrated_otel_counter.clone(),
)
.await?;
let migrated_counter = Arc::new(AtomicU32::new(0));
progress.set_current_stage(ProgressStage::MigratingData {
entity: "devices",
migrated: migrated_counter.clone(),
approx_count: counts.devices as u64,
});
let (mas, _state) = migrate_devices(
&mut synapse,
mas,
rng,
state,
migrated_counter,
migrated_otel_counter.clone(),
)
.await?;
synapse
.finish()
@@ -174,83 +332,117 @@ async fn migrate_users(
mut mas: MasWriter,
mut state: MigrationState,
rng: &mut impl RngCore,
progress_counter: Arc<AtomicU32>,
migrated_otel_counter: Counter<u64>,
) -> Result<(MasWriter, MigrationState), Error> {
let mut user_buffer = MasWriteBuffer::new(MasWriter::write_users);
let mut password_buffer = MasWriteBuffer::new(MasWriter::write_passwords);
let mut users_stream = pin!(synapse.read_users());
let start = Instant::now();
let otel_kv = [KeyValue::new(K_ENTITY, V_ENTITY_USERS)];
while let Some(user_res) = users_stream.next().await {
let user = user_res.into_synapse("reading user")?;
let (tx, mut rx) = tokio::sync::mpsc::channel::<SynapseUser>(10 * 1024 * 1024);
// Handling an edge case: some AS users may have invalid localparts containing
// extra `:` characters. These users are ignored and a warning is logged.
if user.appservice_id.is_some()
&& user
.name
.0
.strip_suffix(&format!(":{}", state.server_name))
.is_some_and(|localpart| localpart.contains(':'))
{
tracing::warn!("AS user {} has invalid localpart, ignoring!", user.name.0);
continue;
}
let mut rng = rand_chacha::ChaCha8Rng::from_rng(rng).expect("failed to seed rng");
let task = tokio::spawn(
async move {
let mut user_buffer = MasWriteBuffer::new(&mas, MasWriter::write_users);
let mut password_buffer = MasWriteBuffer::new(&mas, MasWriter::write_passwords);
let (mas_user, mas_password_opt) = transform_user(&user, &state.server_name, rng)?;
while let Some(user) = rx.recv().await {
// Handling an edge case: some AS users may have invalid localparts containing
// extra `:` characters. These users are ignored and a warning is logged.
if user.appservice_id.is_some()
&& user
.name
.0
.strip_suffix(&format!(":{}", state.server_name))
.is_some_and(|localpart| localpart.contains(':'))
{
tracing::warn!("AS user {} has invalid localpart, ignoring!", user.name.0);
continue;
}
let mut flags = UserFlags::empty();
if bool::from(user.admin) {
flags |= UserFlags::IS_SYNAPSE_ADMIN;
}
if bool::from(user.deactivated) {
flags |= UserFlags::IS_DEACTIVATED;
}
if bool::from(user.is_guest) {
flags |= UserFlags::IS_GUEST;
}
if user.appservice_id.is_some() {
flags |= UserFlags::IS_APPSERVICE;
let (mas_user, mas_password_opt) =
transform_user(&user, &state.server_name, &mut rng)?;
// Special case for appservice users: we don't insert them into the database
// We just record the user's information in the state and continue
state.users.insert(
CompactString::new(&mas_user.username),
UserInfo {
mas_user_id: None,
flags,
},
);
continue;
}
let mut flags = UserFlags::empty();
if bool::from(user.admin) {
flags |= UserFlags::IS_SYNAPSE_ADMIN;
}
if bool::from(user.deactivated) {
flags |= UserFlags::IS_DEACTIVATED;
}
if bool::from(user.is_guest) {
flags |= UserFlags::IS_GUEST;
}
if user.appservice_id.is_some() {
flags |= UserFlags::IS_APPSERVICE;
state.users.insert(
CompactString::new(&mas_user.username),
UserInfo {
mas_user_id: Some(mas_user.user_id),
flags,
},
);
// Special case for appservice users: we don't insert them into the database
// We just record the user's information in the state and continue
state.users.insert(
CompactString::new(&mas_user.username),
UserInfo {
mas_user_id: None,
flags,
},
);
continue;
}
user_buffer
.write(&mut mas, mas_user)
.await
.into_mas("writing user")?;
state.users.insert(
CompactString::new(&mas_user.username),
UserInfo {
mas_user_id: Some(mas_user.user_id),
flags,
},
);
if let Some(mas_password) = mas_password_opt {
password_buffer
.write(&mut mas, mas_password)
user_buffer
.write(&mut mas, mas_user)
.await
.into_mas("writing user")?;
if let Some(mas_password) = mas_password_opt {
password_buffer
.write(&mut mas, mas_password)
.await
.into_mas("writing password")?;
}
migrated_otel_counter.add(1, &otel_kv);
progress_counter.fetch_add(1, Ordering::Relaxed);
}
user_buffer
.finish(&mut mas)
.await
.into_mas("writing password")?;
}
}
.into_mas("writing users")?;
password_buffer
.finish(&mut mas)
.await
.into_mas("writing passwords")?;
user_buffer
.finish(&mut mas)
.await
.into_mas("writing users")?;
password_buffer
.finish(&mut mas)
.await
.into_mas("writing passwords")?;
Ok((mas, state))
}
.instrument(tracing::info_span!("ingest_task")),
);
// In case this has an error, we still want to join the task, so we look at the
// error later
let res = synapse
.read_users()
.map_err(|e| e.into_synapse("reading users"))
.forward(PollSender::new(tx).sink_map_err(|_| Error::ChannelClosed))
.inspect_err(|e| tracing::error!(error = e as &dyn std::error::Error))
.await;
let (mas, state) = task.await.into_join("user write task")??;
res?;
info!(
"users migrated in {:.1}s",
Instant::now().duration_since(start).as_secs_f64()
);
Ok((mas, state))
}
@@ -261,9 +453,14 @@ async fn migrate_threepids(
mut mas: MasWriter,
rng: &mut impl RngCore,
state: MigrationState,
progress_counter: &AtomicU32,
migrated_otel_counter: Counter<u64>,
) -> Result<(MasWriter, MigrationState), Error> {
let mut email_buffer = MasWriteBuffer::new(MasWriter::write_email_threepids);
let mut unsupported_buffer = MasWriteBuffer::new(MasWriter::write_unsupported_threepids);
let start = Instant::now();
let otel_kv = [KeyValue::new(K_ENTITY, V_ENTITY_THREEPIDS)];
let mut email_buffer = MasWriteBuffer::new(&mas, MasWriter::write_email_threepids);
let mut unsupported_buffer = MasWriteBuffer::new(&mas, MasWriter::write_unsupported_threepids);
let mut users_stream = pin!(synapse.read_threepids());
while let Some(threepid_res) = users_stream.next().await {
@@ -320,6 +517,9 @@ async fn migrate_threepids(
.await
.into_mas("writing unsupported threepid")?;
}
migrated_otel_counter.add(1, &otel_kv);
progress_counter.fetch_add(1, Ordering::Relaxed);
}
email_buffer
@@ -331,6 +531,11 @@ async fn migrate_threepids(
.await
.into_mas("writing unsupported threepids")?;
info!(
"third-party IDs migrated in {:.1}s",
Instant::now().duration_since(start).as_secs_f64()
);
Ok((mas, state))
}
@@ -344,8 +549,13 @@ async fn migrate_external_ids(
mut mas: MasWriter,
rng: &mut impl RngCore,
state: MigrationState,
progress_counter: &AtomicU32,
migrated_otel_counter: Counter<u64>,
) -> Result<(MasWriter, MigrationState), Error> {
let mut write_buffer = MasWriteBuffer::new(MasWriter::write_upstream_oauth_links);
let start = Instant::now();
let otel_kv = [KeyValue::new(K_ENTITY, V_ENTITY_EXTERNAL_IDS)];
let mut write_buffer = MasWriteBuffer::new(&mas, MasWriter::write_upstream_oauth_links);
let mut extids_stream = pin!(synapse.read_user_external_ids());
while let Some(extid_res) = extids_stream.next().await {
@@ -395,12 +605,20 @@ async fn migrate_external_ids(
)
.await
.into_mas("failed to write upstream link")?;
migrated_otel_counter.add(1, &otel_kv);
progress_counter.fetch_add(1, Ordering::Relaxed);
}
write_buffer
.finish(&mut mas)
.await
.into_mas("writing threepids")?;
.into_mas("writing upstream links")?;
info!(
"upstream links (external IDs) migrated in {:.1}s",
Instant::now().duration_since(start).as_secs_f64()
);
Ok((mas, state))
}
@@ -419,93 +637,128 @@ async fn migrate_devices(
mut mas: MasWriter,
rng: &mut impl RngCore,
mut state: MigrationState,
progress_counter: Arc<AtomicU32>,
migrated_otel_counter: Counter<u64>,
) -> Result<(MasWriter, MigrationState), Error> {
let mut devices_stream = pin!(synapse.read_devices());
let mut write_buffer = MasWriteBuffer::new(MasWriter::write_compat_sessions);
let start = Instant::now();
let otel_kv = [KeyValue::new(K_ENTITY, V_ENTITY_DEVICES)];
while let Some(device_res) = devices_stream.next().await {
let SynapseDevice {
user_id: synapse_user_id,
device_id,
display_name,
last_seen,
ip,
user_agent,
} = device_res.into_synapse("reading Synapse device")?;
let (tx, mut rx) = tokio::sync::mpsc::channel(10 * 1024 * 1024);
let username = synapse_user_id
.extract_localpart(&state.server_name)
.into_extract_localpart(synapse_user_id.clone())?
.to_owned();
let Some(user_infos) = state.users.get(username.as_str()).copied() else {
return Err(Error::MissingUserFromDependentTable {
table: "devices".to_owned(),
user: synapse_user_id,
});
};
let mut rng = rand_chacha::ChaChaRng::from_rng(rng).expect("failed to seed rng");
let task = tokio::spawn(
async move {
let mut write_buffer = MasWriteBuffer::new(&mas, MasWriter::write_compat_sessions);
let Some(mas_user_id) = user_infos.mas_user_id else {
continue;
};
while let Some(device) = rx.recv().await {
let SynapseDevice {
user_id: synapse_user_id,
device_id,
display_name,
last_seen,
ip,
user_agent,
} = device;
let username = synapse_user_id
.extract_localpart(&state.server_name)
.into_extract_localpart(synapse_user_id.clone())?
.to_owned();
let Some(user_infos) = state.users.get(username.as_str()).copied() else {
return Err(Error::MissingUserFromDependentTable {
table: "devices".to_owned(),
user: synapse_user_id,
});
};
if user_infos.flags.is_deactivated()
|| user_infos.flags.is_guest()
|| user_infos.flags.is_appservice()
{
continue;
}
let Some(mas_user_id) = user_infos.mas_user_id else {
continue;
};
let session_id = *state
.devices_to_compat_sessions
.entry((mas_user_id, CompactString::new(&device_id)))
.or_insert_with(||
if user_infos.flags.is_deactivated()
|| user_infos.flags.is_guest()
|| user_infos.flags.is_appservice()
{
continue;
}
let session_id = *state
.devices_to_compat_sessions
.entry((mas_user_id, CompactString::new(&device_id)))
.or_insert_with(||
// We don't have a creation time for this device (as it has no access token),
// so use now as a least-evil fallback.
Ulid::with_source(rng).into());
let created_at = Ulid::from(session_id).datetime().into();
Ulid::with_source(&mut rng).into());
let created_at = Ulid::from(session_id).datetime().into();
// As we're using a real IP type in the MAS database, it is possible
// that we encounter invalid IP addresses in the Synapse database.
// In that case, we should ignore them, but still log a warning.
// One special case: Synapse will record '-' as IP in some cases, we don't want
// to log about those
let last_active_ip = ip.filter(|ip| ip != "-").and_then(|ip| {
ip.parse()
.map_err(|e| {
tracing::warn!(
error = &e as &dyn std::error::Error,
mxid = %synapse_user_id,
%device_id,
%ip,
"Failed to parse device IP, ignoring"
);
})
.ok()
});
// As we're using a real IP type in the MAS database, it is possible
// that we encounter invalid IP addresses in the Synapse database.
// In that case, we should ignore them, but still log a warning.
// One special case: Synapse will record '-' as IP in some cases, we don't want
// to log about those
let last_active_ip = ip.filter(|ip| ip != "-").and_then(|ip| {
ip.parse()
.map_err(|e| {
tracing::warn!(
error = &e as &dyn std::error::Error,
mxid = %synapse_user_id,
%device_id,
%ip,
"Failed to parse device IP, ignoring"
);
})
.ok()
});
write_buffer
.write(
&mut mas,
MasNewCompatSession {
session_id,
user_id: mas_user_id,
device_id: Some(device_id),
human_name: display_name,
created_at,
is_synapse_admin: user_infos.flags.is_synapse_admin(),
last_active_at: last_seen.map(DateTime::from),
last_active_ip,
user_agent,
},
)
.await
.into_mas("writing compat sessions")?;
}
// TODO skip access tokens for deactivated users
write_buffer
.write(
&mut mas,
MasNewCompatSession {
session_id,
user_id: mas_user_id,
device_id: Some(device_id),
human_name: display_name,
created_at,
is_synapse_admin: user_infos.flags.is_synapse_admin(),
last_active_at: last_seen.map(DateTime::from),
last_active_ip,
user_agent,
},
)
.await
.into_mas("writing compat sessions")?;
write_buffer
.finish(&mut mas)
.await
.into_mas("writing compat sessions")?;
migrated_otel_counter.add(1, &otel_kv);
progress_counter.fetch_add(1, Ordering::Relaxed);
}
write_buffer
.finish(&mut mas)
.await
.into_mas("writing compat sessions")?;
Ok((mas, state))
}
.instrument(tracing::info_span!("ingest_task")),
);
// In case this has an error, we still want to join the task, so we look at the
// error later
let res = synapse
.read_devices()
.map_err(|e| e.into_synapse("reading devices"))
.forward(PollSender::new(tx).sink_map_err(|_| Error::ChannelClosed))
.inspect_err(|e| tracing::error!(error = e as &dyn std::error::Error))
.await;
let (mas, state) = task.await.into_join("device write task")??;
res?;
info!(
"devices migrated in {:.1}s",
Instant::now().duration_since(start).as_secs_f64()
);
Ok((mas, state))
}
@@ -519,107 +772,146 @@ async fn migrate_unrefreshable_access_tokens(
clock: &dyn Clock,
rng: &mut impl RngCore,
mut state: MigrationState,
progress_counter: Arc<AtomicU32>,
migrated_otel_counter: Counter<u64>,
) -> Result<(MasWriter, MigrationState), Error> {
let mut token_stream = pin!(synapse.read_unrefreshable_access_tokens());
let mut write_buffer = MasWriteBuffer::new(MasWriter::write_compat_access_tokens);
let mut deviceless_session_write_buffer = MasWriteBuffer::new(MasWriter::write_compat_sessions);
let start = Instant::now();
let otel_kv = [KeyValue::new(
K_ENTITY,
V_ENTITY_NONREFRESHABLE_ACCESS_TOKENS,
)];
while let Some(token_res) = token_stream.next().await {
let SynapseAccessToken {
user_id: synapse_user_id,
device_id,
token,
valid_until_ms,
last_validated,
} = token_res.into_synapse("reading Synapse access token")?;
let (tx, mut rx) = tokio::sync::mpsc::channel(10 * 1024 * 1024);
let username = synapse_user_id
.extract_localpart(&state.server_name)
.into_extract_localpart(synapse_user_id.clone())?
.to_owned();
let Some(user_infos) = state.users.get(username.as_str()).copied() else {
return Err(Error::MissingUserFromDependentTable {
table: "access_tokens".to_owned(),
user: synapse_user_id,
});
};
let now = clock.now();
let mut rng = rand_chacha::ChaChaRng::from_rng(rng).expect("failed to seed rng");
let task = tokio::spawn(
async move {
let mut write_buffer = MasWriteBuffer::new(&mas, MasWriter::write_compat_access_tokens);
let mut deviceless_session_write_buffer =
MasWriteBuffer::new(&mas, MasWriter::write_compat_sessions);
let Some(mas_user_id) = user_infos.mas_user_id else {
continue;
};
while let Some(token) = rx.recv().await {
let SynapseAccessToken {
user_id: synapse_user_id,
device_id,
token,
valid_until_ms,
last_validated,
} = token;
let username = synapse_user_id
.extract_localpart(&state.server_name)
.into_extract_localpart(synapse_user_id.clone())?
.to_owned();
let Some(user_infos) = state.users.get(username.as_str()).copied() else {
return Err(Error::MissingUserFromDependentTable {
table: "access_tokens".to_owned(),
user: synapse_user_id,
});
};
if user_infos.flags.is_deactivated()
|| user_infos.flags.is_guest()
|| user_infos.flags.is_appservice()
{
continue;
}
let Some(mas_user_id) = user_infos.mas_user_id else {
continue;
};
// It's not always accurate, but last_validated is *often* the creation time of
// the device If we don't have one, then use the current time as a
// fallback.
let created_at = last_validated.map_or_else(|| clock.now(), DateTime::from);
if user_infos.flags.is_deactivated()
|| user_infos.flags.is_guest()
|| user_infos.flags.is_appservice()
{
continue;
}
let session_id = if let Some(device_id) = device_id {
// Use the existing device_id if this is the second token for a device
*state
.devices_to_compat_sessions
.entry((mas_user_id, CompactString::new(&device_id)))
.or_insert_with(|| {
Uuid::from(Ulid::from_datetime_with_source(created_at.into(), rng))
})
} else {
// If this is a deviceless access token, create a deviceless compat session
// for it (since otherwise we won't create one whilst migrating devices)
let deviceless_session_id =
Uuid::from(Ulid::from_datetime_with_source(created_at.into(), rng));
// It's not always accurate, but last_validated is *often* the creation time of
// the device If we don't have one, then use the current time as a
// fallback.
let created_at = last_validated.map_or_else(|| now, DateTime::from);
deviceless_session_write_buffer
.write(
&mut mas,
MasNewCompatSession {
session_id: deviceless_session_id,
user_id: mas_user_id,
device_id: None,
human_name: None,
created_at,
is_synapse_admin: false,
last_active_at: None,
last_active_ip: None,
user_agent: None,
},
)
let session_id = if let Some(device_id) = device_id {
// Use the existing device_id if this is the second token for a device
*state
.devices_to_compat_sessions
.entry((mas_user_id, CompactString::new(&device_id)))
.or_insert_with(|| {
Uuid::from(Ulid::from_datetime_with_source(created_at.into(), &mut rng))
})
} else {
// If this is a deviceless access token, create a deviceless compat session
// for it (since otherwise we won't create one whilst migrating devices)
let deviceless_session_id =
Uuid::from(Ulid::from_datetime_with_source(created_at.into(), &mut rng));
deviceless_session_write_buffer
.write(
&mut mas,
MasNewCompatSession {
session_id: deviceless_session_id,
user_id: mas_user_id,
device_id: None,
human_name: None,
created_at,
is_synapse_admin: false,
last_active_at: None,
last_active_ip: None,
user_agent: None,
},
)
.await
.into_mas("failed to write deviceless compat sessions")?;
deviceless_session_id
};
let token_id =
Uuid::from(Ulid::from_datetime_with_source(created_at.into(), &mut rng));
write_buffer
.write(
&mut mas,
MasNewCompatAccessToken {
token_id,
session_id,
access_token: token,
created_at,
expires_at: valid_until_ms.map(DateTime::from),
},
)
.await
.into_mas("writing compat access tokens")?;
migrated_otel_counter.add(1, &otel_kv);
progress_counter.fetch_add(1, Ordering::Relaxed);
}
write_buffer
.finish(&mut mas)
.await
.into_mas("failed to write deviceless compat sessions")?;
.into_mas("writing compat access tokens")?;
deviceless_session_write_buffer
.finish(&mut mas)
.await
.into_mas("writing deviceless compat sessions")?;
deviceless_session_id
};
Ok((mas, state))
}
.instrument(tracing::info_span!("ingest_task")),
);
let token_id = Uuid::from(Ulid::from_datetime_with_source(created_at.into(), rng));
// In case this has an error, we still want to join the task, so we look at the
// error later
let res = synapse
.read_unrefreshable_access_tokens()
.map_err(|e| e.into_synapse("reading tokens"))
.forward(PollSender::new(tx).sink_map_err(|_| Error::ChannelClosed))
.inspect_err(|e| tracing::error!(error = e as &dyn std::error::Error))
.await;
write_buffer
.write(
&mut mas,
MasNewCompatAccessToken {
token_id,
session_id,
access_token: token,
created_at,
expires_at: valid_until_ms.map(DateTime::from),
},
)
.await
.into_mas("writing compat access tokens")?;
}
let (mas, state) = task.await.into_join("token write task")??;
write_buffer
.finish(&mut mas)
.await
.into_mas("writing compat access tokens")?;
deviceless_session_write_buffer
.finish(&mut mas)
.await
.into_mas("writing deviceless compat sessions")?;
res?;
info!(
"non-refreshable access tokens migrated in {:.1}s",
Instant::now().duration_since(start).as_secs_f64()
);
Ok((mas, state))
}
@@ -633,11 +925,17 @@ async fn migrate_refreshable_token_pairs(
clock: &dyn Clock,
rng: &mut impl RngCore,
mut state: MigrationState,
progress_counter: &AtomicU32,
migrated_otel_counter: Counter<u64>,
) -> Result<(MasWriter, MigrationState), Error> {
let start = Instant::now();
let otel_kv = [KeyValue::new(K_ENTITY, V_ENTITY_REFRESHABLE_TOKEN_PAIRS)];
let mut token_stream = pin!(synapse.read_refreshable_token_pairs());
let mut access_token_write_buffer = MasWriteBuffer::new(MasWriter::write_compat_access_tokens);
let mut access_token_write_buffer =
MasWriteBuffer::new(&mas, MasWriter::write_compat_access_tokens);
let mut refresh_token_write_buffer =
MasWriteBuffer::new(MasWriter::write_compat_refresh_tokens);
MasWriteBuffer::new(&mas, MasWriter::write_compat_refresh_tokens);
while let Some(token_res) = token_stream.next().await {
let SynapseRefreshableTokenPair {
@@ -711,6 +1009,9 @@ async fn migrate_refreshable_token_pairs(
)
.await
.into_mas("writing compat refresh tokens")?;
migrated_otel_counter.add(1, &otel_kv);
progress_counter.fetch_add(1, Ordering::Relaxed);
}
access_token_write_buffer
@@ -723,6 +1024,11 @@ async fn migrate_refreshable_token_pairs(
.await
.into_mas("writing compat refresh tokens")?;
info!(
"refreshable token pairs migrated in {:.1}s",
Instant::now().duration_since(start).as_secs_f64()
);
Ok((mas, state))
}
@@ -748,7 +1054,8 @@ fn transform_user(
user_id,
username,
created_at: user.creation_ts.into(),
locked_at: bool::from(user.deactivated).then_some(user.creation_ts.into()),
locked_at: user.locked.then_some(user.creation_ts.into()),
deactivated_at: bool::from(user.deactivated).then_some(user.creation_ts.into()),
can_request_admin: bool::from(user.admin),
is_guest: bool::from(user.is_guest),
};

View File

@@ -0,0 +1,54 @@
use std::sync::{Arc, atomic::AtomicU32};
use arc_swap::ArcSwap;
/// Tracker for the progress of the migration
///
/// Cloning this struct intuitively gives a 'handle' to the same counters,
/// which means it can be shared between tasks/threads.
#[derive(Clone)]
pub struct Progress {
current_stage: Arc<ArcSwap<ProgressStage>>,
}
impl Progress {
/// Sets the current stage of progress.
///
/// This is probably not cheap enough to use for every individual row,
/// so use of atomic integers for the fields that will be updated is
/// recommended.
#[inline]
pub fn set_current_stage(&self, stage: ProgressStage) {
self.current_stage.store(Arc::new(stage));
}
/// Returns the current stage of progress.
#[inline]
#[must_use]
pub fn get_current_stage(&self) -> arc_swap::Guard<Arc<ProgressStage>> {
self.current_stage.load()
}
}
impl Default for Progress {
fn default() -> Self {
Self {
current_stage: Arc::new(ArcSwap::new(Arc::new(ProgressStage::SettingUp))),
}
}
}
pub enum ProgressStage {
SettingUp,
MigratingData {
entity: &'static str,
migrated: Arc<AtomicU32>,
approx_count: u64,
},
RebuildIndex {
index_name: String,
},
RebuildConstraint {
constraint_name: String,
},
}

View File

@@ -48,21 +48,11 @@ pub enum CheckError {
)]
PasswordSchemeWrongPepper,
#[error(
"Synapse database contains {num_guests} guests which aren't supported by MAS. See https://github.com/element-hq/matrix-authentication-service/issues/1445"
)]
GuestsInDatabase { num_guests: i64 },
#[error(
"Guest support is enabled in the Synapse configuration. Guests aren't supported by MAS, but if you don't have any then you could disable the option. See https://github.com/element-hq/matrix-authentication-service/issues/1445"
)]
GuestsEnabled,
#[error(
"Synapse database contains {num_non_email_3pids} non-email 3PIDs (probably phone numbers), which are not supported by MAS."
)]
NonEmailThreepidsInDatabase { num_non_email_3pids: i64 },
#[error(
"Synapse config has `enable_3pid_changes` explicitly enabled, which must be disabled or removed."
)]
@@ -125,6 +115,16 @@ pub enum CheckWarning {
"Synapse config has a registration CAPTCHA enabled, but no CAPTCHA has been configured in MAS. You may wish to manually configure this."
)]
ShouldPortRegistrationCaptcha,
#[error(
"Synapse database contains {num_guests} guests which will be migrated are not supported by MAS. See https://github.com/element-hq/matrix-authentication-service/issues/1445"
)]
GuestsInDatabase { num_guests: i64 },
#[error(
"Synapse database contains {num_non_email_3pids} non-email 3PIDs (probably phone numbers), which will be migrated but are not supported by MAS."
)]
NonEmailThreepidsInDatabase { num_non_email_3pids: i64 },
}
/// Check that the Synapse configuration is sane for migration.
@@ -140,15 +140,6 @@ pub fn synapse_config_check(synapse_config: &Config) -> (Vec<CheckWarning>, Vec<
warnings.push(CheckWarning::DisableUserConsentAfterMigration);
}
// TODO check the settings directly against the MAS settings
for provider in synapse_config.all_oidc_providers().values() {
if let Some(ref issuer) = provider.issuer {
warnings.push(CheckWarning::UpstreamOidcProvider {
issuer: issuer.clone(),
});
}
}
// TODO provide guidance on migrating these
if synapse_config.cas_config.enabled {
warnings.push(CheckWarning::ExternalAuthSystem("CAS"));
@@ -269,13 +260,13 @@ pub async fn synapse_database_check(
}
let mut errors = Vec::new();
let warnings = Vec::new();
let mut warnings = Vec::new();
let num_guests: i64 = query_scalar("SELECT COUNT(1) FROM users WHERE is_guest <> 0")
.fetch_one(&mut *synapse_connection)
.await?;
if num_guests > 0 {
errors.push(CheckError::GuestsInDatabase { num_guests });
warnings.push(CheckWarning::GuestsInDatabase { num_guests });
}
let num_non_email_3pids: i64 =
@@ -283,7 +274,7 @@ pub async fn synapse_database_check(
.fetch_one(&mut *synapse_connection)
.await?;
if num_non_email_3pids > 0 {
errors.push(CheckError::NonEmailThreepidsInDatabase {
warnings.push(CheckWarning::NonEmailThreepidsInDatabase {
num_non_email_3pids,
});
}

View File

@@ -185,6 +185,8 @@ pub struct SynapseUser {
pub admin: SynapseBool,
/// Whether the user is deactivated
pub deactivated: SynapseBool,
/// Whether the user is locked
pub locked: bool,
/// When the user was created
pub creation_ts: SecondsTimestamp,
/// Whether the user is a guest.
@@ -266,6 +268,10 @@ const TABLES_TO_LOCK: &[&str] = &[
pub struct SynapseRowCounts {
pub users: usize,
pub devices: usize,
pub threepids: usize,
pub external_ids: usize,
pub access_tokens: usize,
pub refresh_tokens: usize,
}
pub struct SynapseReader<'c> {
@@ -336,33 +342,91 @@ impl<'conn> SynapseReader<'conn> {
///
/// - An underlying database error
pub async fn count_rows(&mut self) -> Result<SynapseRowCounts, Error> {
let users: usize = sqlx::query_scalar::<_, i64>(
// We don't get to filter out application service users by using this estimate,
// which is a shame, but on a large database this is way faster.
// On matrix.org, counting users and devices properly takes around 1m10s,
// which is unnecessary extra downtime during the migration, just to
// show a more accurate progress bar and size a hash map accurately.
let users = sqlx::query_scalar::<_, i64>(
"
SELECT COUNT(1) FROM users
WHERE appservice_id IS NULL
SELECT reltuples::bigint AS estimate FROM pg_class WHERE oid = 'users'::regclass;
",
)
.fetch_one(&mut *self.txn)
.await
.into_database("counting Synapse users")?
.into_database("estimating count of users")?
.max(0)
.try_into()
.unwrap_or(usize::MAX);
let devices = sqlx::query_scalar::<_, i64>(
"
SELECT COUNT(1) FROM devices
WHERE NOT hidden
SELECT reltuples::bigint AS estimate FROM pg_class WHERE oid = 'devices'::regclass;
",
)
.fetch_one(&mut *self.txn)
.await
.into_database("counting Synapse devices")?
.into_database("estimating count of devices")?
.max(0)
.try_into()
.unwrap_or(usize::MAX);
Ok(SynapseRowCounts { users, devices })
let threepids = sqlx::query_scalar::<_, i64>(
"
SELECT reltuples::bigint AS estimate FROM pg_class WHERE oid = 'user_threepids'::regclass;
"
)
.fetch_one(&mut *self.txn)
.await
.into_database("estimating count of threepids")?
.max(0)
.try_into()
.unwrap_or(usize::MAX);
let access_tokens = sqlx::query_scalar::<_, i64>(
"
SELECT reltuples::bigint AS estimate FROM pg_class WHERE oid = 'access_tokens'::regclass;
"
)
.fetch_one(&mut *self.txn)
.await
.into_database("estimating count of access tokens")?
.max(0)
.try_into()
.unwrap_or(usize::MAX);
let refresh_tokens = sqlx::query_scalar::<_, i64>(
"
SELECT reltuples::bigint AS estimate FROM pg_class WHERE oid = 'refresh_tokens'::regclass;
"
)
.fetch_one(&mut *self.txn)
.await
.into_database("estimating count of refresh tokens")?
.max(0)
.try_into()
.unwrap_or(usize::MAX);
let external_ids = sqlx::query_scalar::<_, i64>(
"
SELECT reltuples::bigint AS estimate FROM pg_class WHERE oid = 'user_external_ids'::regclass;
"
)
.fetch_one(&mut *self.txn)
.await
.into_database("estimating count of external IDs")?
.max(0)
.try_into()
.unwrap_or(usize::MAX);
Ok(SynapseRowCounts {
users,
devices,
threepids,
external_ids,
access_tokens,
refresh_tokens,
})
}
/// Reads Synapse users, excluding application service users (which do not
@@ -371,7 +435,7 @@ impl<'conn> SynapseReader<'conn> {
sqlx::query_as(
"
SELECT
name, password_hash, admin, deactivated, creation_ts, is_guest, appservice_id
name, password_hash, admin, deactivated, locked, creation_ts, is_guest, appservice_id
FROM users
",
)
@@ -427,6 +491,12 @@ impl<'conn> SynapseReader<'conn> {
/// Reads unrefreshable access tokens from the Synapse database.
/// This does not include access tokens used for puppetting users, as those
/// are not supported by MAS.
///
/// This also excludes access tokens whose referenced device ID does not
/// exist, except for deviceless access tokens.
/// (It's unclear what mechanism led to these, but since Synapse has no
/// foreign key constraints and is not consistently atomic about this,
/// it should be no surprise really)
pub fn read_unrefreshable_access_tokens(
&mut self,
) -> impl Stream<Item = Result<SynapseAccessToken, Error>> + '_ {
@@ -435,7 +505,15 @@ impl<'conn> SynapseReader<'conn> {
SELECT
at0.user_id, at0.device_id, at0.token, at0.valid_until_ms, at0.last_validated
FROM access_tokens at0
INNER JOIN devices USING (user_id, device_id)
WHERE at0.puppets_user_id IS NULL AND at0.refresh_token_id IS NULL
UNION ALL
SELECT
at0.user_id, at0.device_id, at0.token, at0.valid_until_ms, at0.last_validated
FROM access_tokens at0
WHERE at0.puppets_user_id IS NULL AND at0.refresh_token_id IS NULL AND at0.device_id IS NULL
",
)
.fetch(&mut *self.txn)
@@ -459,7 +537,8 @@ impl<'conn> SynapseReader<'conn> {
SELECT
rt0.user_id, rt0.device_id, at0.token AS access_token, rt0.token AS refresh_token, at0.valid_until_ms, at0.last_validated
FROM refresh_tokens rt0
LEFT JOIN access_tokens at0 ON at0.refresh_token_id = rt0.id AND at0.user_id = rt0.user_id AND at0.device_id = rt0.device_id
INNER JOIN devices USING (user_id, device_id)
INNER JOIN access_tokens at0 ON at0.refresh_token_id = rt0.id AND at0.user_id = rt0.user_id AND at0.device_id = rt0.device_id
LEFT JOIN access_tokens at1 ON at1.refresh_token_id = rt0.next_token_id
WHERE NOT at1.used OR at1.used IS NULL
",
@@ -485,7 +564,6 @@ mod test {
},
};
// TODO test me
static MIGRATOR: Migrator = sqlx::migrate!("./test_synapse_migrations");
#[sqlx::test(migrator = "MIGRATOR", fixtures("user_alice"))]
@@ -552,7 +630,10 @@ mod test {
assert_debug_snapshot!(devices);
}
#[sqlx::test(migrator = "MIGRATOR", fixtures("user_alice", "access_token_alice"))]
#[sqlx::test(
migrator = "MIGRATOR",
fixtures("user_alice", "devices_alice", "access_token_alice")
)]
async fn test_read_access_token(pool: PgPool) {
let mut conn = pool.acquire().await.expect("failed to get connection");
let mut reader = SynapseReader::new(&mut conn, false)
@@ -571,7 +652,7 @@ mod test {
/// Tests that puppetting access tokens are ignored.
#[sqlx::test(
migrator = "MIGRATOR",
fixtures("user_alice", "access_token_alice_with_puppet")
fixtures("user_alice", "devices_alice", "access_token_alice_with_puppet")
)]
async fn test_read_access_token_puppet(pool: PgPool) {
let mut conn = pool.acquire().await.expect("failed to get connection");
@@ -590,7 +671,7 @@ mod test {
#[sqlx::test(
migrator = "MIGRATOR",
fixtures("user_alice", "access_token_alice_with_refresh_token")
fixtures("user_alice", "devices_alice", "access_token_alice_with_refresh_token")
)]
async fn test_read_access_and_refresh_tokens(pool: PgPool) {
let mut conn = pool.acquire().await.expect("failed to get connection");
@@ -619,7 +700,11 @@ mod test {
#[sqlx::test(
migrator = "MIGRATOR",
fixtures("user_alice", "access_token_alice_with_unused_refresh_token")
fixtures(
"user_alice",
"devices_alice",
"access_token_alice_with_unused_refresh_token"
)
)]
async fn test_read_access_and_unused_refresh_tokens(pool: PgPool) {
let mut conn = pool.acquire().await.expect("failed to get connection");

View File

@@ -16,6 +16,7 @@ expression: users
deactivated: SynapseBool(
false,
),
locked: false,
creation_ts: SecondsTimestamp(
2018-06-30T21:26:02Z,
),

View File

@@ -0,0 +1,32 @@
use std::sync::LazyLock;
use opentelemetry::{InstrumentationScope, metrics::Meter};
use opentelemetry_semantic_conventions as semcov;
static SCOPE: LazyLock<InstrumentationScope> = LazyLock::new(|| {
InstrumentationScope::builder(env!("CARGO_PKG_NAME"))
.with_version(env!("CARGO_PKG_VERSION"))
.with_schema_url(semcov::SCHEMA_URL)
.build()
});
pub static METER: LazyLock<Meter> =
LazyLock::new(|| opentelemetry::global::meter_with_scope(SCOPE.clone()));
/// Attribute key for syn2mas.entity metrics representing what entity.
pub const K_ENTITY: &str = "entity";
/// Attribute value for syn2mas.entity metrics representing users.
pub const V_ENTITY_USERS: &str = "users";
/// Attribute value for syn2mas.entity metrics representing devices.
pub const V_ENTITY_DEVICES: &str = "devices";
/// Attribute value for syn2mas.entity metrics representing threepids.
pub const V_ENTITY_THREEPIDS: &str = "threepids";
/// Attribute value for syn2mas.entity metrics representing external IDs.
pub const V_ENTITY_EXTERNAL_IDS: &str = "external_ids";
/// Attribute value for syn2mas.entity metrics representing non-refreshable
/// access token entities.
pub const V_ENTITY_NONREFRESHABLE_ACCESS_TOKENS: &str = "nonrefreshable_access_tokens";
/// Attribute value for syn2mas.entity metrics representing refreshable
/// access/refresh token pairs.
pub const V_ENTITY_REFRESHABLE_TOKEN_PAIRS: &str = "refreshable_token_pairs";

View File

@@ -172,14 +172,15 @@ const MAX_CONCURRENT_JOBS: usize = 10;
const MAX_JOBS_TO_FETCH: usize = 5;
// How many attempts a job should be retried
const MAX_ATTEMPTS: usize = 5;
const MAX_ATTEMPTS: usize = 10;
/// Returns the delay to wait before retrying a job
///
/// Uses an exponential backoff: 1s, 2s, 4s, 8s, 16s
/// Uses an exponential backoff: 5s, 10s, 20s, 40s, 1m20s, 2m40s, 5m20s, 10m50s,
/// 21m40s, 43m20s
fn retry_delay(attempt: usize) -> Duration {
let attempt = u32::try_from(attempt).unwrap_or(u32::MAX);
Duration::milliseconds(2_i64.saturating_pow(attempt) * 1000)
Duration::milliseconds(2_i64.saturating_pow(attempt) * 5_000)
}
type JobResult = Result<(), JobError>;

View File

@@ -11,7 +11,7 @@ use mas_storage::{
compat::CompatSessionFilter,
oauth2::OAuth2SessionFilter,
queue::{DeactivateUserJob, ReactivateUserJob},
user::{BrowserSessionFilter, UserRepository},
user::{BrowserSessionFilter, UserEmailFilter, UserRepository},
};
use tracing::info;
@@ -42,7 +42,7 @@ impl RunnableJob for DeactivateUserJob {
.context("User not found")
.map_err(JobError::fail)?;
// Let's first lock the user
// Let's first lock & deactivate the user
let user = repo
.user()
.lock(&clock, user)
@@ -50,6 +50,13 @@ impl RunnableJob for DeactivateUserJob {
.context("Failed to lock user")
.map_err(JobError::retry)?;
let user = repo
.user()
.deactivate(&clock, user)
.await
.context("Failed to deactivate user")
.map_err(JobError::retry)?;
// Kill all sessions for the user
let n = repo
.browser_session()
@@ -81,6 +88,14 @@ impl RunnableJob for DeactivateUserJob {
.map_err(JobError::retry)?;
info!(affected = n, "Killed all compatibility sessions for user");
// Delete all the email addresses for the user
let n = repo
.user_email()
.remove_bulk(UserEmailFilter::new().for_user(&user))
.await
.map_err(JobError::retry)?;
info!(affected = n, "Removed all email addresses for user");
// Before calling back to the homeserver, commit the changes to the database, as
// we want the user to be locked out as soon as possible
repo.save().await.map_err(JobError::retry)?;

Some files were not shown because too many files have changed in this diff Show More