Implement refresh tokens

This commit is contained in:
Quentin Gliech
2021-08-27 15:27:19 +02:00
parent 9b841b2127
commit c5d4c0b83c
13 changed files with 693 additions and 224 deletions

View File

@@ -0,0 +1,16 @@
-- Copyright 2021 The Matrix.org Foundation C.I.C.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
DROP TRIGGER set_timestamp ON oauth2_refresh_tokens;
DROP TABLE oauth2_refresh_tokens;

View File

@@ -0,0 +1,30 @@
-- Copyright 2021 The Matrix.org Foundation C.I.C.
--
-- Licensed under the Apache License, Version 2.0 (the "License");
-- you may not use this file except in compliance with the License.
-- You may obtain a copy of the License at
--
-- http://www.apache.org/licenses/LICENSE-2.0
--
-- Unless required by applicable law or agreed to in writing, software
-- distributed under the License is distributed on an "AS IS" BASIS,
-- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-- See the License for the specific language governing permissions and
-- limitations under the License.
CREATE TABLE oauth2_refresh_tokens (
"id" BIGSERIAL PRIMARY KEY,
"oauth2_session_id" BIGINT NOT NULL REFERENCES oauth2_sessions (id) ON DELETE CASCADE,
"oauth2_access_token_id" BIGINT REFERENCES oauth2_access_tokens (id) ON DELETE SET NULL,
"token" TEXT UNIQUE NOT NULL,
"next_token_id" BIGINT REFERENCES oauth2_refresh_tokens (id),
"created_at" TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now(),
"updated_at" TIMESTAMP WITH TIME ZONE NOT NULL DEFAULT now()
);
CREATE TRIGGER set_timestamp
BEFORE UPDATE ON oauth2_refresh_tokens
FOR EACH ROW
EXECUTE PROCEDURE trigger_set_timestamp();

View File

