syn2mas: add a buffered channel for writing refreshable tokens

This commit is contained in:
Quentin Gliech
2025-04-22 13:48:47 +02:00
parent c292da7ac9
commit b21748c2bd

View File

@@ -11,7 +11,7 @@
//! This module does not implement any of the safety checks that should be run
//! *before* the migration.
use std::{pin::pin, time::Instant};
use std::time::Instant;
use chrono::{DateTime, Utc};
use compact_str::CompactString;
@@ -847,99 +847,127 @@ async fn migrate_refreshable_token_pairs(
) -> Result<(MasWriter, MigrationState), Error> {
let start = Instant::now();
let mut token_stream = pin!(synapse.read_refreshable_token_pairs());
let mut access_token_write_buffer =
MasWriteBuffer::new(&mas, MasWriter::write_compat_access_tokens);
let mut refresh_token_write_buffer =
MasWriteBuffer::new(&mas, MasWriter::write_compat_refresh_tokens);
let (tx, mut rx) = tokio::sync::mpsc::channel::<SynapseRefreshableTokenPair>(10 * 1024 * 1024);
while let Some(token_res) = token_stream.next().await {
let SynapseRefreshableTokenPair {
user_id: synapse_user_id,
device_id,
access_token,
refresh_token,
valid_until_ms,
last_validated,
} = token_res.into_synapse("reading Synapse refresh token")?;
// create a new RNG seeded from the passed RNG so that we can move it into the
// spawned task
let mut rng = rand_chacha::ChaChaRng::from_rng(rng).expect("failed to seed rng");
let now = clock.now();
let task = tokio::spawn(
async move {
let mut access_token_write_buffer =
MasWriteBuffer::new(&mas, MasWriter::write_compat_access_tokens);
let mut refresh_token_write_buffer =
MasWriteBuffer::new(&mas, MasWriter::write_compat_refresh_tokens);
let username = synapse_user_id
.extract_localpart(&state.server_name)
.into_extract_localpart(synapse_user_id.clone())?
.to_owned();
let Some(user_infos) = state.users.get(username.as_str()).copied() else {
return Err(Error::MissingUserFromDependentTable {
table: "refresh_tokens".to_owned(),
user: synapse_user_id,
});
};
let Some(mas_user_id) = user_infos.mas_user_id else {
progress_counter.increment_skipped();
continue;
};
if user_infos.flags.is_deactivated()
|| user_infos.flags.is_guest()
|| user_infos.flags.is_appservice()
{
progress_counter.increment_skipped();
continue;
}
// It's not always accurate, but last_validated is *often* the creation time of
// the device If we don't have one, then use the current time as a
// fallback.
let created_at = last_validated.map_or_else(|| clock.now(), DateTime::from);
// Use the existing device_id if this is the second token for a device
let session_id = *state
.devices_to_compat_sessions
.entry((mas_user_id, CompactString::new(&device_id)))
.or_insert_with(|| Uuid::from(Ulid::from_datetime_with_source(created_at.into(), rng)));
let access_token_id = Uuid::from(Ulid::from_datetime_with_source(created_at.into(), rng));
let refresh_token_id = Uuid::from(Ulid::from_datetime_with_source(created_at.into(), rng));
access_token_write_buffer
.write(
&mut mas,
MasNewCompatAccessToken {
token_id: access_token_id,
session_id,
while let Some(token) = rx.recv().await {
let SynapseRefreshableTokenPair {
user_id: synapse_user_id,
device_id,
access_token,
created_at,
expires_at: valid_until_ms.map(DateTime::from),
},
)
.await
.into_mas("writing compat access tokens")?;
refresh_token_write_buffer
.write(
&mut mas,
MasNewCompatRefreshToken {
refresh_token_id,
session_id,
access_token_id,
refresh_token,
created_at,
},
)
.await
.into_mas("writing compat refresh tokens")?;
valid_until_ms,
last_validated,
} = token;
progress_counter.increment_migrated();
}
let username = synapse_user_id
.extract_localpart(&state.server_name)
.into_extract_localpart(synapse_user_id.clone())?
.to_owned();
let Some(user_infos) = state.users.get(username.as_str()).copied() else {
return Err(Error::MissingUserFromDependentTable {
table: "refresh_tokens".to_owned(),
user: synapse_user_id,
});
};
access_token_write_buffer
.finish(&mut mas)
.await
.into_mas("writing compat access tokens")?;
let Some(mas_user_id) = user_infos.mas_user_id else {
progress_counter.increment_skipped();
continue;
};
refresh_token_write_buffer
.finish(&mut mas)
.await
.into_mas("writing compat refresh tokens")?;
if user_infos.flags.is_deactivated()
|| user_infos.flags.is_guest()
|| user_infos.flags.is_appservice()
{
progress_counter.increment_skipped();
continue;
}
// It's not always accurate, but last_validated is *often* the creation time of
// the device If we don't have one, then use the current time as a
// fallback.
let created_at = last_validated.map_or_else(|| now, DateTime::from);
// Use the existing device_id if this is the second token for a device
let session_id = *state
.devices_to_compat_sessions
.entry((mas_user_id, CompactString::new(&device_id)))
.or_insert_with(|| {
Uuid::from(Ulid::from_datetime_with_source(created_at.into(), &mut rng))
});
let access_token_id =
Uuid::from(Ulid::from_datetime_with_source(created_at.into(), &mut rng));
let refresh_token_id =
Uuid::from(Ulid::from_datetime_with_source(created_at.into(), &mut rng));
access_token_write_buffer
.write(
&mut mas,
MasNewCompatAccessToken {
token_id: access_token_id,
session_id,
access_token,
created_at,
expires_at: valid_until_ms.map(DateTime::from),
},
)
.await
.into_mas("writing compat access tokens")?;
refresh_token_write_buffer
.write(
&mut mas,
MasNewCompatRefreshToken {
refresh_token_id,
session_id,
access_token_id,
refresh_token,
created_at,
},
)
.await
.into_mas("writing compat refresh tokens")?;
progress_counter.increment_migrated();
}
access_token_write_buffer
.finish(&mut mas)
.await
.into_mas("writing compat access tokens")?;
refresh_token_write_buffer
.finish(&mut mas)
.await
.into_mas("writing compat refresh tokens")?;
Ok((mas, state))
}
.instrument(tracing::info_span!("ingest_task")),
);
// In case this has an error, we still want to join the task, so we look at the
// error later
let res = synapse
.read_refreshable_token_pairs()
.map_err(|e| e.into_synapse("reading refresh token pairs"))
.forward(PollSender::new(tx).sink_map_err(|_| Error::ChannelClosed))
.inspect_err(|e| tracing::error!(error = e as &dyn std::error::Error))
.await;
let (mas, state) = task.await.into_join("refresh token write task")??;
res?;
info!(
"refreshable token pairs migrated in {:.1}s",