Better error handling in cookies, session and csrf filters
This commit is contained in:
@@ -12,7 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::convert::Infallible;
|
||||
use std::{convert::Infallible, marker::PhantomData};
|
||||
|
||||
use chacha20poly1305::{
|
||||
aead::{generic_array::GenericArray, Aead, NewAead},
|
||||
@@ -22,11 +22,40 @@ use cookie::Cookie;
|
||||
use data_encoding::BASE64URL_NOPAD;
|
||||
use headers::{Header, HeaderValue, SetCookie};
|
||||
use serde::{de::DeserializeOwned, Deserialize, Serialize};
|
||||
use warp::{Filter, Rejection, Reply};
|
||||
use thiserror::Error;
|
||||
use warp::{reject::Reject, Filter, Rejection, Reply};
|
||||
|
||||
use super::headers::{typed_header, WithTypedHeader};
|
||||
use crate::{config::CookiesConfig, errors::WrapError};
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
struct CookieDecryptionError<T: EncryptableCookieValue>(#[source] anyhow::Error, PhantomData<T>);
|
||||
|
||||
impl<T> Reject for CookieDecryptionError<T> where
|
||||
T: EncryptableCookieValue + Send + Sync + std::fmt::Debug + 'static
|
||||
{
|
||||
}
|
||||
|
||||
impl<T: EncryptableCookieValue> From<anyhow::Error> for CookieDecryptionError<T> {
|
||||
fn from(e: anyhow::Error) -> Self {
|
||||
Self(e, PhantomData)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: EncryptableCookieValue> std::fmt::Display for CookieDecryptionError<T> {
|
||||
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
write!(f, "failed to decrypt cookie {}", T::cookie_key())
|
||||
}
|
||||
}
|
||||
|
||||
fn decryption_error<T>(e: anyhow::Error) -> Rejection
|
||||
where
|
||||
T: EncryptableCookieValue + Send + Sync + std::fmt::Debug + 'static,
|
||||
{
|
||||
let e: CookieDecryptionError<T> = e.into();
|
||||
warp::reject::custom(e)
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
struct EncryptedCookie {
|
||||
nonce: [u8; 12],
|
||||
@@ -69,6 +98,7 @@ impl EncryptedCookie {
|
||||
}
|
||||
}
|
||||
|
||||
/// Extract an optional encrypted cookie
|
||||
#[must_use]
|
||||
pub fn maybe_encrypted<T>(
|
||||
options: &CookiesConfig,
|
||||
@@ -76,14 +106,21 @@ pub fn maybe_encrypted<T>(
|
||||
where
|
||||
T: DeserializeOwned + EncryptableCookieValue + Send + 'static,
|
||||
{
|
||||
let secret = options.secret;
|
||||
warp::cookie::optional(T::cookie_key()).map(move |maybe_value: Option<String>| {
|
||||
maybe_value
|
||||
.and_then(|value| EncryptedCookie::from_cookie_value(&value).ok())
|
||||
.and_then(|encrypted| encrypted.decrypt(&secret).ok())
|
||||
})
|
||||
encrypted(options).map(Some).recover(recover::<T>).unify()
|
||||
}
|
||||
|
||||
async fn recover<T>(_rejection: Rejection) -> Result<Option<T>, Infallible> {
|
||||
// We could actually look for MissingCookie and CookieDecryptionError
|
||||
// rejections, but nothing else should happen here anyway
|
||||
Ok(None)
|
||||
}
|
||||
|
||||
/// Extract an encrypted cookie
|
||||
///
|
||||
/// # Rejections
|
||||
///
|
||||
/// This can reject with either a [`warp::reject::MissingCookie`] or a
|
||||
/// [`CookieDecryptionError`]
|
||||
#[must_use]
|
||||
pub fn encrypted<T>(
|
||||
options: &CookiesConfig,
|
||||
@@ -93,8 +130,9 @@ where
|
||||
{
|
||||
let secret = options.secret;
|
||||
warp::cookie::cookie(T::cookie_key()).and_then(move |value: String| async move {
|
||||
let encrypted = EncryptedCookie::from_cookie_value(&value).wrap_error()?;
|
||||
let decrypted = encrypted.decrypt(&secret).wrap_error()?;
|
||||
let encrypted =
|
||||
EncryptedCookie::from_cookie_value(&value).map_err(decryption_error::<T>)?;
|
||||
let decrypted = encrypted.decrypt(&secret).map_err(decryption_error::<T>)?;
|
||||
Ok::<_, Rejection>(decrypted)
|
||||
})
|
||||
}
|
||||
@@ -109,7 +147,7 @@ pub fn with_cookie_saver(
|
||||
}
|
||||
|
||||
/// A cookie that can be encrypted with a well-known cookie key
|
||||
pub trait EncryptableCookieValue {
|
||||
pub trait EncryptableCookieValue: Send + Sync + std::fmt::Debug {
|
||||
fn cookie_key() -> &'static str;
|
||||
}
|
||||
|
||||
|
||||
@@ -40,7 +40,7 @@ pub enum CsrfError {
|
||||
impl Reject for CsrfError {}
|
||||
|
||||
#[serde_as]
|
||||
#[derive(Serialize, Deserialize)]
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct CsrfToken {
|
||||
#[serde_as(as = "TimestampSeconds<i64>")]
|
||||
expiration: DateTime<Utc>,
|
||||
@@ -113,8 +113,7 @@ impl<T> CsrfForm<T> {
|
||||
}
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn csrf_token(
|
||||
fn csrf_token(
|
||||
cookies_config: &CookiesConfig,
|
||||
) -> impl Filter<Extract = (CsrfToken,), Error = Rejection> + Clone + Send + Sync + 'static {
|
||||
super::cookies::encrypted(cookies_config).and_then(move |token: CsrfToken| async move {
|
||||
@@ -123,6 +122,11 @@ pub fn csrf_token(
|
||||
})
|
||||
}
|
||||
|
||||
/// Extract an up-to-date CSRF token to include in forms
|
||||
///
|
||||
/// Routes using this should not forget to reply the updated CSRF cookie using
|
||||
/// an [`super::cookies::EncryptedCookieSaver`] obtained with
|
||||
/// [`super::cookies::with_cookie_saver`]
|
||||
#[must_use]
|
||||
pub fn updated_csrf_token(
|
||||
cookies_config: &CookiesConfig,
|
||||
@@ -147,6 +151,19 @@ pub fn updated_csrf_token(
|
||||
)
|
||||
}
|
||||
|
||||
/// Extract values from a CSRF-protected form
|
||||
///
|
||||
/// # Rejections
|
||||
///
|
||||
/// This can reject with:
|
||||
///
|
||||
/// - [`warp::filters::body::BodyDeserializeError`] if the overall form failed
|
||||
/// to decode
|
||||
/// - [`CsrfError`] if the CSRF token was invalid or expired
|
||||
/// - [`warp::reject::MissingCookie`] if the CSRF cookie was missing
|
||||
/// - [`super::cookies::CookieDecryptionError`] if the cookie failed to decrypt
|
||||
///
|
||||
/// TODO: we might want to unify the last three rejections in one
|
||||
#[must_use]
|
||||
pub fn protected_form<T>(
|
||||
cookies_config: &CookiesConfig,
|
||||
|
||||
@@ -26,7 +26,7 @@ use crate::{
|
||||
storage::{lookup_active_session, SessionInfo},
|
||||
};
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
#[derive(Serialize, Deserialize, Debug)]
|
||||
pub struct SessionCookie {
|
||||
current: i64,
|
||||
}
|
||||
@@ -43,7 +43,8 @@ impl SessionCookie {
|
||||
&self,
|
||||
executor: impl Executor<'_, Database = Postgres>,
|
||||
) -> anyhow::Result<SessionInfo> {
|
||||
lookup_active_session(executor, self.current).await
|
||||
let res = lookup_active_session(executor, self.current).await?;
|
||||
Ok(res)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -138,11 +138,24 @@ pub async fn login(
|
||||
Ok(session)
|
||||
}
|
||||
|
||||
#[derive(Debug, Error)]
|
||||
#[error("could not fetch session")]
|
||||
pub struct ActiveSessionLookupError(#[from] sqlx::Error);
|
||||
|
||||
/*
|
||||
impl ActiveSessionLookupError {
|
||||
#[must_use]
|
||||
pub fn not_found(&self) -> bool {
|
||||
matches!(self.0, sqlx::Error::RowNotFound)
|
||||
}
|
||||
}
|
||||
*/
|
||||
|
||||
pub async fn lookup_active_session(
|
||||
executor: impl Executor<'_, Database = Postgres>,
|
||||
id: i64,
|
||||
) -> anyhow::Result<SessionInfo> {
|
||||
sqlx::query_as!(
|
||||
) -> Result<SessionInfo, ActiveSessionLookupError> {
|
||||
let res = sqlx::query_as!(
|
||||
SessionInfo,
|
||||
r#"
|
||||
SELECT
|
||||
@@ -164,8 +177,9 @@ pub async fn lookup_active_session(
|
||||
id,
|
||||
)
|
||||
.fetch_one(executor)
|
||||
.await
|
||||
.context("could not fetch session")
|
||||
.await?;
|
||||
|
||||
Ok(res)
|
||||
}
|
||||
|
||||
pub async fn lookup_session(
|
||||
|
||||
Reference in New Issue
Block a user