diff --git a/Cargo.lock b/Cargo.lock index 120a891ee..677b015d1 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -6044,6 +6044,7 @@ name = "syn2mas" version = "0.13.0" dependencies = [ "anyhow", + "bitflags", "camino", "chrono", "compact_str", diff --git a/Cargo.toml b/Cargo.toml index 3459b7674..75b1ef2c8 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -91,6 +91,10 @@ features = ["cookie-private", "cookie-key-expansion", "typed-header"] [workspace.dependencies.base64ct] version = "1.6.0" +# Packed bitfields +[workspace.dependencies.bitflags] +version = "2.6.0" + # Bytes [workspace.dependencies.bytes] version = "1.10.0" diff --git a/crates/cli/src/commands/syn2mas.rs b/crates/cli/src/commands/syn2mas.rs index e6ab68759..fc145536c 100644 --- a/crates/cli/src/commands/syn2mas.rs +++ b/crates/cli/src/commands/syn2mas.rs @@ -233,10 +233,10 @@ impl Options { syn2mas::migrate( &mut reader, &mut writer, - &mas_matrix.homeserver, + mas_matrix.homeserver, &clock, &mut rng, - &provider_id_mappings, + provider_id_mappings, ) .await?; diff --git a/crates/oidc-client/Cargo.toml b/crates/oidc-client/Cargo.toml index 43ec23714..d738aa669 100644 --- a/crates/oidc-client/Cargo.toml +++ b/crates/oidc-client/Cargo.toml @@ -39,7 +39,7 @@ oauth2-types.workspace = true [dev-dependencies] assert_matches = "1.5.0" -bitflags = "2.8.0" +bitflags.workspace = true rand_chacha = "0.3.1" tokio.workspace = true wiremock.workspace = true diff --git a/crates/syn2mas/Cargo.toml b/crates/syn2mas/Cargo.toml index 40ec05a3e..aef5bdaea 100644 --- a/crates/syn2mas/Cargo.toml +++ b/crates/syn2mas/Cargo.toml @@ -11,6 +11,7 @@ repository.workspace = true [dependencies] anyhow.workspace = true +bitflags.workspace = true camino.workspace = true figment.workspace = true serde.workspace = true diff --git a/crates/syn2mas/src/migration.rs b/crates/syn2mas/src/migration.rs index 18d6746be..36dda3a56 100644 --- a/crates/syn2mas/src/migration.rs +++ b/crates/syn2mas/src/migration.rs @@ -11,10 +11,7 @@ //! This module does not implement any of the safety checks that should be run //! *before* the migration. -use std::{ - collections::{HashMap, HashSet}, - pin::pin, -}; +use std::{collections::HashMap, pin::pin}; use chrono::{DateTime, Utc}; use compact_str::CompactString; @@ -69,12 +66,48 @@ pub enum Error { }, } -struct UsersMigrated { - /// Lookup table from user localpart to that user's UUID in MAS. - user_localparts_to_uuid: HashMap, +bitflags::bitflags! { + #[derive(Debug, Clone, Copy)] + struct UserFlags: u8 { + const IS_SYNAPSE_ADMIN = 0b0000_0001; + const IS_DEACTIVATED = 0b0000_0010; + const IS_GUEST = 0b0000_0100; + } +} - /// Set of user UUIDs that correspond to Synapse admins - synapse_admins: HashSet, +impl UserFlags { + const fn is_deactivated(self) -> bool { + self.contains(UserFlags::IS_DEACTIVATED) + } + + const fn is_guest(self) -> bool { + self.contains(UserFlags::IS_GUEST) + } + + const fn is_synapse_admin(self) -> bool { + self.contains(UserFlags::IS_SYNAPSE_ADMIN) + } +} + +#[derive(Debug, Clone, Copy)] +struct UserInfo { + mas_user_id: Uuid, + flags: UserFlags, +} + +struct MigrationState { + /// The server name we're migrating from + server_name: String, + + /// Lookup table from user localpart to that user's infos + users: HashMap, + + /// Mapping of MAS user ID + device ID to a MAS compat session ID. + devices_to_compat_sessions: HashMap<(Uuid, CompactString), Uuid>, + + /// A mapping of Synapse external ID providers to MAS upstream OAuth 2.0 + /// provider ID + provider_id_mapping: HashMap, } /// Performs a migration from Synapse's database to MAS' database. @@ -93,85 +126,26 @@ struct UsersMigrated { pub async fn migrate( synapse: &mut SynapseReader<'_>, mas: &mut MasWriter, - server_name: &str, + server_name: String, clock: &dyn Clock, rng: &mut impl RngCore, - provider_id_mapping: &HashMap, + provider_id_mapping: HashMap, ) -> Result<(), Error> { let counts = synapse.count_rows().await.into_synapse("counting users")?; - let migrated_users = migrate_users( - synapse, - mas, - counts - .users - .try_into() - .expect("More than usize::MAX users — unable to handle this many!"), + let state = MigrationState { server_name, - rng, - ) - .await?; - - migrate_threepids( - synapse, - mas, - server_name, - rng, - &migrated_users.user_localparts_to_uuid, - ) - .await?; - - migrate_external_ids( - synapse, - mas, - server_name, - rng, - &migrated_users.user_localparts_to_uuid, + users: HashMap::with_capacity(counts.users), + devices_to_compat_sessions: HashMap::with_capacity(counts.devices), provider_id_mapping, - ) - .await?; + }; - // `(MAS user_id, device_id)` mapped to `compat_session` ULID - let mut devices_to_compat_sessions: HashMap<(Uuid, CompactString), Uuid> = - HashMap::with_capacity( - counts - .devices - .try_into() - .expect("More than usize::MAX devices — unable to handle this many!"), - ); - - migrate_unrefreshable_access_tokens( - synapse, - mas, - server_name, - clock, - rng, - &migrated_users.user_localparts_to_uuid, - &mut devices_to_compat_sessions, - ) - .await?; - - migrate_refreshable_token_pairs( - synapse, - mas, - server_name, - clock, - rng, - &migrated_users.user_localparts_to_uuid, - &mut devices_to_compat_sessions, - ) - .await?; - - migrate_devices( - synapse, - mas, - server_name, - rng, - &migrated_users.user_localparts_to_uuid, - &mut devices_to_compat_sessions, - &migrated_users.synapse_admins, - ) - .await?; + let state = migrate_users(synapse, mas, state, rng).await?; + let state = migrate_threepids(synapse, mas, rng, state).await?; + let state = migrate_external_ids(synapse, mas, rng, state).await?; + let state = migrate_unrefreshable_access_tokens(synapse, mas, clock, rng, state).await?; + let state = migrate_refreshable_token_pairs(synapse, mas, clock, rng, state).await?; + let _state = migrate_devices(synapse, mas, rng, state).await?; Ok(()) } @@ -180,29 +154,35 @@ pub async fn migrate( async fn migrate_users( synapse: &mut SynapseReader<'_>, mas: &mut MasWriter, - user_count_hint: usize, - server_name: &str, + mut state: MigrationState, rng: &mut impl RngCore, -) -> Result { +) -> Result { let mut user_buffer = MasWriteBuffer::new(MasWriter::write_users); let mut password_buffer = MasWriteBuffer::new(MasWriter::write_passwords); let mut users_stream = pin!(synapse.read_users()); - // TODO is 1:1 capacity enough for a hashmap? - let mut user_localparts_to_uuid = HashMap::with_capacity(user_count_hint); - let mut synapse_admins = HashSet::new(); while let Some(user_res) = users_stream.next().await { let user = user_res.into_synapse("reading user")?; - let (mas_user, mas_password_opt) = transform_user(&user, server_name, rng)?; + let (mas_user, mas_password_opt) = transform_user(&user, &state.server_name, rng)?; + let mut flags = UserFlags::empty(); if bool::from(user.admin) { - // Note down the fact that this user is a Synapse admin, - // because we will grant their existing devices the Synapse admin - // flag - synapse_admins.insert(mas_user.user_id); + flags |= UserFlags::IS_SYNAPSE_ADMIN; + } + if bool::from(user.deactivated) { + flags |= UserFlags::IS_DEACTIVATED; + } + if bool::from(user.is_guest) { + flags |= UserFlags::IS_GUEST; } - user_localparts_to_uuid.insert(CompactString::new(&mas_user.username), mas_user.user_id); + state.users.insert( + CompactString::new(&mas_user.username), + UserInfo { + mas_user_id: mas_user.user_id, + flags, + }, + ); user_buffer .write(mas, mas_user) @@ -223,20 +203,16 @@ async fn migrate_users( .await .into_mas("writing passwords")?; - Ok(UsersMigrated { - user_localparts_to_uuid, - synapse_admins, - }) + Ok(state) } #[tracing::instrument(skip_all, level = Level::INFO)] async fn migrate_threepids( synapse: &mut SynapseReader<'_>, mas: &mut MasWriter, - server_name: &str, rng: &mut impl RngCore, - user_localparts_to_uuid: &HashMap, -) -> Result<(), Error> { + state: MigrationState, +) -> Result { let mut email_buffer = MasWriteBuffer::new(MasWriter::write_email_threepids); let mut unsupported_buffer = MasWriteBuffer::new(MasWriter::write_unsupported_threepids); let mut users_stream = pin!(synapse.read_threepids()); @@ -251,10 +227,10 @@ async fn migrate_threepids( let created_at: DateTime = added_at.into(); let username = synapse_user_id - .extract_localpart(server_name) + .extract_localpart(&state.server_name) .into_extract_localpart(synapse_user_id.clone())? .to_owned(); - let Some(user_id) = user_localparts_to_uuid.get(username.as_str()).copied() else { + let Some(user_infos) = state.users.get(username.as_str()).copied() else { if is_likely_appservice(&username) { continue; } @@ -269,7 +245,7 @@ async fn migrate_threepids( .write( mas, MasNewEmailThreepid { - user_id, + user_id: user_infos.mas_user_id, user_email_id: Uuid::from(Ulid::from_datetime_with_source( created_at.into(), rng, @@ -285,7 +261,7 @@ async fn migrate_threepids( .write( mas, MasNewUnsupportedThreepid { - user_id, + user_id: user_infos.mas_user_id, medium, address, created_at, @@ -305,7 +281,7 @@ async fn migrate_threepids( .await .into_mas("writing unsupported threepids")?; - Ok(()) + Ok(state) } /// # Parameters @@ -316,11 +292,9 @@ async fn migrate_threepids( async fn migrate_external_ids( synapse: &mut SynapseReader<'_>, mas: &mut MasWriter, - server_name: &str, rng: &mut impl RngCore, - user_localparts_to_uuid: &HashMap, - provider_id_mapping: &HashMap, -) -> Result<(), Error> { + state: MigrationState, +) -> Result { let mut write_buffer = MasWriteBuffer::new(MasWriter::write_upstream_oauth_links); let mut extids_stream = pin!(synapse.read_user_external_ids()); @@ -331,10 +305,10 @@ async fn migrate_external_ids( external_id: subject, } = extid_res.into_synapse("reading external ID")?; let username = synapse_user_id - .extract_localpart(server_name) + .extract_localpart(&state.server_name) .into_extract_localpart(synapse_user_id.clone())? .to_owned(); - let Some(user_id) = user_localparts_to_uuid.get(username.as_str()).copied() else { + let Some(user_infos) = state.users.get(username.as_str()).copied() else { if is_likely_appservice(&username) { continue; } @@ -344,7 +318,7 @@ async fn migrate_external_ids( }); }; - let Some(&upstream_provider_id) = provider_id_mapping.get(&auth_provider) else { + let Some(&upstream_provider_id) = state.provider_id_mapping.get(&auth_provider) else { return Err(Error::MissingAuthProviderMapping { synapse_id: auth_provider, user: synapse_user_id, @@ -353,7 +327,7 @@ async fn migrate_external_ids( // To save having to store user creation times, extract it from the ULID // This gives millisecond precision — good enough. - let user_created_ts = Ulid::from(user_id).datetime(); + let user_created_ts = Ulid::from(user_infos.mas_user_id).datetime(); let link_id: Uuid = Ulid::from_datetime_with_source(user_created_ts, rng).into(); @@ -362,7 +336,7 @@ async fn migrate_external_ids( mas, MasNewUpstreamOauthLink { link_id, - user_id, + user_id: user_infos.mas_user_id, upstream_provider_id, subject, created_at: user_created_ts.into(), @@ -377,7 +351,7 @@ async fn migrate_external_ids( .await .into_mas("writing threepids")?; - Ok(()) + Ok(state) } /// Migrate devices from Synapse to MAS (as compat sessions). @@ -392,12 +366,9 @@ async fn migrate_external_ids( async fn migrate_devices( synapse: &mut SynapseReader<'_>, mas: &mut MasWriter, - server_name: &str, rng: &mut impl RngCore, - user_localparts_to_uuid: &HashMap, - devices: &mut HashMap<(Uuid, CompactString), Uuid>, - synapse_admins: &HashSet, -) -> Result<(), Error> { + mut state: MigrationState, +) -> Result { let mut devices_stream = pin!(synapse.read_devices()); let mut write_buffer = MasWriteBuffer::new(MasWriter::write_compat_sessions); @@ -412,10 +383,10 @@ async fn migrate_devices( } = device_res.into_synapse("reading Synapse device")?; let username = synapse_user_id - .extract_localpart(server_name) + .extract_localpart(&state.server_name) .into_extract_localpart(synapse_user_id.clone())? .to_owned(); - let Some(user_id) = user_localparts_to_uuid.get(username.as_str()).copied() else { + let Some(user_infos) = state.users.get(username.as_str()).copied() else { if is_likely_appservice(&username) { continue; } @@ -425,8 +396,13 @@ async fn migrate_devices( }); }; - let session_id = *devices - .entry((user_id, CompactString::new(&device_id))) + if user_infos.flags.is_deactivated() || user_infos.flags.is_guest() { + continue; + } + + let session_id = *state + .devices_to_compat_sessions + .entry((user_infos.mas_user_id, CompactString::new(&device_id))) .or_insert_with(|| // We don't have a creation time for this device (as it has no access token), // so use now as a least-evil fallback. @@ -450,17 +426,16 @@ async fn migrate_devices( .ok() }); - // TODO skip access tokens for deactivated users write_buffer .write( mas, MasNewCompatSession { session_id, - user_id, + user_id: user_infos.mas_user_id, device_id: Some(device_id), human_name: display_name, created_at, - is_synapse_admin: synapse_admins.contains(&user_id), + is_synapse_admin: user_infos.flags.is_synapse_admin(), last_active_at: last_seen.map(DateTime::from), last_active_ip, user_agent, @@ -475,7 +450,7 @@ async fn migrate_devices( .await .into_mas("writing compat sessions")?; - Ok(()) + Ok(state) } /// Migrates unrefreshable access tokens (those without an associated refresh @@ -484,12 +459,10 @@ async fn migrate_devices( async fn migrate_unrefreshable_access_tokens( synapse: &mut SynapseReader<'_>, mas: &mut MasWriter, - server_name: &str, clock: &dyn Clock, rng: &mut impl RngCore, - user_localparts_to_uuid: &HashMap, - devices: &mut HashMap<(Uuid, CompactString), Uuid>, -) -> Result<(), Error> { + mut state: MigrationState, +) -> Result { let mut token_stream = pin!(synapse.read_unrefreshable_access_tokens()); let mut write_buffer = MasWriteBuffer::new(MasWriter::write_compat_access_tokens); let mut deviceless_session_write_buffer = MasWriteBuffer::new(MasWriter::write_compat_sessions); @@ -504,10 +477,10 @@ async fn migrate_unrefreshable_access_tokens( } = token_res.into_synapse("reading Synapse access token")?; let username = synapse_user_id - .extract_localpart(server_name) + .extract_localpart(&state.server_name) .into_extract_localpart(synapse_user_id.clone())? .to_owned(); - let Some(user_id) = user_localparts_to_uuid.get(username.as_str()).copied() else { + let Some(user_infos) = state.users.get(username.as_str()).copied() else { if is_likely_appservice(&username) { continue; } @@ -517,6 +490,10 @@ async fn migrate_unrefreshable_access_tokens( }); }; + if user_infos.flags.is_deactivated() || user_infos.flags.is_guest() { + continue; + } + // It's not always accurate, but last_validated is *often* the creation time of // the device If we don't have one, then use the current time as a // fallback. @@ -524,8 +501,9 @@ async fn migrate_unrefreshable_access_tokens( let session_id = if let Some(device_id) = device_id { // Use the existing device_id if this is the second token for a device - *devices - .entry((user_id, CompactString::new(&device_id))) + *state + .devices_to_compat_sessions + .entry((user_infos.mas_user_id, CompactString::new(&device_id))) .or_insert_with(|| { Uuid::from(Ulid::from_datetime_with_source(created_at.into(), rng)) }) @@ -540,7 +518,7 @@ async fn migrate_unrefreshable_access_tokens( mas, MasNewCompatSession { session_id: deviceless_session_id, - user_id, + user_id: user_infos.mas_user_id, device_id: None, human_name: None, created_at, @@ -558,7 +536,6 @@ async fn migrate_unrefreshable_access_tokens( let token_id = Uuid::from(Ulid::from_datetime_with_source(created_at.into(), rng)); - // TODO skip access tokens for deactivated users write_buffer .write( mas, @@ -583,7 +560,7 @@ async fn migrate_unrefreshable_access_tokens( .await .into_mas("writing deviceless compat sessions")?; - Ok(()) + Ok(state) } /// Migrates (access token, refresh token) pairs. @@ -592,12 +569,10 @@ async fn migrate_unrefreshable_access_tokens( async fn migrate_refreshable_token_pairs( synapse: &mut SynapseReader<'_>, mas: &mut MasWriter, - server_name: &str, clock: &dyn Clock, rng: &mut impl RngCore, - user_localparts_to_uuid: &HashMap, - devices: &mut HashMap<(Uuid, CompactString), Uuid>, -) -> Result<(), Error> { + mut state: MigrationState, +) -> Result { let mut token_stream = pin!(synapse.read_refreshable_token_pairs()); let mut access_token_write_buffer = MasWriteBuffer::new(MasWriter::write_compat_access_tokens); let mut refresh_token_write_buffer = @@ -614,10 +589,10 @@ async fn migrate_refreshable_token_pairs( } = token_res.into_synapse("reading Synapse refresh token")?; let username = synapse_user_id - .extract_localpart(server_name) + .extract_localpart(&state.server_name) .into_extract_localpart(synapse_user_id.clone())? .to_owned(); - let Some(user_id) = user_localparts_to_uuid.get(username.as_str()).copied() else { + let Some(user_infos) = state.users.get(username.as_str()).copied() else { if is_likely_appservice(&username) { continue; } @@ -627,20 +602,24 @@ async fn migrate_refreshable_token_pairs( }); }; + if user_infos.flags.is_deactivated() || user_infos.flags.is_guest() { + continue; + } + // It's not always accurate, but last_validated is *often* the creation time of // the device If we don't have one, then use the current time as a // fallback. let created_at = last_validated.map_or_else(|| clock.now(), DateTime::from); // Use the existing device_id if this is the second token for a device - let session_id = *devices - .entry((user_id, CompactString::new(&device_id))) + let session_id = *state + .devices_to_compat_sessions + .entry((user_infos.mas_user_id, CompactString::new(&device_id))) .or_insert_with(|| Uuid::from(Ulid::from_datetime_with_source(created_at.into(), rng))); let access_token_id = Uuid::from(Ulid::from_datetime_with_source(created_at.into(), rng)); let refresh_token_id = Uuid::from(Ulid::from_datetime_with_source(created_at.into(), rng)); - // TODO skip access tokens for deactivated users access_token_write_buffer .write( mas, @@ -679,7 +658,7 @@ async fn migrate_refreshable_token_pairs( .await .into_mas("writing compat refresh tokens")?; - Ok(()) + Ok(state) } fn transform_user( diff --git a/crates/syn2mas/src/synapse_reader/mod.rs b/crates/syn2mas/src/synapse_reader/mod.rs index 5d5d3303f..fb145af7c 100644 --- a/crates/syn2mas/src/synapse_reader/mod.rs +++ b/crates/syn2mas/src/synapse_reader/mod.rs @@ -262,8 +262,8 @@ const TABLES_TO_LOCK: &[&str] = &[ /// Used to estimate progress. #[derive(Clone, Debug)] pub struct SynapseRowCounts { - pub users: i64, - pub devices: i64, + pub users: usize, + pub devices: usize, } pub struct SynapseReader<'c> { @@ -334,7 +334,7 @@ impl<'conn> SynapseReader<'conn> { /// /// - An underlying database error pub async fn count_rows(&mut self) -> Result { - let users: i64 = sqlx::query_scalar( + let users: usize = sqlx::query_scalar::<_, i64>( " SELECT COUNT(1) FROM users WHERE appservice_id IS NULL @@ -342,9 +342,12 @@ impl<'conn> SynapseReader<'conn> { ) .fetch_one(&mut *self.txn) .await - .into_database("counting Synapse users")?; + .into_database("counting Synapse users")? + .max(0) + .try_into() + .unwrap_or(usize::MAX); - let devices = sqlx::query_scalar( + let devices = sqlx::query_scalar::<_, i64>( " SELECT COUNT(1) FROM devices WHERE NOT hidden @@ -352,7 +355,10 @@ impl<'conn> SynapseReader<'conn> { ) .fetch_one(&mut *self.txn) .await - .into_database("counting Synapse devices")?; + .into_database("counting Synapse devices")? + .max(0) + .try_into() + .unwrap_or(usize::MAX); Ok(SynapseRowCounts { users, devices }) }