Actually consume jobs
This commit is contained in:
43
crates/storage-pg/.sqlx/query-9f2fae84d17991a179f93c4ea43b411aa9f15e7beccfd6212787c3452d35d061.json
generated
Normal file
43
crates/storage-pg/.sqlx/query-9f2fae84d17991a179f93c4ea43b411aa9f15e7beccfd6212787c3452d35d061.json
generated
Normal file
@@ -0,0 +1,43 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n -- We first grab a few jobs that are available,\n -- using a FOR UPDATE SKIP LOCKED so that this can be run concurrently\n -- and we don't get multiple workers grabbing the same jobs\n WITH locked_jobs AS (\n SELECT queue_job_id\n FROM queue_jobs\n WHERE\n status = 'available'\n AND queue_name = ANY($1)\n ORDER BY queue_job_id ASC\n LIMIT $2\n FOR UPDATE\n SKIP LOCKED\n )\n -- then we update the status of those jobs to 'running', returning the job details\n UPDATE queue_jobs\n SET status = 'running', started_at = $3, started_by = $4\n FROM locked_jobs\n WHERE queue_jobs.queue_job_id = locked_jobs.queue_job_id\n RETURNING\n queue_jobs.queue_job_id,\n queue_jobs.queue_name,\n queue_jobs.payload,\n queue_jobs.metadata\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "queue_job_id",
|
||||
"type_info": "Uuid"
|
||||
},
|
||||
{
|
||||
"ordinal": 1,
|
||||
"name": "queue_name",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 2,
|
||||
"name": "payload",
|
||||
"type_info": "Jsonb"
|
||||
},
|
||||
{
|
||||
"ordinal": 3,
|
||||
"name": "metadata",
|
||||
"type_info": "Jsonb"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"TextArray",
|
||||
"Int8",
|
||||
"Timestamptz",
|
||||
"Uuid"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false
|
||||
]
|
||||
},
|
||||
"hash": "9f2fae84d17991a179f93c4ea43b411aa9f15e7beccfd6212787c3452d35d061"
|
||||
}
|
||||
15
crates/storage-pg/.sqlx/query-a63a217981b97448ddcc96b2489ddd9d3bc8c99b5b8b1d373939fc3ae9715c27.json
generated
Normal file
15
crates/storage-pg/.sqlx/query-a63a217981b97448ddcc96b2489ddd9d3bc8c99b5b8b1d373939fc3ae9715c27.json
generated
Normal file
@@ -0,0 +1,15 @@
|
||||
{
|
||||
"db_name": "PostgreSQL",
|
||||
"query": "\n UPDATE queue_jobs\n SET status = 'completed', completed_at = $1\n WHERE queue_job_id = $2 AND status = 'running'\n ",
|
||||
"describe": {
|
||||
"columns": [],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Timestamptz",
|
||||
"Uuid"
|
||||
]
|
||||
},
|
||||
"nullable": []
|
||||
},
|
||||
"hash": "a63a217981b97448ddcc96b2489ddd9d3bc8c99b5b8b1d373939fc3ae9715c27"
|
||||
}
|
||||
@@ -7,13 +7,16 @@
|
||||
//! [`QueueJobRepository`].
|
||||
|
||||
use async_trait::async_trait;
|
||||
use mas_storage::{queue::QueueJobRepository, Clock};
|
||||
use mas_storage::{
|
||||
queue::{Job, QueueJobRepository, Worker},
|
||||
Clock,
|
||||
};
|
||||
use rand::RngCore;
|
||||
use sqlx::PgConnection;
|
||||
use ulid::Ulid;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{DatabaseError, ExecuteExt};
|
||||
use crate::{DatabaseError, DatabaseInconsistencyError, ExecuteExt};
|
||||
|
||||
/// An implementation of [`QueueJobRepository`] for a PostgreSQL connection.
|
||||
pub struct PgQueueJobRepository<'c> {
|
||||
@@ -29,6 +32,37 @@ impl<'c> PgQueueJobRepository<'c> {
|
||||
}
|
||||
}
|
||||
|
||||
struct JobReservationResult {
|
||||
queue_job_id: Uuid,
|
||||
queue_name: String,
|
||||
payload: serde_json::Value,
|
||||
metadata: serde_json::Value,
|
||||
}
|
||||
|
||||
impl TryFrom<JobReservationResult> for Job {
|
||||
type Error = DatabaseInconsistencyError;
|
||||
|
||||
fn try_from(value: JobReservationResult) -> Result<Self, Self::Error> {
|
||||
let id = value.queue_job_id.into();
|
||||
let queue_name = value.queue_name;
|
||||
let payload = value.payload;
|
||||
|
||||
let metadata = serde_json::from_value(value.metadata).map_err(|e| {
|
||||
DatabaseInconsistencyError::on("queue_jobs")
|
||||
.column("metadata")
|
||||
.row(id)
|
||||
.source(e)
|
||||
})?;
|
||||
|
||||
Ok(Self {
|
||||
id,
|
||||
queue_name,
|
||||
payload,
|
||||
metadata,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
impl QueueJobRepository for PgQueueJobRepository<'_> {
|
||||
type Error = DatabaseError;
|
||||
@@ -73,4 +107,96 @@ impl QueueJobRepository for PgQueueJobRepository<'_> {
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.queue_job.reserve",
|
||||
skip_all,
|
||||
fields(
|
||||
db.query.text,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn reserve(
|
||||
&mut self,
|
||||
clock: &dyn Clock,
|
||||
worker: &Worker,
|
||||
queues: &[&str],
|
||||
count: usize,
|
||||
) -> Result<Vec<Job>, Self::Error> {
|
||||
let now = clock.now();
|
||||
let max_count = i64::try_from(count).unwrap_or(i64::MAX);
|
||||
let queues: Vec<String> = queues.iter().map(|&s| s.to_owned()).collect();
|
||||
let results = sqlx::query_as!(
|
||||
JobReservationResult,
|
||||
r#"
|
||||
-- We first grab a few jobs that are available,
|
||||
-- using a FOR UPDATE SKIP LOCKED so that this can be run concurrently
|
||||
-- and we don't get multiple workers grabbing the same jobs
|
||||
WITH locked_jobs AS (
|
||||
SELECT queue_job_id
|
||||
FROM queue_jobs
|
||||
WHERE
|
||||
status = 'available'
|
||||
AND queue_name = ANY($1)
|
||||
ORDER BY queue_job_id ASC
|
||||
LIMIT $2
|
||||
FOR UPDATE
|
||||
SKIP LOCKED
|
||||
)
|
||||
-- then we update the status of those jobs to 'running', returning the job details
|
||||
UPDATE queue_jobs
|
||||
SET status = 'running', started_at = $3, started_by = $4
|
||||
FROM locked_jobs
|
||||
WHERE queue_jobs.queue_job_id = locked_jobs.queue_job_id
|
||||
RETURNING
|
||||
queue_jobs.queue_job_id,
|
||||
queue_jobs.queue_name,
|
||||
queue_jobs.payload,
|
||||
queue_jobs.metadata
|
||||
"#,
|
||||
&queues,
|
||||
max_count,
|
||||
now,
|
||||
Uuid::from(worker.id),
|
||||
)
|
||||
.traced()
|
||||
.fetch_all(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
let jobs = results
|
||||
.into_iter()
|
||||
.map(TryFrom::try_from)
|
||||
.collect::<Result<Vec<_>, _>>()?;
|
||||
|
||||
Ok(jobs)
|
||||
}
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "db.queue_job.mark_as_completed",
|
||||
skip_all,
|
||||
fields(
|
||||
db.query.text,
|
||||
job.id = %id,
|
||||
),
|
||||
err,
|
||||
)]
|
||||
async fn mark_as_completed(&mut self, clock: &dyn Clock, id: Ulid) -> Result<(), Self::Error> {
|
||||
let now = clock.now();
|
||||
let res = sqlx::query!(
|
||||
r#"
|
||||
UPDATE queue_jobs
|
||||
SET status = 'completed', completed_at = $1
|
||||
WHERE queue_job_id = $2 AND status = 'running'
|
||||
"#,
|
||||
now,
|
||||
Uuid::from(id),
|
||||
)
|
||||
.traced()
|
||||
.execute(&mut *self.conn)
|
||||
.await?;
|
||||
|
||||
DatabaseError::ensure_affected_rows(&res, 1)?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ use serde::{Deserialize, Serialize};
|
||||
use tracing_opentelemetry::OpenTelemetrySpanExt;
|
||||
use ulid::Ulid;
|
||||
|
||||
use super::Worker;
|
||||
use crate::{repository_impl, Clock};
|
||||
|
||||
/// Represents a job in the job queue
|
||||
@@ -19,6 +20,9 @@ pub struct Job {
|
||||
/// The ID of the job
|
||||
pub id: Ulid,
|
||||
|
||||
/// The queue on which the job was placed
|
||||
pub queue_name: String,
|
||||
|
||||
/// The payload of the job
|
||||
pub payload: serde_json::Value,
|
||||
|
||||
@@ -27,7 +31,7 @@ pub struct Job {
|
||||
}
|
||||
|
||||
/// Metadata stored alongside the job
|
||||
#[derive(Serialize, Deserialize, Default)]
|
||||
#[derive(Serialize, Deserialize, Default, Clone, Debug)]
|
||||
pub struct JobMetadata {
|
||||
#[serde(default)]
|
||||
trace_id: String,
|
||||
@@ -97,6 +101,38 @@ pub trait QueueJobRepository: Send + Sync {
|
||||
payload: serde_json::Value,
|
||||
metadata: serde_json::Value,
|
||||
) -> Result<(), Self::Error>;
|
||||
|
||||
/// Reserve multiple jobs from multiple queues
|
||||
///
|
||||
/// # Parameters
|
||||
///
|
||||
/// * `clock` - The clock used to generate timestamps
|
||||
/// * `worker` - The worker that is reserving the jobs
|
||||
/// * `queues` - The queues to reserve jobs from
|
||||
/// * `count` - The number of jobs to reserve
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the underlying repository fails.
|
||||
async fn reserve(
|
||||
&mut self,
|
||||
clock: &dyn Clock,
|
||||
worker: &Worker,
|
||||
queues: &[&str],
|
||||
count: usize,
|
||||
) -> Result<Vec<Job>, Self::Error>;
|
||||
|
||||
/// Mark a job as completed
|
||||
///
|
||||
/// # Parameters
|
||||
///
|
||||
/// * `clock` - The clock used to generate timestamps
|
||||
/// * `job` - The job to mark as completed
|
||||
///
|
||||
/// # Errors
|
||||
///
|
||||
/// Returns an error if the underlying repository fails.
|
||||
async fn mark_as_completed(&mut self, clock: &dyn Clock, id: Ulid) -> Result<(), Self::Error>;
|
||||
}
|
||||
|
||||
repository_impl!(QueueJobRepository:
|
||||
@@ -108,6 +144,16 @@ repository_impl!(QueueJobRepository:
|
||||
payload: serde_json::Value,
|
||||
metadata: serde_json::Value,
|
||||
) -> Result<(), Self::Error>;
|
||||
|
||||
async fn reserve(
|
||||
&mut self,
|
||||
clock: &dyn Clock,
|
||||
worker: &Worker,
|
||||
queues: &[&str],
|
||||
count: usize,
|
||||
) -> Result<Vec<Job>, Self::Error>;
|
||||
|
||||
async fn mark_as_completed(&mut self, clock: &dyn Clock, id: Ulid) -> Result<(), Self::Error>;
|
||||
);
|
||||
|
||||
/// Extension trait for [`QueueJobRepository`] to help adding a job to the queue
|
||||
|
||||
@@ -5,97 +5,87 @@
|
||||
// Please see LICENSE in the repository root for full details.
|
||||
|
||||
use anyhow::Context;
|
||||
use apalis_core::{context::JobContext, executor::TokioExecutor, monitor::Monitor};
|
||||
use async_trait::async_trait;
|
||||
use chrono::Duration;
|
||||
use mas_email::{Address, Mailbox};
|
||||
use mas_i18n::locale;
|
||||
use mas_storage::{job::JobWithSpanContext, queue::VerifyEmailJob};
|
||||
use mas_storage::queue::VerifyEmailJob;
|
||||
use mas_templates::{EmailVerificationContext, TemplateContext};
|
||||
use rand::{distributions::Uniform, Rng};
|
||||
use tracing::info;
|
||||
|
||||
use crate::{storage::PostgresStorageFactory, JobContextExt, State};
|
||||
use crate::{
|
||||
new_queue::{JobContext, RunnableJob},
|
||||
State,
|
||||
};
|
||||
|
||||
#[tracing::instrument(
|
||||
name = "job.verify_email",
|
||||
fields(user_email.id = %job.user_email_id()),
|
||||
skip_all,
|
||||
err(Debug),
|
||||
)]
|
||||
async fn verify_email(
|
||||
job: JobWithSpanContext<VerifyEmailJob>,
|
||||
ctx: JobContext,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
let state = ctx.state();
|
||||
let mut repo = state.repository().await?;
|
||||
let mut rng = state.rng();
|
||||
let mailer = state.mailer();
|
||||
let clock = state.clock();
|
||||
#[async_trait]
|
||||
impl RunnableJob for VerifyEmailJob {
|
||||
#[tracing::instrument(
|
||||
name = "job.verify_email",
|
||||
fields(user_email.id = %self.user_email_id()),
|
||||
skip_all,
|
||||
err(Debug),
|
||||
)]
|
||||
async fn run(&self, state: &State, _context: JobContext) -> Result<(), anyhow::Error> {
|
||||
let mut repo = state.repository().await?;
|
||||
let mut rng = state.rng();
|
||||
let mailer = state.mailer();
|
||||
let clock = state.clock();
|
||||
|
||||
let language = job
|
||||
.language()
|
||||
.and_then(|l| l.parse().ok())
|
||||
.unwrap_or(locale!("en").into());
|
||||
let language = self
|
||||
.language()
|
||||
.and_then(|l| l.parse().ok())
|
||||
.unwrap_or(locale!("en").into());
|
||||
|
||||
// Lookup the user email
|
||||
let user_email = repo
|
||||
.user_email()
|
||||
.lookup(job.user_email_id())
|
||||
.await?
|
||||
.context("User email not found")?;
|
||||
// Lookup the user email
|
||||
let user_email = repo
|
||||
.user_email()
|
||||
.lookup(self.user_email_id())
|
||||
.await?
|
||||
.context("User email not found")?;
|
||||
|
||||
// Lookup the user associated with the email
|
||||
let user = repo
|
||||
.user()
|
||||
.lookup(user_email.user_id)
|
||||
.await?
|
||||
.context("User not found")?;
|
||||
// Lookup the user associated with the email
|
||||
let user = repo
|
||||
.user()
|
||||
.lookup(user_email.user_id)
|
||||
.await?
|
||||
.context("User not found")?;
|
||||
|
||||
// Generate a verification code
|
||||
let range = Uniform::<u32>::from(0..1_000_000);
|
||||
let code = rng.sample(range);
|
||||
let code = format!("{code:06}");
|
||||
// Generate a verification code
|
||||
let range = Uniform::<u32>::from(0..1_000_000);
|
||||
let code = rng.sample(range);
|
||||
let code = format!("{code:06}");
|
||||
|
||||
let address: Address = user_email.email.parse()?;
|
||||
let address: Address = user_email.email.parse()?;
|
||||
|
||||
// Save the verification code in the database
|
||||
let verification = repo
|
||||
.user_email()
|
||||
.add_verification_code(
|
||||
&mut rng,
|
||||
&clock,
|
||||
&user_email,
|
||||
Duration::try_hours(8).unwrap(),
|
||||
code,
|
||||
)
|
||||
.await?;
|
||||
// Save the verification code in the database
|
||||
let verification = repo
|
||||
.user_email()
|
||||
.add_verification_code(
|
||||
&mut rng,
|
||||
&clock,
|
||||
&user_email,
|
||||
Duration::try_hours(8).unwrap(),
|
||||
code,
|
||||
)
|
||||
.await?;
|
||||
|
||||
// And send the verification email
|
||||
let mailbox = Mailbox::new(Some(user.username.clone()), address);
|
||||
// And send the verification email
|
||||
let mailbox = Mailbox::new(Some(user.username.clone()), address);
|
||||
|
||||
let context =
|
||||
EmailVerificationContext::new(user.clone(), verification.clone()).with_language(language);
|
||||
let context = EmailVerificationContext::new(user.clone(), verification.clone())
|
||||
.with_language(language);
|
||||
|
||||
mailer.send_verification_email(mailbox, &context).await?;
|
||||
mailer.send_verification_email(mailbox, &context).await?;
|
||||
|
||||
info!(
|
||||
email.id = %user_email.id,
|
||||
"Verification email sent"
|
||||
);
|
||||
info!(
|
||||
email.id = %user_email.id,
|
||||
"Verification email sent"
|
||||
);
|
||||
|
||||
repo.save().await?;
|
||||
repo.save().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn register(
|
||||
suffix: &str,
|
||||
monitor: Monitor<TokioExecutor>,
|
||||
state: &State,
|
||||
storage_factory: &PostgresStorageFactory,
|
||||
) -> Monitor<TokioExecutor> {
|
||||
let verify_email_worker =
|
||||
crate::build!(VerifyEmailJob => verify_email, suffix, state, storage_factory);
|
||||
|
||||
monitor.register(verify_email_worker)
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -18,14 +18,13 @@ use rand::SeedableRng;
|
||||
use sqlx::{Pool, Postgres};
|
||||
use tokio_util::{sync::CancellationToken, task::TaskTracker};
|
||||
|
||||
// TODO: we need to have a way to schedule recurring tasks
|
||||
// mod database;
|
||||
// mod email;
|
||||
// mod matrix;
|
||||
mod email;
|
||||
mod matrix;
|
||||
mod new_queue;
|
||||
// mod recovery;
|
||||
// mod storage;
|
||||
// mod user;
|
||||
// mod utils;
|
||||
mod recovery;
|
||||
mod user;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct State {
|
||||
@@ -111,6 +110,15 @@ pub async fn init(
|
||||
);
|
||||
let mut worker = self::new_queue::QueueWorker::new(state, cancellation_token).await?;
|
||||
|
||||
worker.register_handler::<mas_storage::queue::DeactivateUserJob>();
|
||||
worker.register_handler::<mas_storage::queue::DeleteDeviceJob>();
|
||||
worker.register_handler::<mas_storage::queue::ProvisionDeviceJob>();
|
||||
worker.register_handler::<mas_storage::queue::ProvisionUserJob>();
|
||||
worker.register_handler::<mas_storage::queue::ReactivateUserJob>();
|
||||
worker.register_handler::<mas_storage::queue::SendAccountRecoveryEmailsJob>();
|
||||
worker.register_handler::<mas_storage::queue::SyncDevicesJob>();
|
||||
worker.register_handler::<mas_storage::queue::VerifyEmailJob>();
|
||||
|
||||
task_tracker.spawn(async move {
|
||||
if let Err(e) = worker.run().await {
|
||||
tracing::error!(
|
||||
|
||||
@@ -7,239 +7,239 @@
|
||||
use std::collections::HashSet;
|
||||
|
||||
use anyhow::Context;
|
||||
use apalis_core::{context::JobContext, executor::TokioExecutor, monitor::Monitor};
|
||||
use async_trait::async_trait;
|
||||
use mas_data_model::Device;
|
||||
use mas_matrix::ProvisionRequest;
|
||||
use mas_storage::{
|
||||
compat::CompatSessionFilter,
|
||||
job::{JobRepositoryExt as _, JobWithSpanContext},
|
||||
oauth2::OAuth2SessionFilter,
|
||||
queue::{DeleteDeviceJob, ProvisionDeviceJob, ProvisionUserJob, SyncDevicesJob},
|
||||
queue::{
|
||||
DeleteDeviceJob, ProvisionDeviceJob, ProvisionUserJob, QueueJobRepositoryExt as _,
|
||||
SyncDevicesJob,
|
||||
},
|
||||
user::{UserEmailRepository, UserRepository},
|
||||
Pagination, RepositoryAccess,
|
||||
};
|
||||
use tracing::info;
|
||||
|
||||
use crate::{storage::PostgresStorageFactory, JobContextExt, State};
|
||||
use crate::{
|
||||
new_queue::{JobContext, RunnableJob},
|
||||
State,
|
||||
};
|
||||
|
||||
/// Job to provision a user on the Matrix homeserver.
|
||||
/// This works by doing a PUT request to the /_synapse/admin/v2/users/{user_id}
|
||||
/// endpoint.
|
||||
#[tracing::instrument(
|
||||
name = "job.provision_user"
|
||||
fields(user.id = %job.user_id()),
|
||||
skip_all,
|
||||
err(Debug),
|
||||
)]
|
||||
async fn provision_user(
|
||||
job: JobWithSpanContext<ProvisionUserJob>,
|
||||
ctx: JobContext,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
let state = ctx.state();
|
||||
let matrix = state.matrix_connection();
|
||||
let mut repo = state.repository().await?;
|
||||
/// This works by doing a PUT request to the
|
||||
/// /_synapse/admin/v2/users/{user_id} endpoint.
|
||||
#[async_trait]
|
||||
impl RunnableJob for ProvisionUserJob {
|
||||
#[tracing::instrument(
|
||||
name = "job.provision_user"
|
||||
fields(user.id = %self.user_id()),
|
||||
skip_all,
|
||||
err(Debug),
|
||||
)]
|
||||
async fn run(&self, state: &State, _context: JobContext) -> Result<(), anyhow::Error> {
|
||||
let matrix = state.matrix_connection();
|
||||
let mut repo = state.repository().await?;
|
||||
let mut rng = state.rng();
|
||||
let clock = state.clock();
|
||||
|
||||
let user = repo
|
||||
.user()
|
||||
.lookup(job.user_id())
|
||||
.await?
|
||||
.context("User not found")?;
|
||||
let user = repo
|
||||
.user()
|
||||
.lookup(self.user_id())
|
||||
.await?
|
||||
.context("User not found")?;
|
||||
|
||||
let mxid = matrix.mxid(&user.username);
|
||||
let emails = repo
|
||||
.user_email()
|
||||
.all(&user)
|
||||
.await?
|
||||
.into_iter()
|
||||
.filter(|email| email.confirmed_at.is_some())
|
||||
.map(|email| email.email)
|
||||
.collect();
|
||||
let mut request = ProvisionRequest::new(mxid.clone(), user.sub.clone()).set_emails(emails);
|
||||
let mxid = matrix.mxid(&user.username);
|
||||
let emails = repo
|
||||
.user_email()
|
||||
.all(&user)
|
||||
.await?
|
||||
.into_iter()
|
||||
.filter(|email| email.confirmed_at.is_some())
|
||||
.map(|email| email.email)
|
||||
.collect();
|
||||
let mut request = ProvisionRequest::new(mxid.clone(), user.sub.clone()).set_emails(emails);
|
||||
|
||||
if let Some(display_name) = job.display_name_to_set() {
|
||||
request = request.set_displayname(display_name.to_owned());
|
||||
if let Some(display_name) = self.display_name_to_set() {
|
||||
request = request.set_displayname(display_name.to_owned());
|
||||
}
|
||||
|
||||
let created = matrix.provision_user(&request).await?;
|
||||
|
||||
if created {
|
||||
info!(%user.id, %mxid, "User created");
|
||||
} else {
|
||||
info!(%user.id, %mxid, "User updated");
|
||||
}
|
||||
|
||||
// Schedule a device sync job
|
||||
let sync_device_job = SyncDevicesJob::new(&user);
|
||||
repo.queue_job()
|
||||
.schedule_job(&mut rng, &clock, sync_device_job)
|
||||
.await?;
|
||||
|
||||
repo.save().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
let created = matrix.provision_user(&request).await?;
|
||||
|
||||
if created {
|
||||
info!(%user.id, %mxid, "User created");
|
||||
} else {
|
||||
info!(%user.id, %mxid, "User updated");
|
||||
}
|
||||
|
||||
// Schedule a device sync job
|
||||
let sync_device_job = SyncDevicesJob::new(&user);
|
||||
repo.job().schedule_job(sync_device_job).await?;
|
||||
|
||||
repo.save().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Job to provision a device on the Matrix homeserver.
|
||||
///
|
||||
/// This job is deprecated and therefore just schedules a [`SyncDevicesJob`]
|
||||
#[tracing::instrument(
|
||||
name = "job.provision_device"
|
||||
fields(
|
||||
user.id = %job.user_id(),
|
||||
device.id = %job.device_id(),
|
||||
),
|
||||
skip_all,
|
||||
err(Debug),
|
||||
)]
|
||||
async fn provision_device(
|
||||
job: JobWithSpanContext<ProvisionDeviceJob>,
|
||||
ctx: JobContext,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
let state = ctx.state();
|
||||
let mut repo = state.repository().await?;
|
||||
#[async_trait]
|
||||
impl RunnableJob for ProvisionDeviceJob {
|
||||
#[tracing::instrument(
|
||||
name = "job.provision_device"
|
||||
fields(
|
||||
user.id = %self.user_id(),
|
||||
device.id = %self.device_id(),
|
||||
),
|
||||
skip_all,
|
||||
err(Debug),
|
||||
)]
|
||||
async fn run(&self, state: &State, _context: JobContext) -> Result<(), anyhow::Error> {
|
||||
let mut repo = state.repository().await?;
|
||||
let mut rng = state.rng();
|
||||
let clock = state.clock();
|
||||
|
||||
let user = repo
|
||||
.user()
|
||||
.lookup(job.user_id())
|
||||
.await?
|
||||
.context("User not found")?;
|
||||
let user = repo
|
||||
.user()
|
||||
.lookup(self.user_id())
|
||||
.await?
|
||||
.context("User not found")?;
|
||||
|
||||
// Schedule a device sync job
|
||||
repo.job().schedule_job(SyncDevicesJob::new(&user)).await?;
|
||||
// Schedule a device sync job
|
||||
repo.queue_job()
|
||||
.schedule_job(&mut rng, &clock, SyncDevicesJob::new(&user))
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Job to delete a device from a user's account.
|
||||
///
|
||||
/// This job is deprecated and therefore just schedules a [`SyncDevicesJob`]
|
||||
#[tracing::instrument(
|
||||
name = "job.delete_device"
|
||||
fields(
|
||||
user.id = %job.user_id(),
|
||||
device.id = %job.device_id(),
|
||||
),
|
||||
skip_all,
|
||||
err(Debug),
|
||||
)]
|
||||
async fn delete_device(
|
||||
job: JobWithSpanContext<DeleteDeviceJob>,
|
||||
ctx: JobContext,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
let state = ctx.state();
|
||||
let mut repo = state.repository().await?;
|
||||
#[async_trait]
|
||||
impl RunnableJob for DeleteDeviceJob {
|
||||
#[tracing::instrument(
|
||||
name = "job.delete_device"
|
||||
fields(
|
||||
user.id = %self.user_id(),
|
||||
device.id = %self.device_id(),
|
||||
),
|
||||
skip_all,
|
||||
err(Debug),
|
||||
)]
|
||||
#[tracing::instrument(
|
||||
name = "job.delete_device"
|
||||
fields(
|
||||
user.id = %self.user_id(),
|
||||
device.id = %self.device_id(),
|
||||
),
|
||||
skip_all,
|
||||
err(Debug),
|
||||
)]
|
||||
async fn run(&self, state: &State, _context: JobContext) -> Result<(), anyhow::Error> {
|
||||
let mut rng = state.rng();
|
||||
let clock = state.clock();
|
||||
let mut repo = state.repository().await?;
|
||||
|
||||
let user = repo
|
||||
.user()
|
||||
.lookup(job.user_id())
|
||||
.await?
|
||||
.context("User not found")?;
|
||||
let user = repo
|
||||
.user()
|
||||
.lookup(self.user_id())
|
||||
.await?
|
||||
.context("User not found")?;
|
||||
|
||||
// Schedule a device sync job
|
||||
repo.job().schedule_job(SyncDevicesJob::new(&user)).await?;
|
||||
// Schedule a device sync job
|
||||
repo.queue_job()
|
||||
.schedule_job(&mut rng, &clock, SyncDevicesJob::new(&user))
|
||||
.await?;
|
||||
|
||||
Ok(())
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Job to sync the list of devices of a user with the homeserver.
|
||||
#[tracing::instrument(
|
||||
name = "job.sync_devices",
|
||||
fields(user.id = %job.user_id()),
|
||||
skip_all,
|
||||
err(Debug),
|
||||
)]
|
||||
async fn sync_devices(
|
||||
job: JobWithSpanContext<SyncDevicesJob>,
|
||||
ctx: JobContext,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
let state = ctx.state();
|
||||
let matrix = state.matrix_connection();
|
||||
let mut repo = state.repository().await?;
|
||||
#[async_trait]
|
||||
impl RunnableJob for SyncDevicesJob {
|
||||
#[tracing::instrument(
|
||||
name = "job.sync_devices",
|
||||
fields(user.id = %self.user_id()),
|
||||
skip_all,
|
||||
err(Debug),
|
||||
)]
|
||||
async fn run(&self, state: &State, _context: JobContext) -> Result<(), anyhow::Error> {
|
||||
let matrix = state.matrix_connection();
|
||||
let mut repo = state.repository().await?;
|
||||
|
||||
let user = repo
|
||||
.user()
|
||||
.lookup(job.user_id())
|
||||
.await?
|
||||
.context("User not found")?;
|
||||
let user = repo
|
||||
.user()
|
||||
.lookup(self.user_id())
|
||||
.await?
|
||||
.context("User not found")?;
|
||||
|
||||
// Lock the user sync to make sure we don't get into a race condition
|
||||
repo.user().acquire_lock_for_sync(&user).await?;
|
||||
// Lock the user sync to make sure we don't get into a race condition
|
||||
repo.user().acquire_lock_for_sync(&user).await?;
|
||||
|
||||
let mut devices = HashSet::new();
|
||||
let mut devices = HashSet::new();
|
||||
|
||||
// Cycle through all the compat sessions of the user, and grab the devices
|
||||
let mut cursor = Pagination::first(100);
|
||||
loop {
|
||||
let page = repo
|
||||
.compat_session()
|
||||
.list(
|
||||
CompatSessionFilter::new().for_user(&user).active_only(),
|
||||
cursor,
|
||||
)
|
||||
.await?;
|
||||
// Cycle through all the compat sessions of the user, and grab the devices
|
||||
let mut cursor = Pagination::first(100);
|
||||
loop {
|
||||
let page = repo
|
||||
.compat_session()
|
||||
.list(
|
||||
CompatSessionFilter::new().for_user(&user).active_only(),
|
||||
cursor,
|
||||
)
|
||||
.await?;
|
||||
|
||||
for (compat_session, _) in page.edges {
|
||||
devices.insert(compat_session.device.as_str().to_owned());
|
||||
cursor = cursor.after(compat_session.id);
|
||||
}
|
||||
|
||||
if !page.has_next_page {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Cycle though all the oauth2 sessions of the user, and grab the devices
|
||||
let mut cursor = Pagination::first(100);
|
||||
loop {
|
||||
let page = repo
|
||||
.oauth2_session()
|
||||
.list(
|
||||
OAuth2SessionFilter::new().for_user(&user).active_only(),
|
||||
cursor,
|
||||
)
|
||||
.await?;
|
||||
|
||||
for oauth2_session in page.edges {
|
||||
for scope in &*oauth2_session.scope {
|
||||
if let Some(device) = Device::from_scope_token(scope) {
|
||||
devices.insert(device.as_str().to_owned());
|
||||
}
|
||||
for (compat_session, _) in page.edges {
|
||||
devices.insert(compat_session.device.as_str().to_owned());
|
||||
cursor = cursor.after(compat_session.id);
|
||||
}
|
||||
|
||||
cursor = cursor.after(oauth2_session.id);
|
||||
if !page.has_next_page {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if !page.has_next_page {
|
||||
break;
|
||||
// Cycle though all the oauth2 sessions of the user, and grab the devices
|
||||
let mut cursor = Pagination::first(100);
|
||||
loop {
|
||||
let page = repo
|
||||
.oauth2_session()
|
||||
.list(
|
||||
OAuth2SessionFilter::new().for_user(&user).active_only(),
|
||||
cursor,
|
||||
)
|
||||
.await?;
|
||||
|
||||
for oauth2_session in page.edges {
|
||||
for scope in &*oauth2_session.scope {
|
||||
if let Some(device) = Device::from_scope_token(scope) {
|
||||
devices.insert(device.as_str().to_owned());
|
||||
}
|
||||
}
|
||||
|
||||
cursor = cursor.after(oauth2_session.id);
|
||||
}
|
||||
|
||||
if !page.has_next_page {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
let mxid = matrix.mxid(&user.username);
|
||||
matrix.sync_devices(&mxid, devices).await?;
|
||||
|
||||
// We kept the connection until now, so that we still hold the lock on the user
|
||||
// throughout the sync
|
||||
repo.save().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
let mxid = matrix.mxid(&user.username);
|
||||
matrix.sync_devices(&mxid, devices).await?;
|
||||
|
||||
// We kept the connection until now, so that we still hold the lock on the user
|
||||
// throughout the sync
|
||||
repo.save().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn register(
|
||||
suffix: &str,
|
||||
monitor: Monitor<TokioExecutor>,
|
||||
state: &State,
|
||||
storage_factory: &PostgresStorageFactory,
|
||||
) -> Monitor<TokioExecutor> {
|
||||
let provision_user_worker =
|
||||
crate::build!(ProvisionUserJob => provision_user, suffix, state, storage_factory);
|
||||
let provision_device_worker =
|
||||
crate::build!(ProvisionDeviceJob => provision_device, suffix, state, storage_factory);
|
||||
let delete_device_worker =
|
||||
crate::build!(DeleteDeviceJob => delete_device, suffix, state, storage_factory);
|
||||
let sync_devices_worker =
|
||||
crate::build!(SyncDevicesJob => sync_devices, suffix, state, storage_factory);
|
||||
|
||||
monitor
|
||||
.register(provision_user_worker)
|
||||
.register(provision_device_worker)
|
||||
.register(delete_device_worker)
|
||||
.register(sync_devices_worker)
|
||||
}
|
||||
|
||||
@@ -3,12 +3,12 @@
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
// Please see LICENSE in the repository root for full details.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
|
||||
use async_trait::async_trait;
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
use mas_storage::{
|
||||
queue::{InsertableJob, Job, Worker},
|
||||
queue::{InsertableJob, Job, JobMetadata, Worker},
|
||||
Clock, RepositoryAccess, RepositoryError,
|
||||
};
|
||||
use mas_storage_pg::{DatabaseError, PgRepository};
|
||||
@@ -20,12 +20,42 @@ use sqlx::{
|
||||
Acquire, Either,
|
||||
};
|
||||
use thiserror::Error;
|
||||
use tokio::task::JoinSet;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tracing::{Instrument as _, Span};
|
||||
use tracing_opentelemetry::OpenTelemetrySpanExt as _;
|
||||
use ulid::Ulid;
|
||||
|
||||
use crate::State;
|
||||
|
||||
type JobPayload = serde_json::Value;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub struct JobContext {
|
||||
pub id: Ulid,
|
||||
pub metadata: JobMetadata,
|
||||
pub queue_name: String,
|
||||
pub cancellation_token: CancellationToken,
|
||||
}
|
||||
|
||||
impl JobContext {
|
||||
pub fn span(&self) -> Span {
|
||||
let span = tracing::info_span!(
|
||||
parent: Span::none(),
|
||||
"job.run",
|
||||
job.id = %self.id,
|
||||
job.queue_name = self.queue_name,
|
||||
job.attempt = self.attempt,
|
||||
);
|
||||
|
||||
span.add_link(self.metadata.span_context());
|
||||
|
||||
span
|
||||
}
|
||||
}
|
||||
|
||||
pub trait FromJob {
|
||||
fn from_job(job: &Job) -> Result<Self, anyhow::Error>
|
||||
fn from_job(payload: JobPayload) -> Result<Self, anyhow::Error>
|
||||
where
|
||||
Self: Sized;
|
||||
}
|
||||
@@ -34,14 +64,14 @@ impl<T> FromJob for T
|
||||
where
|
||||
T: DeserializeOwned,
|
||||
{
|
||||
fn from_job(job: &Job) -> Result<Self, anyhow::Error> {
|
||||
serde_json::from_value(job.payload.clone()).map_err(Into::into)
|
||||
fn from_job(payload: JobPayload) -> Result<Self, anyhow::Error> {
|
||||
serde_json::from_value(payload).map_err(Into::into)
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait]
|
||||
pub trait RunnableJob: FromJob + Send + 'static {
|
||||
async fn run(&self, state: &State) -> Result<(), anyhow::Error>;
|
||||
async fn run(&self, state: &State, context: JobContext) -> Result<(), anyhow::Error>;
|
||||
}
|
||||
|
||||
fn box_runnable_job<T: RunnableJob + 'static>(job: T) -> Box<dyn RunnableJob> {
|
||||
@@ -79,7 +109,13 @@ pub enum QueueRunnerError {
|
||||
const MIN_SLEEP_DURATION: std::time::Duration = std::time::Duration::from_millis(900);
|
||||
const MAX_SLEEP_DURATION: std::time::Duration = std::time::Duration::from_millis(1100);
|
||||
|
||||
type JobFactory = Box<dyn FnMut(&Job) -> Box<dyn RunnableJob> + Send>;
|
||||
// How many jobs can we run concurrently
|
||||
const MAX_CONCURRENT_JOBS: usize = 10;
|
||||
|
||||
// How many jobs can we fetch at once
|
||||
const MAX_JOBS_TO_FETCH: usize = 5;
|
||||
|
||||
type JobFactory = Arc<dyn Fn(JobPayload) -> Box<dyn RunnableJob> + Send + Sync>;
|
||||
|
||||
pub struct QueueWorker {
|
||||
rng: ChaChaRng,
|
||||
@@ -89,7 +125,14 @@ pub struct QueueWorker {
|
||||
am_i_leader: bool,
|
||||
last_heartbeat: DateTime<Utc>,
|
||||
cancellation_token: CancellationToken,
|
||||
state: State,
|
||||
running_jobs: JoinSet<Result<(), anyhow::Error>>,
|
||||
job_contexts: HashMap<tokio::task::Id, JobContext>,
|
||||
factories: HashMap<&'static str, JobFactory>,
|
||||
|
||||
#[allow(clippy::type_complexity)]
|
||||
last_join_result:
|
||||
Option<Result<(tokio::task::Id, Result<(), anyhow::Error>), tokio::task::JoinError>>,
|
||||
}
|
||||
|
||||
impl QueueWorker {
|
||||
@@ -115,6 +158,12 @@ impl QueueWorker {
|
||||
.await
|
||||
.map_err(QueueRunnerError::SetupListener)?;
|
||||
|
||||
// We get notifications when a job is available on this channel
|
||||
listener
|
||||
.listen("queue_available")
|
||||
.await
|
||||
.map_err(QueueRunnerError::SetupListener)?;
|
||||
|
||||
let txn = listener
|
||||
.begin()
|
||||
.await
|
||||
@@ -139,14 +188,22 @@ impl QueueWorker {
|
||||
am_i_leader: false,
|
||||
last_heartbeat: now,
|
||||
cancellation_token,
|
||||
state,
|
||||
job_contexts: HashMap::new(),
|
||||
running_jobs: JoinSet::new(),
|
||||
factories: HashMap::new(),
|
||||
last_join_result: None,
|
||||
})
|
||||
}
|
||||
|
||||
pub fn register_handler<T: RunnableJob + InsertableJob>(&mut self) -> &mut Self {
|
||||
// TODO: error handling
|
||||
let factory = |job: &Job| box_runnable_job(T::from_job(job).unwrap());
|
||||
self.factories.insert(T::QUEUE_NAME, Box::new(factory));
|
||||
// There is a potential panic here, which is fine as it's going to be caught
|
||||
// within the job task
|
||||
let factory = |payload: JobPayload| {
|
||||
box_runnable_job(T::from_job(payload).expect("Failed to deserialize job"))
|
||||
};
|
||||
|
||||
self.factories.insert(T::QUEUE_NAME, Arc::new(factory));
|
||||
self
|
||||
}
|
||||
|
||||
@@ -164,6 +221,7 @@ impl QueueWorker {
|
||||
async fn run_loop(&mut self) -> Result<(), QueueRunnerError> {
|
||||
self.wait_until_wakeup().await?;
|
||||
|
||||
// TODO: join all the jobs handles when shutting down
|
||||
if self.cancellation_token.is_cancelled() {
|
||||
return Ok(());
|
||||
}
|
||||
@@ -214,6 +272,8 @@ impl QueueWorker {
|
||||
.sample(Uniform::new(MIN_SLEEP_DURATION, MAX_SLEEP_DURATION));
|
||||
let wakeup_sleep = tokio::time::sleep(sleep_duration);
|
||||
|
||||
// TODO: add metrics to track the wake up reasons
|
||||
|
||||
tokio::select! {
|
||||
() = self.cancellation_token.cancelled() => {
|
||||
tracing::debug!("Woke up from cancellation");
|
||||
@@ -223,6 +283,11 @@ impl QueueWorker {
|
||||
tracing::debug!("Woke up from sleep");
|
||||
},
|
||||
|
||||
Some(result) = self.running_jobs.join_next_with_id() => {
|
||||
tracing::debug!("Joined job task");
|
||||
self.last_join_result = Some(result);
|
||||
},
|
||||
|
||||
notification = self.listener.recv() => {
|
||||
match notification {
|
||||
Ok(notification) => {
|
||||
@@ -281,6 +346,127 @@ impl QueueWorker {
|
||||
.try_get_leader_lease(&self.clock, &self.registration)
|
||||
.await?;
|
||||
|
||||
// Find any job task which finished
|
||||
// If we got woken up by a join on the joinset, it will be stored in the
|
||||
// last_join_result so that we don't loose it
|
||||
|
||||
if self.last_join_result.is_none() {
|
||||
self.last_join_result = self.running_jobs.try_join_next_with_id();
|
||||
}
|
||||
|
||||
while let Some(result) = self.last_join_result.take() {
|
||||
// TODO: add metrics to track the job status and the time it took
|
||||
let context = match result {
|
||||
Ok((id, Ok(()))) => {
|
||||
// The job succeeded
|
||||
let context = self
|
||||
.job_contexts
|
||||
.remove(&id)
|
||||
.expect("Job context not found");
|
||||
|
||||
tracing::info!(
|
||||
job.id = %context.id,
|
||||
job.queue_name = %context.queue_name,
|
||||
"Job completed"
|
||||
);
|
||||
|
||||
context
|
||||
}
|
||||
Ok((id, Err(e))) => {
|
||||
// The job failed
|
||||
let context = self
|
||||
.job_contexts
|
||||
.remove(&id)
|
||||
.expect("Job context not found");
|
||||
|
||||
tracing::error!(
|
||||
error = ?e,
|
||||
job.id = %context.id,
|
||||
job.queue_name = %context.queue_name,
|
||||
"Job failed"
|
||||
);
|
||||
|
||||
// TODO: reschedule the job
|
||||
|
||||
context
|
||||
}
|
||||
Err(e) => {
|
||||
// The job crashed (or was cancelled)
|
||||
let id = e.id();
|
||||
let context = self
|
||||
.job_contexts
|
||||
.remove(&id)
|
||||
.expect("Job context not found");
|
||||
|
||||
tracing::error!(
|
||||
error = &e as &dyn std::error::Error,
|
||||
job.id = %context.id,
|
||||
job.queue_name = %context.queue_name,
|
||||
"Job crashed"
|
||||
);
|
||||
|
||||
// TODO: reschedule the job
|
||||
|
||||
context
|
||||
}
|
||||
};
|
||||
|
||||
repo.queue_job()
|
||||
.mark_as_completed(&self.clock, context.id)
|
||||
.await?;
|
||||
|
||||
self.last_join_result = self.running_jobs.try_join_next_with_id();
|
||||
}
|
||||
|
||||
// Compute how many jobs we should fetch at most
|
||||
let max_jobs_to_fetch = MAX_CONCURRENT_JOBS
|
||||
.saturating_sub(self.running_jobs.len())
|
||||
.max(MAX_JOBS_TO_FETCH);
|
||||
|
||||
if max_jobs_to_fetch == 0 {
|
||||
tracing::warn!("Internal job queue is full, not fetching any new jobs");
|
||||
} else {
|
||||
// Grab a few jobs in the queue
|
||||
let queues = self.factories.keys().copied().collect::<Vec<_>>();
|
||||
let jobs = repo
|
||||
.queue_job()
|
||||
.reserve(&self.clock, &self.registration, &queues, max_jobs_to_fetch)
|
||||
.await?;
|
||||
|
||||
for Job {
|
||||
id,
|
||||
queue_name,
|
||||
payload,
|
||||
metadata,
|
||||
} in jobs
|
||||
{
|
||||
let cancellation_token = self.cancellation_token.child_token();
|
||||
let factory = self.factories.get(queue_name.as_str()).cloned();
|
||||
let context = JobContext {
|
||||
id,
|
||||
metadata,
|
||||
queue_name,
|
||||
cancellation_token,
|
||||
};
|
||||
|
||||
let task = {
|
||||
let context = context.clone();
|
||||
let span = context.span();
|
||||
let state = self.state.clone();
|
||||
async move {
|
||||
// We should never crash, but in case we do, we do that in the task and
|
||||
// don't crash the worker
|
||||
let job = factory.expect("unknown job factory")(payload);
|
||||
job.run(&state, context).await
|
||||
}
|
||||
.instrument(span)
|
||||
};
|
||||
|
||||
let handle = self.running_jobs.spawn(task);
|
||||
self.job_contexts.insert(handle.id(), context);
|
||||
}
|
||||
}
|
||||
|
||||
// After this point, we are locking the leader table, so it's important that we
|
||||
// commit as soon as possible to not block the other workers for too long
|
||||
repo.into_inner()
|
||||
@@ -353,6 +539,8 @@ impl QueueWorker {
|
||||
.shutdown_dead_workers(&self.clock, Duration::minutes(2))
|
||||
.await?;
|
||||
|
||||
// TODO: mark tasks those workers had as lost
|
||||
|
||||
// Release the leader lock
|
||||
let txn = repo
|
||||
.into_inner()
|
||||
|
||||
@@ -5,11 +5,10 @@
|
||||
// Please see LICENSE in the repository root for full details.
|
||||
|
||||
use anyhow::Context;
|
||||
use apalis_core::{context::JobContext, executor::TokioExecutor, monitor::Monitor};
|
||||
use async_trait::async_trait;
|
||||
use mas_email::{Address, Mailbox};
|
||||
use mas_i18n::DataLocale;
|
||||
use mas_storage::{
|
||||
job::JobWithSpanContext,
|
||||
queue::SendAccountRecoveryEmailsJob,
|
||||
user::{UserEmailFilter, UserRecoveryRepository},
|
||||
Pagination, RepositoryAccess,
|
||||
@@ -18,117 +17,108 @@ use mas_templates::{EmailRecoveryContext, TemplateContext};
|
||||
use rand::distributions::{Alphanumeric, DistString};
|
||||
use tracing::{error, info};
|
||||
|
||||
use crate::{storage::PostgresStorageFactory, JobContextExt, State};
|
||||
use crate::{
|
||||
new_queue::{JobContext, RunnableJob},
|
||||
State,
|
||||
};
|
||||
|
||||
/// Job to send account recovery emails for a given recovery session.
|
||||
#[tracing::instrument(
|
||||
name = "job.send_account_recovery_email",
|
||||
fields(
|
||||
user_recovery_session.id = %job.user_recovery_session_id(),
|
||||
user_recovery_session.email,
|
||||
),
|
||||
skip_all,
|
||||
err(Debug),
|
||||
)]
|
||||
async fn send_account_recovery_email_job(
|
||||
job: JobWithSpanContext<SendAccountRecoveryEmailsJob>,
|
||||
ctx: JobContext,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
let state = ctx.state();
|
||||
let clock = state.clock();
|
||||
let mailer = state.mailer();
|
||||
let url_builder = state.url_builder();
|
||||
let mut rng = state.rng();
|
||||
let mut repo = state.repository().await?;
|
||||
#[async_trait]
|
||||
impl RunnableJob for SendAccountRecoveryEmailsJob {
|
||||
#[tracing::instrument(
|
||||
name = "job.send_account_recovery_email",
|
||||
fields(
|
||||
user_recovery_session.id = %self.user_recovery_session_id(),
|
||||
user_recovery_session.email,
|
||||
),
|
||||
skip_all,
|
||||
err(Debug),
|
||||
)]
|
||||
async fn run(&self, state: &State, _context: JobContext) -> Result<(), anyhow::Error> {
|
||||
let clock = state.clock();
|
||||
let mailer = state.mailer();
|
||||
let url_builder = state.url_builder();
|
||||
let mut rng = state.rng();
|
||||
let mut repo = state.repository().await?;
|
||||
|
||||
let session = repo
|
||||
.user_recovery()
|
||||
.lookup_session(job.user_recovery_session_id())
|
||||
.await?
|
||||
.context("User recovery session not found")?;
|
||||
let session = repo
|
||||
.user_recovery()
|
||||
.lookup_session(self.user_recovery_session_id())
|
||||
.await?
|
||||
.context("User recovery session not found")?;
|
||||
|
||||
tracing::Span::current().record("user_recovery_session.email", &session.email);
|
||||
tracing::Span::current().record("user_recovery_session.email", &session.email);
|
||||
|
||||
if session.consumed_at.is_some() {
|
||||
info!("Recovery session already consumed, not sending email");
|
||||
return Ok(());
|
||||
}
|
||||
if session.consumed_at.is_some() {
|
||||
info!("Recovery session already consumed, not sending email");
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let mut cursor = Pagination::first(50);
|
||||
let mut cursor = Pagination::first(50);
|
||||
|
||||
let lang: DataLocale = session
|
||||
.locale
|
||||
.parse()
|
||||
.context("Invalid locale in database on recovery session")?;
|
||||
let lang: DataLocale = session
|
||||
.locale
|
||||
.parse()
|
||||
.context("Invalid locale in database on recovery session")?;
|
||||
|
||||
loop {
|
||||
let page = repo
|
||||
.user_email()
|
||||
.list(
|
||||
UserEmailFilter::new()
|
||||
.for_email(&session.email)
|
||||
.verified_only(),
|
||||
cursor,
|
||||
)
|
||||
.await?;
|
||||
|
||||
for email in page.edges {
|
||||
let ticket = Alphanumeric.sample_string(&mut rng, 32);
|
||||
|
||||
let ticket = repo
|
||||
.user_recovery()
|
||||
.add_ticket(&mut rng, &clock, &session, &email, ticket)
|
||||
loop {
|
||||
let page = repo
|
||||
.user_email()
|
||||
.list(
|
||||
UserEmailFilter::new()
|
||||
.for_email(&session.email)
|
||||
.verified_only(),
|
||||
cursor,
|
||||
)
|
||||
.await?;
|
||||
|
||||
let user_email = repo
|
||||
.user_email()
|
||||
.lookup(email.id)
|
||||
.await?
|
||||
.context("User email not found")?;
|
||||
for email in page.edges {
|
||||
let ticket = Alphanumeric.sample_string(&mut rng, 32);
|
||||
|
||||
let user = repo
|
||||
.user()
|
||||
.lookup(user_email.user_id)
|
||||
.await?
|
||||
.context("User not found")?;
|
||||
let ticket = repo
|
||||
.user_recovery()
|
||||
.add_ticket(&mut rng, &clock, &session, &email, ticket)
|
||||
.await?;
|
||||
|
||||
let url = url_builder.account_recovery_link(ticket.ticket);
|
||||
let user_email = repo
|
||||
.user_email()
|
||||
.lookup(email.id)
|
||||
.await?
|
||||
.context("User email not found")?;
|
||||
|
||||
let address: Address = user_email.email.parse()?;
|
||||
let mailbox = Mailbox::new(Some(user.username.clone()), address);
|
||||
let user = repo
|
||||
.user()
|
||||
.lookup(user_email.user_id)
|
||||
.await?
|
||||
.context("User not found")?;
|
||||
|
||||
info!("Sending recovery email to {}", mailbox);
|
||||
let context =
|
||||
EmailRecoveryContext::new(user, session.clone(), url).with_language(lang.clone());
|
||||
let url = url_builder.account_recovery_link(ticket.ticket);
|
||||
|
||||
// XXX: we only log if the email fails to send, to avoid stopping the loop
|
||||
if let Err(e) = mailer.send_recovery_email(mailbox, &context).await {
|
||||
error!(
|
||||
error = &e as &dyn std::error::Error,
|
||||
"Failed to send recovery email"
|
||||
);
|
||||
let address: Address = user_email.email.parse()?;
|
||||
let mailbox = Mailbox::new(Some(user.username.clone()), address);
|
||||
|
||||
info!("Sending recovery email to {}", mailbox);
|
||||
let context = EmailRecoveryContext::new(user, session.clone(), url)
|
||||
.with_language(lang.clone());
|
||||
|
||||
// XXX: we only log if the email fails to send, to avoid stopping the loop
|
||||
if let Err(e) = mailer.send_recovery_email(mailbox, &context).await {
|
||||
error!(
|
||||
error = &e as &dyn std::error::Error,
|
||||
"Failed to send recovery email"
|
||||
);
|
||||
}
|
||||
|
||||
cursor = cursor.after(email.id);
|
||||
}
|
||||
|
||||
cursor = cursor.after(email.id);
|
||||
if !page.has_next_page {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
if !page.has_next_page {
|
||||
break;
|
||||
}
|
||||
repo.save().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
repo.save().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn register(
|
||||
suffix: &str,
|
||||
monitor: Monitor<TokioExecutor>,
|
||||
state: &State,
|
||||
storage_factory: &PostgresStorageFactory,
|
||||
) -> Monitor<TokioExecutor> {
|
||||
let send_user_recovery_email_worker = crate::build!(SendAccountRecoveryEmailsJob => send_account_recovery_email_job, suffix, state, storage_factory);
|
||||
|
||||
monitor.register(send_user_recovery_email_worker)
|
||||
}
|
||||
|
||||
@@ -1,70 +0,0 @@
|
||||
// Copyright 2024 New Vector Ltd.
|
||||
// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
// Please see LICENSE in the repository root for full details.
|
||||
|
||||
use std::str::FromStr;
|
||||
|
||||
use apalis_core::{context::JobContext, job::JobId, request::JobRequest, worker::WorkerId};
|
||||
use chrono::{DateTime, Utc};
|
||||
use serde_json::Value;
|
||||
use sqlx::Row;
|
||||
|
||||
/// Wrapper for [`JobRequest`]
|
||||
pub(crate) struct SqlJobRequest<T>(JobRequest<T>);
|
||||
|
||||
impl<T> From<SqlJobRequest<T>> for JobRequest<T> {
|
||||
fn from(val: SqlJobRequest<T>) -> Self {
|
||||
val.0
|
||||
}
|
||||
}
|
||||
|
||||
impl<'r, T: serde::de::DeserializeOwned> sqlx::FromRow<'r, sqlx::postgres::PgRow>
|
||||
for SqlJobRequest<T>
|
||||
{
|
||||
fn from_row(row: &'r sqlx::postgres::PgRow) -> Result<Self, sqlx::Error> {
|
||||
let job: Value = row.try_get("job")?;
|
||||
let id: JobId =
|
||||
JobId::from_str(row.try_get("id")?).map_err(|e| sqlx::Error::ColumnDecode {
|
||||
index: "id".to_owned(),
|
||||
source: Box::new(e),
|
||||
})?;
|
||||
let mut context = JobContext::new(id);
|
||||
|
||||
let run_at = row.try_get("run_at")?;
|
||||
context.set_run_at(run_at);
|
||||
|
||||
let attempts = row.try_get("attempts").unwrap_or(0);
|
||||
context.set_attempts(attempts);
|
||||
|
||||
let max_attempts = row.try_get("max_attempts").unwrap_or(25);
|
||||
context.set_max_attempts(max_attempts);
|
||||
|
||||
let done_at: Option<DateTime<Utc>> = row.try_get("done_at").unwrap_or_default();
|
||||
context.set_done_at(done_at);
|
||||
|
||||
let lock_at: Option<DateTime<Utc>> = row.try_get("lock_at").unwrap_or_default();
|
||||
context.set_lock_at(lock_at);
|
||||
|
||||
let last_error = row.try_get("last_error").unwrap_or_default();
|
||||
context.set_last_error(last_error);
|
||||
|
||||
let status: String = row.try_get("status")?;
|
||||
context.set_status(status.parse().map_err(|e| sqlx::Error::ColumnDecode {
|
||||
index: "job".to_owned(),
|
||||
source: Box::new(e),
|
||||
})?);
|
||||
|
||||
let lock_by: Option<String> = row.try_get("lock_by").unwrap_or_default();
|
||||
context.set_lock_by(lock_by.map(WorkerId::new));
|
||||
|
||||
Ok(SqlJobRequest(JobRequest::new_with_context(
|
||||
serde_json::from_value(job).map_err(|e| sqlx::Error::ColumnDecode {
|
||||
index: "job".to_owned(),
|
||||
source: Box::new(e),
|
||||
})?,
|
||||
context,
|
||||
)))
|
||||
}
|
||||
}
|
||||
@@ -1,14 +0,0 @@
|
||||
// Copyright 2024 New Vector Ltd.
|
||||
// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
// Please see LICENSE in the repository root for full details.
|
||||
|
||||
//! Reimplementation of the [`apalis_sql::storage::PostgresStorage`] using a
|
||||
//! shared connection for the [`PgListener`]
|
||||
|
||||
mod from_row;
|
||||
mod postgres;
|
||||
|
||||
use self::from_row::SqlJobRequest;
|
||||
pub(crate) use self::postgres::StorageFactory as PostgresStorageFactory;
|
||||
@@ -1,391 +0,0 @@
|
||||
// Copyright 2024 New Vector Ltd.
|
||||
// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
// Please see LICENSE in the repository root for full details.
|
||||
|
||||
use std::{convert::TryInto, marker::PhantomData, ops::Add, sync::Arc, time::Duration};
|
||||
|
||||
use apalis_core::{
|
||||
error::JobStreamError,
|
||||
job::{Job, JobId, JobStreamResult},
|
||||
request::JobRequest,
|
||||
storage::{StorageError, StorageResult, StorageWorkerPulse},
|
||||
utils::Timer,
|
||||
worker::WorkerId,
|
||||
};
|
||||
use async_stream::try_stream;
|
||||
use chrono::{DateTime, Utc};
|
||||
use event_listener::Event;
|
||||
use futures_lite::{Stream, StreamExt};
|
||||
use serde::{de::DeserializeOwned, Serialize};
|
||||
use sqlx::{postgres::PgListener, PgPool, Pool, Postgres, Row};
|
||||
use tokio::task::JoinHandle;
|
||||
|
||||
use super::SqlJobRequest;
|
||||
|
||||
pub struct StorageFactory {
|
||||
pool: PgPool,
|
||||
event: Arc<Event>,
|
||||
}
|
||||
|
||||
impl StorageFactory {
|
||||
pub fn new(pool: Pool<Postgres>) -> Self {
|
||||
StorageFactory {
|
||||
pool,
|
||||
event: Arc::new(Event::new()),
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn listen(self) -> Result<JoinHandle<()>, sqlx::Error> {
|
||||
let mut listener = PgListener::connect_with(&self.pool).await?;
|
||||
listener.listen("apalis::job").await?;
|
||||
|
||||
let handle = tokio::spawn(async move {
|
||||
loop {
|
||||
let notification = listener.recv().await.expect("Failed to poll notification");
|
||||
self.event.notify(usize::MAX);
|
||||
tracing::debug!(?notification, "Broadcast notification");
|
||||
}
|
||||
});
|
||||
|
||||
Ok(handle)
|
||||
}
|
||||
|
||||
pub fn build<T>(&self) -> Storage<T> {
|
||||
Storage {
|
||||
pool: self.pool.clone(),
|
||||
event: self.event.clone(),
|
||||
job_type: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Represents a [`apalis_core::storage::Storage`] that persists to Postgres
|
||||
#[derive(Debug)]
|
||||
pub struct Storage<T> {
|
||||
pool: PgPool,
|
||||
event: Arc<Event>,
|
||||
job_type: PhantomData<T>,
|
||||
}
|
||||
|
||||
impl<T> Clone for Storage<T> {
|
||||
fn clone(&self) -> Self {
|
||||
Storage {
|
||||
pool: self.pool.clone(),
|
||||
event: self.event.clone(),
|
||||
job_type: PhantomData,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: DeserializeOwned + Send + Unpin + Job> Storage<T> {
|
||||
fn stream_jobs(
|
||||
&self,
|
||||
worker_id: &WorkerId,
|
||||
interval: Duration,
|
||||
buffer_size: usize,
|
||||
) -> impl Stream<Item = Result<JobRequest<T>, JobStreamError>> {
|
||||
let pool = self.pool.clone();
|
||||
let sleeper = apalis_core::utils::timer::TokioTimer;
|
||||
let worker_id = worker_id.clone();
|
||||
let event = self.event.clone();
|
||||
try_stream! {
|
||||
loop {
|
||||
// Wait for a notification or a timeout
|
||||
let listener = event.listen();
|
||||
let interval = sleeper.sleep(interval);
|
||||
futures_lite::future::race(interval, listener).await;
|
||||
|
||||
let tx = pool.clone();
|
||||
let job_type = T::NAME;
|
||||
let fetch_query = "SELECT * FROM apalis.get_jobs($1, $2, $3);";
|
||||
let jobs: Vec<SqlJobRequest<T>> = sqlx::query_as(fetch_query)
|
||||
.bind(worker_id.name())
|
||||
.bind(job_type)
|
||||
// https://docs.rs/sqlx/latest/sqlx/postgres/types/index.html
|
||||
.bind(i32::try_from(buffer_size).map_err(|e| JobStreamError::BrokenPipe(Box::from(e)))?)
|
||||
.fetch_all(&tx)
|
||||
.await.map_err(|e| JobStreamError::BrokenPipe(Box::from(e)))?;
|
||||
for job in jobs {
|
||||
yield job.into()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn keep_alive_at<Service>(
|
||||
&mut self,
|
||||
worker_id: &WorkerId,
|
||||
last_seen: DateTime<Utc>,
|
||||
) -> StorageResult<()> {
|
||||
let pool = self.pool.clone();
|
||||
|
||||
let worker_type = T::NAME;
|
||||
let storage_name = std::any::type_name::<Self>();
|
||||
let query = "INSERT INTO apalis.workers (id, worker_type, storage_name, layers, last_seen)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
ON CONFLICT (id) DO
|
||||
UPDATE SET last_seen = EXCLUDED.last_seen";
|
||||
sqlx::query(query)
|
||||
.bind(worker_id.name())
|
||||
.bind(worker_type)
|
||||
.bind(storage_name)
|
||||
.bind(std::any::type_name::<Service>())
|
||||
.bind(last_seen)
|
||||
.execute(&pool)
|
||||
.await
|
||||
.map_err(|e| StorageError::Database(Box::from(e)))?;
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
#[async_trait::async_trait]
|
||||
impl<T> apalis_core::storage::Storage for Storage<T>
|
||||
where
|
||||
T: Job + Serialize + DeserializeOwned + Send + 'static + Unpin + Sync,
|
||||
{
|
||||
type Output = T;
|
||||
|
||||
/// Push a job to Postgres [Storage]
|
||||
///
|
||||
/// # SQL Example
|
||||
///
|
||||
/// ```sql
|
||||
/// SELECT apalis.push_job(job_type::text, job::json);
|
||||
/// ```
|
||||
async fn push(&mut self, job: Self::Output) -> StorageResult<JobId> {
|
||||
let id = JobId::new();
|
||||
let query = "INSERT INTO apalis.jobs VALUES ($1, $2, $3, 'Pending', 0, 25, NOW() , NULL, NULL, NULL, NULL)";
|
||||
let pool = self.pool.clone();
|
||||
let job = serde_json::to_value(&job).map_err(|e| StorageError::Parse(Box::from(e)))?;
|
||||
let job_type = T::NAME;
|
||||
sqlx::query(query)
|
||||
.bind(job)
|
||||
.bind(id.to_string())
|
||||
.bind(job_type)
|
||||
.execute(&pool)
|
||||
.await
|
||||
.map_err(|e| StorageError::Database(Box::from(e)))?;
|
||||
Ok(id)
|
||||
}
|
||||
|
||||
async fn schedule(
|
||||
&mut self,
|
||||
job: Self::Output,
|
||||
on: chrono::DateTime<Utc>,
|
||||
) -> StorageResult<JobId> {
|
||||
let query =
|
||||
"INSERT INTO apalis.jobs VALUES ($1, $2, $3, 'Pending', 0, 25, $4, NULL, NULL, NULL, NULL)";
|
||||
|
||||
let mut conn = self
|
||||
.pool
|
||||
.acquire()
|
||||
.await
|
||||
.map_err(|e| StorageError::Connection(Box::from(e)))?;
|
||||
|
||||
let id = JobId::new();
|
||||
let job = serde_json::to_value(&job).map_err(|e| StorageError::Parse(Box::from(e)))?;
|
||||
let job_type = T::NAME;
|
||||
sqlx::query(query)
|
||||
.bind(job)
|
||||
.bind(id.to_string())
|
||||
.bind(job_type)
|
||||
.bind(on)
|
||||
.execute(&mut *conn)
|
||||
.await
|
||||
.map_err(|e| StorageError::Database(Box::from(e)))?;
|
||||
Ok(id)
|
||||
}
|
||||
|
||||
async fn fetch_by_id(&self, job_id: &JobId) -> StorageResult<Option<JobRequest<Self::Output>>> {
|
||||
let mut conn = self
|
||||
.pool
|
||||
.acquire()
|
||||
.await
|
||||
.map_err(|e| StorageError::Connection(Box::from(e)))?;
|
||||
|
||||
let fetch_query = "SELECT * FROM apalis.jobs WHERE id = $1";
|
||||
let res: Option<SqlJobRequest<T>> = sqlx::query_as(fetch_query)
|
||||
.bind(job_id.to_string())
|
||||
.fetch_optional(&mut *conn)
|
||||
.await
|
||||
.map_err(|e| StorageError::Database(Box::from(e)))?;
|
||||
Ok(res.map(Into::into))
|
||||
}
|
||||
|
||||
async fn heartbeat(&mut self, pulse: StorageWorkerPulse) -> StorageResult<bool> {
|
||||
match pulse {
|
||||
StorageWorkerPulse::EnqueueScheduled { count: _ } => {
|
||||
// Ideally jobs are queue via run_at. So this is not necessary
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
// Worker not seen in 5 minutes yet has running jobs
|
||||
StorageWorkerPulse::ReenqueueOrphaned { count, .. } => {
|
||||
let job_type = T::NAME;
|
||||
let mut conn = self
|
||||
.pool
|
||||
.acquire()
|
||||
.await
|
||||
.map_err(|e| StorageError::Database(Box::from(e)))?;
|
||||
let query = "UPDATE apalis.jobs
|
||||
SET status = 'Pending', done_at = NULL, lock_by = NULL, lock_at = NULL, last_error ='Job was abandoned'
|
||||
WHERE id in
|
||||
(SELECT jobs.id from apalis.jobs INNER join apalis.workers ON lock_by = workers.id
|
||||
WHERE status = 'Running' AND workers.last_seen < NOW() - INTERVAL '5 minutes'
|
||||
AND workers.worker_type = $1 ORDER BY lock_at ASC LIMIT $2);";
|
||||
sqlx::query(query)
|
||||
.bind(job_type)
|
||||
.bind(count)
|
||||
.execute(&mut *conn)
|
||||
.await
|
||||
.map_err(|e| StorageError::Database(Box::from(e)))?;
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
_ => unimplemented!(),
|
||||
}
|
||||
}
|
||||
|
||||
async fn kill(&mut self, worker_id: &WorkerId, job_id: &JobId) -> StorageResult<()> {
|
||||
let pool = self.pool.clone();
|
||||
|
||||
let mut conn = pool
|
||||
.acquire()
|
||||
.await
|
||||
.map_err(|e| StorageError::Connection(Box::from(e)))?;
|
||||
let query =
|
||||
"UPDATE apalis.jobs SET status = 'Killed', done_at = now() WHERE id = $1 AND lock_by = $2";
|
||||
sqlx::query(query)
|
||||
.bind(job_id.to_string())
|
||||
.bind(worker_id.name())
|
||||
.execute(&mut *conn)
|
||||
.await
|
||||
.map_err(|e| StorageError::Database(Box::from(e)))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Puts the job instantly back into the queue
|
||||
/// Another [Worker] may consume
|
||||
async fn retry(&mut self, worker_id: &WorkerId, job_id: &JobId) -> StorageResult<()> {
|
||||
let pool = self.pool.clone();
|
||||
|
||||
let mut conn = pool
|
||||
.acquire()
|
||||
.await
|
||||
.map_err(|e| StorageError::Connection(Box::from(e)))?;
|
||||
let query =
|
||||
"UPDATE apalis.jobs SET status = 'Pending', done_at = NULL, lock_by = NULL WHERE id = $1 AND lock_by = $2";
|
||||
sqlx::query(query)
|
||||
.bind(job_id.to_string())
|
||||
.bind(worker_id.name())
|
||||
.execute(&mut *conn)
|
||||
.await
|
||||
.map_err(|e| StorageError::Database(Box::from(e)))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn consume(
|
||||
&mut self,
|
||||
worker_id: &WorkerId,
|
||||
interval: Duration,
|
||||
buffer_size: usize,
|
||||
) -> JobStreamResult<T> {
|
||||
Box::pin(
|
||||
self.stream_jobs(worker_id, interval, buffer_size)
|
||||
.map(|r| r.map(Some)),
|
||||
)
|
||||
}
|
||||
async fn len(&self) -> StorageResult<i64> {
|
||||
let pool = self.pool.clone();
|
||||
let query = "SELECT COUNT(*) AS count FROM apalis.jobs WHERE status = 'Pending'";
|
||||
let record = sqlx::query(query)
|
||||
.fetch_one(&pool)
|
||||
.await
|
||||
.map_err(|e| StorageError::Database(Box::from(e)))?;
|
||||
Ok(record
|
||||
.try_get("count")
|
||||
.map_err(|e| StorageError::Database(Box::from(e)))?)
|
||||
}
|
||||
async fn ack(&mut self, worker_id: &WorkerId, job_id: &JobId) -> StorageResult<()> {
|
||||
let pool = self.pool.clone();
|
||||
let query =
|
||||
"UPDATE apalis.jobs SET status = 'Done', done_at = now() WHERE id = $1 AND lock_by = $2";
|
||||
sqlx::query(query)
|
||||
.bind(job_id.to_string())
|
||||
.bind(worker_id.name())
|
||||
.execute(&pool)
|
||||
.await
|
||||
.map_err(|e| StorageError::Database(Box::from(e)))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn reschedule(&mut self, job: &JobRequest<T>, wait: Duration) -> StorageResult<()> {
|
||||
let pool = self.pool.clone();
|
||||
let job_id = job.id();
|
||||
|
||||
let wait: i64 = wait
|
||||
.as_secs()
|
||||
.try_into()
|
||||
.map_err(|e| StorageError::Database(Box::new(e)))?;
|
||||
let wait = chrono::Duration::microseconds(wait * 1000 * 1000);
|
||||
// TODO: should we use a clock here?
|
||||
#[allow(clippy::disallowed_methods)]
|
||||
let run_at = Utc::now().add(wait);
|
||||
|
||||
let mut conn = pool
|
||||
.acquire()
|
||||
.await
|
||||
.map_err(|e| StorageError::Connection(Box::from(e)))?;
|
||||
let query =
|
||||
"UPDATE apalis.jobs SET status = 'Pending', done_at = NULL, lock_by = NULL, lock_at = NULL, run_at = $2 WHERE id = $1";
|
||||
sqlx::query(query)
|
||||
.bind(job_id.to_string())
|
||||
.bind(run_at)
|
||||
.execute(&mut *conn)
|
||||
.await
|
||||
.map_err(|e| StorageError::Database(Box::from(e)))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn update_by_id(
|
||||
&self,
|
||||
job_id: &JobId,
|
||||
job: &JobRequest<Self::Output>,
|
||||
) -> StorageResult<()> {
|
||||
let pool = self.pool.clone();
|
||||
let status = job.status().as_ref();
|
||||
let attempts = job.attempts();
|
||||
let done_at = *job.done_at();
|
||||
let lock_by = job.lock_by().clone();
|
||||
let lock_at = *job.lock_at();
|
||||
let last_error = job.last_error().clone();
|
||||
|
||||
let mut conn = pool
|
||||
.acquire()
|
||||
.await
|
||||
.map_err(|e| StorageError::Connection(Box::from(e)))?;
|
||||
let query =
|
||||
"UPDATE apalis.jobs SET status = $1, attempts = $2, done_at = $3, lock_by = $4, lock_at = $5, last_error = $6 WHERE id = $7";
|
||||
sqlx::query(query)
|
||||
.bind(status.to_owned())
|
||||
.bind(attempts)
|
||||
.bind(done_at)
|
||||
.bind(lock_by.as_ref().map(WorkerId::name))
|
||||
.bind(lock_at)
|
||||
.bind(last_error)
|
||||
.bind(job_id.to_string())
|
||||
.execute(&mut *conn)
|
||||
.await
|
||||
.map_err(|e| StorageError::Database(Box::from(e)))?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn keep_alive<Service>(&mut self, worker_id: &WorkerId) -> StorageResult<()> {
|
||||
#[allow(clippy::disallowed_methods)]
|
||||
let now = Utc::now();
|
||||
|
||||
self.keep_alive_at::<Service>(worker_id, now).await
|
||||
}
|
||||
}
|
||||
@@ -5,10 +5,9 @@
|
||||
// Please see LICENSE in the repository root for full details.
|
||||
|
||||
use anyhow::Context;
|
||||
use apalis_core::{context::JobContext, executor::TokioExecutor, monitor::Monitor};
|
||||
use async_trait::async_trait;
|
||||
use mas_storage::{
|
||||
compat::CompatSessionFilter,
|
||||
job::JobWithSpanContext,
|
||||
oauth2::OAuth2SessionFilter,
|
||||
queue::{DeactivateUserJob, ReactivateUserJob},
|
||||
user::{BrowserSessionFilter, UserRepository},
|
||||
@@ -16,122 +15,106 @@ use mas_storage::{
|
||||
};
|
||||
use tracing::info;
|
||||
|
||||
use crate::{storage::PostgresStorageFactory, JobContextExt, State};
|
||||
use crate::{
|
||||
new_queue::{JobContext, RunnableJob},
|
||||
State,
|
||||
};
|
||||
|
||||
/// Job to deactivate a user, both locally and on the Matrix homeserver.
|
||||
#[tracing::instrument(
|
||||
#[async_trait]
|
||||
impl RunnableJob for DeactivateUserJob {
|
||||
#[tracing::instrument(
|
||||
name = "job.deactivate_user"
|
||||
fields(user.id = %job.user_id(), erase = %job.hs_erase()),
|
||||
skip_all,
|
||||
err(Debug),
|
||||
)]
|
||||
async fn deactivate_user(
|
||||
job: JobWithSpanContext<DeactivateUserJob>,
|
||||
ctx: JobContext,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
let state = ctx.state();
|
||||
let clock = state.clock();
|
||||
let matrix = state.matrix_connection();
|
||||
let mut repo = state.repository().await?;
|
||||
fields(user.id = %self.user_id(), erase = %self.hs_erase()),
|
||||
skip_all,
|
||||
err(Debug),
|
||||
)]
|
||||
async fn run(&self, state: &State, _context: JobContext) -> Result<(), anyhow::Error> {
|
||||
let clock = state.clock();
|
||||
let matrix = state.matrix_connection();
|
||||
let mut repo = state.repository().await?;
|
||||
|
||||
let user = repo
|
||||
.user()
|
||||
.lookup(job.user_id())
|
||||
.await?
|
||||
.context("User not found")?;
|
||||
let user = repo
|
||||
.user()
|
||||
.lookup(self.user_id())
|
||||
.await?
|
||||
.context("User not found")?;
|
||||
|
||||
// Let's first lock the user
|
||||
let user = repo
|
||||
.user()
|
||||
.lock(&clock, user)
|
||||
.await
|
||||
.context("Failed to lock user")?;
|
||||
// Let's first lock the user
|
||||
let user = repo
|
||||
.user()
|
||||
.lock(&clock, user)
|
||||
.await
|
||||
.context("Failed to lock user")?;
|
||||
|
||||
// Kill all sessions for the user
|
||||
let n = repo
|
||||
.browser_session()
|
||||
.finish_bulk(
|
||||
&clock,
|
||||
BrowserSessionFilter::new().for_user(&user).active_only(),
|
||||
)
|
||||
.await?;
|
||||
info!(affected = n, "Killed all browser sessions for user");
|
||||
// Kill all sessions for the user
|
||||
let n = repo
|
||||
.browser_session()
|
||||
.finish_bulk(
|
||||
&clock,
|
||||
BrowserSessionFilter::new().for_user(&user).active_only(),
|
||||
)
|
||||
.await?;
|
||||
info!(affected = n, "Killed all browser sessions for user");
|
||||
|
||||
let n = repo
|
||||
.oauth2_session()
|
||||
.finish_bulk(
|
||||
&clock,
|
||||
OAuth2SessionFilter::new().for_user(&user).active_only(),
|
||||
)
|
||||
.await?;
|
||||
info!(affected = n, "Killed all OAuth 2.0 sessions for user");
|
||||
let n = repo
|
||||
.oauth2_session()
|
||||
.finish_bulk(
|
||||
&clock,
|
||||
OAuth2SessionFilter::new().for_user(&user).active_only(),
|
||||
)
|
||||
.await?;
|
||||
info!(affected = n, "Killed all OAuth 2.0 sessions for user");
|
||||
|
||||
let n = repo
|
||||
.compat_session()
|
||||
.finish_bulk(
|
||||
&clock,
|
||||
CompatSessionFilter::new().for_user(&user).active_only(),
|
||||
)
|
||||
.await?;
|
||||
info!(affected = n, "Killed all compatibility sessions for user");
|
||||
let n = repo
|
||||
.compat_session()
|
||||
.finish_bulk(
|
||||
&clock,
|
||||
CompatSessionFilter::new().for_user(&user).active_only(),
|
||||
)
|
||||
.await?;
|
||||
info!(affected = n, "Killed all compatibility sessions for user");
|
||||
|
||||
// Before calling back to the homeserver, commit the changes to the database, as
|
||||
// we want the user to be locked out as soon as possible
|
||||
repo.save().await?;
|
||||
// Before calling back to the homeserver, commit the changes to the database, as
|
||||
// we want the user to be locked out as soon as possible
|
||||
repo.save().await?;
|
||||
|
||||
let mxid = matrix.mxid(&user.username);
|
||||
info!("Deactivating user {} on homeserver", mxid);
|
||||
matrix.delete_user(&mxid, job.hs_erase()).await?;
|
||||
let mxid = matrix.mxid(&user.username);
|
||||
info!("Deactivating user {} on homeserver", mxid);
|
||||
matrix.delete_user(&mxid, self.hs_erase()).await?;
|
||||
|
||||
Ok(())
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
/// Job to reactivate a user, both locally and on the Matrix homeserver.
|
||||
#[tracing::instrument(
|
||||
name = "job.reactivate_user",
|
||||
fields(user.id = %job.user_id()),
|
||||
skip_all,
|
||||
err(Debug),
|
||||
)]
|
||||
pub async fn reactivate_user(
|
||||
job: JobWithSpanContext<ReactivateUserJob>,
|
||||
ctx: JobContext,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
let state = ctx.state();
|
||||
let matrix = state.matrix_connection();
|
||||
let mut repo = state.repository().await?;
|
||||
#[async_trait]
|
||||
impl RunnableJob for ReactivateUserJob {
|
||||
#[tracing::instrument(
|
||||
name = "job.reactivate_user",
|
||||
fields(user.id = %self.user_id()),
|
||||
skip_all,
|
||||
err(Debug),
|
||||
)]
|
||||
async fn run(&self, state: &State, _context: JobContext) -> Result<(), anyhow::Error> {
|
||||
let matrix = state.matrix_connection();
|
||||
let mut repo = state.repository().await?;
|
||||
|
||||
let user = repo
|
||||
.user()
|
||||
.lookup(job.user_id())
|
||||
.await?
|
||||
.context("User not found")?;
|
||||
let user = repo
|
||||
.user()
|
||||
.lookup(self.user_id())
|
||||
.await?
|
||||
.context("User not found")?;
|
||||
|
||||
let mxid = matrix.mxid(&user.username);
|
||||
info!("Reactivating user {} on homeserver", mxid);
|
||||
matrix.reactivate_user(&mxid).await?;
|
||||
let mxid = matrix.mxid(&user.username);
|
||||
info!("Reactivating user {} on homeserver", mxid);
|
||||
matrix.reactivate_user(&mxid).await?;
|
||||
|
||||
// We want to unlock the user from our side only once it has been reactivated on
|
||||
// the homeserver
|
||||
let _user = repo.user().unlock(user).await?;
|
||||
repo.save().await?;
|
||||
// We want to unlock the user from our side only once it has been reactivated on
|
||||
// the homeserver
|
||||
let _user = repo.user().unlock(user).await?;
|
||||
repo.save().await?;
|
||||
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) fn register(
|
||||
suffix: &str,
|
||||
monitor: Monitor<TokioExecutor>,
|
||||
state: &State,
|
||||
storage_factory: &PostgresStorageFactory,
|
||||
) -> Monitor<TokioExecutor> {
|
||||
let deactivate_user_worker =
|
||||
crate::build!(DeactivateUserJob => deactivate_user, suffix, state, storage_factory);
|
||||
|
||||
let reactivate_user_worker =
|
||||
crate::build!(ReactivateUserJob => reactivate_user, suffix, state, storage_factory);
|
||||
|
||||
monitor
|
||||
.register(deactivate_user_worker)
|
||||
.register(reactivate_user_worker)
|
||||
Ok(())
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,91 +0,0 @@
|
||||
// Copyright 2024 New Vector Ltd.
|
||||
// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// SPDX-License-Identifier: AGPL-3.0-only
|
||||
// Please see LICENSE in the repository root for full details.
|
||||
|
||||
use apalis_core::{job::Job, request::JobRequest};
|
||||
use mas_storage::job::JobWithSpanContext;
|
||||
use mas_tower::{
|
||||
make_span_fn, DurationRecorderLayer, FnWrapper, IdentityLayer, InFlightCounterLayer,
|
||||
TraceLayer, KV,
|
||||
};
|
||||
use opentelemetry::{trace::SpanContext, Key, KeyValue};
|
||||
use tracing::info_span;
|
||||
use tracing_opentelemetry::OpenTelemetrySpanExt;
|
||||
|
||||
const JOB_NAME: Key = Key::from_static_str("job.name");
|
||||
const JOB_STATUS: Key = Key::from_static_str("job.status");
|
||||
|
||||
/// Represents a job that can may have a span context attached to it.
|
||||
pub trait TracedJob: Job {
|
||||
/// Returns the span context for this job, if any.
|
||||
///
|
||||
/// The default implementation returns `None`.
|
||||
fn span_context(&self) -> Option<SpanContext> {
|
||||
None
|
||||
}
|
||||
}
|
||||
|
||||
/// Implements [`TracedJob`] for any job with the [`JobWithSpanContext`]
|
||||
/// wrapper.
|
||||
impl<J: Job> TracedJob for JobWithSpanContext<J> {
|
||||
fn span_context(&self) -> Option<SpanContext> {
|
||||
JobWithSpanContext::span_context(self)
|
||||
}
|
||||
}
|
||||
|
||||
fn make_span_for_job_request<J: TracedJob>(req: &JobRequest<J>) -> tracing::Span {
|
||||
let span = info_span!(
|
||||
"job.run",
|
||||
"otel.kind" = "consumer",
|
||||
"otel.status_code" = tracing::field::Empty,
|
||||
"job.id" = %req.id(),
|
||||
"job.attempts" = req.attempts(),
|
||||
"job.name" = J::NAME,
|
||||
);
|
||||
|
||||
if let Some(context) = req.inner().span_context() {
|
||||
span.add_link(context);
|
||||
}
|
||||
|
||||
span
|
||||
}
|
||||
|
||||
type TraceLayerForJob<J> =
|
||||
TraceLayer<FnWrapper<fn(&JobRequest<J>) -> tracing::Span>, KV<&'static str>, KV<&'static str>>;
|
||||
|
||||
pub(crate) fn trace_layer<J>() -> TraceLayerForJob<J>
|
||||
where
|
||||
J: TracedJob,
|
||||
{
|
||||
TraceLayer::new(make_span_fn(
|
||||
make_span_for_job_request::<J> as fn(&JobRequest<J>) -> tracing::Span,
|
||||
))
|
||||
.on_response(KV("otel.status_code", "OK"))
|
||||
.on_error(KV("otel.status_code", "ERROR"))
|
||||
}
|
||||
|
||||
type MetricsLayerForJob<J> = (
|
||||
IdentityLayer<JobRequest<J>>,
|
||||
DurationRecorderLayer<KeyValue, KeyValue, KeyValue>,
|
||||
InFlightCounterLayer<KeyValue>,
|
||||
);
|
||||
|
||||
pub(crate) fn metrics_layer<J>() -> MetricsLayerForJob<J>
|
||||
where
|
||||
J: Job,
|
||||
{
|
||||
let duration_recorder = DurationRecorderLayer::new("job.run.duration")
|
||||
.on_request(JOB_NAME.string(J::NAME))
|
||||
.on_response(JOB_STATUS.string("success"))
|
||||
.on_error(JOB_STATUS.string("error"));
|
||||
let in_flight_counter =
|
||||
InFlightCounterLayer::new("job.run.active").on_request(JOB_NAME.string(J::NAME));
|
||||
|
||||
(
|
||||
IdentityLayer::default(),
|
||||
duration_recorder,
|
||||
in_flight_counter,
|
||||
)
|
||||
}
|
||||
Reference in New Issue
Block a user