Actually consume jobs

This commit is contained in:
Quentin Gliech
2024-10-31 17:38:43 +01:00
parent 70b864f609
commit bd72a57719
14 changed files with 867 additions and 1044 deletions

View File

@@ -0,0 +1,43 @@
{
"db_name": "PostgreSQL",
"query": "\n -- We first grab a few jobs that are available,\n -- using a FOR UPDATE SKIP LOCKED so that this can be run concurrently\n -- and we don't get multiple workers grabbing the same jobs\n WITH locked_jobs AS (\n SELECT queue_job_id\n FROM queue_jobs\n WHERE\n status = 'available'\n AND queue_name = ANY($1)\n ORDER BY queue_job_id ASC\n LIMIT $2\n FOR UPDATE\n SKIP LOCKED\n )\n -- then we update the status of those jobs to 'running', returning the job details\n UPDATE queue_jobs\n SET status = 'running', started_at = $3, started_by = $4\n FROM locked_jobs\n WHERE queue_jobs.queue_job_id = locked_jobs.queue_job_id\n RETURNING\n queue_jobs.queue_job_id,\n queue_jobs.queue_name,\n queue_jobs.payload,\n queue_jobs.metadata\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "queue_job_id",
"type_info": "Uuid"
},
{
"ordinal": 1,
"name": "queue_name",
"type_info": "Text"
},
{
"ordinal": 2,
"name": "payload",
"type_info": "Jsonb"
},
{
"ordinal": 3,
"name": "metadata",
"type_info": "Jsonb"
}
],
"parameters": {
"Left": [
"TextArray",
"Int8",
"Timestamptz",
"Uuid"
]
},
"nullable": [
false,
false,
false,
false
]
},
"hash": "9f2fae84d17991a179f93c4ea43b411aa9f15e7beccfd6212787c3452d35d061"
}

View File

@@ -0,0 +1,15 @@
{
"db_name": "PostgreSQL",
"query": "\n UPDATE queue_jobs\n SET status = 'completed', completed_at = $1\n WHERE queue_job_id = $2 AND status = 'running'\n ",
"describe": {
"columns": [],
"parameters": {
"Left": [
"Timestamptz",
"Uuid"
]
},
"nullable": []
},
"hash": "a63a217981b97448ddcc96b2489ddd9d3bc8c99b5b8b1d373939fc3ae9715c27"
}

View File

@@ -7,13 +7,16 @@
//! [`QueueJobRepository`].
use async_trait::async_trait;
use mas_storage::{queue::QueueJobRepository, Clock};
use mas_storage::{
queue::{Job, QueueJobRepository, Worker},
Clock,
};
use rand::RngCore;
use sqlx::PgConnection;
use ulid::Ulid;
use uuid::Uuid;
use crate::{DatabaseError, ExecuteExt};
use crate::{DatabaseError, DatabaseInconsistencyError, ExecuteExt};
/// An implementation of [`QueueJobRepository`] for a PostgreSQL connection.
pub struct PgQueueJobRepository<'c> {
@@ -29,6 +32,37 @@ impl<'c> PgQueueJobRepository<'c> {
}
}
struct JobReservationResult {
queue_job_id: Uuid,
queue_name: String,
payload: serde_json::Value,
metadata: serde_json::Value,
}
impl TryFrom<JobReservationResult> for Job {
type Error = DatabaseInconsistencyError;
fn try_from(value: JobReservationResult) -> Result<Self, Self::Error> {
let id = value.queue_job_id.into();
let queue_name = value.queue_name;
let payload = value.payload;
let metadata = serde_json::from_value(value.metadata).map_err(|e| {
DatabaseInconsistencyError::on("queue_jobs")
.column("metadata")
.row(id)
.source(e)
})?;
Ok(Self {
id,
queue_name,
payload,
metadata,
})
}
}
#[async_trait]
impl QueueJobRepository for PgQueueJobRepository<'_> {
type Error = DatabaseError;
@@ -73,4 +107,96 @@ impl QueueJobRepository for PgQueueJobRepository<'_> {
Ok(())
}
#[tracing::instrument(
name = "db.queue_job.reserve",
skip_all,
fields(
db.query.text,
),
err,
)]
async fn reserve(
&mut self,
clock: &dyn Clock,
worker: &Worker,
queues: &[&str],
count: usize,
) -> Result<Vec<Job>, Self::Error> {
let now = clock.now();
let max_count = i64::try_from(count).unwrap_or(i64::MAX);
let queues: Vec<String> = queues.iter().map(|&s| s.to_owned()).collect();
let results = sqlx::query_as!(
JobReservationResult,
r#"
-- We first grab a few jobs that are available,
-- using a FOR UPDATE SKIP LOCKED so that this can be run concurrently
-- and we don't get multiple workers grabbing the same jobs
WITH locked_jobs AS (
SELECT queue_job_id
FROM queue_jobs
WHERE
status = 'available'
AND queue_name = ANY($1)
ORDER BY queue_job_id ASC
LIMIT $2
FOR UPDATE
SKIP LOCKED
)
-- then we update the status of those jobs to 'running', returning the job details
UPDATE queue_jobs
SET status = 'running', started_at = $3, started_by = $4
FROM locked_jobs
WHERE queue_jobs.queue_job_id = locked_jobs.queue_job_id
RETURNING
queue_jobs.queue_job_id,
queue_jobs.queue_name,
queue_jobs.payload,
queue_jobs.metadata
"#,
&queues,
max_count,
now,
Uuid::from(worker.id),
)
.traced()
.fetch_all(&mut *self.conn)
.await?;
let jobs = results
.into_iter()
.map(TryFrom::try_from)
.collect::<Result<Vec<_>, _>>()?;
Ok(jobs)
}
#[tracing::instrument(
name = "db.queue_job.mark_as_completed",
skip_all,
fields(
db.query.text,
job.id = %id,
),
err,
)]
async fn mark_as_completed(&mut self, clock: &dyn Clock, id: Ulid) -> Result<(), Self::Error> {
let now = clock.now();
let res = sqlx::query!(
r#"
UPDATE queue_jobs
SET status = 'completed', completed_at = $1
WHERE queue_job_id = $2 AND status = 'running'
"#,
now,
Uuid::from(id),
)
.traced()
.execute(&mut *self.conn)
.await?;
DatabaseError::ensure_affected_rows(&res, 1)?;
Ok(())
}
}

View File