@@ -219,6 +219,50 @@
]
}
},
"562b0d4dcf857e99c20e9288e9c8bd46232290715c0d2459b0398a1c746cf65d": {
"query": "\n SELECT\n rt.id,\n rt.oauth2_session_id,\n rt.oauth2_access_token_id,\n os.client_id AS \"client_id!\",\n os.scope AS \"scope!\"\n FROM oauth2_refresh_tokens rt\n INNER JOIN oauth2_sessions os\n ON os.id = rt.oauth2_session_id\n WHERE rt.token = $1 AND rt.next_token_id IS NULL\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "id",
"type_info": "Int8"
},
{
"ordinal": 1,
"name": "oauth2_session_id",
"type_info": "Int8"
},
{
"ordinal": 2,
"name": "oauth2_access_token_id",
"type_info": "Int8"
},
{
"ordinal": 3,
"name": "client_id!",
"type_info": "Text"
},
{
"ordinal": 4,
"name": "scope!",
"type_info": "Text"
}
],
"parameters": {
"Left": [
"Text"
]
},
"nullable": [
false,
false,
true,
false,
false
]
}
},
"5d1a17b2ad6153217551ae31549ad9d62cc39d2f9a4e62a7ccb60fd91e0ac685": {
"query": "\n DELETE FROM oauth2_access_tokens\n WHERE created_at + (expires_after * INTERVAL '1 second') + INTERVAL '15 minutes' < now()\n ",
"describe": {
@@ -279,6 +323,76 @@
]
}
},
"73f2d928f7bf88af79a3685bd6346652b4e4454b0ce75e38343840c9765e3f27": {
"query": "\n INSERT INTO oauth2_refresh_tokens\n (oauth2_session_id, oauth2_access_token_id, token)\n VALUES\n ($1, $2, $3)\n RETURNING\n id, oauth2_session_id, oauth2_access_token_id, token, next_token_id, \n created_at, updated_at\n ",
"describe": {
"columns": [
{
"ordinal": 0,
"name": "id",
"type_info": "Int8"
},
{
"ordinal": 1,
"name": "oauth2_session_id",
"type_info": "Int8"
},
{
"ordinal": 2,
"name": "oauth2_access_token_id",
"type_info": "Int8"
},
{
"ordinal": 3,
"name": "token",
"type_info": "Text"
},
{
"ordinal": 4,
"name": "next_token_id",
"type_info": "Int8"
},
{
"ordinal": 5,
"name": "created_at",
"type_info": "Timestamptz"
},
{
"ordinal": 6,
"name": "updated_at",
"type_info": "Timestamptz"
}
],
"parameters": {
"Left": [
"Int8",
"Int8",
"Text"
]
},
"nullable": [
false,
false,
true,
false,
true,
false,
false
]
}
},
"88ac8783bd5881c42eafd9cf87a16fe6031f3153fd6a8618e689694584aeb2de": {
"query": "\n DELETE FROM oauth2_access_tokens\n WHERE id = $1\n ",
"describe": {
"columns": [],
"parameters": {
"Left": [
"Int8"
]
},
"nullable": []
}
},
"8c21b0b46e74ae5667b82bfe57706a64396683ac4dc29311424008c3f3e94136": {
"query": "\n SELECT\n oc.id,\n os.id AS \"oauth2_session_id!\",\n os.client_id AS \"client_id!\",\n os.redirect_uri,\n os.scope AS \"scope!\"\n FROM oauth2_codes oc\n INNER JOIN oauth2_sessions os\n ON os.id = oc.oauth2_session_id\n WHERE oc.code = $1\n ",
"describe": {
@@ -434,6 +548,19 @@
]
}
},
"c2c402cfe0adcafa615f14a499caba4c96ca71d9ffb163e1feb05e5d85f3462c": {
"query": "\n UPDATE oauth2_refresh_tokens\n SET next_token_id = $2\n WHERE id = $1\n ",
"describe": {
"columns": [],
"parameters": {
"Left": [
"Int8",
"Int8"
]
},
"nullable": []
}
},
"cacec823f5d4ed886854fbd62b5f5bb2def792582df58c8a047c769d34d9b190": {
"query": "\n INSERT INTO oauth2_sessions\n (user_session_id, client_id, redirect_uri, scope, state, nonce, max_age,\n response_type, response_mode)\n VALUES\n ($1, $2, $3, $4, $5, $6, $7, $8, $9)\n RETURNING\n id, user_session_id, client_id, redirect_uri, scope, state, nonce, max_age,\n response_type, response_mode, created_at, updated_at\n ",
"describe": {

View File

@@ -52,7 +52,11 @@ use crate::{
},
handlers::views::LoginRequest,
storage::{
oauth2::{add_access_token, get_session_by_id, start_session},
oauth2::{
access_token::add_access_token,
refresh_token::add_refresh_token,
session::{get_session_by_id, start_session},
},
SessionInfo,
},
templates::{FormPostContext, Templates},
@@ -300,51 +304,58 @@ async fn step(
let redirect_uri = oauth2_session.redirect_uri().wrap_error()?;
// Check if the active session is valid
let reply =
if user_session.active && user_session.last_authd_at >= oauth2_session.max_auth_time() {
// Yep! Let's complete the auth now
let mut params = AuthorizationResponse {
state: oauth2_session.state.clone(),
..AuthorizationResponse::default()
let reply = if user_session.active
&& user_session.last_authd_at >= oauth2_session.max_auth_time()
{
// Yep! Let's complete the auth now
let mut params = AuthorizationResponse {
state: oauth2_session.state.clone(),
..AuthorizationResponse::default()
};
// Did they request an auth code?
if response_type.contains(&ResponseType::Code) {
params.code = Some(oauth2_session.fetch_code(&mut txn).await.wrap_error()?);
}
// Did they request an access token?
if response_type.contains(&ResponseType::Token) {
let ttl = Duration::minutes(5);
let (access_token, refresh_token) = {
let mut rng = thread_rng();
(
tokens::generate(&mut rng, tokens::TokenType::AccessToken),
tokens::generate(&mut rng, tokens::TokenType::RefreshToken),
)
};
// Did they request an auth code?
if response_type.contains(&ResponseType::Code) {
params.code = Some(oauth2_session.fetch_code(&mut txn).await.wrap_error()?);
}
let access_token = add_access_token(&mut txn, oauth2_session_id, &access_token, ttl)
.await
.wrap_error()?;
// Did they request an access token?
if response_type.contains(&ResponseType::Token) {
let ttl = Duration::minutes(5);
let (access_token, refresh_token) = {
let mut rng = thread_rng();
(
tokens::generate(&mut rng, tokens::TokenType::AccessToken),
tokens::generate(&mut rng, tokens::TokenType::RefreshToken),
)
};
add_access_token(&mut txn, oauth2_session_id, &access_token, ttl)
let refresh_token =
add_refresh_token(&mut txn, oauth2_session_id, access_token.id, &refresh_token)
.await
.wrap_error()?;
params.response = Some(
AccessTokenResponse::new(access_token)
.with_expires_in(ttl)
.with_refresh_token(refresh_token),
);
}
// Did they request an ID token?
if response_type.contains(&ResponseType::IdToken) {
todo!("id tokens are not implemented yet");
}
params.response = Some(
AccessTokenResponse::new(access_token.token)
.with_expires_in(ttl)
.with_refresh_token(refresh_token.token),
);
}
back_to_client(redirect_uri, response_mode, params, &templates).wrap_error()?
} else {
// Ask for a reauth
// TODO: have the OAuth2 session ID in there
Box::new(see_other(Uri::from_static("/reauth")))
};
// Did they request an ID token?
if response_type.contains(&ResponseType::IdToken) {
todo!("id tokens are not implemented yet");
}
back_to_client(redirect_uri, response_mode, params, &templates).wrap_error()?
} else {
// Ask for a reauth
// TODO: have the OAuth2 session ID in there
Box::new(see_other(Uri::from_static("/reauth")))
};
txn.commit().await.wrap_error()?;
Ok(reply)

View File

@@ -15,6 +15,7 @@
use chrono::Utc;
use oauth2_types::requests::{IntrospectionRequest, IntrospectionResponse, TokenTypeHint};
use sqlx::{pool::PoolConnection, PgPool, Postgres};
use tracing::{info, warn};
use warp::{Filter, Rejection, Reply};
use crate::{
@@ -24,7 +25,7 @@ use crate::{
client::{with_client_auth, ClientAuthentication},
database::with_connection,
},
storage::oauth2::lookup_access_token,
storage::oauth2::{access_token::lookup_access_token, refresh_token::lookup_refresh_token},
tokens,
};
@@ -58,11 +59,12 @@ const INACTIVE: IntrospectionResponse = IntrospectionResponse {
async fn introspect(
mut conn: PoolConnection<Postgres>,
auth: ClientAuthentication,
_client: OAuth2ClientConfig,
client: OAuth2ClientConfig,
params: IntrospectionRequest,
) -> Result<impl Reply, Rejection> {
// Token introspection is only allowed by confidential clients
if auth.public() {
warn!(?client, "Client tried to introspect");
// TODO: have a nice error here
return Ok(warp::reply::json(&INACTIVE));
}
@@ -71,6 +73,7 @@ async fn introspect(
let token_type = tokens::check(token).wrap_error()?;
if let Some(hint) = params.token_type_hint {
if token_type != hint {
info!("Token type hint did not match");
return Ok(warp::reply::json(&INACTIVE));
}
}
@@ -82,6 +85,7 @@ async fn introspect(
// Check it is active and did not expire
if !token.active || exp < Utc::now() {
info!(?token, "Access token expired");
return Ok(warp::reply::json(&INACTIVE));
}
@@ -100,7 +104,24 @@ async fn introspect(
jti: None,
}
}
tokens::TokenType::RefreshToken => INACTIVE,
tokens::TokenType::RefreshToken => {
let token = lookup_refresh_token(&mut conn, token).await.wrap_error()?;
IntrospectionResponse {
active: true,
scope: None, // TODO: parse back scopes
client_id: Some(token.client_id),
username: None,
token_type: Some(TokenTypeHint::RefreshToken),
exp: None,
iat: None,
nbf: None,
sub: None,
aud: None,
iss: None,
jti: None,
}
}
};
Ok(warp::reply::json(&reply))

View File

@@ -30,7 +30,11 @@ use crate::{
client::{with_client_auth, ClientAuthentication},
database::with_connection,
},
storage::oauth2::{add_access_token, lookup_code},
storage::oauth2::{
access_token::{add_access_token, revoke_access_token},
authorization_code::lookup_code,
refresh_token::{add_refresh_token, lookup_refresh_token, replace_refresh_token},
},
tokens,
};
@@ -91,16 +95,24 @@ async fn authorization_code_grant(
)
};
add_access_token(&mut txn, code.oauth2_session_id, &access_token, ttl)
let access_token = add_access_token(&mut txn, code.oauth2_session_id, &access_token, ttl)
.await
.wrap_error()?;
// TODO: save the refresh token
let refresh_token = add_refresh_token(
&mut txn,
code.oauth2_session_id,
access_token.id,
&refresh_token,
)
.await
.wrap_error()?;
// TODO: generate id_token if the "openid" scope was asked
// TODO: have the scopes back here
let params = AccessTokenResponse::new(access_token)
let params = AccessTokenResponse::new(access_token.token)
.with_expires_in(ttl)
.with_refresh_token(refresh_token);
.with_refresh_token(refresh_token.token);
txn.commit().await.wrap_error()?;
@@ -108,9 +120,62 @@ async fn authorization_code_grant(
}
async fn refresh_token_grant(
_grant: &RefreshTokenGrant,
_client: &OAuth2ClientConfig,
_conn: &mut PoolConnection<Postgres>,
grant: &RefreshTokenGrant,
client: &OAuth2ClientConfig,
conn: &mut PoolConnection<Postgres>,
) -> Result<AccessTokenResponse, Rejection> {
todo!()
let mut txn = conn.begin().await.wrap_error()?;
// TODO: scope handling
let refresh_token_lookup = lookup_refresh_token(&mut txn, &grant.refresh_token)
.await
.wrap_error()?;
if client.client_id != refresh_token_lookup.client_id {
return Err(anyhow::anyhow!("invalid client")).wrap_error();
}
let ttl = Duration::minutes(5);
let (access_token, refresh_token) = {
let mut rng = thread_rng();
(
tokens::generate(&mut rng, tokens::TokenType::AccessToken),
tokens::generate(&mut rng, tokens::TokenType::RefreshToken),
)
};
let access_token = add_access_token(
&mut txn,
refresh_token_lookup.oauth2_session_id,
&access_token,
ttl,
)
.await
.wrap_error()?;
let refresh_token = add_refresh_token(
&mut txn,
refresh_token_lookup.oauth2_session_id,
access_token.id,
&refresh_token,
)
.await
.wrap_error()?;
replace_refresh_token(&mut txn, refresh_token_lookup.id, refresh_token.id)
.await
.wrap_error()?;
if let Some(access_token_id) = refresh_token_lookup.oauth2_access_token_id {
revoke_access_token(&mut txn, access_token_id)
.await
.wrap_error()?;
}
let params = AccessTokenResponse::new(access_token.token)
.with_expires_in(ttl)
.with_refresh_token(refresh_token.token);
txn.commit().await.wrap_error()?;
Ok(params)
}

View File

@@ -0,0 +1,141 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use std::convert::TryFrom;
use anyhow::Context;
use chrono::{DateTime, Duration, Utc};
use serde::Serialize;
use sqlx::{Executor, FromRow, Postgres};
#[derive(FromRow, Serialize)]
pub struct OAuth2AccessToken {
pub id: i64,
pub oauth2_session_id: i64,
pub token: String,
expires_after: i32,
created_at: DateTime<Utc>,
}
pub async fn add_access_token(
executor: impl Executor<'_, Database = Postgres>,
oauth2_session_id: i64,
token: &str,
expires_after: Duration,
) -> anyhow::Result<OAuth2AccessToken> {
// Checked convertion of duration to i32, maxing at i32::MAX
let expires_after = i32::try_from(expires_after.num_seconds()).unwrap_or(i32::MAX);
sqlx::query_as!(
OAuth2AccessToken,
r#"
INSERT INTO oauth2_access_tokens
(oauth2_session_id, token, expires_after)
VALUES
($1, $2, $3)
RETURNING
id, oauth2_session_id, token, expires_after, created_at
"#,
oauth2_session_id,
token,
expires_after,
)
.fetch_one(executor)
.await
.context("could not insert oauth2 access token")
}
#[derive(Debug)]
pub struct OAuth2AccessTokenLookup {
pub active: bool,
pub username: String,
pub client_id: String,
pub scope: String,
pub created_at: DateTime<Utc>,
expires_after: i32,
}
impl OAuth2AccessTokenLookup {
pub fn exp(&self) -> DateTime<Utc> {
self.created_at + Duration::seconds(i64::from(self.expires_after))
}
}
pub async fn lookup_access_token(
executor: impl Executor<'_, Database = Postgres>,
token: &str,
) -> anyhow::Result<OAuth2AccessTokenLookup> {
sqlx::query_as!(
OAuth2AccessTokenLookup,
r#"
SELECT
u.username AS "username!",
us.active AS "active!",
os.client_id AS "client_id!",
os.scope AS "scope!",
at.created_at AS "created_at!",
at.expires_after AS "expires_after!"
FROM oauth2_access_tokens at
INNER JOIN oauth2_sessions os
ON os.id = at.oauth2_session_id
INNER JOIN user_sessions us
ON us.id = os.user_session_id
INNER JOIN users u
ON u.id = us.user_id
WHERE at.token = $1
"#,
token,
)
.fetch_one(executor)
.await
.context("could not introspect oauth2 access token")
}
pub async fn revoke_access_token(
executor: impl Executor<'_, Database = Postgres>,
id: i64,
) -> anyhow::Result<()> {
let res = sqlx::query!(
r#"
DELETE FROM oauth2_access_tokens
WHERE id = $1
"#,
id,
)
.execute(executor)
.await
.context("could not revoke access tokens")?;
if res.rows_affected() == 1 {
Ok(())
} else {
Err(anyhow::anyhow!("no row were affected when revoking token"))
}
}
pub async fn cleanup_expired(
executor: impl Executor<'_, Database = Postgres>,
) -> anyhow::Result<u64> {
let res = sqlx::query!(
r#"
DELETE FROM oauth2_access_tokens
WHERE created_at + (expires_after * INTERVAL '1 second') + INTERVAL '15 minutes' < now()
"#,
)
.execute(executor)
.await
.context("could not cleanup expired access tokens")?;
Ok(res.rows_affected())
}

View File

@@ -0,0 +1,90 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use anyhow::Context;
use oauth2_types::pkce;
use serde::Serialize;
use sqlx::{Executor, FromRow, Postgres};
#[derive(FromRow, Serialize)]
pub struct OAuth2Code {
id: i64,
oauth2_session_id: i64,
pub code: String,
code_challenge: Option<String>,
code_challenge_method: Option<i16>,
}
pub async fn add_code(
executor: impl Executor<'_, Database = Postgres>,
oauth2_session_id: i64,
code: &str,
code_challenge: &Option<pkce::Request>,
) -> anyhow::Result<OAuth2Code> {
let code_challenge_method = code_challenge
.as_ref()
.map(|c| c.code_challenge_method as i16);
let code_challenge = code_challenge.as_ref().map(|c| &c.code_challenge);
sqlx::query_as!(
OAuth2Code,
r#"
INSERT INTO oauth2_codes
(oauth2_session_id, code, code_challenge_method, code_challenge)
VALUES
($1, $2, $3, $4)
RETURNING
id, oauth2_session_id, code, code_challenge_method, code_challenge
"#,
oauth2_session_id,
code,
code_challenge_method,
code_challenge,
)
.fetch_one(executor)
.await
.context("could not insert oauth2 authorization code")
}
pub struct OAuth2CodeLookup {
pub id: i64,
pub oauth2_session_id: i64,
pub client_id: String,
pub redirect_uri: String,
pub scope: String,
}
pub async fn lookup_code(
executor: impl Executor<'_, Database = Postgres>,
code: &str,
) -> anyhow::Result<OAuth2CodeLookup> {
sqlx::query_as!(
OAuth2CodeLookup,
r#"
SELECT
oc.id,
os.id AS "oauth2_session_id!",
os.client_id AS "client_id!",
os.redirect_uri,
os.scope AS "scope!"
FROM oauth2_codes oc
INNER JOIN oauth2_sessions os
ON os.id = oc.oauth2_session_id
WHERE oc.code = $1
"#,
code,
)
.fetch_one(executor)
.await
.context("could not lookup oauth2 code")
}

View File

@@ -0,0 +1,18 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
pub mod access_token;
pub mod authorization_code;
pub mod refresh_token;
pub mod session;

View File

@@ -0,0 +1,114 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
use anyhow::Context;
use chrono::{DateTime, Utc};
use sqlx::{Executor, Postgres};
#[derive(Debug)]
pub struct OAuth2RefreshToken {
pub id: i64,
oauth2_session_id: i64,
oauth2_access_token_id: Option<i64>,
pub token: String,
next_token_id: Option<i64>,
created_at: DateTime<Utc>,
updated_at: DateTime<Utc>,
}
pub async fn add_refresh_token(
executor: impl Executor<'_, Database = Postgres>,
oauth2_session_id: i64,
oauth2_access_token_id: i64,
token: &str,
) -> anyhow::Result<OAuth2RefreshToken> {
sqlx::query_as!(
OAuth2RefreshToken,
r#"
INSERT INTO oauth2_refresh_tokens
(oauth2_session_id, oauth2_access_token_id, token)
VALUES
($1, $2, $3)
RETURNING
id, oauth2_session_id, oauth2_access_token_id, token, next_token_id,
created_at, updated_at
"#,
oauth2_session_id,
oauth2_access_token_id,
token,
)
.fetch_one(executor)
.await
.context("could not insert oauth2 refresh token")
}
pub struct OAuth2RefreshTokenLookup {
pub id: i64,
pub oauth2_session_id: i64,
pub oauth2_access_token_id: Option<i64>,
pub client_id: String,
pub scope: String,
}
pub async fn lookup_refresh_token(
executor: impl Executor<'_, Database = Postgres>,
token: &str,
) -> anyhow::Result<OAuth2RefreshTokenLookup> {
sqlx::query_as!(
OAuth2RefreshTokenLookup,
r#"
SELECT
rt.id,
rt.oauth2_session_id,
rt.oauth2_access_token_id,
os.client_id AS "client_id!",
os.scope AS "scope!"
FROM oauth2_refresh_tokens rt
INNER JOIN oauth2_sessions os
ON os.id = rt.oauth2_session_id
WHERE rt.token = $1 AND rt.next_token_id IS NULL
"#,
token,
)
.fetch_one(executor)
.await
.context("failed to fetch oauth2 refresh token")
}
pub async fn replace_refresh_token(
executor: impl Executor<'_, Database = Postgres>,
refresh_token_id: i64,
next_refresh_token_id: i64,
) -> anyhow::Result<()> {
let res = sqlx::query!(
r#"
UPDATE oauth2_refresh_tokens
SET next_token_id = $2
WHERE id = $1
"#,
refresh_token_id,
next_refresh_token_id
)
.execute(executor)
.await
.context("failed to update oauth2 refresh token")?;
if res.rows_affected() == 1 {
Ok(())
} else {
Err(anyhow::anyhow!(
"no row were affected when updating refresh token"
))
}
}

View File

@@ -25,7 +25,10 @@ use serde::Serialize;
use sqlx::{Executor, FromRow, Postgres};
use url::Url;
use super::{user::lookup_session, SessionInfo};
use super::{
super::{user::lookup_session, SessionInfo},
authorization_code::{add_code, OAuth2Code},
};
#[derive(FromRow, Serialize)]
pub struct OAuth2Session {
@@ -208,173 +211,3 @@ pub async fn get_code_for_session(
.await
.context("could not fetch oauth2 code")
}
#[derive(FromRow, Serialize)]
pub struct OAuth2Code {
id: i64,
oauth2_session_id: i64,
pub code: String,
code_challenge: Option<String>,
code_challenge_method: Option<i16>,
}
pub async fn add_code(
executor: impl Executor<'_, Database = Postgres>,
oauth2_session_id: i64,
code: &str,
code_challenge: &Option<pkce::Request>,
) -> anyhow::Result<OAuth2Code> {
let code_challenge_method = code_challenge
.as_ref()
.map(|c| c.code_challenge_method as i16);
let code_challenge = code_challenge.as_ref().map(|c| &c.code_challenge);
sqlx::query_as!(
OAuth2Code,
r#"
INSERT INTO oauth2_codes
(oauth2_session_id, code, code_challenge_method, code_challenge)
VALUES
($1, $2, $3, $4)
RETURNING
id, oauth2_session_id, code, code_challenge_method, code_challenge
"#,
oauth2_session_id,
code,
code_challenge_method,
code_challenge,
)
.fetch_one(executor)
.await
.context("could not insert oauth2 authorization code")
}
#[derive(FromRow, Serialize)]
pub struct OAuth2AccessToken {
id: i64,
oauth2_session_id: i64,
token: String,
expires_after: i32,
created_at: DateTime<Utc>,
}
pub async fn add_access_token(
executor: impl Executor<'_, Database = Postgres>,
oauth2_session_id: i64,
token: &str,
expires_after: Duration,
) -> anyhow::Result<OAuth2AccessToken> {
// Checked convertion of duration to i32, maxing at i32::MAX
let expires_after = i32::try_from(expires_after.num_seconds()).unwrap_or(i32::MAX);
sqlx::query_as!(
OAuth2AccessToken,
r#"
INSERT INTO oauth2_access_tokens
(oauth2_session_id, token, expires_after)
VALUES
($1, $2, $3)
RETURNING
id, oauth2_session_id, token, expires_after, created_at
"#,
oauth2_session_id,
token,
expires_after,
)
.fetch_one(executor)
.await
.context("could not insert oauth2 access token")
}
pub struct OAuth2AccessTokenLookup {
pub active: bool,
pub username: String,
pub client_id: String,
pub scope: String,
pub created_at: DateTime<Utc>,
expires_after: i32,
}
impl OAuth2AccessTokenLookup {
pub fn exp(&self) -> DateTime<Utc> {
self.created_at + Duration::seconds(i64::from(self.expires_after))
}
}
pub async fn lookup_access_token(
executor: impl Executor<'_, Database = Postgres>,
token: &str,
) -> anyhow::Result<OAuth2AccessTokenLookup> {
sqlx::query_as!(
OAuth2AccessTokenLookup,
r#"
SELECT
u.username AS "username!",
us.active AS "active!",
os.client_id AS "client_id!",
os.scope AS "scope!",
at.created_at AS "created_at!",
at.expires_after AS "expires_after!"
FROM oauth2_access_tokens at
INNER JOIN oauth2_sessions os
ON os.id = at.oauth2_session_id
INNER JOIN user_sessions us
ON us.id = os.user_session_id
INNER JOIN users u
ON u.id = us.user_id
WHERE at.token = $1
"#,
token,
)
.fetch_one(executor)
.await
.context("could not introspect oauth2 access token")
}
pub struct OAuth2CodeLookup {
pub id: i64,
pub oauth2_session_id: i64,
pub client_id: String,
pub redirect_uri: String,
pub scope: String,
}
pub async fn lookup_code(
executor: impl Executor<'_, Database = Postgres>,
code: &str,
) -> anyhow::Result<OAuth2CodeLookup> {
sqlx::query_as!(
OAuth2CodeLookup,
r#"
SELECT
oc.id,
os.id AS "oauth2_session_id!",
os.client_id AS "client_id!",
os.redirect_uri,
os.scope AS "scope!"
FROM oauth2_codes oc
INNER JOIN oauth2_sessions os
ON os.id = oc.oauth2_session_id
WHERE oc.code = $1
"#,
code,
)
.fetch_one(executor)
.await
.context("could not lookup oauth2 code")
}
pub async fn cleanup_expired(
executor: impl Executor<'_, Database = Postgres>,
) -> anyhow::Result<u64> {
let res = sqlx::query!(
r#"
DELETE FROM oauth2_access_tokens
WHERE created_at + (expires_after * INTERVAL '1 second') + INTERVAL '15 minutes' < now()
"#,
)
.execute(executor)
.await
.context("could not cleanup expired access tokens")?;
Ok(res.rows_affected())
}

View File

@@ -23,7 +23,7 @@ struct CleanupExpired(Pool<Postgres>);
#[async_trait::async_trait]
impl Task for CleanupExpired {
async fn run(&self) {
let res = crate::storage::oauth2::cleanup_expired(&self.0).await;
let res = crate::storage::oauth2::access_token::cleanup_expired(&self.0).await;
match res {
Ok(0) => {
debug!("no token to clean up");

View File

@@ -183,14 +183,16 @@ pub enum TokenType {
#[derive(Serialize, Deserialize, Debug, PartialEq)]
pub struct AuthorizationCodeGrant {
pub code: String,
#[serde(default)]
pub redirect_uri: Option<Url>,
}
#[serde_as]
#[derive(Serialize, Deserialize, Debug, PartialEq)]
pub struct RefreshTokenGrant {
refresh_token: String,
pub refresh_token: String,
#[serde(default)]
#[serde_as(as = "Option<StringWithSeparator::<SpaceSeparator, String>>")]
scope: Option<HashSet<String>>,
}
@@ -198,6 +200,7 @@ pub struct RefreshTokenGrant {
#[serde_as]
#[derive(Serialize, Deserialize, Debug, PartialEq)]
pub struct ClientCredentialsGrant {
#[serde(default)]
#[serde_as(as = "Option<StringWithSeparator::<SpaceSeparator, String>>")]
scope: Option<HashSet<String>>,
}