Encapsulate migration state in a single structure (#3991)

This commit is contained in:
Quentin Gliech
2025-02-10 10:22:45 +01:00
committed by GitHub
7 changed files with 148 additions and 157 deletions

1
Cargo.lock generated
View File

@@ -6044,6 +6044,7 @@ name = "syn2mas"
version = "0.13.0"
dependencies = [
"anyhow",
"bitflags",
"camino",
"chrono",
"compact_str",

View File

@@ -91,6 +91,10 @@ features = ["cookie-private", "cookie-key-expansion", "typed-header"]
[workspace.dependencies.base64ct]
version = "1.6.0"
# Packed bitfields
[workspace.dependencies.bitflags]
version = "2.6.0"
# Bytes
[workspace.dependencies.bytes]
version = "1.10.0"

View File

@@ -233,10 +233,10 @@ impl Options {
syn2mas::migrate(
&mut reader,
&mut writer,
&mas_matrix.homeserver,
mas_matrix.homeserver,
&clock,
&mut rng,
&provider_id_mappings,
provider_id_mappings,
)
.await?;

View File

@@ -39,7 +39,7 @@ oauth2-types.workspace = true
[dev-dependencies]
assert_matches = "1.5.0"
bitflags = "2.8.0"
bitflags.workspace = true
rand_chacha = "0.3.1"
tokio.workspace = true
wiremock.workspace = true

View File

@@ -11,6 +11,7 @@ repository.workspace = true
[dependencies]
anyhow.workspace = true
bitflags.workspace = true
camino.workspace = true
figment.workspace = true
serde.workspace = true

View File

@@ -11,10 +11,7 @@
//! This module does not implement any of the safety checks that should be run
//! *before* the migration.
use std::{
collections::{HashMap, HashSet},
pin::pin,
};
use std::{collections::HashMap, pin::pin};
use chrono::{DateTime, Utc};
use compact_str::CompactString;
@@ -69,12 +66,48 @@ pub enum Error {
},
}
struct UsersMigrated {
/// Lookup table from user localpart to that user's UUID in MAS.
user_localparts_to_uuid: HashMap<CompactString, Uuid>,
bitflags::bitflags! {
#[derive(Debug, Clone, Copy)]
struct UserFlags: u8 {
const IS_SYNAPSE_ADMIN = 0b0000_0001;
const IS_DEACTIVATED = 0b0000_0010;
const IS_GUEST = 0b0000_0100;
}
}
/// Set of user UUIDs that correspond to Synapse admins
synapse_admins: HashSet<Uuid>,
impl UserFlags {
const fn is_deactivated(self) -> bool {
self.contains(UserFlags::IS_DEACTIVATED)
}
const fn is_guest(self) -> bool {
self.contains(UserFlags::IS_GUEST)
}
const fn is_synapse_admin(self) -> bool {
self.contains(UserFlags::IS_SYNAPSE_ADMIN)
}
}
#[derive(Debug, Clone, Copy)]
struct UserInfo {
mas_user_id: Uuid,
flags: UserFlags,
}
struct MigrationState {
/// The server name we're migrating from
server_name: String,
/// Lookup table from user localpart to that user's infos
users: HashMap<CompactString, UserInfo>,
/// Mapping of MAS user ID + device ID to a MAS compat session ID.
devices_to_compat_sessions: HashMap<(Uuid, CompactString), Uuid>,
/// A mapping of Synapse external ID providers to MAS upstream OAuth 2.0
/// provider ID
provider_id_mapping: HashMap<String, Uuid>,
}
/// Performs a migration from Synapse's database to MAS' database.
@@ -93,85 +126,26 @@ struct UsersMigrated {
pub async fn migrate(
synapse: &mut SynapseReader<'_>,
mas: &mut MasWriter,
server_name: &str,
server_name: String,
clock: &dyn Clock,
rng: &mut impl RngCore,
provider_id_mapping: &HashMap<String, Uuid>,
provider_id_mapping: HashMap<String, Uuid>,
) -> Result<(), Error> {
let counts = synapse.count_rows().await.into_synapse("counting users")?;
let migrated_users = migrate_users(
synapse,
mas,
counts
.users
.try_into()
.expect("More than usize::MAX users — unable to handle this many!"),
let state = MigrationState {
server_name,
rng,
)
.await?;
migrate_threepids(
synapse,
mas,
server_name,
rng,
&migrated_users.user_localparts_to_uuid,
)
.await?;
migrate_external_ids(
synapse,
mas,
server_name,
rng,
&migrated_users.user_localparts_to_uuid,
users: HashMap::with_capacity(counts.users),
devices_to_compat_sessions: HashMap::with_capacity(counts.devices),
provider_id_mapping,
)
.await?;
};
// `(MAS user_id, device_id)` mapped to `compat_session` ULID
let mut devices_to_compat_sessions: HashMap<(Uuid, CompactString), Uuid> =
HashMap::with_capacity(
counts
.devices
.try_into()
.expect("More than usize::MAX devices — unable to handle this many!"),
);
migrate_unrefreshable_access_tokens(
synapse,
mas,
server_name,
clock,
rng,
&migrated_users.user_localparts_to_uuid,
&mut devices_to_compat_sessions,
)
.await?;
migrate_refreshable_token_pairs(
synapse,
mas,
server_name,
clock,
rng,
&migrated_users.user_localparts_to_uuid,
&mut devices_to_compat_sessions,
)
.await?;
migrate_devices(
synapse,
mas,
server_name,
rng,
&migrated_users.user_localparts_to_uuid,
&mut devices_to_compat_sessions,
&migrated_users.synapse_admins,
)
.await?;
let state = migrate_users(synapse, mas, state, rng).await?;
let state = migrate_threepids(synapse, mas, rng, state).await?;
let state = migrate_external_ids(synapse, mas, rng, state).await?;
let state = migrate_unrefreshable_access_tokens(synapse, mas, clock, rng, state).await?;
let state = migrate_refreshable_token_pairs(synapse, mas, clock, rng, state).await?;
let _state = migrate_devices(synapse, mas, rng, state).await?;
Ok(())
}
@@ -180,29 +154,35 @@ pub async fn migrate(
async fn migrate_users(
synapse: &mut SynapseReader<'_>,
mas: &mut MasWriter,
user_count_hint: usize,
server_name: &str,
mut state: MigrationState,
rng: &mut impl RngCore,
) -> Result<UsersMigrated, Error> {
) -> Result<MigrationState, Error> {
let mut user_buffer = MasWriteBuffer::new(MasWriter::write_users);
let mut password_buffer = MasWriteBuffer::new(MasWriter::write_passwords);
let mut users_stream = pin!(synapse.read_users());
// TODO is 1:1 capacity enough for a hashmap?
let mut user_localparts_to_uuid = HashMap::with_capacity(user_count_hint);
let mut synapse_admins = HashSet::new();
while let Some(user_res) = users_stream.next().await {
let user = user_res.into_synapse("reading user")?;
let (mas_user, mas_password_opt) = transform_user(&user, server_name, rng)?;
let (mas_user, mas_password_opt) = transform_user(&user, &state.server_name, rng)?;
let mut flags = UserFlags::empty();
if bool::from(user.admin) {
// Note down the fact that this user is a Synapse admin,
// because we will grant their existing devices the Synapse admin
// flag
synapse_admins.insert(mas_user.user_id);
flags |= UserFlags::IS_SYNAPSE_ADMIN;
}
if bool::from(user.deactivated) {
flags |= UserFlags::IS_DEACTIVATED;
}
if bool::from(user.is_guest) {
flags |= UserFlags::IS_GUEST;
}
user_localparts_to_uuid.insert(CompactString::new(&mas_user.username), mas_user.user_id);
state.users.insert(
CompactString::new(&mas_user.username),
UserInfo {
mas_user_id: mas_user.user_id,
flags,
},
);
user_buffer
.write(mas, mas_user)
@@ -223,20 +203,16 @@ async fn migrate_users(
.await
.into_mas("writing passwords")?;
Ok(UsersMigrated {
user_localparts_to_uuid,
synapse_admins,
})
Ok(state)
}
#[tracing::instrument(skip_all, level = Level::INFO)]
async fn migrate_threepids(
synapse: &mut SynapseReader<'_>,
mas: &mut MasWriter,
server_name: &str,
rng: &mut impl RngCore,
user_localparts_to_uuid: &HashMap<CompactString, Uuid>,
) -> Result<(), Error> {
state: MigrationState,
) -> Result<MigrationState, Error> {
let mut email_buffer = MasWriteBuffer::new(MasWriter::write_email_threepids);
let mut unsupported_buffer = MasWriteBuffer::new(MasWriter::write_unsupported_threepids);
let mut users_stream = pin!(synapse.read_threepids());
@@ -251,10 +227,10 @@ async fn migrate_threepids(
let created_at: DateTime<Utc> = added_at.into();
let username = synapse_user_id
.extract_localpart(server_name)
.extract_localpart(&state.server_name)
.into_extract_localpart(synapse_user_id.clone())?
.to_owned();
let Some(user_id) = user_localparts_to_uuid.get(username.as_str()).copied() else {
let Some(user_infos) = state.users.get(username.as_str()).copied() else {
if is_likely_appservice(&username) {
continue;
}
@@ -269,7 +245,7 @@ async fn migrate_threepids(
.write(
mas,
MasNewEmailThreepid {
user_id,
user_id: user_infos.mas_user_id,
user_email_id: Uuid::from(Ulid::from_datetime_with_source(
created_at.into(),
rng,
@@ -285,7 +261,7 @@ async fn migrate_threepids(
.write(
mas,
MasNewUnsupportedThreepid {
user_id,
user_id: user_infos.mas_user_id,
medium,
address,
created_at,
@@ -305,7 +281,7 @@ async fn migrate_threepids(
.await
.into_mas("writing unsupported threepids")?;
Ok(())
Ok(state)
}
/// # Parameters
@@ -316,11 +292,9 @@ async fn migrate_threepids(
async fn migrate_external_ids(
synapse: &mut SynapseReader<'_>,
mas: &mut MasWriter,
server_name: &str,
rng: &mut impl RngCore,
user_localparts_to_uuid: &HashMap<CompactString, Uuid>,
provider_id_mapping: &HashMap<String, Uuid>,
) -> Result<(), Error> {
state: MigrationState,
) -> Result<MigrationState, Error> {
let mut write_buffer = MasWriteBuffer::new(MasWriter::write_upstream_oauth_links);
let mut extids_stream = pin!(synapse.read_user_external_ids());
@@ -331,10 +305,10 @@ async fn migrate_external_ids(
external_id: subject,
} = extid_res.into_synapse("reading external ID")?;
let username = synapse_user_id
.extract_localpart(server_name)
.extract_localpart(&state.server_name)
.into_extract_localpart(synapse_user_id.clone())?
.to_owned();
let Some(user_id) = user_localparts_to_uuid.get(username.as_str()).copied() else {
let Some(user_infos) = state.users.get(username.as_str()).copied() else {
if is_likely_appservice(&username) {
continue;
}
@@ -344,7 +318,7 @@ async fn migrate_external_ids(
});
};
let Some(&upstream_provider_id) = provider_id_mapping.get(&auth_provider) else {
let Some(&upstream_provider_id) = state.provider_id_mapping.get(&auth_provider) else {
return Err(Error::MissingAuthProviderMapping {
synapse_id: auth_provider,
user: synapse_user_id,
@@ -353,7 +327,7 @@ async fn migrate_external_ids(
// To save having to store user creation times, extract it from the ULID
// This gives millisecond precision — good enough.
let user_created_ts = Ulid::from(user_id).datetime();
let user_created_ts = Ulid::from(user_infos.mas_user_id).datetime();
let link_id: Uuid = Ulid::from_datetime_with_source(user_created_ts, rng).into();
@@ -362,7 +336,7 @@ async fn migrate_external_ids(
mas,
MasNewUpstreamOauthLink {
link_id,
user_id,
user_id: user_infos.mas_user_id,
upstream_provider_id,
subject,
created_at: user_created_ts.into(),
@@ -377,7 +351,7 @@ async fn migrate_external_ids(
.await
.into_mas("writing threepids")?;
Ok(())
Ok(state)
}
/// Migrate devices from Synapse to MAS (as compat sessions).
@@ -392,12 +366,9 @@ async fn migrate_external_ids(
async fn migrate_devices(
synapse: &mut SynapseReader<'_>,
mas: &mut MasWriter,
server_name: &str,
rng: &mut impl RngCore,
user_localparts_to_uuid: &HashMap<CompactString, Uuid>,
devices: &mut HashMap<(Uuid, CompactString), Uuid>,
synapse_admins: &HashSet<Uuid>,
) -> Result<(), Error> {
mut state: MigrationState,
) -> Result<MigrationState, Error> {
let mut devices_stream = pin!(synapse.read_devices());
let mut write_buffer = MasWriteBuffer::new(MasWriter::write_compat_sessions);
@@ -412,10 +383,10 @@ async fn migrate_devices(
} = device_res.into_synapse("reading Synapse device")?;
let username = synapse_user_id
.extract_localpart(server_name)
.extract_localpart(&state.server_name)
.into_extract_localpart(synapse_user_id.clone())?
.to_owned();
let Some(user_id) = user_localparts_to_uuid.get(username.as_str()).copied() else {
let Some(user_infos) = state.users.get(username.as_str()).copied() else {
if is_likely_appservice(&username) {
continue;
}
@@ -425,8 +396,13 @@ async fn migrate_devices(
});
};
let session_id = *devices
.entry((user_id, CompactString::new(&device_id)))
if user_infos.flags.is_deactivated() || user_infos.flags.is_guest() {
continue;
}
let session_id = *state
.devices_to_compat_sessions
.entry((user_infos.mas_user_id, CompactString::new(&device_id)))
.or_insert_with(||
// We don't have a creation time for this device (as it has no access token),
// so use now as a least-evil fallback.
@@ -450,17 +426,16 @@ async fn migrate_devices(
.ok()
});
// TODO skip access tokens for deactivated users
write_buffer
.write(
mas,
MasNewCompatSession {
session_id,
user_id,
user_id: user_infos.mas_user_id,
device_id: Some(device_id),
human_name: display_name,
created_at,
is_synapse_admin: synapse_admins.contains(&user_id),
is_synapse_admin: user_infos.flags.is_synapse_admin(),
last_active_at: last_seen.map(DateTime::from),
last_active_ip,
user_agent,
@@ -475,7 +450,7 @@ async fn migrate_devices(
.await
.into_mas("writing compat sessions")?;
Ok(())
Ok(state)
}
/// Migrates unrefreshable access tokens (those without an associated refresh
@@ -484,12 +459,10 @@ async fn migrate_devices(
async fn migrate_unrefreshable_access_tokens(
synapse: &mut SynapseReader<'_>,
mas: &mut MasWriter,
server_name: &str,
clock: &dyn Clock,
rng: &mut impl RngCore,
user_localparts_to_uuid: &HashMap<CompactString, Uuid>,
devices: &mut HashMap<(Uuid, CompactString), Uuid>,
) -> Result<(), Error> {
mut state: MigrationState,
) -> Result<MigrationState, Error> {
let mut token_stream = pin!(synapse.read_unrefreshable_access_tokens());
let mut write_buffer = MasWriteBuffer::new(MasWriter::write_compat_access_tokens);
let mut deviceless_session_write_buffer = MasWriteBuffer::new(MasWriter::write_compat_sessions);
@@ -504,10 +477,10 @@ async fn migrate_unrefreshable_access_tokens(
} = token_res.into_synapse("reading Synapse access token")?;
let username = synapse_user_id
.extract_localpart(server_name)
.extract_localpart(&state.server_name)
.into_extract_localpart(synapse_user_id.clone())?
.to_owned();
let Some(user_id) = user_localparts_to_uuid.get(username.as_str()).copied() else {
let Some(user_infos) = state.users.get(username.as_str()).copied() else {
if is_likely_appservice(&username) {
continue;
}
@@ -517,6 +490,10 @@ async fn migrate_unrefreshable_access_tokens(
});
};
if user_infos.flags.is_deactivated() || user_infos.flags.is_guest() {
continue;
}
// It's not always accurate, but last_validated is *often* the creation time of
// the device If we don't have one, then use the current time as a
// fallback.
@@ -524,8 +501,9 @@ async fn migrate_unrefreshable_access_tokens(
let session_id = if let Some(device_id) = device_id {
// Use the existing device_id if this is the second token for a device
*devices
.entry((user_id, CompactString::new(&device_id)))
*state
.devices_to_compat_sessions
.entry((user_infos.mas_user_id, CompactString::new(&device_id)))
.or_insert_with(|| {
Uuid::from(Ulid::from_datetime_with_source(created_at.into(), rng))
})
@@ -540,7 +518,7 @@ async fn migrate_unrefreshable_access_tokens(
mas,
MasNewCompatSession {
session_id: deviceless_session_id,
user_id,
user_id: user_infos.mas_user_id,
device_id: None,
human_name: None,
created_at,
@@ -558,7 +536,6 @@ async fn migrate_unrefreshable_access_tokens(
let token_id = Uuid::from(Ulid::from_datetime_with_source(created_at.into(), rng));
// TODO skip access tokens for deactivated users
write_buffer
.write(
mas,
@@ -583,7 +560,7 @@ async fn migrate_unrefreshable_access_tokens(
.await
.into_mas("writing deviceless compat sessions")?;
Ok(())
Ok(state)
}
/// Migrates (access token, refresh token) pairs.
@@ -592,12 +569,10 @@ async fn migrate_unrefreshable_access_tokens(
async fn migrate_refreshable_token_pairs(
synapse: &mut SynapseReader<'_>,
mas: &mut MasWriter,
server_name: &str,
clock: &dyn Clock,
rng: &mut impl RngCore,
user_localparts_to_uuid: &HashMap<CompactString, Uuid>,
devices: &mut HashMap<(Uuid, CompactString), Uuid>,
) -> Result<(), Error> {
mut state: MigrationState,
) -> Result<MigrationState, Error> {
let mut token_stream = pin!(synapse.read_refreshable_token_pairs());
let mut access_token_write_buffer = MasWriteBuffer::new(MasWriter::write_compat_access_tokens);
let mut refresh_token_write_buffer =
@@ -614,10 +589,10 @@ async fn migrate_refreshable_token_pairs(
} = token_res.into_synapse("reading Synapse refresh token")?;
let username = synapse_user_id
.extract_localpart(server_name)
.extract_localpart(&state.server_name)
.into_extract_localpart(synapse_user_id.clone())?
.to_owned();
let Some(user_id) = user_localparts_to_uuid.get(username.as_str()).copied() else {
let Some(user_infos) = state.users.get(username.as_str()).copied() else {
if is_likely_appservice(&username) {
continue;
}
@@ -627,20 +602,24 @@ async fn migrate_refreshable_token_pairs(
});
};
if user_infos.flags.is_deactivated() || user_infos.flags.is_guest() {
continue;
}
// It's not always accurate, but last_validated is *often* the creation time of
// the device If we don't have one, then use the current time as a
// fallback.
let created_at = last_validated.map_or_else(|| clock.now(), DateTime::from);
// Use the existing device_id if this is the second token for a device
let session_id = *devices
.entry((user_id, CompactString::new(&device_id)))
let session_id = *state
.devices_to_compat_sessions
.entry((user_infos.mas_user_id, CompactString::new(&device_id)))
.or_insert_with(|| Uuid::from(Ulid::from_datetime_with_source(created_at.into(), rng)));
let access_token_id = Uuid::from(Ulid::from_datetime_with_source(created_at.into(), rng));
let refresh_token_id = Uuid::from(Ulid::from_datetime_with_source(created_at.into(), rng));
// TODO skip access tokens for deactivated users
access_token_write_buffer
.write(
mas,
@@ -679,7 +658,7 @@ async fn migrate_refreshable_token_pairs(
.await
.into_mas("writing compat refresh tokens")?;
Ok(())
Ok(state)
}
fn transform_user(

View File

@@ -262,8 +262,8 @@ const TABLES_TO_LOCK: &[&str] = &[
/// Used to estimate progress.
#[derive(Clone, Debug)]
pub struct SynapseRowCounts {
pub users: i64,
pub devices: i64,
pub users: usize,
pub devices: usize,
}
pub struct SynapseReader<'c> {
@@ -334,7 +334,7 @@ impl<'conn> SynapseReader<'conn> {
///
/// - An underlying database error
pub async fn count_rows(&mut self) -> Result<SynapseRowCounts, Error> {
let users: i64 = sqlx::query_scalar(
let users: usize = sqlx::query_scalar::<_, i64>(
"
SELECT COUNT(1) FROM users
WHERE appservice_id IS NULL
@@ -342,9 +342,12 @@ impl<'conn> SynapseReader<'conn> {
)
.fetch_one(&mut *self.txn)
.await
.into_database("counting Synapse users")?;
.into_database("counting Synapse users")?
.max(0)
.try_into()
.unwrap_or(usize::MAX);
let devices = sqlx::query_scalar(
let devices = sqlx::query_scalar::<_, i64>(
"
SELECT COUNT(1) FROM devices
WHERE NOT hidden
@@ -352,7 +355,10 @@ impl<'conn> SynapseReader<'conn> {
)
.fetch_one(&mut *self.txn)
.await
.into_database("counting Synapse devices")?;
.into_database("counting Synapse devices")?
.max(0)
.try_into()
.unwrap_or(usize::MAX);
Ok(SynapseRowCounts { users, devices })
}