allow completing an oauth2 session after login
This commit is contained in:
@@ -16,6 +16,7 @@ CREATE TABLE oauth2_sessions (
|
||||
"id" BIGSERIAL PRIMARY KEY,
|
||||
"user_session_id" BIGINT REFERENCES user_sessions (id) ON DELETE CASCADE,
|
||||
"client_id" TEXT NOT NULL,
|
||||
"redirect_uri" TEXT NOT NULL,
|
||||
"scope" TEXT NOT NULL,
|
||||
"state" TEXT,
|
||||
"nonce" TEXT,
|
||||
|
||||
@@ -73,6 +73,26 @@
|
||||
]
|
||||
}
|
||||
},
|
||||
"17729fd0354a84e04bfcd525db6575ed2ba75dd730bea3f2be964f4b347dd484": {
|
||||
"query": "\n SELECT code\n FROM oauth2_codes\n WHERE oauth2_session_id = $1\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "code",
|
||||
"type_info": "Text"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Int8"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false
|
||||
]
|
||||
}
|
||||
},
|
||||
"35bedaa6fdf7ac91d54b458b4637f2182c2f82be3e2f80cd2db934ee279a7f2a": {
|
||||
"query": "\n SELECT id, username\n FROM users\n WHERE id = $1\n ",
|
||||
"describe": {
|
||||
@@ -99,93 +119,6 @@
|
||||
]
|
||||
}
|
||||
},
|
||||
"3f8aaca9f29ded15acf4f0056a789af643ab62816a2fb598dcc9af4fff2841f0": {
|
||||
"query": "\n INSERT INTO oauth2_sessions \n (user_session_id, client_id, scope, state, nonce, max_age, response_type, response_mode)\n VALUES\n ($1, $2, $3, $4, $5, $6, $7, $8)\n RETURNING\n id, user_session_id, client_id, scope, state, nonce, max_age, \n response_type, response_mode, created_at, updated_at\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "id",
|
||||
"type_info": "Int8"
|
||||
},
|
||||
{
|
||||
"ordinal": 1,
|
||||
"name": "user_session_id",
|
||||
"type_info": "Int8"
|
||||
},
|
||||
{
|
||||
"ordinal": 2,
|
||||
"name": "client_id",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 3,
|
||||
"name": "scope",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 4,
|
||||
"name": "state",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 5,
|
||||
"name": "nonce",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 6,
|
||||
"name": "max_age",
|
||||
"type_info": "Int4"
|
||||
},
|
||||
{
|
||||
"ordinal": 7,
|
||||
"name": "response_type",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 8,
|
||||
"name": "response_mode",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 9,
|
||||
"name": "created_at",
|
||||
"type_info": "Timestamptz"
|
||||
},
|
||||
{
|
||||
"ordinal": 10,
|
||||
"name": "updated_at",
|
||||
"type_info": "Timestamptz"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Int8",
|
||||
"Text",
|
||||
"Text",
|
||||
"Text",
|
||||
"Text",
|
||||
"Int4",
|
||||
"Text",
|
||||
"Text"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false
|
||||
]
|
||||
}
|
||||
},
|
||||
"4f925a277d73df779360f81e0cf5d7983b50ebe744f461559dd561b7e36c20d4": {
|
||||
"query": "\n SELECT\n s.id,\n u.id as user_id,\n u.username,\n s.active,\n s.created_at,\n a.created_at as \"last_authd_at?\"\n FROM user_sessions s\n INNER JOIN users u \n ON s.user_id = u.id\n LEFT JOIN user_session_authentications a\n ON a.session_id = s.id\n WHERE s.id = $1 AND s.active\n ORDER BY a.created_at DESC\n LIMIT 1\n ",
|
||||
"describe": {
|
||||
@@ -306,6 +239,100 @@
|
||||
]
|
||||
}
|
||||
},
|
||||
"a051f542df7d3f80f5dc6dd6f04d49a462c64e4ce9146d90069d16ec9b61084b": {
|
||||
"query": "\n INSERT INTO oauth2_sessions \n (user_session_id, client_id, redirect_uri, scope, state, nonce, max_age,\n response_type, response_mode)\n VALUES\n ($1, $2, $3, $4, $5, $6, $7, $8, $9)\n RETURNING\n id, user_session_id, client_id, redirect_uri, scope, state, nonce, max_age, \n response_type, response_mode, created_at, updated_at\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "id",
|
||||
"type_info": "Int8"
|
||||
},
|
||||
{
|
||||
"ordinal": 1,
|
||||
"name": "user_session_id",
|
||||
"type_info": "Int8"
|
||||
},
|
||||
{
|
||||
"ordinal": 2,
|
||||
"name": "client_id",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 3,
|
||||
"name": "redirect_uri",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 4,
|
||||
"name": "scope",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 5,
|
||||
"name": "state",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 6,
|
||||
"name": "nonce",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 7,
|
||||
"name": "max_age",
|
||||
"type_info": "Int4"
|
||||
},
|
||||
{
|
||||
"ordinal": 8,
|
||||
"name": "response_type",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 9,
|
||||
"name": "response_mode",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 10,
|
||||
"name": "created_at",
|
||||
"type_info": "Timestamptz"
|
||||
},
|
||||
{
|
||||
"ordinal": 11,
|
||||
"name": "updated_at",
|
||||
"type_info": "Timestamptz"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Int8",
|
||||
"Text",
|
||||
"Text",
|
||||
"Text",
|
||||
"Text",
|
||||
"Text",
|
||||
"Int4",
|
||||
"Text",
|
||||
"Text"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false
|
||||
]
|
||||
}
|
||||
},
|
||||
"a09dfe1019110f2ec6eba0d35bafa467ab4b7980dd8b556826f03863f8edb0ab": {
|
||||
"query": "UPDATE user_sessions SET active = FALSE WHERE id = $1",
|
||||
"describe": {
|
||||
@@ -338,6 +365,19 @@
|
||||
]
|
||||
}
|
||||
},
|
||||
"a6eb935107d060dd01bf9824ceff87b9ff5492b58cefef002a49f444d3a3daa1": {
|
||||
"query": "UPDATE oauth2_sessions SET user_session_id = $1 WHERE id = $2",
|
||||
"describe": {
|
||||
"columns": [],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Int8",
|
||||
"Int8"
|
||||
]
|
||||
},
|
||||
"nullable": []
|
||||
}
|
||||
},
|
||||
"f9a09ff53b6f221649f4f050e3d5ade114f852ddf50a78610a6c0ef0689af681": {
|
||||
"query": "\n INSERT INTO users (username, hashed_password)\n VALUES ($1, $2)\n RETURNING id\n ",
|
||||
"describe": {
|
||||
@@ -358,5 +398,91 @@
|
||||
false
|
||||
]
|
||||
}
|
||||
},
|
||||
"ff515ebb80ba4af1948472f5c7120a03e25b1ebe42151b8a2036bfbb042f17f6": {
|
||||
"query": "\n SELECT\n id, user_session_id, client_id, redirect_uri, scope, state, nonce,\n max_age, response_type, response_mode, created_at, updated_at\n FROM oauth2_sessions\n WHERE id = $1\n ",
|
||||
"describe": {
|
||||
"columns": [
|
||||
{
|
||||
"ordinal": 0,
|
||||
"name": "id",
|
||||
"type_info": "Int8"
|
||||
},
|
||||
{
|
||||
"ordinal": 1,
|
||||
"name": "user_session_id",
|
||||
"type_info": "Int8"
|
||||
},
|
||||
{
|
||||
"ordinal": 2,
|
||||
"name": "client_id",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 3,
|
||||
"name": "redirect_uri",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 4,
|
||||
"name": "scope",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 5,
|
||||
"name": "state",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 6,
|
||||
"name": "nonce",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 7,
|
||||
"name": "max_age",
|
||||
"type_info": "Int4"
|
||||
},
|
||||
{
|
||||
"ordinal": 8,
|
||||
"name": "response_type",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 9,
|
||||
"name": "response_mode",
|
||||
"type_info": "Text"
|
||||
},
|
||||
{
|
||||
"ordinal": 10,
|
||||
"name": "created_at",
|
||||
"type_info": "Timestamptz"
|
||||
},
|
||||
{
|
||||
"ordinal": 11,
|
||||
"name": "updated_at",
|
||||
"type_info": "Timestamptz"
|
||||
}
|
||||
],
|
||||
"parameters": {
|
||||
"Left": [
|
||||
"Int8"
|
||||
]
|
||||
},
|
||||
"nullable": [
|
||||
false,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
true,
|
||||
true,
|
||||
true,
|
||||
false,
|
||||
false,
|
||||
false,
|
||||
false
|
||||
]
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -12,10 +12,17 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::collections::{HashMap, HashSet};
|
||||
use std::{
|
||||
collections::{HashMap, HashSet},
|
||||
convert::TryFrom,
|
||||
};
|
||||
|
||||
use data_encoding::BASE64URL_NOPAD;
|
||||
use hyper::{header::LOCATION, StatusCode};
|
||||
use hyper::{
|
||||
header::LOCATION,
|
||||
http::uri::{Parts, PathAndQuery, Uri},
|
||||
StatusCode,
|
||||
};
|
||||
use itertools::Itertools;
|
||||
use oauth2_types::{
|
||||
pkce,
|
||||
@@ -28,6 +35,7 @@ use serde::{Deserialize, Serialize};
|
||||
use sqlx::PgPool;
|
||||
use url::Url;
|
||||
use warp::{
|
||||
redirect::see_other,
|
||||
reply::{html, with_header},
|
||||
Filter, Rejection, Reply,
|
||||
};
|
||||
@@ -35,8 +43,15 @@ use warp::{
|
||||
use crate::{
|
||||
config::{CookiesConfig, OAuth2ClientConfig, OAuth2Config},
|
||||
errors::WrapError,
|
||||
filters::{session::with_optional_session, with_pool, with_templates},
|
||||
storage::{oauth2::start_session, SessionInfo},
|
||||
filters::{
|
||||
session::{with_optional_session, with_session},
|
||||
with_pool, with_templates,
|
||||
},
|
||||
handlers::views::LoginRequest,
|
||||
storage::{
|
||||
oauth2::{get_session_by_id, start_session},
|
||||
SessionInfo,
|
||||
},
|
||||
templates::{FormPostContext, Templates},
|
||||
};
|
||||
|
||||
@@ -144,14 +159,24 @@ pub fn filter(
|
||||
cookies_config: &CookiesConfig,
|
||||
) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone + Send + Sync + 'static {
|
||||
let clients = oauth2_config.clients.clone();
|
||||
warp::get()
|
||||
let authorize = warp::get()
|
||||
.and(warp::path!("oauth2" / "authorize"))
|
||||
.map(move || clients.clone())
|
||||
.and(warp::query())
|
||||
.and(with_optional_session(pool, cookies_config))
|
||||
.and(with_pool(pool))
|
||||
.and(with_templates(templates))
|
||||
.and_then(get)
|
||||
.and_then(get);
|
||||
|
||||
let step = warp::get()
|
||||
.and(warp::path!("oauth2" / "authorize" / "step"))
|
||||
.and(warp::query().map(|s: StepRequest| s.id))
|
||||
.and(with_session(pool, cookies_config))
|
||||
.and(with_pool(pool))
|
||||
.and(with_templates(templates))
|
||||
.and_then(step);
|
||||
|
||||
authorize.or(step)
|
||||
}
|
||||
|
||||
async fn get(
|
||||
@@ -190,6 +215,7 @@ async fn get(
|
||||
&mut txn,
|
||||
maybe_session_id,
|
||||
&client.client_id,
|
||||
redirect_uri,
|
||||
&scope,
|
||||
params.auth.state.as_deref(),
|
||||
params.auth.nonce.as_deref(),
|
||||
@@ -200,24 +226,81 @@ async fn get(
|
||||
.await
|
||||
.wrap_error()?;
|
||||
|
||||
let code = if response_type.contains(&ResponseType::Code) {
|
||||
// Generate the code at this stage, since we have the PKCE params ready
|
||||
if response_type.contains(&ResponseType::Code) {
|
||||
// 192bit random bytes encoded in base64, which gives a 32 character code
|
||||
let code: [u8; 24] = rand::random();
|
||||
let code = BASE64URL_NOPAD.encode(&code);
|
||||
Some(
|
||||
oauth2_session
|
||||
.add_code(&mut txn, &code, ¶ms.pkce)
|
||||
.await
|
||||
.wrap_error()?,
|
||||
)
|
||||
} else {
|
||||
None
|
||||
oauth2_session
|
||||
.add_code(&mut txn, &code, ¶ms.pkce)
|
||||
.await
|
||||
.wrap_error()?;
|
||||
};
|
||||
|
||||
// Do we have a user in this session, with a last authentication time that
|
||||
// matches the requirement?
|
||||
let user_session = oauth2_session.fetch_session(&mut txn).await.wrap_error()?;
|
||||
txn.commit().await.wrap_error()?;
|
||||
|
||||
if let Some(user_session) = user_session {
|
||||
step(oauth2_session.id, user_session, pool, templates).await
|
||||
} else {
|
||||
let next = StepRequest::new(oauth2_session.id)
|
||||
.build_uri()
|
||||
.wrap_error()?
|
||||
.to_string();
|
||||
|
||||
let destination = LoginRequest::new(Some(next)).build_uri().wrap_error()?;
|
||||
Ok(Box::new(see_other(destination)))
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize, Serialize)]
|
||||
struct StepRequest {
|
||||
id: i64,
|
||||
}
|
||||
|
||||
impl StepRequest {
|
||||
fn new(id: i64) -> Self {
|
||||
Self { id }
|
||||
}
|
||||
|
||||
fn build_uri(&self) -> anyhow::Result<Uri> {
|
||||
let qs = serde_urlencoded::to_string(self)?;
|
||||
let path_and_query = PathAndQuery::try_from(format!("/oauth2/authorize/step?{}", qs))?;
|
||||
let uri = Uri::from_parts({
|
||||
let mut parts = Parts::default();
|
||||
parts.path_and_query = Some(path_and_query);
|
||||
parts
|
||||
})?;
|
||||
Ok(uri)
|
||||
}
|
||||
}
|
||||
|
||||
async fn step(
|
||||
oauth2_session_id: i64,
|
||||
user_session: SessionInfo,
|
||||
pool: PgPool,
|
||||
templates: Templates,
|
||||
) -> Result<Box<dyn Reply>, Rejection> {
|
||||
// Start a DB transaction
|
||||
let mut txn = pool.begin().await.wrap_error()?;
|
||||
|
||||
let mut oauth2_session = get_session_by_id(&mut txn, oauth2_session_id)
|
||||
.await
|
||||
.wrap_error()?;
|
||||
|
||||
let user_session = oauth2_session
|
||||
.match_or_set_session(&mut txn, user_session)
|
||||
.await
|
||||
.wrap_error()?;
|
||||
|
||||
let response_mode = oauth2_session.response_mode().wrap_error()?;
|
||||
let response_type = oauth2_session.response_type().wrap_error()?;
|
||||
let redirect_uri = oauth2_session.redirect_uri().wrap_error()?;
|
||||
|
||||
let reply =
|
||||
// Check if the active session is valid
|
||||
if user_session.active && user_session.last_authd_at >= oauth2_session.max_auth_time() {
|
||||
// Yep! Let's complete the auth now
|
||||
let mut params = AuthorizationResponse {
|
||||
@@ -226,8 +309,8 @@ async fn get(
|
||||
};
|
||||
|
||||
// Did they request an auth code?
|
||||
if let Some(ref code) = code {
|
||||
params.code = Some(code.code.clone());
|
||||
if response_type.contains(&ResponseType::Code) {
|
||||
params.code = Some(oauth2_session.fetch_code(&mut txn).await.wrap_error()?);
|
||||
}
|
||||
|
||||
// Did they request an access token?
|
||||
@@ -243,20 +326,13 @@ async fn get(
|
||||
todo!("id tokens are not implemented yet");
|
||||
}
|
||||
|
||||
txn.commit().await.wrap_error()?;
|
||||
let reply = back_to_client(redirect_uri.clone(), response_mode, params, &templates)
|
||||
.wrap_error()?;
|
||||
return Ok(reply);
|
||||
}
|
||||
// TODO: show reauth form
|
||||
}
|
||||
|
||||
// TODO: show login form
|
||||
back_to_client(redirect_uri, response_mode, params, &templates).wrap_error()?
|
||||
} else {
|
||||
// Ask for a reauth
|
||||
// TODO: have the OAuth2 session ID in there
|
||||
Box::new(see_other(Uri::from_static("/reauth")))
|
||||
};
|
||||
|
||||
txn.commit().await.wrap_error()?;
|
||||
Ok(Box::new(warp::reply::json(&serde_json::json!({
|
||||
"session": oauth2_session,
|
||||
"code": code,
|
||||
"redirect_uri": redirect_uri,
|
||||
}))))
|
||||
Ok(reply)
|
||||
}
|
||||
|
||||
@@ -12,9 +12,12 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use serde::Deserialize;
|
||||
use std::convert::TryFrom;
|
||||
|
||||
use hyper::http::uri::{Parts, PathAndQuery, Uri};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use sqlx::PgPool;
|
||||
use warp::{hyper::Uri, reply::html, wrap_fn, Filter, Rejection, Reply};
|
||||
use warp::{reply::html, wrap_fn, Filter, Rejection, Reply};
|
||||
|
||||
use crate::{
|
||||
config::{CookiesConfig, CsrfConfig},
|
||||
@@ -28,6 +31,28 @@ use crate::{
|
||||
templates::{TemplateContext, Templates},
|
||||
};
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct LoginRequest {
|
||||
next: Option<String>,
|
||||
}
|
||||
|
||||
impl LoginRequest {
|
||||
pub fn new(next: Option<String>) -> Self {
|
||||
Self { next }
|
||||
}
|
||||
|
||||
pub fn build_uri(&self) -> anyhow::Result<Uri> {
|
||||
let qs = serde_urlencoded::to_string(self)?;
|
||||
let path_and_query = PathAndQuery::try_from(format!("/login?{}", qs))?;
|
||||
let uri = Uri::from_parts({
|
||||
let mut parts = Parts::default();
|
||||
parts.path_and_query = Some(path_and_query);
|
||||
parts
|
||||
})?;
|
||||
Ok(uri)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
struct LoginForm {
|
||||
username: String,
|
||||
@@ -50,6 +75,7 @@ pub(super) fn filter(
|
||||
let post = warp::post()
|
||||
.and(with_pool(pool))
|
||||
.and(protected_form(cookies_config))
|
||||
.and(warp::query())
|
||||
.and_then(post)
|
||||
.untuple_one()
|
||||
.with(wrap_fn(save_session(cookies_config)));
|
||||
@@ -68,10 +94,28 @@ async fn get(
|
||||
Ok((csrf_token, html(content)))
|
||||
}
|
||||
|
||||
async fn post(db: PgPool, form: LoginForm) -> Result<(SessionInfo, impl Reply), Rejection> {
|
||||
async fn post(
|
||||
db: PgPool,
|
||||
form: LoginForm,
|
||||
query: LoginRequest,
|
||||
) -> Result<(SessionInfo, impl Reply), Rejection> {
|
||||
let session_info = login(&db, &form.username, &form.password)
|
||||
.await
|
||||
.wrap_error()?;
|
||||
|
||||
Ok((session_info, warp::redirect(Uri::from_static("/"))))
|
||||
let uri: Uri = Uri::from_parts({
|
||||
let mut parts = Parts::default();
|
||||
parts.path_and_query = Some(
|
||||
query
|
||||
.next
|
||||
.map(warp::http::uri::PathAndQuery::try_from)
|
||||
.transpose()
|
||||
.wrap_error()?
|
||||
.unwrap_or_else(|| PathAndQuery::from_static("/")),
|
||||
);
|
||||
parts
|
||||
})
|
||||
.wrap_error()?;
|
||||
|
||||
Ok((session_info, warp::redirect(uri)))
|
||||
}
|
||||
|
||||
@@ -25,6 +25,7 @@ mod login;
|
||||
mod logout;
|
||||
mod reauth;
|
||||
|
||||
pub use self::login::LoginRequest;
|
||||
use self::{
|
||||
index::filter as index, login::filter as login, logout::filter as logout,
|
||||
reauth::filter as reauth,
|
||||
|
||||
@@ -12,7 +12,7 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::{collections::HashSet, convert::TryFrom};
|
||||
use std::{collections::HashSet, convert::TryFrom, str::FromStr};
|
||||
|
||||
use anyhow::Context;
|
||||
use chrono::{DateTime, Duration, Utc};
|
||||
@@ -23,14 +23,16 @@ use oauth2_types::{
|
||||
};
|
||||
use serde::Serialize;
|
||||
use sqlx::{Executor, FromRow, Postgres};
|
||||
use url::Url;
|
||||
|
||||
use super::{user::lookup_session, SessionInfo};
|
||||
|
||||
#[derive(FromRow, Serialize)]
|
||||
pub struct OAuth2Session {
|
||||
id: i64,
|
||||
pub id: i64,
|
||||
user_session_id: Option<i64>,
|
||||
client_id: String,
|
||||
redirect_uri: String,
|
||||
scope: String,
|
||||
pub state: Option<String>,
|
||||
nonce: Option<String>,
|
||||
@@ -52,9 +54,9 @@ impl OAuth2Session {
|
||||
add_code(executor, self.id, code, code_challenge).await
|
||||
}
|
||||
|
||||
pub async fn fetch_session<'e>(
|
||||
pub async fn fetch_session(
|
||||
&self,
|
||||
executor: impl Executor<'e, Database = Postgres>,
|
||||
executor: impl Executor<'_, Database = Postgres>,
|
||||
) -> anyhow::Result<Option<SessionInfo>> {
|
||||
match self.user_session_id {
|
||||
Some(id) => {
|
||||
@@ -65,11 +67,61 @@ impl OAuth2Session {
|
||||
}
|
||||
}
|
||||
|
||||
pub async fn fetch_code(
|
||||
&self,
|
||||
executor: impl Executor<'_, Database = Postgres>,
|
||||
) -> anyhow::Result<String> {
|
||||
get_code_for_session(executor, self.id).await
|
||||
}
|
||||
|
||||
pub async fn match_or_set_session(
|
||||
&mut self,
|
||||
executor: impl Executor<'_, Database = Postgres>,
|
||||
session: SessionInfo,
|
||||
) -> anyhow::Result<SessionInfo> {
|
||||
match self.user_session_id {
|
||||
Some(id) if id == session.key() => Ok(session),
|
||||
Some(id) => Err(anyhow::anyhow!(
|
||||
"session mismatch, expected {}, got {}",
|
||||
id,
|
||||
session.key()
|
||||
)),
|
||||
None => {
|
||||
sqlx::query!(
|
||||
"UPDATE oauth2_sessions SET user_session_id = $1 WHERE id = $2",
|
||||
session.key(),
|
||||
self.id,
|
||||
)
|
||||
.execute(executor)
|
||||
.await
|
||||
.context("could not update oauth2 session")?;
|
||||
Ok(session)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub fn max_auth_time(&self) -> Option<DateTime<Utc>> {
|
||||
self.max_age
|
||||
.map(|d| Duration::seconds(i64::from(d)))
|
||||
.map(|d| self.created_at - d)
|
||||
}
|
||||
|
||||
pub fn response_type(&self) -> anyhow::Result<HashSet<ResponseType>> {
|
||||
self.response_type
|
||||
.split(' ')
|
||||
.map(|s| {
|
||||
ResponseType::from_str(s).with_context(|| format!("invalid response type {}", s))
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub fn response_mode(&self) -> anyhow::Result<ResponseMode> {
|
||||
self.response_mode.parse().context("invalid response mode")
|
||||
}
|
||||
|
||||
pub fn redirect_uri(&self) -> anyhow::Result<Url> {
|
||||
self.redirect_uri.parse().context("invalid redirect uri")
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(clippy::too_many_arguments)]
|
||||
@@ -77,6 +129,7 @@ pub async fn start_session(
|
||||
executor: impl Executor<'_, Database = Postgres>,
|
||||
optional_session_id: Option<i64>,
|
||||
client_id: &str,
|
||||
redirect_uri: &Url,
|
||||
scope: &str,
|
||||
state: Option<&str>,
|
||||
nonce: Option<&str>,
|
||||
@@ -96,15 +149,17 @@ pub async fn start_session(
|
||||
OAuth2Session,
|
||||
r#"
|
||||
INSERT INTO oauth2_sessions
|
||||
(user_session_id, client_id, scope, state, nonce, max_age, response_type, response_mode)
|
||||
(user_session_id, client_id, redirect_uri, scope, state, nonce, max_age,
|
||||
response_type, response_mode)
|
||||
VALUES
|
||||
($1, $2, $3, $4, $5, $6, $7, $8)
|
||||
($1, $2, $3, $4, $5, $6, $7, $8, $9)
|
||||
RETURNING
|
||||
id, user_session_id, client_id, scope, state, nonce, max_age,
|
||||
id, user_session_id, client_id, redirect_uri, scope, state, nonce, max_age,
|
||||
response_type, response_mode, created_at, updated_at
|
||||
"#,
|
||||
optional_session_id,
|
||||
client_id,
|
||||
redirect_uri.as_str(),
|
||||
scope,
|
||||
state,
|
||||
nonce,
|
||||
@@ -117,6 +172,43 @@ pub async fn start_session(
|
||||
.context("could not insert oauth2 session")
|
||||
}
|
||||
|
||||
pub async fn get_session_by_id(
|
||||
executor: impl Executor<'_, Database = Postgres>,
|
||||
oauth2_session_id: i64,
|
||||
) -> anyhow::Result<OAuth2Session> {
|
||||
sqlx::query_as!(
|
||||
OAuth2Session,
|
||||
r#"
|
||||
SELECT
|
||||
id, user_session_id, client_id, redirect_uri, scope, state, nonce,
|
||||
max_age, response_type, response_mode, created_at, updated_at
|
||||
FROM oauth2_sessions
|
||||
WHERE id = $1
|
||||
"#,
|
||||
oauth2_session_id
|
||||
)
|
||||
.fetch_one(executor)
|
||||
.await
|
||||
.context("could not fetch oauth2 session")
|
||||
}
|
||||
|
||||
pub async fn get_code_for_session(
|
||||
executor: impl Executor<'_, Database = Postgres>,
|
||||
oauth2_session_id: i64,
|
||||
) -> anyhow::Result<String> {
|
||||
sqlx::query_scalar!(
|
||||
r#"
|
||||
SELECT code
|
||||
FROM oauth2_codes
|
||||
WHERE oauth2_session_id = $1
|
||||
"#,
|
||||
oauth2_session_id
|
||||
)
|
||||
.fetch_one(executor)
|
||||
.await
|
||||
.context("could not fetch oauth2 code")
|
||||
}
|
||||
|
||||
#[derive(FromRow, Serialize)]
|
||||
pub struct OAuth2Code {
|
||||
id: i64,
|
||||
|
||||
Reference in New Issue
Block a user