syn2mas: better performance, output tweaks, tracing tweaks, access token fixes (#4175)

This commit is contained in:
reivilibre
2025-03-12 10:56:56 +00:00
committed by GitHub
10 changed files with 535 additions and 313 deletions

3
Cargo.lock generated
View File

@@ -6108,11 +6108,14 @@ dependencies = [
"mas-storage",
"mas-storage-pg",
"rand",
"rand_chacha",
"rustc-hash 2.1.1",
"serde",
"sqlx",
"thiserror 2.0.11",
"thiserror-ext",
"tokio",
"tokio-util",
"tracing",
"ulid",
"uuid",

View File

@@ -80,6 +80,7 @@ enum Subcommand {
const NUM_WRITER_CONNECTIONS: usize = 8;
impl Options {
#[tracing::instrument("cli.syn2mas.run", skip_all)]
#[allow(clippy::too_many_lines)]
pub async fn run(self, figment: &Figment) -> anyhow::Result<ExitCode> {
warn!(
@@ -173,14 +174,14 @@ impl Options {
// Display errors and warnings
if !check_errors.is_empty() {
eprintln!("===== Errors =====");
eprintln!("\n\n===== Errors =====");
eprintln!("These issues prevent migrating from Synapse to MAS right now:\n");
for error in &check_errors {
eprintln!("{error}\n");
}
}
if !check_warnings.is_empty() {
eprintln!("===== Warnings =====");
eprintln!("\n\n===== Warnings =====");
eprintln!(
"These potential issues should be considered before migrating from Synapse to MAS right now:\n"
);
@@ -220,6 +221,7 @@ impl Options {
// TODO how should we handle warnings at this stage?
// TODO this dry-run flag should be set to false in real circumstances !!!
let reader = SynapseReader::new(&mut syn_conn, true).await?;
let mut writer_mas_connections = Vec::with_capacity(NUM_WRITER_CONNECTIONS);
for _ in 0..NUM_WRITER_CONNECTIONS {
@@ -234,6 +236,7 @@ impl Options {
// TODO progress reporting
let mas_matrix = MatrixConfig::extract(figment)?;
eprintln!("\n\n");
syn2mas::migrate(
reader,
writer,

View File

@@ -18,13 +18,16 @@ serde.workspace = true
thiserror.workspace = true
thiserror-ext.workspace = true
tokio.workspace = true
tokio-util.workspace = true
sqlx.workspace = true
chrono.workspace = true
compact_str.workspace = true
tracing.workspace = true
futures-util = "0.3.31"
rustc-hash = "2.1.1"
rand.workspace = true
rand_chacha = "0.3.1"
uuid = "1.15.1"
ulid = { workspace = true, features = ["uuid"] }

View File

@@ -8,6 +8,9 @@ mod synapse_reader;
mod migration;
type RandomState = rustc_hash::FxBuildHasher;
type HashMap<K, V> = rustc_hash::FxHashMap<K, V>;
pub use self::{
mas_writer::{MasWriter, checks::mas_pre_migration_checks, locking::LockedMasDatabase},
migration::migrate,

View File

@@ -10,6 +10,7 @@
use thiserror::Error;
use thiserror_ext::ContextInto;
use tracing::Instrument as _;
use super::{MAS_TABLES_AFFECTED_BY_MIGRATION, is_syn2mas_in_progress, locking::LockedMasDatabase};
@@ -46,7 +47,7 @@ pub enum Error {
/// - If any MAS tables involved in the migration are not empty.
/// - If we can't check whether syn2mas is already in progress on this database
/// or not.
#[tracing::instrument(skip_all)]
#[tracing::instrument(name = "syn2mas.mas_pre_migration_checks", skip_all)]
pub async fn mas_pre_migration_checks(mas_connection: &mut LockedMasDatabase) -> Result<(), Error> {
if is_syn2mas_in_progress(mas_connection.as_mut())
.await
@@ -60,8 +61,11 @@ pub async fn mas_pre_migration_checks(mas_connection: &mut LockedMasDatabase) ->
// empty database.
for &table in MAS_TABLES_AFFECTED_BY_MIGRATION {
let row_present = sqlx::query(&format!("SELECT 1 AS dummy FROM {table} LIMIT 1"))
let query = format!("SELECT 1 AS dummy FROM {table} LIMIT 1");
let span = tracing::info_span!("db.query", db.query.text = query);
let row_present = sqlx::query(&query)
.fetch_optional(mas_connection.as_mut())
.instrument(span)
.await
.into_maybe_not_mas(table)?
.is_some();

View File

@@ -3,8 +3,10 @@
// SPDX-License-Identifier: AGPL-3.0-only
// Please see LICENSE in the repository root for full details.
use std::time::Instant;
use sqlx::PgConnection;
use tracing::debug;
use tracing::{debug, info};
use super::{Error, IntoDatabase};
@@ -109,15 +111,20 @@ pub async fn drop_index(conn: &mut PgConnection, index: &IndexDescription) -> Re
/// Restores (recreates) a constraint.
///
/// The constraint must not exist prior to this call.
#[tracing::instrument(name = "syn2mas.restore_constraint", skip_all, fields(constraint.name = constraint.name))]
pub async fn restore_constraint(
conn: &mut PgConnection,
constraint: &ConstraintDescription,
) -> Result<(), Error> {
let start = Instant::now();
let ConstraintDescription {
name,
table_name,
definition,
} = &constraint;
info!("rebuilding constraint {name}");
sqlx::query(&format!(
"ALTER TABLE {table_name} ADD CONSTRAINT {name} {definition};"
))
@@ -127,13 +134,21 @@ pub async fn restore_constraint(
format!("failed to recreate constraint {name} on {table_name} with {definition}")
})?;
info!(
"constraint {name} rebuilt in {:.1}s",
Instant::now().duration_since(start).as_secs_f64()
);
Ok(())
}
/// Restores (recreates) a index.
///
/// The index must not exist prior to this call.
#[tracing::instrument(name = "syn2mas.restore_index", skip_all, fields(index.name = index.name))]
pub async fn restore_index(conn: &mut PgConnection, index: &IndexDescription) -> Result<(), Error> {
let start = Instant::now();
let IndexDescription {
name,
table_name,
@@ -147,5 +162,10 @@ pub async fn restore_index(conn: &mut PgConnection, index: &IndexDescription) ->
format!("failed to recreate index {name} on {table_name} with {definition}")
})?;
info!(
"index {name} rebuilt in {:.1}s",
Instant::now().duration_since(start).as_secs_f64()
);
Ok(())
}

View File

@@ -7,7 +7,14 @@
//!
//! This module is responsible for writing new records to MAS' database.
use std::{fmt::Display, net::IpAddr};
use std::{
fmt::Display,
net::IpAddr,
sync::{
Arc,
atomic::{AtomicU32, Ordering},
},
};
use chrono::{DateTime, Utc};
use futures_util::{FutureExt, TryStreamExt, future::BoxFuture};
@@ -15,7 +22,7 @@ use sqlx::{Executor, PgConnection, query, query_as};
use thiserror::Error;
use thiserror_ext::{Construct, ContextInto};
use tokio::sync::mpsc::{self, Receiver, Sender};
use tracing::{Level, error, info, warn};
use tracing::{Instrument, Level, error, info, warn};
use uuid::{NonNilUuid, Uuid};
use self::{
@@ -44,6 +51,9 @@ pub enum Error {
#[error("inconsistent database: {0}")]
Inconsistent(String),
#[error("bug in syn2mas: write buffers not finished")]
WriteBuffersNotFinished,
#[error("{0}")]
Multiple(MultipleErrors),
}
@@ -109,18 +119,21 @@ impl WriterConnectionPool {
match self.connection_rx.recv().await {
Some(Ok(mut connection)) => {
let connection_tx = self.connection_tx.clone();
tokio::task::spawn(async move {
let to_return = match task(&mut connection).await {
Ok(()) => Ok(connection),
Err(error) => {
error!("error in writer: {error}");
Err(error)
}
};
// This should always succeed in sending unless we're already shutting
// down for some other reason.
let _: Result<_, _> = connection_tx.send(to_return).await;
});
tokio::task::spawn(
async move {
let to_return = match task(&mut connection).await {
Ok(()) => Ok(connection),
Err(error) => {
error!("error in writer: {error}");
Err(error)
}
};
// This should always succeed in sending unless we're already shutting
// down for some other reason.
let _: Result<_, _> = connection_tx.send(to_return).await;
}
.instrument(tracing::debug_span!("spawn_with_connection")),
);
Ok(())
}
@@ -188,12 +201,52 @@ impl WriterConnectionPool {
}
}
/// Small utility to make sure `finish()` is called on all write buffers
/// before committing to the database.
#[derive(Default)]
struct FinishChecker {
counter: Arc<AtomicU32>,
}
struct FinishCheckerHandle {
counter: Arc<AtomicU32>,
}
impl FinishChecker {
/// Acquire a new handle, for a task that should declare when it has
/// finished.
pub fn handle(&self) -> FinishCheckerHandle {
self.counter.fetch_add(1, Ordering::SeqCst);
FinishCheckerHandle {
counter: Arc::clone(&self.counter),
}
}
/// Check that all handles have been declared as finished.
pub fn check_all_finished(self) -> Result<(), Error> {
if self.counter.load(Ordering::SeqCst) == 0 {
Ok(())
} else {
Err(Error::WriteBuffersNotFinished)
}
}
}
impl FinishCheckerHandle {
/// Declare that the task this handle represents has been finished.
pub fn declare_finished(self) {
self.counter.fetch_sub(1, Ordering::SeqCst);
}
}
pub struct MasWriter {
conn: LockedMasDatabase,
writer_pool: WriterConnectionPool,
indices_to_restore: Vec<IndexDescription>,
constraints_to_restore: Vec<ConstraintDescription>,
write_buffer_finish_checker: FinishChecker,
}
pub struct MasNewUser {
@@ -337,7 +390,7 @@ impl MasWriter {
///
/// - If the database connection experiences an error.
#[allow(clippy::missing_panics_doc)] // not real
#[tracing::instrument(skip_all)]
#[tracing::instrument(name = "syn2mas.mas_writer.new", skip_all)]
pub async fn new(
mut conn: LockedMasDatabase,
mut writer_connections: Vec<PgConnection>,
@@ -454,6 +507,7 @@ impl MasWriter {
writer_pool: WriterConnectionPool::new(writer_connections),
indices_to_restore,
constraints_to_restore,
write_buffer_finish_checker: FinishChecker::default(),
})
}
@@ -521,6 +575,8 @@ impl MasWriter {
/// - If the database connection experiences an error.
#[tracing::instrument(skip_all)]
pub async fn finish(mut self) -> Result<PgConnection, Error> {
self.write_buffer_finish_checker.check_all_finished()?;
// Commit all writer transactions to the database.
self.writer_pool
.finish()
@@ -1041,28 +1097,24 @@ type WriteBufferFlusher<T> =
/// A buffer for writing rows to the MAS database.
/// Generic over the type of rows.
///
/// # Panics
///
/// Panics if dropped before `finish()` has been called.
pub struct MasWriteBuffer<T> {
rows: Vec<T>,
flusher: WriteBufferFlusher<T>,
finished: bool,
finish_checker_handle: FinishCheckerHandle,
}
impl<T> MasWriteBuffer<T> {
pub fn new(flusher: WriteBufferFlusher<T>) -> Self {
pub fn new(writer: &MasWriter, flusher: WriteBufferFlusher<T>) -> Self {
MasWriteBuffer {
rows: Vec::with_capacity(WRITE_BUFFER_BATCH_SIZE),
flusher,
finished: false,
finish_checker_handle: writer.write_buffer_finish_checker.handle(),
}
}
pub async fn finish(mut self, writer: &mut MasWriter) -> Result<(), Error> {
self.finished = true;
self.flush(writer).await?;
self.finish_checker_handle.declare_finished();
Ok(())
}
@@ -1085,12 +1137,6 @@ impl<T> MasWriteBuffer<T> {
}
}
impl<T> Drop for MasWriteBuffer<T> {
fn drop(&mut self) {
assert!(self.finished, "MasWriteBuffer dropped but not finished!");
}
}
#[cfg(test)]
mod test {
use std::collections::{BTreeMap, BTreeSet};

View File

@@ -11,21 +11,22 @@
//! This module does not implement any of the safety checks that should be run
//! *before* the migration.
use std::{collections::HashMap, pin::pin};
use std::{pin::pin, time::Instant};
use chrono::{DateTime, Utc};
use compact_str::CompactString;
use futures_util::StreamExt as _;
use futures_util::{SinkExt, StreamExt as _, TryFutureExt, TryStreamExt as _};
use mas_storage::Clock;
use rand::RngCore;
use rand::{RngCore, SeedableRng};
use thiserror::Error;
use thiserror_ext::ContextInto;
use tracing::Level;
use tokio_util::sync::PollSender;
use tracing::{Instrument as _, Level, info};
use ulid::Ulid;
use uuid::{NonNilUuid, Uuid};
use crate::{
SynapseReader,
HashMap, RandomState, SynapseReader,
mas_writer::{
self, MasNewCompatAccessToken, MasNewCompatRefreshToken, MasNewCompatSession,
MasNewEmailThreepid, MasNewUnsupportedThreepid, MasNewUpstreamOauthLink, MasNewUser,
@@ -54,6 +55,15 @@ pub enum Error {
source: ExtractLocalpartError,
user: FullUserId,
},
#[error("channel closed")]
ChannelClosed,
#[error("task failed ({context}): {source}")]
Join {
source: tokio::task::JoinError,
context: String,
},
#[error("user {user} was not found for migration but a row in {table} was found for them")]
MissingUserFromDependentTable { table: String, user: FullUserId },
#[error(
@@ -114,7 +124,7 @@ struct MigrationState {
/// A mapping of Synapse external ID providers to MAS upstream OAuth 2.0
/// provider ID
provider_id_mapping: HashMap<String, Uuid>,
provider_id_mapping: std::collections::HashMap<String, Uuid>,
}
/// Performs a migration from Synapse's database to MAS' database.
@@ -136,14 +146,19 @@ pub async fn migrate(
server_name: String,
clock: &dyn Clock,
rng: &mut impl RngCore,
provider_id_mapping: HashMap<String, Uuid>,
provider_id_mapping: std::collections::HashMap<String, Uuid>,
) -> Result<(), Error> {
let counts = synapse.count_rows().await.into_synapse("counting users")?;
let state = MigrationState {
server_name,
users: HashMap::with_capacity(counts.users),
devices_to_compat_sessions: HashMap::with_capacity(counts.devices),
// We oversize the hashmaps, as the estimates are innaccurate, and we would like to avoid
// reallocations.
users: HashMap::with_capacity_and_hasher(counts.users * 9 / 8, RandomState::default()),
devices_to_compat_sessions: HashMap::with_capacity_and_hasher(
counts.devices * 9 / 8,
RandomState::default(),
),
provider_id_mapping,
};
@@ -175,82 +190,110 @@ async fn migrate_users(
mut state: MigrationState,
rng: &mut impl RngCore,
) -> Result<(MasWriter, MigrationState), Error> {
let mut user_buffer = MasWriteBuffer::new(MasWriter::write_users);
let mut password_buffer = MasWriteBuffer::new(MasWriter::write_passwords);
let mut users_stream = pin!(synapse.read_users());
let start = Instant::now();
while let Some(user_res) = users_stream.next().await {
let user = user_res.into_synapse("reading user")?;
let (tx, mut rx) = tokio::sync::mpsc::channel::<SynapseUser>(10 * 1024 * 1024);
// Handling an edge case: some AS users may have invalid localparts containing
// extra `:` characters. These users are ignored and a warning is logged.
if user.appservice_id.is_some()
&& user
.name
.0
.strip_suffix(&format!(":{}", state.server_name))
.is_some_and(|localpart| localpart.contains(':'))
{
tracing::warn!("AS user {} has invalid localpart, ignoring!", user.name.0);
continue;
}
let mut rng = rand_chacha::ChaCha8Rng::from_rng(rng).expect("failed to seed rng");
let task = tokio::spawn(
async move {
let mut user_buffer = MasWriteBuffer::new(&mas, MasWriter::write_users);
let mut password_buffer = MasWriteBuffer::new(&mas, MasWriter::write_passwords);
let (mas_user, mas_password_opt) = transform_user(&user, &state.server_name, rng)?;
while let Some(user) = rx.recv().await {
// Handling an edge case: some AS users may have invalid localparts containing
// extra `:` characters. These users are ignored and a warning is logged.
if user.appservice_id.is_some()
&& user
.name
.0
.strip_suffix(&format!(":{}", state.server_name))
.is_some_and(|localpart| localpart.contains(':'))
{
tracing::warn!("AS user {} has invalid localpart, ignoring!", user.name.0);
continue;
}
let mut flags = UserFlags::empty();
if bool::from(user.admin) {
flags |= UserFlags::IS_SYNAPSE_ADMIN;
}
if bool::from(user.deactivated) {
flags |= UserFlags::IS_DEACTIVATED;
}
if bool::from(user.is_guest) {
flags |= UserFlags::IS_GUEST;
}
if user.appservice_id.is_some() {
flags |= UserFlags::IS_APPSERVICE;
let (mas_user, mas_password_opt) =
transform_user(&user, &state.server_name, &mut rng)?;
// Special case for appservice users: we don't insert them into the database
// We just record the user's information in the state and continue
state.users.insert(
CompactString::new(&mas_user.username),
UserInfo {
mas_user_id: None,
flags,
},
);
continue;
}
let mut flags = UserFlags::empty();
if bool::from(user.admin) {
flags |= UserFlags::IS_SYNAPSE_ADMIN;
}
if bool::from(user.deactivated) {
flags |= UserFlags::IS_DEACTIVATED;
}
if bool::from(user.is_guest) {
flags |= UserFlags::IS_GUEST;
}
if user.appservice_id.is_some() {
flags |= UserFlags::IS_APPSERVICE;
state.users.insert(
CompactString::new(&mas_user.username),
UserInfo {
mas_user_id: Some(mas_user.user_id),
flags,
},
);
// Special case for appservice users: we don't insert them into the database
// We just record the user's information in the state and continue
state.users.insert(
CompactString::new(&mas_user.username),
UserInfo {
mas_user_id: None,
flags,
},
);
continue;
}
user_buffer
.write(&mut mas, mas_user)
.await
.into_mas("writing user")?;
state.users.insert(
CompactString::new(&mas_user.username),
UserInfo {
mas_user_id: Some(mas_user.user_id),
flags,
},
);
if let Some(mas_password) = mas_password_opt {
password_buffer
.write(&mut mas, mas_password)
user_buffer
.write(&mut mas, mas_user)
.await
.into_mas("writing user")?;
if let Some(mas_password) = mas_password_opt {
password_buffer
.write(&mut mas, mas_password)
.await
.into_mas("writing password")?;
}
}
user_buffer
.finish(&mut mas)
.await
.into_mas("writing password")?;
}
}
.into_mas("writing users")?;
password_buffer
.finish(&mut mas)
.await
.into_mas("writing passwords")?;
user_buffer
.finish(&mut mas)
.await
.into_mas("writing users")?;
password_buffer
.finish(&mut mas)
.await
.into_mas("writing passwords")?;
Ok((mas, state))
}
.instrument(tracing::info_span!("ingest_task")),
);
// In case this has an error, we still want to join the task, so we look at the
// error later
let res = synapse
.read_users()
.map_err(|e| e.into_synapse("reading users"))
.forward(PollSender::new(tx).sink_map_err(|_| Error::ChannelClosed))
.inspect_err(|e| tracing::error!(error = e as &dyn std::error::Error))
.await;
let (mas, state) = task.await.into_join("user write task")??;
res?;
info!(
"users migrated in {:.1}s",
Instant::now().duration_since(start).as_secs_f64()
);
Ok((mas, state))
}
@@ -262,8 +305,10 @@ async fn migrate_threepids(
rng: &mut impl RngCore,
state: MigrationState,
) -> Result<(MasWriter, MigrationState), Error> {
let mut email_buffer = MasWriteBuffer::new(MasWriter::write_email_threepids);
let mut unsupported_buffer = MasWriteBuffer::new(MasWriter::write_unsupported_threepids);
let start = Instant::now();
let mut email_buffer = MasWriteBuffer::new(&mas, MasWriter::write_email_threepids);
let mut unsupported_buffer = MasWriteBuffer::new(&mas, MasWriter::write_unsupported_threepids);
let mut users_stream = pin!(synapse.read_threepids());
while let Some(threepid_res) = users_stream.next().await {
@@ -331,6 +376,11 @@ async fn migrate_threepids(
.await
.into_mas("writing unsupported threepids")?;
info!(
"third-party IDs migrated in {:.1}s",
Instant::now().duration_since(start).as_secs_f64()
);
Ok((mas, state))
}
@@ -345,7 +395,9 @@ async fn migrate_external_ids(
rng: &mut impl RngCore,
state: MigrationState,
) -> Result<(MasWriter, MigrationState), Error> {
let mut write_buffer = MasWriteBuffer::new(MasWriter::write_upstream_oauth_links);
let start = Instant::now();
let mut write_buffer = MasWriteBuffer::new(&mas, MasWriter::write_upstream_oauth_links);
let mut extids_stream = pin!(synapse.read_user_external_ids());
while let Some(extid_res) = extids_stream.next().await {
@@ -400,7 +452,12 @@ async fn migrate_external_ids(
write_buffer
.finish(&mut mas)
.await
.into_mas("writing threepids")?;
.into_mas("writing upstream links")?;
info!(
"upstream links (external IDs) migrated in {:.1}s",
Instant::now().duration_since(start).as_secs_f64()
);
Ok((mas, state))
}
@@ -420,92 +477,121 @@ async fn migrate_devices(
rng: &mut impl RngCore,
mut state: MigrationState,
) -> Result<(MasWriter, MigrationState), Error> {
let mut devices_stream = pin!(synapse.read_devices());
let mut write_buffer = MasWriteBuffer::new(MasWriter::write_compat_sessions);
let start = Instant::now();
while let Some(device_res) = devices_stream.next().await {
let SynapseDevice {
user_id: synapse_user_id,
device_id,
display_name,
last_seen,
ip,
user_agent,
} = device_res.into_synapse("reading Synapse device")?;
let (tx, mut rx) = tokio::sync::mpsc::channel(10 * 1024 * 1024);
let username = synapse_user_id
.extract_localpart(&state.server_name)
.into_extract_localpart(synapse_user_id.clone())?
.to_owned();
let Some(user_infos) = state.users.get(username.as_str()).copied() else {
return Err(Error::MissingUserFromDependentTable {
table: "devices".to_owned(),
user: synapse_user_id,
});
};
let mut rng = rand_chacha::ChaChaRng::from_rng(rng).expect("failed to seed rng");
let task = tokio::spawn(
async move {
let mut write_buffer = MasWriteBuffer::new(&mas, MasWriter::write_compat_sessions);
let Some(mas_user_id) = user_infos.mas_user_id else {
continue;
};
while let Some(device) = rx.recv().await {
let SynapseDevice {
user_id: synapse_user_id,
device_id,
display_name,
last_seen,
ip,
user_agent,
} = device;
let username = synapse_user_id
.extract_localpart(&state.server_name)
.into_extract_localpart(synapse_user_id.clone())?
.to_owned();
let Some(user_infos) = state.users.get(username.as_str()).copied() else {
return Err(Error::MissingUserFromDependentTable {
table: "devices".to_owned(),
user: synapse_user_id,
});
};
if user_infos.flags.is_deactivated()
|| user_infos.flags.is_guest()
|| user_infos.flags.is_appservice()
{
continue;
}
let Some(mas_user_id) = user_infos.mas_user_id else {
continue;
};
let session_id = *state
.devices_to_compat_sessions
.entry((mas_user_id, CompactString::new(&device_id)))
.or_insert_with(||
if user_infos.flags.is_deactivated()
|| user_infos.flags.is_guest()
|| user_infos.flags.is_appservice()
{
continue;
}
let session_id = *state
.devices_to_compat_sessions
.entry((mas_user_id, CompactString::new(&device_id)))
.or_insert_with(||
// We don't have a creation time for this device (as it has no access token),
// so use now as a least-evil fallback.
Ulid::with_source(rng).into());
let created_at = Ulid::from(session_id).datetime().into();
Ulid::with_source(&mut rng).into());
let created_at = Ulid::from(session_id).datetime().into();
// As we're using a real IP type in the MAS database, it is possible
// that we encounter invalid IP addresses in the Synapse database.
// In that case, we should ignore them, but still log a warning.
// One special case: Synapse will record '-' as IP in some cases, we don't want
// to log about those
let last_active_ip = ip.filter(|ip| ip != "-").and_then(|ip| {
ip.parse()
.map_err(|e| {
tracing::warn!(
error = &e as &dyn std::error::Error,
mxid = %synapse_user_id,
%device_id,
%ip,
"Failed to parse device IP, ignoring"
);
})
.ok()
});
// As we're using a real IP type in the MAS database, it is possible
// that we encounter invalid IP addresses in the Synapse database.
// In that case, we should ignore them, but still log a warning.
// One special case: Synapse will record '-' as IP in some cases, we don't want
// to log about those
let last_active_ip = ip.filter(|ip| ip != "-").and_then(|ip| {
ip.parse()
.map_err(|e| {
tracing::warn!(
error = &e as &dyn std::error::Error,
mxid = %synapse_user_id,
%device_id,
%ip,
"Failed to parse device IP, ignoring"
);
})
.ok()
});
write_buffer
.write(
&mut mas,
MasNewCompatSession {
session_id,
user_id: mas_user_id,
device_id: Some(device_id),
human_name: display_name,
created_at,
is_synapse_admin: user_infos.flags.is_synapse_admin(),
last_active_at: last_seen.map(DateTime::from),
last_active_ip,
user_agent,
},
)
.await
.into_mas("writing compat sessions")?;
}
// TODO skip access tokens for deactivated users
write_buffer
.write(
&mut mas,
MasNewCompatSession {
session_id,
user_id: mas_user_id,
device_id: Some(device_id),
human_name: display_name,
created_at,
is_synapse_admin: user_infos.flags.is_synapse_admin(),
last_active_at: last_seen.map(DateTime::from),
last_active_ip,
user_agent,
},
)
.await
.into_mas("writing compat sessions")?;
}
write_buffer
.finish(&mut mas)
.await
.into_mas("writing compat sessions")?;
write_buffer
.finish(&mut mas)
.await
.into_mas("writing compat sessions")?;
Ok((mas, state))
}
.instrument(tracing::info_span!("ingest_task")),
);
// In case this has an error, we still want to join the task, so we look at the
// error later
let res = synapse
.read_devices()
.map_err(|e| e.into_synapse("reading devices"))
.forward(PollSender::new(tx).sink_map_err(|_| Error::ChannelClosed))
.inspect_err(|e| tracing::error!(error = e as &dyn std::error::Error))
.await;
let (mas, state) = task.await.into_join("device write task")??;
res?;
info!(
"devices migrated in {:.1}s",
Instant::now().duration_since(start).as_secs_f64()
);
Ok((mas, state))
}
@@ -520,106 +606,136 @@ async fn migrate_unrefreshable_access_tokens(
rng: &mut impl RngCore,
mut state: MigrationState,
) -> Result<(MasWriter, MigrationState), Error> {
let mut token_stream = pin!(synapse.read_unrefreshable_access_tokens());
let mut write_buffer = MasWriteBuffer::new(MasWriter::write_compat_access_tokens);
let mut deviceless_session_write_buffer = MasWriteBuffer::new(MasWriter::write_compat_sessions);
let start = Instant::now();
while let Some(token_res) = token_stream.next().await {
let SynapseAccessToken {
user_id: synapse_user_id,
device_id,
token,
valid_until_ms,
last_validated,
} = token_res.into_synapse("reading Synapse access token")?;
let (tx, mut rx) = tokio::sync::mpsc::channel(10 * 1024 * 1024);
let username = synapse_user_id
.extract_localpart(&state.server_name)
.into_extract_localpart(synapse_user_id.clone())?
.to_owned();
let Some(user_infos) = state.users.get(username.as_str()).copied() else {
return Err(Error::MissingUserFromDependentTable {
table: "access_tokens".to_owned(),
user: synapse_user_id,
});
};
let now = clock.now();
let mut rng = rand_chacha::ChaChaRng::from_rng(rng).expect("failed to seed rng");
let task = tokio::spawn(
async move {
let mut write_buffer = MasWriteBuffer::new(&mas, MasWriter::write_compat_access_tokens);
let mut deviceless_session_write_buffer =
MasWriteBuffer::new(&mas, MasWriter::write_compat_sessions);
let Some(mas_user_id) = user_infos.mas_user_id else {
continue;
};
while let Some(token) = rx.recv().await {
let SynapseAccessToken {
user_id: synapse_user_id,
device_id,
token,
valid_until_ms,
last_validated,
} = token;
let username = synapse_user_id
.extract_localpart(&state.server_name)
.into_extract_localpart(synapse_user_id.clone())?
.to_owned();
let Some(user_infos) = state.users.get(username.as_str()).copied() else {
return Err(Error::MissingUserFromDependentTable {
table: "access_tokens".to_owned(),
user: synapse_user_id,
});
};
if user_infos.flags.is_deactivated()
|| user_infos.flags.is_guest()
|| user_infos.flags.is_appservice()
{
continue;
}
let Some(mas_user_id) = user_infos.mas_user_id else {
continue;
};
// It's not always accurate, but last_validated is *often* the creation time of
// the device If we don't have one, then use the current time as a
// fallback.
let created_at = last_validated.map_or_else(|| clock.now(), DateTime::from);
if user_infos.flags.is_deactivated()
|| user_infos.flags.is_guest()
|| user_infos.flags.is_appservice()
{
continue;
}
let session_id = if let Some(device_id) = device_id {
// Use the existing device_id if this is the second token for a device
*state
.devices_to_compat_sessions
.entry((mas_user_id, CompactString::new(&device_id)))
.or_insert_with(|| {
Uuid::from(Ulid::from_datetime_with_source(created_at.into(), rng))
})
} else {
// If this is a deviceless access token, create a deviceless compat session
// for it (since otherwise we won't create one whilst migrating devices)
let deviceless_session_id =
Uuid::from(Ulid::from_datetime_with_source(created_at.into(), rng));
// It's not always accurate, but last_validated is *often* the creation time of
// the device If we don't have one, then use the current time as a
// fallback.
let created_at = last_validated.map_or_else(|| now, DateTime::from);
deviceless_session_write_buffer
.write(
&mut mas,
MasNewCompatSession {
session_id: deviceless_session_id,
user_id: mas_user_id,
device_id: None,
human_name: None,
created_at,
is_synapse_admin: false,
last_active_at: None,
last_active_ip: None,
user_agent: None,
},
)
let session_id = if let Some(device_id) = device_id {
// Use the existing device_id if this is the second token for a device
*state
.devices_to_compat_sessions
.entry((mas_user_id, CompactString::new(&device_id)))
.or_insert_with(|| {
Uuid::from(Ulid::from_datetime_with_source(created_at.into(), &mut rng))
})
} else {
// If this is a deviceless access token, create a deviceless compat session
// for it (since otherwise we won't create one whilst migrating devices)
let deviceless_session_id =
Uuid::from(Ulid::from_datetime_with_source(created_at.into(), &mut rng));
deviceless_session_write_buffer
.write(
&mut mas,
MasNewCompatSession {
session_id: deviceless_session_id,
user_id: mas_user_id,
device_id: None,
human_name: None,
created_at,
is_synapse_admin: false,
last_active_at: None,
last_active_ip: None,
user_agent: None,
},
)
.await
.into_mas("failed to write deviceless compat sessions")?;
deviceless_session_id
};
let token_id =
Uuid::from(Ulid::from_datetime_with_source(created_at.into(), &mut rng));
write_buffer
.write(
&mut mas,
MasNewCompatAccessToken {
token_id,
session_id,
access_token: token,
created_at,
expires_at: valid_until_ms.map(DateTime::from),
},
)
.await
.into_mas("writing compat access tokens")?;
}
write_buffer
.finish(&mut mas)
.await
.into_mas("failed to write deviceless compat sessions")?;
.into_mas("writing compat access tokens")?;
deviceless_session_write_buffer
.finish(&mut mas)
.await
.into_mas("writing deviceless compat sessions")?;
deviceless_session_id
};
Ok((mas, state))
}
.instrument(tracing::info_span!("ingest_task")),
);
let token_id = Uuid::from(Ulid::from_datetime_with_source(created_at.into(), rng));
// In case this has an error, we still want to join the task, so we look at the
// error later
let res = synapse
.read_unrefreshable_access_tokens()
.map_err(|e| e.into_synapse("reading tokens"))
.forward(PollSender::new(tx).sink_map_err(|_| Error::ChannelClosed))
.inspect_err(|e| tracing::error!(error = e as &dyn std::error::Error))
.await;
write_buffer
.write(
&mut mas,
MasNewCompatAccessToken {
token_id,
session_id,
access_token: token,
created_at,
expires_at: valid_until_ms.map(DateTime::from),
},
)
.await
.into_mas("writing compat access tokens")?;
}
let (mas, state) = task.await.into_join("token write task")??;
write_buffer
.finish(&mut mas)
.await
.into_mas("writing compat access tokens")?;
deviceless_session_write_buffer
.finish(&mut mas)
.await
.into_mas("writing deviceless compat sessions")?;
res?;
info!(
"non-refreshable access tokens migrated in {:.1}s",
Instant::now().duration_since(start).as_secs_f64()
);
Ok((mas, state))
}
@@ -634,10 +750,13 @@ async fn migrate_refreshable_token_pairs(
rng: &mut impl RngCore,
mut state: MigrationState,
) -> Result<(MasWriter, MigrationState), Error> {
let start = Instant::now();
let mut token_stream = pin!(synapse.read_refreshable_token_pairs());
let mut access_token_write_buffer = MasWriteBuffer::new(MasWriter::write_compat_access_tokens);
let mut access_token_write_buffer =
MasWriteBuffer::new(&mas, MasWriter::write_compat_access_tokens);
let mut refresh_token_write_buffer =
MasWriteBuffer::new(MasWriter::write_compat_refresh_tokens);
MasWriteBuffer::new(&mas, MasWriter::write_compat_refresh_tokens);
while let Some(token_res) = token_stream.next().await {
let SynapseRefreshableTokenPair {
@@ -723,6 +842,11 @@ async fn migrate_refreshable_token_pairs(
.await
.into_mas("writing compat refresh tokens")?;
info!(
"refreshable token pairs migrated in {:.1}s",
Instant::now().duration_since(start).as_secs_f64()
);
Ok((mas, state))
}

View File

@@ -48,21 +48,11 @@ pub enum CheckError {
)]
PasswordSchemeWrongPepper,
#[error(
"Synapse database contains {num_guests} guests which aren't supported by MAS. See https://github.com/element-hq/matrix-authentication-service/issues/1445"
)]
GuestsInDatabase { num_guests: i64 },
#[error(
"Guest support is enabled in the Synapse configuration. Guests aren't supported by MAS, but if you don't have any then you could disable the option. See https://github.com/element-hq/matrix-authentication-service/issues/1445"
)]
GuestsEnabled,
#[error(
"Synapse database contains {num_non_email_3pids} non-email 3PIDs (probably phone numbers), which are not supported by MAS."
)]
NonEmailThreepidsInDatabase { num_non_email_3pids: i64 },
#[error(
"Synapse config has `enable_3pid_changes` explicitly enabled, which must be disabled or removed."
)]
@@ -125,6 +115,16 @@ pub enum CheckWarning {
"Synapse config has a registration CAPTCHA enabled, but no CAPTCHA has been configured in MAS. You may wish to manually configure this."
)]
ShouldPortRegistrationCaptcha,
#[error(
"Synapse database contains {num_guests} guests which will be migrated are not supported by MAS. See https://github.com/element-hq/matrix-authentication-service/issues/1445"
)]
GuestsInDatabase { num_guests: i64 },
#[error(
"Synapse database contains {num_non_email_3pids} non-email 3PIDs (probably phone numbers), which will be migrated but are not supported by MAS."
)]
NonEmailThreepidsInDatabase { num_non_email_3pids: i64 },
}
/// Check that the Synapse configuration is sane for migration.
@@ -140,15 +140,6 @@ pub fn synapse_config_check(synapse_config: &Config) -> (Vec<CheckWarning>, Vec<
warnings.push(CheckWarning::DisableUserConsentAfterMigration);
}
// TODO check the settings directly against the MAS settings
for provider in synapse_config.all_oidc_providers().values() {
if let Some(ref issuer) = provider.issuer {
warnings.push(CheckWarning::UpstreamOidcProvider {
issuer: issuer.clone(),
});
}
}
// TODO provide guidance on migrating these
if synapse_config.cas_config.enabled {
warnings.push(CheckWarning::ExternalAuthSystem("CAS"));
@@ -269,13 +260,13 @@ pub async fn synapse_database_check(
}
let mut errors = Vec::new();
let warnings = Vec::new();
let mut warnings = Vec::new();
let num_guests: i64 = query_scalar("SELECT COUNT(1) FROM users WHERE is_guest <> 0")
.fetch_one(&mut *synapse_connection)
.await?;
if num_guests > 0 {
errors.push(CheckError::GuestsInDatabase { num_guests });
warnings.push(CheckWarning::GuestsInDatabase { num_guests });
}
let num_non_email_3pids: i64 =
@@ -283,7 +274,7 @@ pub async fn synapse_database_check(
.fetch_one(&mut *synapse_connection)
.await?;
if num_non_email_3pids > 0 {
errors.push(CheckError::NonEmailThreepidsInDatabase {
warnings.push(CheckWarning::NonEmailThreepidsInDatabase {
num_non_email_3pids,
});
}

View File

@@ -338,28 +338,31 @@ impl<'conn> SynapseReader<'conn> {
///
/// - An underlying database error
pub async fn count_rows(&mut self) -> Result<SynapseRowCounts, Error> {
let users: usize = sqlx::query_scalar::<_, i64>(
// We don't get to filter out application service users by using this estimate,
// which is a shame, but on a large database this is way faster.
// On matrix.org, counting users and devices properly takes around 1m10s,
// which is unnecessary extra downtime during the migration, just to
// show a more accurate progress bar and size a hash map accurately.
let users = sqlx::query_scalar::<_, i64>(
"
SELECT COUNT(1) FROM users
WHERE appservice_id IS NULL
SELECT reltuples::bigint AS estimate FROM pg_class WHERE oid = 'users'::regclass;
",
)
.fetch_one(&mut *self.txn)
.await
.into_database("counting Synapse users")?
.into_database("estimating count of users")?
.max(0)
.try_into()
.unwrap_or(usize::MAX);
let devices = sqlx::query_scalar::<_, i64>(
"
SELECT COUNT(1) FROM devices
WHERE NOT hidden
SELECT reltuples::bigint AS estimate FROM pg_class WHERE oid = 'devices'::regclass;
",
)
.fetch_one(&mut *self.txn)
.await
.into_database("counting Synapse devices")?
.into_database("estimating count of devices")?
.max(0)
.try_into()
.unwrap_or(usize::MAX);
@@ -429,6 +432,12 @@ impl<'conn> SynapseReader<'conn> {
/// Reads unrefreshable access tokens from the Synapse database.
/// This does not include access tokens used for puppetting users, as those
/// are not supported by MAS.
///
/// This also excludes access tokens whose referenced device ID does not
/// exist, except for deviceless access tokens.
/// (It's unclear what mechanism led to these, but since Synapse has no
/// foreign key constraints and is not consistently atomic about this,
/// it should be no surprise really)
pub fn read_unrefreshable_access_tokens(
&mut self,
) -> impl Stream<Item = Result<SynapseAccessToken, Error>> + '_ {
@@ -437,7 +446,15 @@ impl<'conn> SynapseReader<'conn> {
SELECT
at0.user_id, at0.device_id, at0.token, at0.valid_until_ms, at0.last_validated
FROM access_tokens at0
INNER JOIN devices USING (user_id, device_id)
WHERE at0.puppets_user_id IS NULL AND at0.refresh_token_id IS NULL
UNION
SELECT
at0.user_id, at0.device_id, at0.token, at0.valid_until_ms, at0.last_validated
FROM access_tokens at0
WHERE at0.puppets_user_id IS NULL AND at0.refresh_token_id IS NULL AND at0.device_id IS NULL
",
)
.fetch(&mut *self.txn)
@@ -461,7 +478,8 @@ impl<'conn> SynapseReader<'conn> {
SELECT
rt0.user_id, rt0.device_id, at0.token AS access_token, rt0.token AS refresh_token, at0.valid_until_ms, at0.last_validated
FROM refresh_tokens rt0
LEFT JOIN access_tokens at0 ON at0.refresh_token_id = rt0.id AND at0.user_id = rt0.user_id AND at0.device_id = rt0.device_id
INNER JOIN devices USING (device_id)
INNER JOIN access_tokens at0 ON at0.refresh_token_id = rt0.id AND at0.user_id = rt0.user_id AND at0.device_id = rt0.device_id
LEFT JOIN access_tokens at1 ON at1.refresh_token_id = rt0.next_token_id
WHERE NOT at1.used OR at1.used IS NULL
",
@@ -554,7 +572,10 @@ mod test {
assert_debug_snapshot!(devices);
}
#[sqlx::test(migrator = "MIGRATOR", fixtures("user_alice", "access_token_alice"))]
#[sqlx::test(
migrator = "MIGRATOR",
fixtures("user_alice", "devices_alice", "access_token_alice")
)]
async fn test_read_access_token(pool: PgPool) {
let mut conn = pool.acquire().await.expect("failed to get connection");
let mut reader = SynapseReader::new(&mut conn, false)
@@ -573,7 +594,7 @@ mod test {
/// Tests that puppetting access tokens are ignored.
#[sqlx::test(
migrator = "MIGRATOR",
fixtures("user_alice", "access_token_alice_with_puppet")
fixtures("user_alice", "devices_alice", "access_token_alice_with_puppet")
)]
async fn test_read_access_token_puppet(pool: PgPool) {
let mut conn = pool.acquire().await.expect("failed to get connection");
@@ -592,7 +613,7 @@ mod test {
#[sqlx::test(
migrator = "MIGRATOR",
fixtures("user_alice", "access_token_alice_with_refresh_token")
fixtures("user_alice", "devices_alice", "access_token_alice_with_refresh_token")
)]
async fn test_read_access_and_refresh_tokens(pool: PgPool) {
let mut conn = pool.acquire().await.expect("failed to get connection");
@@ -621,7 +642,11 @@ mod test {
#[sqlx::test(
migrator = "MIGRATOR",
fixtures("user_alice", "access_token_alice_with_unused_refresh_token")
fixtures(
"user_alice",
"devices_alice",
"access_token_alice_with_unused_refresh_token"
)
)]
async fn test_read_access_and_unused_refresh_tokens(pool: PgPool) {
let mut conn = pool.acquire().await.expect("failed to get connection");