@@ -12,6 +12,7 @@ use serde::{Deserialize, Serialize};
use tracing_opentelemetry::OpenTelemetrySpanExt;
use ulid::Ulid;
use super::Worker;
use crate::{repository_impl, Clock};
/// Represents a job in the job queue
@@ -19,6 +20,9 @@ pub struct Job {
/// The ID of the job
pub id: Ulid,
/// The queue on which the job was placed
pub queue_name: String,
/// The payload of the job
pub payload: serde_json::Value,
@@ -27,7 +31,7 @@ pub struct Job {
}
/// Metadata stored alongside the job
#[derive(Serialize, Deserialize, Default)]
#[derive(Serialize, Deserialize, Default, Clone, Debug)]
pub struct JobMetadata {
#[serde(default)]
trace_id: String,
@@ -97,6 +101,38 @@ pub trait QueueJobRepository: Send + Sync {
payload: serde_json::Value,
metadata: serde_json::Value,
) -> Result<(), Self::Error>;
/// Reserve multiple jobs from multiple queues
///
/// # Parameters
///
/// * `clock` - The clock used to generate timestamps
/// * `worker` - The worker that is reserving the jobs
/// * `queues` - The queues to reserve jobs from
/// * `count` - The number of jobs to reserve
///
/// # Errors
///
/// Returns an error if the underlying repository fails.
async fn reserve(
&mut self,
clock: &dyn Clock,
worker: &Worker,
queues: &[&str],
count: usize,
) -> Result<Vec<Job>, Self::Error>;
/// Mark a job as completed
///
/// # Parameters
///
/// * `clock` - The clock used to generate timestamps
/// * `job` - The job to mark as completed
///
/// # Errors
///
/// Returns an error if the underlying repository fails.
async fn mark_as_completed(&mut self, clock: &dyn Clock, id: Ulid) -> Result<(), Self::Error>;
}
repository_impl!(QueueJobRepository:
@@ -108,6 +144,16 @@ repository_impl!(QueueJobRepository:
payload: serde_json::Value,
metadata: serde_json::Value,
) -> Result<(), Self::Error>;
async fn reserve(
&mut self,
clock: &dyn Clock,
worker: &Worker,
queues: &[&str],
count: usize,
) -> Result<Vec<Job>, Self::Error>;
async fn mark_as_completed(&mut self, clock: &dyn Clock, id: Ulid) -> Result<(), Self::Error>;
);
/// Extension trait for [`QueueJobRepository`] to help adding a job to the queue

View File

@@ -5,97 +5,87 @@
// Please see LICENSE in the repository root for full details.
use anyhow::Context;
use apalis_core::{context::JobContext, executor::TokioExecutor, monitor::Monitor};
use async_trait::async_trait;
use chrono::Duration;
use mas_email::{Address, Mailbox};
use mas_i18n::locale;
use mas_storage::{job::JobWithSpanContext, queue::VerifyEmailJob};
use mas_storage::queue::VerifyEmailJob;
use mas_templates::{EmailVerificationContext, TemplateContext};
use rand::{distributions::Uniform, Rng};
use tracing::info;
use crate::{storage::PostgresStorageFactory, JobContextExt, State};
use crate::{
new_queue::{JobContext, RunnableJob},
State,
};
#[tracing::instrument(
name = "job.verify_email",
fields(user_email.id = %job.user_email_id()),
skip_all,
err(Debug),
)]
async fn verify_email(
job: JobWithSpanContext<VerifyEmailJob>,
ctx: JobContext,
) -> Result<(), anyhow::Error> {
let state = ctx.state();
let mut repo = state.repository().await?;
let mut rng = state.rng();
let mailer = state.mailer();
let clock = state.clock();
#[async_trait]
impl RunnableJob for VerifyEmailJob {
#[tracing::instrument(
name = "job.verify_email",
fields(user_email.id = %self.user_email_id()),
skip_all,
err(Debug),
)]
async fn run(&self, state: &State, _context: JobContext) -> Result<(), anyhow::Error> {
let mut repo = state.repository().await?;
let mut rng = state.rng();
let mailer = state.mailer();
let clock = state.clock();
let language = job
.language()
.and_then(|l| l.parse().ok())
.unwrap_or(locale!("en").into());
let language = self
.language()
.and_then(|l| l.parse().ok())
.unwrap_or(locale!("en").into());
// Lookup the user email
let user_email = repo
.user_email()
.lookup(job.user_email_id())
.await?
.context("User email not found")?;
// Lookup the user email
let user_email = repo
.user_email()
.lookup(self.user_email_id())
.await?
.context("User email not found")?;
// Lookup the user associated with the email
let user = repo
.user()
.lookup(user_email.user_id)
.await?
.context("User not found")?;
// Lookup the user associated with the email
let user = repo
.user()
.lookup(user_email.user_id)
.await?
.context("User not found")?;
// Generate a verification code
let range = Uniform::<u32>::from(0..1_000_000);
let code = rng.sample(range);
let code = format!("{code:06}");
// Generate a verification code
let range = Uniform::<u32>::from(0..1_000_000);
let code = rng.sample(range);
let code = format!("{code:06}");
let address: Address = user_email.email.parse()?;
let address: Address = user_email.email.parse()?;
// Save the verification code in the database
let verification = repo
.user_email()
.add_verification_code(
&mut rng,
&clock,
&user_email,
Duration::try_hours(8).unwrap(),
code,
)
.await?;
// Save the verification code in the database
let verification = repo
.user_email()
.add_verification_code(
&mut rng,
&clock,
&user_email,
Duration::try_hours(8).unwrap(),
code,
)
.await?;
// And send the verification email
let mailbox = Mailbox::new(Some(user.username.clone()), address);
// And send the verification email
let mailbox = Mailbox::new(Some(user.username.clone()), address);
let context =
EmailVerificationContext::new(user.clone(), verification.clone()).with_language(language);
let context = EmailVerificationContext::new(user.clone(), verification.clone())
.with_language(language);
mailer.send_verification_email(mailbox, &context).await?;
mailer.send_verification_email(mailbox, &context).await?;
info!(
email.id = %user_email.id,
"Verification email sent"
);
info!(
email.id = %user_email.id,
"Verification email sent"
);
repo.save().await?;
repo.save().await?;
Ok(())
}
pub(crate) fn register(
suffix: &str,
monitor: Monitor<TokioExecutor>,
state: &State,
storage_factory: &PostgresStorageFactory,
) -> Monitor<TokioExecutor> {
let verify_email_worker =
crate::build!(VerifyEmailJob => verify_email, suffix, state, storage_factory);
monitor.register(verify_email_worker)
Ok(())
}
}

View File

