diff --git a/crates/syn2mas/src/migration.rs b/crates/syn2mas/src/migration.rs index fb2ea5487..3388ff387 100644 --- a/crates/syn2mas/src/migration.rs +++ b/crates/syn2mas/src/migration.rs @@ -11,7 +11,7 @@ //! This module does not implement any of the safety checks that should be run //! *before* the migration. -use std::{pin::pin, time::Instant}; +use std::time::Instant; use chrono::{DateTime, Utc}; use compact_str::CompactString; @@ -847,99 +847,127 @@ async fn migrate_refreshable_token_pairs( ) -> Result<(MasWriter, MigrationState), Error> { let start = Instant::now(); - let mut token_stream = pin!(synapse.read_refreshable_token_pairs()); - let mut access_token_write_buffer = - MasWriteBuffer::new(&mas, MasWriter::write_compat_access_tokens); - let mut refresh_token_write_buffer = - MasWriteBuffer::new(&mas, MasWriter::write_compat_refresh_tokens); + let (tx, mut rx) = tokio::sync::mpsc::channel::(10 * 1024 * 1024); - while let Some(token_res) = token_stream.next().await { - let SynapseRefreshableTokenPair { - user_id: synapse_user_id, - device_id, - access_token, - refresh_token, - valid_until_ms, - last_validated, - } = token_res.into_synapse("reading Synapse refresh token")?; + // create a new RNG seeded from the passed RNG so that we can move it into the + // spawned task + let mut rng = rand_chacha::ChaChaRng::from_rng(rng).expect("failed to seed rng"); + let now = clock.now(); + let task = tokio::spawn( + async move { + let mut access_token_write_buffer = + MasWriteBuffer::new(&mas, MasWriter::write_compat_access_tokens); + let mut refresh_token_write_buffer = + MasWriteBuffer::new(&mas, MasWriter::write_compat_refresh_tokens); - let username = synapse_user_id - .extract_localpart(&state.server_name) - .into_extract_localpart(synapse_user_id.clone())? - .to_owned(); - let Some(user_infos) = state.users.get(username.as_str()).copied() else { - return Err(Error::MissingUserFromDependentTable { - table: "refresh_tokens".to_owned(), - user: synapse_user_id, - }); - }; - - let Some(mas_user_id) = user_infos.mas_user_id else { - progress_counter.increment_skipped(); - continue; - }; - - if user_infos.flags.is_deactivated() - || user_infos.flags.is_guest() - || user_infos.flags.is_appservice() - { - progress_counter.increment_skipped(); - continue; - } - - // It's not always accurate, but last_validated is *often* the creation time of - // the device If we don't have one, then use the current time as a - // fallback. - let created_at = last_validated.map_or_else(|| clock.now(), DateTime::from); - - // Use the existing device_id if this is the second token for a device - let session_id = *state - .devices_to_compat_sessions - .entry((mas_user_id, CompactString::new(&device_id))) - .or_insert_with(|| Uuid::from(Ulid::from_datetime_with_source(created_at.into(), rng))); - - let access_token_id = Uuid::from(Ulid::from_datetime_with_source(created_at.into(), rng)); - let refresh_token_id = Uuid::from(Ulid::from_datetime_with_source(created_at.into(), rng)); - - access_token_write_buffer - .write( - &mut mas, - MasNewCompatAccessToken { - token_id: access_token_id, - session_id, + while let Some(token) = rx.recv().await { + let SynapseRefreshableTokenPair { + user_id: synapse_user_id, + device_id, access_token, - created_at, - expires_at: valid_until_ms.map(DateTime::from), - }, - ) - .await - .into_mas("writing compat access tokens")?; - refresh_token_write_buffer - .write( - &mut mas, - MasNewCompatRefreshToken { - refresh_token_id, - session_id, - access_token_id, refresh_token, - created_at, - }, - ) - .await - .into_mas("writing compat refresh tokens")?; + valid_until_ms, + last_validated, + } = token; - progress_counter.increment_migrated(); - } + let username = synapse_user_id + .extract_localpart(&state.server_name) + .into_extract_localpart(synapse_user_id.clone())? + .to_owned(); + let Some(user_infos) = state.users.get(username.as_str()).copied() else { + return Err(Error::MissingUserFromDependentTable { + table: "refresh_tokens".to_owned(), + user: synapse_user_id, + }); + }; - access_token_write_buffer - .finish(&mut mas) - .await - .into_mas("writing compat access tokens")?; + let Some(mas_user_id) = user_infos.mas_user_id else { + progress_counter.increment_skipped(); + continue; + }; - refresh_token_write_buffer - .finish(&mut mas) - .await - .into_mas("writing compat refresh tokens")?; + if user_infos.flags.is_deactivated() + || user_infos.flags.is_guest() + || user_infos.flags.is_appservice() + { + progress_counter.increment_skipped(); + continue; + } + + // It's not always accurate, but last_validated is *often* the creation time of + // the device If we don't have one, then use the current time as a + // fallback. + let created_at = last_validated.map_or_else(|| now, DateTime::from); + + // Use the existing device_id if this is the second token for a device + let session_id = *state + .devices_to_compat_sessions + .entry((mas_user_id, CompactString::new(&device_id))) + .or_insert_with(|| { + Uuid::from(Ulid::from_datetime_with_source(created_at.into(), &mut rng)) + }); + + let access_token_id = + Uuid::from(Ulid::from_datetime_with_source(created_at.into(), &mut rng)); + let refresh_token_id = + Uuid::from(Ulid::from_datetime_with_source(created_at.into(), &mut rng)); + + access_token_write_buffer + .write( + &mut mas, + MasNewCompatAccessToken { + token_id: access_token_id, + session_id, + access_token, + created_at, + expires_at: valid_until_ms.map(DateTime::from), + }, + ) + .await + .into_mas("writing compat access tokens")?; + refresh_token_write_buffer + .write( + &mut mas, + MasNewCompatRefreshToken { + refresh_token_id, + session_id, + access_token_id, + refresh_token, + created_at, + }, + ) + .await + .into_mas("writing compat refresh tokens")?; + + progress_counter.increment_migrated(); + } + + access_token_write_buffer + .finish(&mut mas) + .await + .into_mas("writing compat access tokens")?; + + refresh_token_write_buffer + .finish(&mut mas) + .await + .into_mas("writing compat refresh tokens")?; + Ok((mas, state)) + } + .instrument(tracing::info_span!("ingest_task")), + ); + + // In case this has an error, we still want to join the task, so we look at the + // error later + let res = synapse + .read_refreshable_token_pairs() + .map_err(|e| e.into_synapse("reading refresh token pairs")) + .forward(PollSender::new(tx).sink_map_err(|_| Error::ChannelClosed)) + .inspect_err(|e| tracing::error!(error = e as &dyn std::error::Error)) + .await; + + let (mas, state) = task.await.into_join("refresh token write task")??; + + res?; info!( "refreshable token pairs migrated in {:.1}s",