From 47009a8800b85fe9403a0908900cadddad54b5ef Mon Sep 17 00:00:00 2001 From: Quentin Gliech Date: Tue, 22 Apr 2025 15:49:17 +0200 Subject: [PATCH] syn2mas: make the MasWriteBuffer use the WriteBatch trait --- crates/syn2mas/src/mas_writer/mod.rs | 25 ++++++++++++------------- crates/syn2mas/src/migration.rs | 24 ++++++++++-------------- 2 files changed, 22 insertions(+), 27 deletions(-) diff --git a/crates/syn2mas/src/mas_writer/mod.rs b/crates/syn2mas/src/mas_writer/mod.rs index 0012ef04a..257e07a11 100644 --- a/crates/syn2mas/src/mas_writer/mod.rs +++ b/crates/syn2mas/src/mas_writer/mod.rs @@ -114,7 +114,6 @@ impl WriterConnectionPool { where F: for<'conn> FnOnce(&'conn mut PgConnection) -> BoxFuture<'conn, Result<(), Error>> + Send - + Sync + 'static, { match self.connection_rx.recv().await { @@ -250,11 +249,11 @@ pub struct MasWriter { write_buffer_finish_checker: FinishChecker, } -trait WriteBatch: Sized { +pub trait WriteBatch: Send + Sync + Sized + 'static { fn write_batch( conn: &mut PgConnection, batch: Vec, - ) -> impl Future>; + ) -> impl Future> + Send; } pub struct MasNewUser { @@ -1167,24 +1166,20 @@ impl MasWriter { // database. const WRITE_BUFFER_BATCH_SIZE: usize = 4096; -/// A function that can accept and flush buffers from a `MasWriteBuffer`. -/// Intended uses are the methods on `MasWriter` such as `write_users`. -type WriteBufferFlusher = - for<'a> fn(&'a mut MasWriter, Vec) -> BoxFuture<'a, Result<(), Error>>; - /// A buffer for writing rows to the MAS database. /// Generic over the type of rows. pub struct MasWriteBuffer { rows: Vec, - flusher: WriteBufferFlusher, finish_checker_handle: FinishCheckerHandle, } -impl MasWriteBuffer { - pub fn new(writer: &MasWriter, flusher: WriteBufferFlusher) -> Self { +impl MasWriteBuffer +where + T: WriteBatch, +{ + pub fn new(writer: &MasWriter) -> Self { MasWriteBuffer { rows: Vec::with_capacity(WRITE_BUFFER_BATCH_SIZE), - flusher, finish_checker_handle: writer.write_buffer_finish_checker.handle(), } } @@ -1201,7 +1196,11 @@ impl MasWriteBuffer { } let rows = std::mem::take(&mut self.rows); self.rows.reserve_exact(WRITE_BUFFER_BATCH_SIZE); - (self.flusher)(writer, rows).await?; + writer + .writer_pool + .spawn_with_connection(move |conn| T::write_batch(conn, rows).boxed()) + .boxed() + .await?; Ok(()) } diff --git a/crates/syn2mas/src/migration.rs b/crates/syn2mas/src/migration.rs index 3388ff387..7b34baf62 100644 --- a/crates/syn2mas/src/migration.rs +++ b/crates/syn2mas/src/migration.rs @@ -220,8 +220,8 @@ async fn migrate_users( let mut rng = rand_chacha::ChaChaRng::from_rng(rng).expect("failed to seed rng"); let task = tokio::spawn( async move { - let mut user_buffer = MasWriteBuffer::new(&mas, MasWriter::write_users); - let mut password_buffer = MasWriteBuffer::new(&mas, MasWriter::write_passwords); + let mut user_buffer = MasWriteBuffer::new(&mas); + let mut password_buffer = MasWriteBuffer::new(&mas); while let Some(user) = rx.recv().await { // Handling an edge case: some AS users may have invalid localparts containing @@ -342,9 +342,8 @@ async fn migrate_threepids( let mut rng = rand_chacha::ChaChaRng::from_rng(rng).expect("failed to seed rng"); let task = tokio::spawn( async move { - let mut email_buffer = MasWriteBuffer::new(&mas, MasWriter::write_email_threepids); - let mut unsupported_buffer = - MasWriteBuffer::new(&mas, MasWriter::write_unsupported_threepids); + let mut email_buffer = MasWriteBuffer::new(&mas); + let mut unsupported_buffer = MasWriteBuffer::new(&mas); while let Some(threepid) = rx.recv().await { let SynapseThreepid { @@ -457,7 +456,7 @@ async fn migrate_external_ids( let mut rng = rand_chacha::ChaChaRng::from_rng(rng).expect("failed to seed rng"); let task = tokio::spawn( async move { - let mut write_buffer = MasWriteBuffer::new(&mas, MasWriter::write_upstream_oauth_links); + let mut write_buffer = MasWriteBuffer::new(&mas); while let Some(extid) = rx.recv().await { let SynapseExternalId { @@ -569,7 +568,7 @@ async fn migrate_devices( let mut rng = rand_chacha::ChaChaRng::from_rng(rng).expect("failed to seed rng"); let task = tokio::spawn( async move { - let mut write_buffer = MasWriteBuffer::new(&mas, MasWriter::write_compat_sessions); + let mut write_buffer = MasWriteBuffer::new(&mas); while let Some(device) = rx.recv().await { let SynapseDevice { @@ -704,9 +703,8 @@ async fn migrate_unrefreshable_access_tokens( let mut rng = rand_chacha::ChaChaRng::from_rng(rng).expect("failed to seed rng"); let task = tokio::spawn( async move { - let mut write_buffer = MasWriteBuffer::new(&mas, MasWriter::write_compat_access_tokens); - let mut deviceless_session_write_buffer = - MasWriteBuffer::new(&mas, MasWriter::write_compat_sessions); + let mut write_buffer = MasWriteBuffer::new(&mas); + let mut deviceless_session_write_buffer = MasWriteBuffer::new(&mas); while let Some(token) = rx.recv().await { let SynapseAccessToken { @@ -855,10 +853,8 @@ async fn migrate_refreshable_token_pairs( let now = clock.now(); let task = tokio::spawn( async move { - let mut access_token_write_buffer = - MasWriteBuffer::new(&mas, MasWriter::write_compat_access_tokens); - let mut refresh_token_write_buffer = - MasWriteBuffer::new(&mas, MasWriter::write_compat_refresh_tokens); + let mut access_token_write_buffer = MasWriteBuffer::new(&mas); + let mut refresh_token_write_buffer = MasWriteBuffer::new(&mas); while let Some(token) = rx.recv().await { let SynapseRefreshableTokenPair {