@@ -18,14 +18,13 @@ use rand::SeedableRng;
use sqlx::{Pool, Postgres};
use tokio_util::{sync::CancellationToken, task::TaskTracker};
// TODO: we need to have a way to schedule recurring tasks
// mod database;
// mod email;
// mod matrix;
mod email;
mod matrix;
mod new_queue;
// mod recovery;
// mod storage;
// mod user;
// mod utils;
mod recovery;
mod user;
#[derive(Clone)]
struct State {
@@ -111,6 +110,15 @@ pub async fn init(
);
let mut worker = self::new_queue::QueueWorker::new(state, cancellation_token).await?;
worker.register_handler::<mas_storage::queue::DeactivateUserJob>();
worker.register_handler::<mas_storage::queue::DeleteDeviceJob>();
worker.register_handler::<mas_storage::queue::ProvisionDeviceJob>();
worker.register_handler::<mas_storage::queue::ProvisionUserJob>();
worker.register_handler::<mas_storage::queue::ReactivateUserJob>();
worker.register_handler::<mas_storage::queue::SendAccountRecoveryEmailsJob>();
worker.register_handler::<mas_storage::queue::SyncDevicesJob>();
worker.register_handler::<mas_storage::queue::VerifyEmailJob>();
task_tracker.spawn(async move {
if let Err(e) = worker.run().await {
tracing::error!(

View File

@@ -7,239 +7,239 @@
use std::collections::HashSet;
use anyhow::Context;
use apalis_core::{context::JobContext, executor::TokioExecutor, monitor::Monitor};
use async_trait::async_trait;
use mas_data_model::Device;
use mas_matrix::ProvisionRequest;
use mas_storage::{
compat::CompatSessionFilter,
job::{JobRepositoryExt as _, JobWithSpanContext},
oauth2::OAuth2SessionFilter,
queue::{DeleteDeviceJob, ProvisionDeviceJob, ProvisionUserJob, SyncDevicesJob},
queue::{
DeleteDeviceJob, ProvisionDeviceJob, ProvisionUserJob, QueueJobRepositoryExt as _,
SyncDevicesJob,
},
user::{UserEmailRepository, UserRepository},
Pagination, RepositoryAccess,
};
use tracing::info;
use crate::{storage::PostgresStorageFactory, JobContextExt, State};
use crate::{
new_queue::{JobContext, RunnableJob},
State,
};
/// Job to provision a user on the Matrix homeserver.
/// This works by doing a PUT request to the /_synapse/admin/v2/users/{user_id}
/// endpoint.
#[tracing::instrument(
name = "job.provision_user"
fields(user.id = %job.user_id()),
skip_all,
err(Debug),
)]
async fn provision_user(
job: JobWithSpanContext<ProvisionUserJob>,
ctx: JobContext,
) -> Result<(), anyhow::Error> {
let state = ctx.state();
let matrix = state.matrix_connection();
let mut repo = state.repository().await?;
/// This works by doing a PUT request to the
/// /_synapse/admin/v2/users/{user_id} endpoint.
#[async_trait]
impl RunnableJob for ProvisionUserJob {
#[tracing::instrument(
name = "job.provision_user"
fields(user.id = %self.user_id()),
skip_all,
err(Debug),
)]
async fn run(&self, state: &State, _context: JobContext) -> Result<(), anyhow::Error> {
let matrix = state.matrix_connection();
let mut repo = state.repository().await?;
let mut rng = state.rng();
let clock = state.clock();
let user = repo
.user()
.lookup(job.user_id())
.await?
.context("User not found")?;
let user = repo
.user()
.lookup(self.user_id())
.await?
.context("User not found")?;
let mxid = matrix.mxid(&user.username);
let emails = repo
.user_email()
.all(&user)
.await?
.into_iter()
.filter(|email| email.confirmed_at.is_some())
.map(|email| email.email)
.collect();
let mut request = ProvisionRequest::new(mxid.clone(), user.sub.clone()).set_emails(emails);
let mxid = matrix.mxid(&user.username);
let emails = repo
.user_email()
.all(&user)
.await?
.into_iter()
.filter(|email| email.confirmed_at.is_some())
.map(|email| email.email)
.collect();
let mut request = ProvisionRequest::new(mxid.clone(), user.sub.clone()).set_emails(emails);
if let Some(display_name) = job.display_name_to_set() {
request = request.set_displayname(display_name.to_owned());
if let Some(display_name) = self.display_name_to_set() {
request = request.set_displayname(display_name.to_owned());
}
let created = matrix.provision_user(&request).await?;
if created {
info!(%user.id, %mxid, "User created");
} else {
info!(%user.id, %mxid, "User updated");
}
// Schedule a device sync job
let sync_device_job = SyncDevicesJob::new(&user);
repo.queue_job()
.schedule_job(&mut rng, &clock, sync_device_job)
.await?;
repo.save().await?;
Ok(())
}
let created = matrix.provision_user(&request).await?;
if created {
info!(%user.id, %mxid, "User created");
} else {
info!(%user.id, %mxid, "User updated");
}
// Schedule a device sync job
let sync_device_job = SyncDevicesJob::new(&user);
repo.job().schedule_job(sync_device_job).await?;
repo.save().await?;
Ok(())
}
/// Job to provision a device on the Matrix homeserver.
///
/// This job is deprecated and therefore just schedules a [`SyncDevicesJob`]
#[tracing::instrument(
name = "job.provision_device"
fields(
user.id = %job.user_id(),
device.id = %job.device_id(),
),
skip_all,
err(Debug),
)]
async fn provision_device(
job: JobWithSpanContext<ProvisionDeviceJob>,
ctx: JobContext,
) -> Result<(), anyhow::Error> {
let state = ctx.state();
let mut repo = state.repository().await?;
#[async_trait]
impl RunnableJob for ProvisionDeviceJob {
#[tracing::instrument(
name = "job.provision_device"
fields(
user.id = %self.user_id(),
device.id = %self.device_id(),
),
skip_all,
err(Debug),
)]
async fn run(&self, state: &State, _context: JobContext) -> Result<(), anyhow::Error> {
let mut repo = state.repository().await?;
let mut rng = state.rng();
let clock = state.clock();
let user = repo
.user()
.lookup(job.user_id())
.await?
.context("User not found")?;
let user = repo
.user()
.lookup(self.user_id())
.await?
.context("User not found")?;
// Schedule a device sync job
repo.job().schedule_job(SyncDevicesJob::new(&user)).await?;
// Schedule a device sync job
repo.queue_job()
.schedule_job(&mut rng, &clock, SyncDevicesJob::new(&user))
.await?;
Ok(())
Ok(())
}
}
/// Job to delete a device from a user's account.
///
/// This job is deprecated and therefore just schedules a [`SyncDevicesJob`]
#[tracing::instrument(
name = "job.delete_device"
fields(
user.id = %job.user_id(),
device.id = %job.device_id(),
),
skip_all,
err(Debug),
)]
async fn delete_device(
job: JobWithSpanContext<DeleteDeviceJob>,
ctx: JobContext,
) -> Result<(), anyhow::Error> {
let state = ctx.state();
let mut repo = state.repository().await?;
#[async_trait]
impl RunnableJob for DeleteDeviceJob {
#[tracing::instrument(
name = "job.delete_device"
fields(
user.id = %self.user_id(),
device.id = %self.device_id(),
),
skip_all,
err(Debug),
)]
#[tracing::instrument(
name = "job.delete_device"
fields(
user.id = %self.user_id(),
device.id = %self.device_id(),
),
skip_all,
err(Debug),
)]
async fn run(&self, state: &State, _context: JobContext) -> Result<(), anyhow::Error> {
let mut rng = state.rng();
let clock = state.clock();
let mut repo = state.repository().await?;
let user = repo
.user()
.lookup(job.user_id())
.await?
.context("User not found")?;
let user = repo
.user()
.lookup(self.user_id())
.await?
.context("User not found")?;
// Schedule a device sync job
repo.job().schedule_job(SyncDevicesJob::new(&user)).await?;
// Schedule a device sync job
repo.queue_job()
.schedule_job(&mut rng, &clock, SyncDevicesJob::new(&user))
.await?;
Ok(())
Ok(())
}
}
/// Job to sync the list of devices of a user with the homeserver.
#[tracing::instrument(
name = "job.sync_devices",
fields(user.id = %job.user_id()),
skip_all,
err(Debug),
)]
async fn sync_devices(
job: JobWithSpanContext<SyncDevicesJob>,
ctx: JobContext,
) -> Result<(), anyhow::Error> {
let state = ctx.state();
let matrix = state.matrix_connection();
let mut repo = state.repository().await?;
#[async_trait]
impl RunnableJob for SyncDevicesJob {
#[tracing::instrument(
name = "job.sync_devices",
fields(user.id = %self.user_id()),
skip_all,
err(Debug),
)]
async fn run(&self, state: &State, _context: JobContext) -> Result<(), anyhow::Error> {
let matrix = state.matrix_connection();
let mut repo = state.repository().await?;
let user = repo
.user()
.lookup(job.user_id())
.await?
.context("User not found")?;
let user = repo
.user()
.lookup(self.user_id())
.await?
.context("User not found")?;
// Lock the user sync to make sure we don't get into a race condition
repo.user().acquire_lock_for_sync(&user).await?;
// Lock the user sync to make sure we don't get into a race condition
repo.user().acquire_lock_for_sync(&user).await?;
let mut devices = HashSet::new();
let mut devices = HashSet::new();
// Cycle through all the compat sessions of the user, and grab the devices
let mut cursor = Pagination::first(100);
loop {
let page = repo
.compat_session()
.list(
CompatSessionFilter::new().for_user(&user).active_only(),
cursor,
)
.await?;
// Cycle through all the compat sessions of the user, and grab the devices
let mut cursor = Pagination::first(100);
loop {
let page = repo
.compat_session()
.list(
CompatSessionFilter::new().for_user(&user).active_only(),
cursor,
)
.await?;
for (compat_session, _) in page.edges {
devices.insert(compat_session.device.as_str().to_owned());
cursor = cursor.after(compat_session.id);
}
if !page.has_next_page {
break;
}
}
// Cycle though all the oauth2 sessions of the user, and grab the devices
let mut cursor = Pagination::first(100);
loop {
let page = repo
.oauth2_session()
.list(
OAuth2SessionFilter::new().for_user(&user).active_only(),
cursor,
)
.await?;
for oauth2_session in page.edges {
for scope in &*oauth2_session.scope {
if let Some(device) = Device::from_scope_token(scope) {
devices.insert(device.as_str().to_owned());
}
for (compat_session, _) in page.edges {
devices.insert(compat_session.device.as_str().to_owned());
cursor = cursor.after(compat_session.id);
}
cursor = cursor.after(oauth2_session.id);
if !page.has_next_page {
break;
}
}
if !page.has_next_page {
break;
// Cycle though all the oauth2 sessions of the user, and grab the devices
let mut cursor = Pagination::first(100);
loop {
let page = repo
.oauth2_session()
.list(
OAuth2SessionFilter::new().for_user(&user).active_only(),
cursor,
)
.await?;
for oauth2_session in page.edges {
for scope in &*oauth2_session.scope {
if let Some(device) = Device::from_scope_token(scope) {
devices.insert(device.as_str().to_owned());
}
}
cursor = cursor.after(oauth2_session.id);
}
if !page.has_next_page {
break;
}
}
let mxid = matrix.mxid(&user.username);
matrix.sync_devices(&mxid, devices).await?;
// We kept the connection until now, so that we still hold the lock on the user
// throughout the sync
repo.save().await?;
Ok(())
}
let mxid = matrix.mxid(&user.username);
matrix.sync_devices(&mxid, devices).await?;
// We kept the connection until now, so that we still hold the lock on the user
// throughout the sync
repo.save().await?;
Ok(())
}
pub(crate) fn register(
suffix: &str,
monitor: Monitor<TokioExecutor>,
state: &State,
storage_factory: &PostgresStorageFactory,
) -> Monitor<TokioExecutor> {
let provision_user_worker =
crate::build!(ProvisionUserJob => provision_user, suffix, state, storage_factory);
let provision_device_worker =
crate::build!(ProvisionDeviceJob => provision_device, suffix, state, storage_factory);
let delete_device_worker =
crate::build!(DeleteDeviceJob => delete_device, suffix, state, storage_factory);
let sync_devices_worker =
crate::build!(SyncDevicesJob => sync_devices, suffix, state, storage_factory);
monitor
.register(provision_user_worker)
.register(provision_device_worker)
.register(delete_device_worker)
.register(sync_devices_worker)
}

View File

@@ -3,12 +3,12 @@
// SPDX-License-Identifier: AGPL-3.0-only
// Please see LICENSE in the repository root for full details.
use std::collections::HashMap;
use std::{collections::HashMap, sync::Arc};
use async_trait::async_trait;
use chrono::{DateTime, Duration, Utc};
use mas_storage::{
queue::{InsertableJob, Job, Worker},
queue::{InsertableJob, Job, JobMetadata, Worker},
Clock, RepositoryAccess, RepositoryError,
};
use mas_storage_pg::{DatabaseError, PgRepository};
@@ -20,12 +20,42 @@ use sqlx::{
Acquire, Either,
};
use thiserror::Error;
use tokio::task::JoinSet;
use tokio_util::sync::CancellationToken;
use tracing::{Instrument as _, Span};
use tracing_opentelemetry::OpenTelemetrySpanExt as _;
use ulid::Ulid;
use crate::State;
type JobPayload = serde_json::Value;
#[derive(Clone)]
pub struct JobContext {
pub id: Ulid,
pub metadata: JobMetadata,
pub queue_name: String,
pub cancellation_token: CancellationToken,
}
impl JobContext {
pub fn span(&self) -> Span {
let span = tracing::info_span!(
parent: Span::none(),
"job.run",
job.id = %self.id,
job.queue_name = self.queue_name,
job.attempt = self.attempt,
);
span.add_link(self.metadata.span_context());
span
}
}
pub trait FromJob {
fn from_job(job: &Job) -> Result<Self, anyhow::Error>
fn from_job(payload: JobPayload) -> Result<Self, anyhow::Error>
where
Self: Sized;
}
@@ -34,14 +64,14 @@ impl<T> FromJob for T
where
T: DeserializeOwned,
{
fn from_job(job: &Job) -> Result<Self, anyhow::Error> {
serde_json::from_value(job.payload.clone()).map_err(Into::into)
fn from_job(payload: JobPayload) -> Result<Self, anyhow::Error> {
serde_json::from_value(payload).map_err(Into::into)
}
}
#[async_trait]
pub trait RunnableJob: FromJob + Send + 'static {
async fn run(&self, state: &State) -> Result<(), anyhow::Error>;
async fn run(&self, state: &State, context: JobContext) -> Result<(), anyhow::Error>;
}
fn box_runnable_job<T: RunnableJob + 'static>(job: T) -> Box<dyn RunnableJob> {
@@ -79,7 +109,13 @@ pub enum QueueRunnerError {
const MIN_SLEEP_DURATION: std::time::Duration = std::time::Duration::from_millis(900);
const MAX_SLEEP_DURATION: std::time::Duration = std::time::Duration::from_millis(1100);
type JobFactory = Box<dyn FnMut(&Job) -> Box<dyn RunnableJob> + Send>;
// How many jobs can we run concurrently
const MAX_CONCURRENT_JOBS: usize = 10;
// How many jobs can we fetch at once
const MAX_JOBS_TO_FETCH: usize = 5;
type JobFactory = Arc<dyn Fn(JobPayload) -> Box<dyn RunnableJob> + Send + Sync>;
pub struct QueueWorker {
rng: ChaChaRng,
@@ -89,7 +125,14 @@ pub struct QueueWorker {
am_i_leader: bool,
last_heartbeat: DateTime<Utc>,
cancellation_token: CancellationToken,
state: State,
running_jobs: JoinSet<Result<(), anyhow::Error>>,
job_contexts: HashMap<tokio::task::Id, JobContext>,
factories: HashMap<&'static str, JobFactory>,
#[allow(clippy::type_complexity)]
last_join_result:
Option<Result<(tokio::task::Id, Result<(), anyhow::Error>), tokio::task::JoinError>>,
}
impl QueueWorker {
@@ -115,6 +158,12 @@ impl QueueWorker {
.await
.map_err(QueueRunnerError::SetupListener)?;
// We get notifications when a job is available on this channel
listener
.listen("queue_available")
.await
.map_err(QueueRunnerError::SetupListener)?;
let txn = listener
.begin()
.await
@@ -139,14 +188,22 @@ impl QueueWorker {
am_i_leader: false,
last_heartbeat: now,
cancellation_token,
state,
job_contexts: HashMap::new(),
running_jobs: JoinSet::new(),
factories: HashMap::new(),
last_join_result: None,
})
}
pub fn register_handler<T: RunnableJob + InsertableJob>(&mut self) -> &mut Self {
// TODO: error handling
let factory = |job: &Job| box_runnable_job(T::from_job(job).unwrap());
self.factories.insert(T::QUEUE_NAME, Box::new(factory));
// There is a potential panic here, which is fine as it's going to be caught
// within the job task
let factory = |payload: JobPayload| {
box_runnable_job(T::from_job(payload).expect("Failed to deserialize job"))
};
self.factories.insert(T::QUEUE_NAME, Arc::new(factory));
self
}
@@ -164,6 +221,7 @@ impl QueueWorker {
async fn run_loop(&mut self) -> Result<(), QueueRunnerError> {
self.wait_until_wakeup().await?;
// TODO: join all the jobs handles when shutting down
if self.cancellation_token.is_cancelled() {
return Ok(());
}
@@ -214,6 +272,8 @@ impl QueueWorker {
.sample(Uniform::new(MIN_SLEEP_DURATION, MAX_SLEEP_DURATION));
let wakeup_sleep = tokio::time::sleep(sleep_duration);
// TODO: add metrics to track the wake up reasons
tokio::select! {
() = self.cancellation_token.cancelled() => {
tracing::debug!("Woke up from cancellation");
@@ -223,6 +283,11 @@ impl QueueWorker {
tracing::debug!("Woke up from sleep");
},
Some(result) = self.running_jobs.join_next_with_id() => {
tracing::debug!("Joined job task");
self.last_join_result = Some(result);
},
notification = self.listener.recv() => {
match notification {
Ok(notification) => {
@@ -281,6 +346,127 @@ impl QueueWorker {
.try_get_leader_lease(&self.clock, &self.registration)
.await?;
// Find any job task which finished
// If we got woken up by a join on the joinset, it will be stored in the
// last_join_result so that we don't loose it
if self.last_join_result.is_none() {
self.last_join_result = self.running_jobs.try_join_next_with_id();
}
while let Some(result) = self.last_join_result.take() {
// TODO: add metrics to track the job status and the time it took
let context = match result {
Ok((id, Ok(()))) => {
// The job succeeded
let context = self
.job_contexts
.remove(&id)
.expect("Job context not found");
tracing::info!(
job.id = %context.id,
job.queue_name = %context.queue_name,
"Job completed"
);
context
}
Ok((id, Err(e))) => {
// The job failed
let context = self
.job_contexts
.remove(&id)
.expect("Job context not found");
tracing::error!(
error = ?e,
job.id = %context.id,
job.queue_name = %context.queue_name,
"Job failed"
);
// TODO: reschedule the job
context
}
Err(e) => {
// The job crashed (or was cancelled)
let id = e.id();
let context = self
.job_contexts
.remove(&id)
.expect("Job context not found");
tracing::error!(
error = &e as &dyn std::error::Error,
job.id = %context.id,
job.queue_name = %context.queue_name,
"Job crashed"
);
// TODO: reschedule the job
context
}
};
repo.queue_job()
.mark_as_completed(&self.clock, context.id)
.await?;
self.last_join_result = self.running_jobs.try_join_next_with_id();
}
// Compute how many jobs we should fetch at most
let max_jobs_to_fetch = MAX_CONCURRENT_JOBS
.saturating_sub(self.running_jobs.len())
.max(MAX_JOBS_TO_FETCH);
if max_jobs_to_fetch == 0 {
tracing::warn!("Internal job queue is full, not fetching any new jobs");
} else {
// Grab a few jobs in the queue
let queues = self.factories.keys().copied().collect::<Vec<_>>();
let jobs = repo
.queue_job()
.reserve(&self.clock, &self.registration, &queues, max_jobs_to_fetch)
.await?;
for Job {
id,
queue_name,
payload,
metadata,
} in jobs
{
let cancellation_token = self.cancellation_token.child_token();
let factory = self.factories.get(queue_name.as_str()).cloned();
let context = JobContext {
id,
metadata,
queue_name,
cancellation_token,
};
let task = {
let context = context.clone();
let span = context.span();
let state = self.state.clone();
async move {
// We should never crash, but in case we do, we do that in the task and
// don't crash the worker
let job = factory.expect("unknown job factory")(payload);
job.run(&state, context).await
}
.instrument(span)
};
let handle = self.running_jobs.spawn(task);
self.job_contexts.insert(handle.id(), context);
}
}
// After this point, we are locking the leader table, so it's important that we
// commit as soon as possible to not block the other workers for too long
repo.into_inner()
@@ -353,6 +539,8 @@ impl QueueWorker {
.shutdown_dead_workers(&self.clock, Duration::minutes(2))
.await?;
// TODO: mark tasks those workers had as lost
// Release the leader lock
let txn = repo
.into_inner()

View File

@@ -5,11 +5,10 @@
// Please see LICENSE in the repository root for full details.
use anyhow::Context;
use apalis_core::{context::JobContext, executor::TokioExecutor, monitor::Monitor};
use async_trait::async_trait;
use mas_email::{Address, Mailbox};
use mas_i18n::DataLocale;
use mas_storage::{
job::JobWithSpanContext,
queue::SendAccountRecoveryEmailsJob,
user::{UserEmailFilter, UserRecoveryRepository},
Pagination, RepositoryAccess,
@@ -18,117 +17,108 @@ use mas_templates::{EmailRecoveryContext, TemplateContext};
use rand::distributions::{Alphanumeric, DistString};
use tracing::{error, info};
use crate::{storage::PostgresStorageFactory, JobContextExt, State};
use crate::{
new_queue::{JobContext, RunnableJob},
State,
};
/// Job to send account recovery emails for a given recovery session.
#[tracing::instrument(
name = "job.send_account_recovery_email",
fields(
user_recovery_session.id = %job.user_recovery_session_id(),
user_recovery_session.email,
),
skip_all,
err(Debug),
)]
async fn send_account_recovery_email_job(
job: JobWithSpanContext<SendAccountRecoveryEmailsJob>,
ctx: JobContext,
) -> Result<(), anyhow::Error> {
let state = ctx.state();
let clock = state.clock();
let mailer = state.mailer();
let url_builder = state.url_builder();
let mut rng = state.rng();
let mut repo = state.repository().await?;
#[async_trait]
impl RunnableJob for SendAccountRecoveryEmailsJob {
#[tracing::instrument(
name = "job.send_account_recovery_email",
fields(
user_recovery_session.id = %self.user_recovery_session_id(),
user_recovery_session.email,
),
skip_all,
err(Debug),
)]
async fn run(&self, state: &State, _context: JobContext) -> Result<(), anyhow::Error> {
let clock = state.clock();
let mailer = state.mailer();
let url_builder = state.url_builder();
let mut rng = state.rng();
let mut repo = state.repository().await?;
let session = repo
.user_recovery()
.lookup_session(job.user_recovery_session_id())
.await?
.context("User recovery session not found")?;
let session = repo
.user_recovery()
.lookup_session(self.user_recovery_session_id())
.await?
.context("User recovery session not found")?;
tracing::Span::current().record("user_recovery_session.email", &session.email);
tracing::Span::current().record("user_recovery_session.email", &session.email);
if session.consumed_at.is_some() {
info!("Recovery session already consumed, not sending email");
return Ok(());
}
if session.consumed_at.is_some() {
info!("Recovery session already consumed, not sending email");
return Ok(());
}
let mut cursor = Pagination::first(50);
let mut cursor = Pagination::first(50);
let lang: DataLocale = session
.locale
.parse()
.context("Invalid locale in database on recovery session")?;
let lang: DataLocale = session
.locale
.parse()
.context("Invalid locale in database on recovery session")?;
loop {
let page = repo
.user_email()
.list(
UserEmailFilter::new()
.for_email(&session.email)
.verified_only(),
cursor,
)
.await?;
for email in page.edges {
let ticket = Alphanumeric.sample_string(&mut rng, 32);
let ticket = repo
.user_recovery()
.add_ticket(&mut rng, &clock, &session, &email, ticket)
loop {
let page = repo
.user_email()
.list(
UserEmailFilter::new()
.for_email(&session.email)
.verified_only(),
cursor,
)
.await?;
let user_email = repo
.user_email()
.lookup(email.id)
.await?
.context("User email not found")?;
for email in page.edges {
let ticket = Alphanumeric.sample_string(&mut rng, 32);
let user = repo
.user()
.lookup(user_email.user_id)
.await?
.context("User not found")?;
let ticket = repo
.user_recovery()
.add_ticket(&mut rng, &clock, &session, &email, ticket)
.await?;
let url = url_builder.account_recovery_link(ticket.ticket);
let user_email = repo
.user_email()
.lookup(email.id)
.await?
.context("User email not found")?;
let address: Address = user_email.email.parse()?;
let mailbox = Mailbox::new(Some(user.username.clone()), address);
let user = repo
.user()
.lookup(user_email.user_id)
.await?
.context("User not found")?;
info!("Sending recovery email to {}", mailbox);
let context =
EmailRecoveryContext::new(user, session.clone(), url).with_language(lang.clone());
let url = url_builder.account_recovery_link(ticket.ticket);
// XXX: we only log if the email fails to send, to avoid stopping the loop
if let Err(e) = mailer.send_recovery_email(mailbox, &context).await {
error!(
error = &e as &dyn std::error::Error,
"Failed to send recovery email"
);
let address: Address = user_email.email.parse()?;
let mailbox = Mailbox::new(Some(user.username.clone()), address);
info!("Sending recovery email to {}", mailbox);
let context = EmailRecoveryContext::new(user, session.clone(), url)
.with_language(lang.clone());
// XXX: we only log if the email fails to send, to avoid stopping the loop
if let Err(e) = mailer.send_recovery_email(mailbox, &context).await {
error!(
error = &e as &dyn std::error::Error,
"Failed to send recovery email"
);
}
cursor = cursor.after(email.id);
}
cursor = cursor.after(email.id);
if !page.has_next_page {
break;
}
}
if !page.has_next_page {
break;
}
repo.save().await?;
Ok(())
}
repo.save().await?;
Ok(())
}
pub(crate) fn register(
suffix: &str,
monitor: Monitor<TokioExecutor>,
state: &State,
storage_factory: &PostgresStorageFactory,
) -> Monitor<TokioExecutor> {
let send_user_recovery_email_worker = crate::build!(SendAccountRecoveryEmailsJob => send_account_recovery_email_job, suffix, state, storage_factory);
monitor.register(send_user_recovery_email_worker)
}

View File

@@ -1,70 +0,0 @@
// Copyright 2024 New Vector Ltd.
// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
//
// SPDX-License-Identifier: AGPL-3.0-only
// Please see LICENSE in the repository root for full details.
use std::str::FromStr;
use apalis_core::{context::JobContext, job::JobId, request::JobRequest, worker::WorkerId};
use chrono::{DateTime, Utc};
use serde_json::Value;
use sqlx::Row;
/// Wrapper for [`JobRequest`]
pub(crate) struct SqlJobRequest<T>(JobRequest<T>);
impl<T> From<SqlJobRequest<T>> for JobRequest<T> {
fn from(val: SqlJobRequest<T>) -> Self {
val.0
}
}
impl<'r, T: serde::de::DeserializeOwned> sqlx::FromRow<'r, sqlx::postgres::PgRow>
for SqlJobRequest<T>
{
fn from_row(row: &'r sqlx::postgres::PgRow) -> Result<Self, sqlx::Error> {
let job: Value = row.try_get("job")?;
let id: JobId =
JobId::from_str(row.try_get("id")?).map_err(|e| sqlx::Error::ColumnDecode {
index: "id".to_owned(),
source: Box::new(e),
})?;
let mut context = JobContext::new(id);
let run_at = row.try_get("run_at")?;
context.set_run_at(run_at);
let attempts = row.try_get("attempts").unwrap_or(0);
context.set_attempts(attempts);
let max_attempts = row.try_get("max_attempts").unwrap_or(25);
context.set_max_attempts(max_attempts);
let done_at: Option<DateTime<Utc>> = row.try_get("done_at").unwrap_or_default();
context.set_done_at(done_at);
let lock_at: Option<DateTime<Utc>> = row.try_get("lock_at").unwrap_or_default();
context.set_lock_at(lock_at);
let last_error = row.try_get("last_error").unwrap_or_default();
context.set_last_error(last_error);
let status: String = row.try_get("status")?;
context.set_status(status.parse().map_err(|e| sqlx::Error::ColumnDecode {
index: "job".to_owned(),
source: Box::new(e),
})?);
let lock_by: Option<String> = row.try_get("lock_by").unwrap_or_default();
context.set_lock_by(lock_by.map(WorkerId::new));
Ok(SqlJobRequest(JobRequest::new_with_context(
serde_json::from_value(job).map_err(|e| sqlx::Error::ColumnDecode {
index: "job".to_owned(),
source: Box::new(e),
})?,
context,
)))
}
}

View File

@@ -1,14 +0,0 @@
// Copyright 2024 New Vector Ltd.
// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
//
// SPDX-License-Identifier: AGPL-3.0-only
// Please see LICENSE in the repository root for full details.
//! Reimplementation of the [`apalis_sql::storage::PostgresStorage`] using a
//! shared connection for the [`PgListener`]
mod from_row;
mod postgres;
use self::from_row::SqlJobRequest;
pub(crate) use self::postgres::StorageFactory as PostgresStorageFactory;

View File

@@ -1,391 +0,0 @@
// Copyright 2024 New Vector Ltd.
// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
//
// SPDX-License-Identifier: AGPL-3.0-only
// Please see LICENSE in the repository root for full details.
use std::{convert::TryInto, marker::PhantomData, ops::Add, sync::Arc, time::Duration};
use apalis_core::{
error::JobStreamError,
job::{Job, JobId, JobStreamResult},
request::JobRequest,
storage::{StorageError, StorageResult, StorageWorkerPulse},
utils::Timer,
worker::WorkerId,
};
use async_stream::try_stream;
use chrono::{DateTime, Utc};
use event_listener::Event;
use futures_lite::{Stream, StreamExt};
use serde::{de::DeserializeOwned, Serialize};
use sqlx::{postgres::PgListener, PgPool, Pool, Postgres, Row};
use tokio::task::JoinHandle;
use super::SqlJobRequest;
pub struct StorageFactory {
pool: PgPool,
event: Arc<Event>,
}
impl StorageFactory {
pub fn new(pool: Pool<Postgres>) -> Self {
StorageFactory {
pool,
event: Arc::new(Event::new()),
}
}
pub async fn listen(self) -> Result<JoinHandle<()>, sqlx::Error> {
let mut listener = PgListener::connect_with(&self.pool).await?;
listener.listen("apalis::job").await?;
let handle = tokio::spawn(async move {
loop {
let notification = listener.recv().await.expect("Failed to poll notification");
self.event.notify(usize::MAX);
tracing::debug!(?notification, "Broadcast notification");
}
});
Ok(handle)
}
pub fn build<T>(&self) -> Storage<T> {
Storage {
pool: self.pool.clone(),
event: self.event.clone(),
job_type: PhantomData,
}
}
}
/// Represents a [`apalis_core::storage::Storage`] that persists to Postgres
#[derive(Debug)]
pub struct Storage<T> {
pool: PgPool,
event: Arc<Event>,
job_type: PhantomData<T>,
}
impl<T> Clone for Storage<T> {
fn clone(&self) -> Self {
Storage {
pool: self.pool.clone(),
event: self.event.clone(),
job_type: PhantomData,
}
}
}
impl<T: DeserializeOwned + Send + Unpin + Job> Storage<T> {
fn stream_jobs(
&self,
worker_id: &WorkerId,
interval: Duration,
buffer_size: usize,
) -> impl Stream<Item = Result<JobRequest<T>, JobStreamError>> {
let pool = self.pool.clone();
let sleeper = apalis_core::utils::timer::TokioTimer;
let worker_id = worker_id.clone();
let event = self.event.clone();
try_stream! {
loop {
// Wait for a notification or a timeout
let listener = event.listen();
let interval = sleeper.sleep(interval);
futures_lite::future::race(interval, listener).await;
let tx = pool.clone();
let job_type = T::NAME;
let fetch_query = "SELECT * FROM apalis.get_jobs($1, $2, $3);";
let jobs: Vec<SqlJobRequest<T>> = sqlx::query_as(fetch_query)
.bind(worker_id.name())
.bind(job_type)
// https://docs.rs/sqlx/latest/sqlx/postgres/types/index.html
.bind(i32::try_from(buffer_size).map_err(|e| JobStreamError::BrokenPipe(Box::from(e)))?)
.fetch_all(&tx)
.await.map_err(|e| JobStreamError::BrokenPipe(Box::from(e)))?;
for job in jobs {
yield job.into()
}
}
}
}
async fn keep_alive_at<Service>(
&mut self,
worker_id: &WorkerId,
last_seen: DateTime<Utc>,
) -> StorageResult<()> {
let pool = self.pool.clone();
let worker_type = T::NAME;
let storage_name = std::any::type_name::<Self>();
let query = "INSERT INTO apalis.workers (id, worker_type, storage_name, layers, last_seen)
VALUES ($1, $2, $3, $4, $5)
ON CONFLICT (id) DO
UPDATE SET last_seen = EXCLUDED.last_seen";
sqlx::query(query)
.bind(worker_id.name())
.bind(worker_type)
.bind(storage_name)
.bind(std::any::type_name::<Service>())
.bind(last_seen)
.execute(&pool)
.await
.map_err(|e| StorageError::Database(Box::from(e)))?;
Ok(())
}
}
#[async_trait::async_trait]
impl<T> apalis_core::storage::Storage for Storage<T>
where
T: Job + Serialize + DeserializeOwned + Send + 'static + Unpin + Sync,
{
type Output = T;
/// Push a job to Postgres [Storage]
///
/// # SQL Example
///
/// ```sql
/// SELECT apalis.push_job(job_type::text, job::json);
/// ```
async fn push(&mut self, job: Self::Output) -> StorageResult<JobId> {
let id = JobId::new();
let query = "INSERT INTO apalis.jobs VALUES ($1, $2, $3, 'Pending', 0, 25, NOW() , NULL, NULL, NULL, NULL)";
let pool = self.pool.clone();
let job = serde_json::to_value(&job).map_err(|e| StorageError::Parse(Box::from(e)))?;
let job_type = T::NAME;
sqlx::query(query)
.bind(job)
.bind(id.to_string())
.bind(job_type)
.execute(&pool)
.await
.map_err(|e| StorageError::Database(Box::from(e)))?;
Ok(id)
}
async fn schedule(
&mut self,
job: Self::Output,
on: chrono::DateTime<Utc>,
) -> StorageResult<JobId> {
let query =
"INSERT INTO apalis.jobs VALUES ($1, $2, $3, 'Pending', 0, 25, $4, NULL, NULL, NULL, NULL)";
let mut conn = self
.pool
.acquire()
.await
.map_err(|e| StorageError::Connection(Box::from(e)))?;
let id = JobId::new();
let job = serde_json::to_value(&job).map_err(|e| StorageError::Parse(Box::from(e)))?;
let job_type = T::NAME;
sqlx::query(query)
.bind(job)
.bind(id.to_string())
.bind(job_type)
.bind(on)
.execute(&mut *conn)
.await
.map_err(|e| StorageError::Database(Box::from(e)))?;
Ok(id)
}
async fn fetch_by_id(&self, job_id: &JobId) -> StorageResult<Option<JobRequest<Self::Output>>> {
let mut conn = self
.pool
.acquire()
.await
.map_err(|e| StorageError::Connection(Box::from(e)))?;
let fetch_query = "SELECT * FROM apalis.jobs WHERE id = $1";
let res: Option<SqlJobRequest<T>> = sqlx::query_as(fetch_query)
.bind(job_id.to_string())
.fetch_optional(&mut *conn)
.await
.map_err(|e| StorageError::Database(Box::from(e)))?;
Ok(res.map(Into::into))
}
async fn heartbeat(&mut self, pulse: StorageWorkerPulse) -> StorageResult<bool> {
match pulse {
StorageWorkerPulse::EnqueueScheduled { count: _ } => {
// Ideally jobs are queue via run_at. So this is not necessary
Ok(true)
}
// Worker not seen in 5 minutes yet has running jobs
StorageWorkerPulse::ReenqueueOrphaned { count, .. } => {
let job_type = T::NAME;
let mut conn = self
.pool
.acquire()
.await
.map_err(|e| StorageError::Database(Box::from(e)))?;
let query = "UPDATE apalis.jobs
SET status = 'Pending', done_at = NULL, lock_by = NULL, lock_at = NULL, last_error ='Job was abandoned'
WHERE id in
(SELECT jobs.id from apalis.jobs INNER join apalis.workers ON lock_by = workers.id
WHERE status = 'Running' AND workers.last_seen < NOW() - INTERVAL '5 minutes'
AND workers.worker_type = $1 ORDER BY lock_at ASC LIMIT $2);";
sqlx::query(query)
.bind(job_type)
.bind(count)
.execute(&mut *conn)
.await
.map_err(|e| StorageError::Database(Box::from(e)))?;
Ok(true)
}
_ => unimplemented!(),
}
}
async fn kill(&mut self, worker_id: &WorkerId, job_id: &JobId) -> StorageResult<()> {
let pool = self.pool.clone();
let mut conn = pool
.acquire()
.await
.map_err(|e| StorageError::Connection(Box::from(e)))?;
let query =
"UPDATE apalis.jobs SET status = 'Killed', done_at = now() WHERE id = $1 AND lock_by = $2";
sqlx::query(query)
.bind(job_id.to_string())
.bind(worker_id.name())
.execute(&mut *conn)
.await
.map_err(|e| StorageError::Database(Box::from(e)))?;
Ok(())
}
/// Puts the job instantly back into the queue
/// Another [Worker] may consume
async fn retry(&mut self, worker_id: &WorkerId, job_id: &JobId) -> StorageResult<()> {
let pool = self.pool.clone();
let mut conn = pool
.acquire()
.await
.map_err(|e| StorageError::Connection(Box::from(e)))?;
let query =
"UPDATE apalis.jobs SET status = 'Pending', done_at = NULL, lock_by = NULL WHERE id = $1 AND lock_by = $2";
sqlx::query(query)
.bind(job_id.to_string())
.bind(worker_id.name())
.execute(&mut *conn)
.await
.map_err(|e| StorageError::Database(Box::from(e)))?;
Ok(())
}
fn consume(
&mut self,
worker_id: &WorkerId,
interval: Duration,
buffer_size: usize,
) -> JobStreamResult<T> {
Box::pin(
self.stream_jobs(worker_id, interval, buffer_size)
.map(|r| r.map(Some)),
)
}
async fn len(&self) -> StorageResult<i64> {
let pool = self.pool.clone();
let query = "SELECT COUNT(*) AS count FROM apalis.jobs WHERE status = 'Pending'";
let record = sqlx::query(query)
.fetch_one(&pool)
.await
.map_err(|e| StorageError::Database(Box::from(e)))?;
Ok(record
.try_get("count")
.map_err(|e| StorageError::Database(Box::from(e)))?)
}
async fn ack(&mut self, worker_id: &WorkerId, job_id: &JobId) -> StorageResult<()> {
let pool = self.pool.clone();
let query =
"UPDATE apalis.jobs SET status = 'Done', done_at = now() WHERE id = $1 AND lock_by = $2";
sqlx::query(query)
.bind(job_id.to_string())
.bind(worker_id.name())
.execute(&pool)
.await
.map_err(|e| StorageError::Database(Box::from(e)))?;
Ok(())
}
async fn reschedule(&mut self, job: &JobRequest<T>, wait: Duration) -> StorageResult<()> {
let pool = self.pool.clone();
let job_id = job.id();
let wait: i64 = wait
.as_secs()
.try_into()
.map_err(|e| StorageError::Database(Box::new(e)))?;
let wait = chrono::Duration::microseconds(wait * 1000 * 1000);
// TODO: should we use a clock here?
#[allow(clippy::disallowed_methods)]
let run_at = Utc::now().add(wait);
let mut conn = pool
.acquire()
.await
.map_err(|e| StorageError::Connection(Box::from(e)))?;
let query =
"UPDATE apalis.jobs SET status = 'Pending', done_at = NULL, lock_by = NULL, lock_at = NULL, run_at = $2 WHERE id = $1";
sqlx::query(query)
.bind(job_id.to_string())
.bind(run_at)
.execute(&mut *conn)
.await
.map_err(|e| StorageError::Database(Box::from(e)))?;
Ok(())
}
async fn update_by_id(
&self,
job_id: &JobId,
job: &JobRequest<Self::Output>,
) -> StorageResult<()> {
let pool = self.pool.clone();
let status = job.status().as_ref();
let attempts = job.attempts();
let done_at = *job.done_at();
let lock_by = job.lock_by().clone();
let lock_at = *job.lock_at();
let last_error = job.last_error().clone();
let mut conn = pool
.acquire()
.await
.map_err(|e| StorageError::Connection(Box::from(e)))?;
let query =
"UPDATE apalis.jobs SET status = $1, attempts = $2, done_at = $3, lock_by = $4, lock_at = $5, last_error = $6 WHERE id = $7";
sqlx::query(query)
.bind(status.to_owned())
.bind(attempts)
.bind(done_at)
.bind(lock_by.as_ref().map(WorkerId::name))
.bind(lock_at)
.bind(last_error)
.bind(job_id.to_string())
.execute(&mut *conn)
.await
.map_err(|e| StorageError::Database(Box::from(e)))?;
Ok(())
}
async fn keep_alive<Service>(&mut self, worker_id: &WorkerId) -> StorageResult<()> {
#[allow(clippy::disallowed_methods)]
let now = Utc::now();
self.keep_alive_at::<Service>(worker_id, now).await
}
}

View File

@@ -5,10 +5,9 @@
// Please see LICENSE in the repository root for full details.
use anyhow::Context;
use apalis_core::{context::JobContext, executor::TokioExecutor, monitor::Monitor};
use async_trait::async_trait;
use mas_storage::{
compat::CompatSessionFilter,
job::JobWithSpanContext,
oauth2::OAuth2SessionFilter,
queue::{DeactivateUserJob, ReactivateUserJob},
user::{BrowserSessionFilter, UserRepository},
@@ -16,122 +15,106 @@ use mas_storage::{
};
use tracing::info;
use crate::{storage::PostgresStorageFactory, JobContextExt, State};
use crate::{
new_queue::{JobContext, RunnableJob},
State,
};
/// Job to deactivate a user, both locally and on the Matrix homeserver.
#[tracing::instrument(
#[async_trait]
impl RunnableJob for DeactivateUserJob {
#[tracing::instrument(
name = "job.deactivate_user"
fields(user.id = %job.user_id(), erase = %job.hs_erase()),
skip_all,
err(Debug),
)]
async fn deactivate_user(
job: JobWithSpanContext<DeactivateUserJob>,
ctx: JobContext,
) -> Result<(), anyhow::Error> {
let state = ctx.state();
let clock = state.clock();
let matrix = state.matrix_connection();
let mut repo = state.repository().await?;
fields(user.id = %self.user_id(), erase = %self.hs_erase()),
skip_all,
err(Debug),
)]
async fn run(&self, state: &State, _context: JobContext) -> Result<(), anyhow::Error> {
let clock = state.clock();
let matrix = state.matrix_connection();
let mut repo = state.repository().await?;
let user = repo
.user()
.lookup(job.user_id())
.await?
.context("User not found")?;
let user = repo
.user()
.lookup(self.user_id())
.await?
.context("User not found")?;
// Let's first lock the user
let user = repo
.user()
.lock(&clock, user)
.await
.context("Failed to lock user")?;
// Let's first lock the user
let user = repo
.user()
.lock(&clock, user)
.await
.context("Failed to lock user")?;
// Kill all sessions for the user
let n = repo
.browser_session()
.finish_bulk(
&clock,
BrowserSessionFilter::new().for_user(&user).active_only(),
)
.await?;
info!(affected = n, "Killed all browser sessions for user");
// Kill all sessions for the user
let n = repo
.browser_session()
.finish_bulk(
&clock,
BrowserSessionFilter::new().for_user(&user).active_only(),
)
.await?;
info!(affected = n, "Killed all browser sessions for user");
let n = repo
.oauth2_session()
.finish_bulk(
&clock,
OAuth2SessionFilter::new().for_user(&user).active_only(),
)
.await?;
info!(affected = n, "Killed all OAuth 2.0 sessions for user");
let n = repo
.oauth2_session()
.finish_bulk(
&clock,
OAuth2SessionFilter::new().for_user(&user).active_only(),
)
.await?;
info!(affected = n, "Killed all OAuth 2.0 sessions for user");
let n = repo
.compat_session()
.finish_bulk(
&clock,
CompatSessionFilter::new().for_user(&user).active_only(),
)
.await?;
info!(affected = n, "Killed all compatibility sessions for user");
let n = repo
.compat_session()
.finish_bulk(
&clock,
CompatSessionFilter::new().for_user(&user).active_only(),
)
.await?;
info!(affected = n, "Killed all compatibility sessions for user");
// Before calling back to the homeserver, commit the changes to the database, as
// we want the user to be locked out as soon as possible
repo.save().await?;
// Before calling back to the homeserver, commit the changes to the database, as
// we want the user to be locked out as soon as possible
repo.save().await?;
let mxid = matrix.mxid(&user.username);
info!("Deactivating user {} on homeserver", mxid);
matrix.delete_user(&mxid, job.hs_erase()).await?;
let mxid = matrix.mxid(&user.username);
info!("Deactivating user {} on homeserver", mxid);
matrix.delete_user(&mxid, self.hs_erase()).await?;
Ok(())
Ok(())
}
}
/// Job to reactivate a user, both locally and on the Matrix homeserver.
#[tracing::instrument(
name = "job.reactivate_user",
fields(user.id = %job.user_id()),
skip_all,
err(Debug),
)]
pub async fn reactivate_user(
job: JobWithSpanContext<ReactivateUserJob>,
ctx: JobContext,
) -> Result<(), anyhow::Error> {
let state = ctx.state();
let matrix = state.matrix_connection();
let mut repo = state.repository().await?;
#[async_trait]
impl RunnableJob for ReactivateUserJob {
#[tracing::instrument(
name = "job.reactivate_user",
fields(user.id = %self.user_id()),
skip_all,
err(Debug),
)]
async fn run(&self, state: &State, _context: JobContext) -> Result<(), anyhow::Error> {
let matrix = state.matrix_connection();
let mut repo = state.repository().await?;
let user = repo
.user()
.lookup(job.user_id())
.await?
.context("User not found")?;
let user = repo
.user()
.lookup(self.user_id())
.await?
.context("User not found")?;
let mxid = matrix.mxid(&user.username);
info!("Reactivating user {} on homeserver", mxid);
matrix.reactivate_user(&mxid).await?;
let mxid = matrix.mxid(&user.username);
info!("Reactivating user {} on homeserver", mxid);
matrix.reactivate_user(&mxid).await?;
// We want to unlock the user from our side only once it has been reactivated on
// the homeserver
let _user = repo.user().unlock(user).await?;
repo.save().await?;
// We want to unlock the user from our side only once it has been reactivated on
// the homeserver
let _user = repo.user().unlock(user).await?;
repo.save().await?;
Ok(())
}
pub(crate) fn register(
suffix: &str,
monitor: Monitor<TokioExecutor>,
state: &State,
storage_factory: &PostgresStorageFactory,
) -> Monitor<TokioExecutor> {
let deactivate_user_worker =
crate::build!(DeactivateUserJob => deactivate_user, suffix, state, storage_factory);
let reactivate_user_worker =
crate::build!(ReactivateUserJob => reactivate_user, suffix, state, storage_factory);
monitor
.register(deactivate_user_worker)
.register(reactivate_user_worker)
Ok(())
}
}

View File

@@ -1,91 +0,0 @@
// Copyright 2024 New Vector Ltd.
// Copyright 2023, 2024 The Matrix.org Foundation C.I.C.
//
// SPDX-License-Identifier: AGPL-3.0-only
// Please see LICENSE in the repository root for full details.
use apalis_core::{job::Job, request::JobRequest};
use mas_storage::job::JobWithSpanContext;
use mas_tower::{
make_span_fn, DurationRecorderLayer, FnWrapper, IdentityLayer, InFlightCounterLayer,
TraceLayer, KV,
};
use opentelemetry::{trace::SpanContext, Key, KeyValue};
use tracing::info_span;
use tracing_opentelemetry::OpenTelemetrySpanExt;
const JOB_NAME: Key = Key::from_static_str("job.name");
const JOB_STATUS: Key = Key::from_static_str("job.status");
/// Represents a job that can may have a span context attached to it.
pub trait TracedJob: Job {
/// Returns the span context for this job, if any.
///
/// The default implementation returns `None`.
fn span_context(&self) -> Option<SpanContext> {
None
}
}
/// Implements [`TracedJob`] for any job with the [`JobWithSpanContext`]
/// wrapper.
impl<J: Job> TracedJob for JobWithSpanContext<J> {
fn span_context(&self) -> Option<SpanContext> {
JobWithSpanContext::span_context(self)
}
}
fn make_span_for_job_request<J: TracedJob>(req: &JobRequest<J>) -> tracing::Span {
let span = info_span!(
"job.run",
"otel.kind" = "consumer",
"otel.status_code" = tracing::field::Empty,
"job.id" = %req.id(),
"job.attempts" = req.attempts(),
"job.name" = J::NAME,
);
if let Some(context) = req.inner().span_context() {
span.add_link(context);
}
span
}
type TraceLayerForJob<J> =
TraceLayer<FnWrapper<fn(&JobRequest<J>) -> tracing::Span>, KV<&'static str>, KV<&'static str>>;
pub(crate) fn trace_layer<J>() -> TraceLayerForJob<J>
where
J: TracedJob,
{
TraceLayer::new(make_span_fn(
make_span_for_job_request::<J> as fn(&JobRequest<J>) -> tracing::Span,
))
.on_response(KV("otel.status_code", "OK"))
.on_error(KV("otel.status_code", "ERROR"))
}
type MetricsLayerForJob<J> = (
IdentityLayer<JobRequest<J>>,
DurationRecorderLayer<KeyValue, KeyValue, KeyValue>,
InFlightCounterLayer<KeyValue>,
);
pub(crate) fn metrics_layer<J>() -> MetricsLayerForJob<J>
where
J: Job,
{
let duration_recorder = DurationRecorderLayer::new("job.run.duration")
.on_request(JOB_NAME.string(J::NAME))
.on_response(JOB_STATUS.string("success"))
.on_error(JOB_STATUS.string("error"));
let in_flight_counter =
InFlightCounterLayer::new("job.run.active").on_request(JOB_NAME.string(J::NAME));
(
IdentityLayer::default(),
duration_recorder,
in_flight_counter,
)
}