Basic id_token signing
This commit is contained in:
711
Cargo.lock
generated
711
Cargo.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -7,45 +7,54 @@ license = "Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
# Async runtime
|
||||
tokio = { version = "1.10.0", features = ["full"] }
|
||||
tokio = { version = "1.11.0", features = ["full"] }
|
||||
async-trait = "0.1.51"
|
||||
tokio-stream = "0.1.7"
|
||||
futures-util = "0.3.16"
|
||||
futures-util = "0.3.17"
|
||||
|
||||
# Logging and tracing
|
||||
tracing = "0.1.26"
|
||||
tracing-subscriber = "0.2.19"
|
||||
tracing-subscriber = "0.2.20"
|
||||
|
||||
# Error management
|
||||
thiserror = "1.0.26"
|
||||
anyhow = "1.0.42"
|
||||
thiserror = "1.0.29"
|
||||
anyhow = "1.0.43"
|
||||
|
||||
# Web server
|
||||
warp = "0.3.1"
|
||||
tower = { version = "0.4.8", features = ["full"] }
|
||||
tower-http = { version = "0.1.1", features = ["full"] }
|
||||
hyper = { version = "0.14.11", features = ["full"] }
|
||||
hyper = { version = "0.14.12", features = ["full"] }
|
||||
|
||||
# Template engine
|
||||
tera = "1.12.1"
|
||||
|
||||
# Database access
|
||||
sqlx = { version = "0.5.5", features = ["runtime-tokio-rustls", "postgres", "migrate", "chrono", "offline"] }
|
||||
sqlx = { version = "0.5.7", features = ["runtime-tokio-rustls", "postgres", "migrate", "chrono", "offline"] }
|
||||
|
||||
# Various structure (de)serialization
|
||||
serde = { version = "1.0.127", features = ["derive"] }
|
||||
serde_yaml = "0.8.17"
|
||||
serde_with = { version = "1.9.4", features = ["hex", "chrono"] }
|
||||
serde = { version = "1.0.130", features = ["derive"] }
|
||||
serde_yaml = "0.8.20"
|
||||
serde_with = { version = "1.10.0", features = ["hex", "chrono"] }
|
||||
serde_json = "1.0.67"
|
||||
serde_urlencoded = "0.7.0"
|
||||
|
||||
# Argument & config parsing
|
||||
clap = "3.0.0-beta.2"
|
||||
clap = "3.0.0-beta.4"
|
||||
figment = { version = "0.10.6", features = ["env", "yaml", "test"] }
|
||||
schemars = { version = "0.8.3", features = ["url", "chrono"] }
|
||||
dotenv = "0.15.0"
|
||||
|
||||
# Password hashing
|
||||
argon2 = { version = "0.2.2", features = ["password-hash"] }
|
||||
password-hash = { version = "0.2.2", features = ["std"] }
|
||||
argon2 = { version = "0.3.0", features = ["password-hash"] }
|
||||
password-hash = { version = "0.3.0", features = ["std"] }
|
||||
|
||||
# Crypto and signing stuff
|
||||
rsa = "0.5.0"
|
||||
k256 = "0.9.6"
|
||||
pkcs8 = { version = "0.7.5", features = ["pem"] }
|
||||
elliptic-curve = { version = "0.10.6", features = ["pem"] }
|
||||
chacha20poly1305 = { version = "0.9.0", features = ["std"] }
|
||||
|
||||
# Various data types and utilities
|
||||
data-encoding = "2.3.2"
|
||||
@@ -57,9 +66,15 @@ rand = "0.8.4"
|
||||
bincode = "1.3.3"
|
||||
headers = "0.3.4"
|
||||
cookie = "0.15.1"
|
||||
chacha20poly1305 = { version = "0.8.1", features = ["std"] }
|
||||
crc = "2.0.0"
|
||||
|
||||
oauth2-types = { path = "../oauth2-types", features = ["sqlx_type"] }
|
||||
serde_json = "1.0.66"
|
||||
serde_urlencoded = "0.7.0"
|
||||
crc = "2.0.0"
|
||||
|
||||
[dependencies.jwt-compact]
|
||||
# Waiting on the next release because of the bump of the `rsa` dependency
|
||||
git = "https://github.com/slowli/jwt-compact.git"
|
||||
rev = "7a6dee6824c1d4e7c7f81019c9a968e5c9e44923"
|
||||
features = ["rsa", "k256"]
|
||||
|
||||
[dev-dependencies]
|
||||
indoc = "1.0.3"
|
||||
|
||||
@@ -27,13 +27,12 @@ pub use self::{
|
||||
csrf::CsrfConfig,
|
||||
database::DatabaseConfig,
|
||||
http::HttpConfig,
|
||||
oauth2::{OAuth2ClientConfig, OAuth2Config},
|
||||
oauth2::{Algorithm, KeySet, OAuth2ClientConfig, OAuth2Config},
|
||||
util::ConfigurationSection,
|
||||
};
|
||||
|
||||
#[derive(Debug, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct RootConfig {
|
||||
#[serde(default)]
|
||||
pub oauth2: OAuth2Config,
|
||||
|
||||
#[serde(default)]
|
||||
|
||||
@@ -12,14 +12,262 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::convert::TryFrom;
|
||||
|
||||
use anyhow::Context;
|
||||
use jwt_compact::{
|
||||
alg::{self, StrongAlg, StrongKey},
|
||||
jwk::JsonWebKey,
|
||||
AlgorithmExt, Claims, Header,
|
||||
};
|
||||
use pkcs8::{FromPrivateKey, ToPrivateKey};
|
||||
use rsa::RsaPrivateKey;
|
||||
use schemars::JsonSchema;
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde::{
|
||||
de::{MapAccess, Visitor},
|
||||
ser::SerializeStruct,
|
||||
Deserialize, Serialize,
|
||||
};
|
||||
use serde_with::skip_serializing_none;
|
||||
use thiserror::Error;
|
||||
use url::Url;
|
||||
|
||||
use super::ConfigurationSection;
|
||||
|
||||
// TODO: a lot of the signing logic should go out somewhere else
|
||||
|
||||
const RS256: StrongAlg<alg::Rsa> = StrongAlg(alg::Rsa::rs256());
|
||||
|
||||
#[derive(Serialize, Deserialize, Clone, Copy)]
|
||||
#[serde(rename_all = "UPPERCASE")]
|
||||
pub enum Algorithm {
|
||||
Rs256,
|
||||
Es256k,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Clone)]
|
||||
pub struct Jwk {
|
||||
kid: String,
|
||||
alg: Algorithm,
|
||||
|
||||
#[serde(flatten)]
|
||||
inner: serde_json::Value,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Clone)]
|
||||
pub struct Jwks {
|
||||
keys: Vec<Jwk>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Clone)]
|
||||
#[serde(transparent)]
|
||||
pub struct KeySet(Vec<Key>);
|
||||
|
||||
impl KeySet {
|
||||
pub fn to_public_jwks(&self) -> Jwks {
|
||||
let keys = self.0.iter().map(Key::to_public_jwk).collect();
|
||||
Jwks { keys }
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn token<T>(
|
||||
&self,
|
||||
alg: Algorithm,
|
||||
header: Header,
|
||||
claims: &Claims<T>,
|
||||
) -> anyhow::Result<String>
|
||||
where
|
||||
T: Serialize,
|
||||
{
|
||||
match alg {
|
||||
Algorithm::Rs256 => {
|
||||
let (kid, key) = self
|
||||
.0
|
||||
.iter()
|
||||
.find_map(Key::rsa)
|
||||
.context("could not find RSA key")?;
|
||||
let header = header.with_key_id(kid);
|
||||
// TODO: store them as strong keys
|
||||
RS256
|
||||
.token(header, claims, &StrongKey::try_from(key.clone())?)
|
||||
.context("failed to sign token")
|
||||
}
|
||||
Algorithm::Es256k => {
|
||||
// TODO: make this const with lazy_static?
|
||||
let es256k: alg::Es256k = alg::Es256k::default();
|
||||
let (kid, key) = self
|
||||
.0
|
||||
.iter()
|
||||
.find_map(Key::ecdsa)
|
||||
.context("could not find ECDSA key")?;
|
||||
let key = k256::ecdsa::SigningKey::from(key);
|
||||
let header = header.with_key_id(kid);
|
||||
// TODO: use StrongAlg
|
||||
es256k
|
||||
.token(header, claims, &key)
|
||||
.context("failed to sign token")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, Clone)]
|
||||
#[non_exhaustive]
|
||||
pub enum Key {
|
||||
Rsa { key: RsaPrivateKey, kid: String },
|
||||
Ecdsa { key: k256::SecretKey, kid: String },
|
||||
}
|
||||
|
||||
impl Key {
|
||||
fn from_ecdsa(key: k256::SecretKey) -> Self {
|
||||
// TODO: hash the key and use as KID
|
||||
let kid = String::from("ecdsa-kid");
|
||||
Self::Ecdsa { kid, key }
|
||||
}
|
||||
|
||||
fn from_ecdsa_pem(key: &str) -> anyhow::Result<Self> {
|
||||
let key = k256::SecretKey::from_pkcs8_pem(key)?;
|
||||
Ok(Self::from_ecdsa(key))
|
||||
}
|
||||
|
||||
fn from_rsa(key: RsaPrivateKey) -> Self {
|
||||
// TODO: hash the key and use as KID
|
||||
let kid = String::from("rsa-kid");
|
||||
Self::Rsa { kid, key }
|
||||
}
|
||||
|
||||
fn from_rsa_pem(key: &str) -> anyhow::Result<Self> {
|
||||
let key = RsaPrivateKey::from_pkcs8_pem(key)?;
|
||||
Ok(Self::from_rsa(key))
|
||||
}
|
||||
|
||||
fn to_public_jwk(&self) -> Jwk {
|
||||
match self {
|
||||
Key::Rsa { key, kid } => {
|
||||
let pubkey = key.to_public_key();
|
||||
let inner = JsonWebKey::from(&pubkey);
|
||||
let inner = serde_json::to_value(&inner).unwrap();
|
||||
let kid = kid.to_string();
|
||||
let alg = Algorithm::Rs256;
|
||||
Jwk { kid, alg, inner }
|
||||
}
|
||||
Key::Ecdsa { key, kid } => {
|
||||
let pubkey = k256::ecdsa::VerifyingKey::from(key.public_key());
|
||||
let inner = JsonWebKey::from(&pubkey);
|
||||
let inner = serde_json::to_value(&inner).unwrap();
|
||||
let kid = kid.to_string();
|
||||
let alg = Algorithm::Es256k;
|
||||
Jwk { kid, alg, inner }
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn rsa(&self) -> Option<(&str, &RsaPrivateKey)> {
|
||||
match self {
|
||||
Key::Rsa { key, kid } => Some((kid, key)),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
|
||||
fn ecdsa(&self) -> Option<(&str, &k256::SecretKey)> {
|
||||
match self {
|
||||
Key::Ecdsa { key, kid } => Some((kid, key)),
|
||||
_ => None,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for Key {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
let mut map = serializer.serialize_struct("Key", 2)?;
|
||||
match self {
|
||||
Key::Rsa { key, kid: _ } => {
|
||||
map.serialize_field("type", "rsa")?;
|
||||
let pem = key.to_pkcs8_pem().map_err(serde::ser::Error::custom)?;
|
||||
map.serialize_field("key", pem.as_str())?;
|
||||
}
|
||||
Key::Ecdsa { key, kid: _ } => {
|
||||
map.serialize_field("type", "ecdsa")?;
|
||||
let pem = key.to_pkcs8_pem().map_err(serde::ser::Error::custom)?;
|
||||
map.serialize_field("key", pem.as_str())?;
|
||||
}
|
||||
}
|
||||
|
||||
map.end()
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for Key {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
#[derive(Deserialize, Debug)]
|
||||
#[serde(field_identifier, rename_all = "lowercase")]
|
||||
enum Field {
|
||||
Type,
|
||||
Key,
|
||||
}
|
||||
|
||||
#[derive(Deserialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
enum KeyType {
|
||||
Rsa,
|
||||
Ecdsa,
|
||||
}
|
||||
|
||||
struct KeyVisitor;
|
||||
|
||||
impl<'de> Visitor<'de> for KeyVisitor {
|
||||
type Value = Key;
|
||||
|
||||
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
|
||||
formatter.write_str("struct Key")
|
||||
}
|
||||
|
||||
fn visit_map<V>(self, mut map: V) -> Result<Key, V::Error>
|
||||
where
|
||||
V: MapAccess<'de>,
|
||||
{
|
||||
let mut key_type = None;
|
||||
let mut key_key = None;
|
||||
while let Some(key) = map.next_key()? {
|
||||
match key {
|
||||
Field::Type => {
|
||||
if key_type.is_some() {
|
||||
return Err(serde::de::Error::duplicate_field("type"));
|
||||
}
|
||||
key_type = Some(map.next_value()?);
|
||||
}
|
||||
Field::Key => {
|
||||
if key_key.is_some() {
|
||||
return Err(serde::de::Error::duplicate_field("key"));
|
||||
}
|
||||
key_key = Some(map.next_value()?);
|
||||
}
|
||||
}
|
||||
}
|
||||
let key_type: KeyType =
|
||||
key_type.ok_or_else(|| serde::de::Error::missing_field("type"))?;
|
||||
let key_key: String =
|
||||
key_key.ok_or_else(|| serde::de::Error::missing_field("key"))?;
|
||||
|
||||
match key_type {
|
||||
KeyType::Rsa => Key::from_rsa_pem(&key_key).map_err(serde::de::Error::custom),
|
||||
KeyType::Ecdsa => {
|
||||
Key::from_ecdsa_pem(&key_key).map_err(serde::de::Error::custom)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
deserializer.deserialize_struct("Key", &["type", "key"], KeyVisitor)
|
||||
}
|
||||
}
|
||||
|
||||
#[skip_serializing_none]
|
||||
#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
|
||||
pub struct OAuth2ClientConfig {
|
||||
@@ -69,15 +317,9 @@ pub struct OAuth2Config {
|
||||
|
||||
#[serde(default)]
|
||||
pub clients: Vec<OAuth2ClientConfig>,
|
||||
}
|
||||
|
||||
impl Default for OAuth2Config {
|
||||
fn default() -> Self {
|
||||
Self {
|
||||
issuer: default_oauth2_issuer(),
|
||||
clients: Vec::new(),
|
||||
}
|
||||
}
|
||||
#[schemars(with = "Vec<String>")]
|
||||
pub keys: KeySet,
|
||||
}
|
||||
|
||||
impl OAuth2Config {
|
||||
@@ -86,6 +328,37 @@ impl OAuth2Config {
|
||||
.join(".well-known/openid-configuration")
|
||||
.expect("could not build discovery url")
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
pub fn test() -> Self {
|
||||
let rsa_key = Key::from_rsa_pem(indoc::indoc! {r#"
|
||||
-----BEGIN PRIVATE KEY-----
|
||||
MIIBVQIBADANBgkqhkiG9w0BAQEFAASCAT8wggE7AgEAAkEAymS2RkeIZo7pUeEN
|
||||
QUGCG4GLJru5jzxomO9jiNr5D/oRcerhpQVc9aCpBfAAg4l4a1SmYdBzWqX0X5pU
|
||||
scgTtQIDAQABAkEArNIMlrxUK4bSklkCcXtXdtdKE9vuWfGyOw0GyAB69fkEUBxh
|
||||
3j65u+u3ZmW+bpMWHgp1FtdobE9nGwb2VBTWAQIhAOyU1jiUEkrwKK004+6b5QRE
|
||||
vC9UI2vDWy5vioMNx5Y1AiEA2wGAJ6ETF8FF2Vd+kZlkKK7J0em9cl0gbJDsWIEw
|
||||
N4ECIEyWYkMurD1WQdTQqnk0Po+DMOihdFYOiBYgRdbnPxWBAiEAmtd0xJAd7622
|
||||
tPQniMnrBtiN2NxqFXHCev/8Gpc8gAECIBcaPcF59qVeRmYrfqzKBxFm7LmTwlAl
|
||||
Gh7BNzCeN+D6
|
||||
-----END PRIVATE KEY-----
|
||||
"#})
|
||||
.unwrap();
|
||||
let ecdsa_key = Key::from_rsa_pem(indoc::indoc! {r#"
|
||||
-----BEGIN PRIVATE KEY-----
|
||||
MIGEAgEAMBAGByqGSM49AgEGBSuBBAAKBG0wawIBAQQgqfn5mYO/5Qq/wOOiWgHA
|
||||
NaiDiepgUJ2GI5eq2V8D8nahRANCAARMK9aKUd/H28qaU+0qvS6bSJItzAge1VHn
|
||||
OhBAAUVci1RpmUA+KdCL5sw9nadAEiONeiGr+28RYHZmlB9qXnjC
|
||||
-----END PRIVATE KEY-----
|
||||
"#})
|
||||
.unwrap();
|
||||
|
||||
Self {
|
||||
issuer: default_oauth2_issuer(),
|
||||
clients: Vec::new(),
|
||||
keys: KeySet(vec![rsa_key, ecdsa_key]),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl ConfigurationSection<'_> for OAuth2Config {
|
||||
@@ -94,7 +367,14 @@ impl ConfigurationSection<'_> for OAuth2Config {
|
||||
}
|
||||
|
||||
fn generate() -> Self {
|
||||
Self::default()
|
||||
let mut rng = rand::thread_rng();
|
||||
let rsa_key = RsaPrivateKey::new(&mut rng, 2048).unwrap();
|
||||
let ecdsa_key = k256::SecretKey::random(rng);
|
||||
Self {
|
||||
issuer: default_oauth2_issuer(),
|
||||
clients: Vec::new(),
|
||||
keys: KeySet(vec![Key::from_rsa(rsa_key), Key::from_ecdsa(ecdsa_key)]),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -111,6 +391,26 @@ mod tests {
|
||||
"config.yaml",
|
||||
r#"
|
||||
oauth2:
|
||||
keys:
|
||||
- type: rsa
|
||||
key: |
|
||||
-----BEGIN PRIVATE KEY-----
|
||||
MIIBVQIBADANBgkqhkiG9w0BAQEFAASCAT8wggE7AgEAAkEAymS2RkeIZo7pUeEN
|
||||
QUGCG4GLJru5jzxomO9jiNr5D/oRcerhpQVc9aCpBfAAg4l4a1SmYdBzWqX0X5pU
|
||||
scgTtQIDAQABAkEArNIMlrxUK4bSklkCcXtXdtdKE9vuWfGyOw0GyAB69fkEUBxh
|
||||
3j65u+u3ZmW+bpMWHgp1FtdobE9nGwb2VBTWAQIhAOyU1jiUEkrwKK004+6b5QRE
|
||||
vC9UI2vDWy5vioMNx5Y1AiEA2wGAJ6ETF8FF2Vd+kZlkKK7J0em9cl0gbJDsWIEw
|
||||
N4ECIEyWYkMurD1WQdTQqnk0Po+DMOihdFYOiBYgRdbnPxWBAiEAmtd0xJAd7622
|
||||
tPQniMnrBtiN2NxqFXHCev/8Gpc8gAECIBcaPcF59qVeRmYrfqzKBxFm7LmTwlAl
|
||||
Gh7BNzCeN+D6
|
||||
-----END PRIVATE KEY-----
|
||||
- type: ecdsa
|
||||
key: |
|
||||
-----BEGIN PRIVATE KEY-----
|
||||
MIGEAgEAMBAGByqGSM49AgEGBSuBBAAKBG0wawIBAQQgqfn5mYO/5Qq/wOOiWgHA
|
||||
NaiDiepgUJ2GI5eq2V8D8nahRANCAARMK9aKUd/H28qaU+0qvS6bSJItzAge1VHn
|
||||
OhBAAUVci1RpmUA+KdCL5sw9nadAEiONeiGr+28RYHZmlB9qXnjC
|
||||
-----END PRIVATE KEY-----
|
||||
issuer: https://example.com
|
||||
clients:
|
||||
- client_id: hello
|
||||
|
||||
@@ -149,7 +149,7 @@ mod tests {
|
||||
use super::*;
|
||||
|
||||
fn oauth2_config() -> OAuth2Config {
|
||||
let mut config = OAuth2Config::default();
|
||||
let mut config = OAuth2Config::test();
|
||||
config.clients.push(OAuth2ClientConfig {
|
||||
client_id: "public".to_string(),
|
||||
client_secret: None,
|
||||
|
||||
@@ -26,7 +26,10 @@ use std::convert::Infallible;
|
||||
use warp::Filter;
|
||||
|
||||
pub use self::csrf::CsrfToken;
|
||||
use crate::templates::Templates;
|
||||
use crate::{
|
||||
config::{KeySet, OAuth2Config},
|
||||
templates::Templates,
|
||||
};
|
||||
|
||||
pub fn with_templates(
|
||||
templates: &Templates,
|
||||
@@ -34,3 +37,10 @@ pub fn with_templates(
|
||||
let templates = templates.clone();
|
||||
warp::any().map(move || templates.clone())
|
||||
}
|
||||
|
||||
pub fn with_keys(
|
||||
oauth2_config: &OAuth2Config,
|
||||
) -> impl Filter<Extract = (KeySet,), Error = Infallible> + Clone + Send + Sync + 'static {
|
||||
let keyset = oauth2_config.keys.clone();
|
||||
warp::any().map(move || keyset.clone())
|
||||
}
|
||||
|
||||
@@ -47,7 +47,7 @@ pub(super) fn filter(
|
||||
let metadata = Metadata {
|
||||
authorization_endpoint: base.join("oauth2/authorize").ok(),
|
||||
token_endpoint: base.join("oauth2/token").ok(),
|
||||
jwks_uri: base.join(".well-known/jwks.json").ok(),
|
||||
jwks_uri: base.join("oauth2/keys.json").ok(),
|
||||
introspection_endpoint: base.join("oauth2/introspect").ok(),
|
||||
issuer: base,
|
||||
registration_endpoint: None,
|
||||
|
||||
30
matrix-authentication-service/src/handlers/oauth2/keys.rs
Normal file
30
matrix-authentication-service/src/handlers/oauth2/keys.rs
Normal file
@@ -0,0 +1,30 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use warp::{Filter, Rejection, Reply};
|
||||
|
||||
use crate::config::OAuth2Config;
|
||||
|
||||
pub(super) fn filter(
|
||||
config: &OAuth2Config,
|
||||
) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone + Send + Sync + 'static {
|
||||
let jwks = config.keys.to_public_jwks();
|
||||
|
||||
let cors = warp::cors().allow_any_origin();
|
||||
|
||||
warp::path!("oauth2" / "keys.json")
|
||||
.and(warp::get())
|
||||
.map(move || warp::reply::json(&jwks))
|
||||
.with(cors)
|
||||
}
|
||||
@@ -23,11 +23,12 @@ use crate::{
|
||||
mod authorization;
|
||||
mod discovery;
|
||||
mod introspection;
|
||||
mod keys;
|
||||
mod token;
|
||||
|
||||
use self::{
|
||||
authorization::filter as authorization, discovery::filter as discovery,
|
||||
introspection::filter as introspection, token::filter as token,
|
||||
introspection::filter as introspection, keys::filter as keys, token::filter as token,
|
||||
};
|
||||
|
||||
pub fn filter(
|
||||
@@ -37,6 +38,7 @@ pub fn filter(
|
||||
cookies_config: &CookiesConfig,
|
||||
) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone + Send + Sync + 'static {
|
||||
discovery(oauth2_config)
|
||||
.or(keys(oauth2_config))
|
||||
.or(authorization(
|
||||
pool,
|
||||
templates,
|
||||
|
||||
@@ -12,7 +12,9 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use anyhow::Context;
|
||||
use chrono::Duration;
|
||||
use jwt_compact::{Claims, Header, TimeOptions};
|
||||
use oauth2_types::{
|
||||
errors::{InvalidGrant, OAuth2Error},
|
||||
requests::{
|
||||
@@ -20,15 +22,18 @@ use oauth2_types::{
|
||||
},
|
||||
};
|
||||
use rand::thread_rng;
|
||||
use serde::Serialize;
|
||||
use sqlx::{pool::PoolConnection, Acquire, PgPool, Postgres};
|
||||
use url::Url;
|
||||
use warp::{Filter, Rejection, Reply};
|
||||
|
||||
use crate::{
|
||||
config::{OAuth2ClientConfig, OAuth2Config},
|
||||
config::{KeySet, OAuth2ClientConfig, OAuth2Config},
|
||||
errors::WrapError,
|
||||
filters::{
|
||||
client::{with_client_auth, ClientAuthentication},
|
||||
database::with_connection,
|
||||
with_keys,
|
||||
},
|
||||
storage::oauth2::{
|
||||
access_token::{add_access_token, revoke_access_token},
|
||||
@@ -38,13 +43,26 @@ use crate::{
|
||||
tokens,
|
||||
};
|
||||
|
||||
#[derive(Serialize)]
|
||||
struct CustomClaims {
|
||||
#[serde(rename = "iss")]
|
||||
issuer: Url,
|
||||
#[serde(rename = "sub")]
|
||||
subject: String,
|
||||
#[serde(rename = "aud")]
|
||||
audiences: Vec<String>,
|
||||
}
|
||||
|
||||
pub fn filter(
|
||||
pool: &PgPool,
|
||||
oauth2_config: &OAuth2Config,
|
||||
) -> impl Filter<Extract = (impl Reply,), Error = Rejection> + Clone + Send + Sync + 'static {
|
||||
let issuer = oauth2_config.issuer.clone();
|
||||
warp::path!("oauth2" / "token")
|
||||
.and(warp::post())
|
||||
.and(with_client_auth(oauth2_config))
|
||||
.and(with_keys(oauth2_config))
|
||||
.and(warp::any().map(move || issuer.clone()))
|
||||
.and(with_connection(pool))
|
||||
.and_then(token)
|
||||
}
|
||||
@@ -53,11 +71,13 @@ async fn token(
|
||||
_auth: ClientAuthentication,
|
||||
client: OAuth2ClientConfig,
|
||||
req: AccessTokenRequest,
|
||||
keys: KeySet,
|
||||
issuer: Url,
|
||||
mut conn: PoolConnection<Postgres>,
|
||||
) -> Result<impl Reply, Rejection> {
|
||||
let reply = match req {
|
||||
AccessTokenRequest::AuthorizationCode(grant) => {
|
||||
let reply = authorization_code_grant(&grant, &client, &mut conn).await?;
|
||||
let reply = authorization_code_grant(&grant, &client, &keys, issuer, &mut conn).await?;
|
||||
warp::reply::json(&reply)
|
||||
}
|
||||
AccessTokenRequest::RefreshToken(grant) => {
|
||||
@@ -76,6 +96,8 @@ async fn token(
|
||||
async fn authorization_code_grant(
|
||||
grant: &AuthorizationCodeGrant,
|
||||
client: &OAuth2ClientConfig,
|
||||
keys: &KeySet,
|
||||
issuer: Url,
|
||||
conn: &mut PoolConnection<Postgres>,
|
||||
) -> Result<AccessTokenResponse, Rejection> {
|
||||
let mut txn = conn.begin().await.wrap_error()?;
|
||||
@@ -108,11 +130,26 @@ async fn authorization_code_grant(
|
||||
.await
|
||||
.wrap_error()?;
|
||||
|
||||
// TODO: generate id_token if the "openid" scope was asked
|
||||
// TODO: generate id_token only if the "openid" scope was asked
|
||||
let header = Header::default();
|
||||
let options = TimeOptions::default();
|
||||
let claims = Claims::new(CustomClaims {
|
||||
issuer,
|
||||
// TODO: get that from the session
|
||||
subject: "random-subject".to_string(),
|
||||
audiences: vec![client.client_id.clone()],
|
||||
})
|
||||
.set_duration_and_issuance(&options, Duration::minutes(30));
|
||||
let id_token = keys
|
||||
.token(crate::config::Algorithm::Rs256, header, &claims)
|
||||
.context("could not sign ID token")
|
||||
.wrap_error()?;
|
||||
|
||||
// TODO: have the scopes back here
|
||||
let params = AccessTokenResponse::new(access_token.token)
|
||||
.with_expires_in(ttl)
|
||||
.with_refresh_token(refresh_token.token);
|
||||
.with_refresh_token(refresh_token.token)
|
||||
.with_id_token(id_token);
|
||||
|
||||
txn.commit().await.wrap_error()?;
|
||||
|
||||
|
||||
@@ -7,14 +7,14 @@ license = "Apache-2.0"
|
||||
|
||||
[dependencies]
|
||||
http = "0.2.4"
|
||||
serde = "1.0.127"
|
||||
serde_json = "1.0.66"
|
||||
serde = "1.0.130"
|
||||
serde_json = "1.0.67"
|
||||
language-tags = { version = "0.3.2", features = ["serde"] }
|
||||
url = { version = "2.2.2", features = ["serde"] }
|
||||
parse-display = "0.5.1"
|
||||
indoc = "1.0.3"
|
||||
serde_with = { version = "1.9.4", features = ["chrono"] }
|
||||
sqlx = { version = "0.5.5", default-features = false, optional = true }
|
||||
serde_with = { version = "1.10.0", features = ["chrono"] }
|
||||
sqlx = { version = "0.5.7", default-features = false, optional = true }
|
||||
chrono = "0.4.19"
|
||||
|
||||
[features]
|
||||
|
||||
@@ -237,10 +237,13 @@ pub enum AccessTokenRequest {
|
||||
}
|
||||
|
||||
#[serde_as]
|
||||
#[skip_serializing_none]
|
||||
#[derive(Serialize, Deserialize, Debug, PartialEq)]
|
||||
pub struct AccessTokenResponse {
|
||||
access_token: String,
|
||||
refresh_token: Option<String>,
|
||||
// TODO: this should be somewhere else
|
||||
id_token: Option<String>,
|
||||
|
||||
token_type: TokenType,
|
||||
|
||||
@@ -257,6 +260,7 @@ impl AccessTokenResponse {
|
||||
AccessTokenResponse {
|
||||
access_token,
|
||||
refresh_token: None,
|
||||
id_token: None,
|
||||
token_type: TokenType::Bearer,
|
||||
expires_in: None,
|
||||
scope: None,
|
||||
@@ -269,6 +273,12 @@ impl AccessTokenResponse {
|
||||
self
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn with_id_token(mut self, id_token: String) -> Self {
|
||||
self.id_token = Some(id_token);
|
||||
self
|
||||
}
|
||||
|
||||
#[must_use]
|
||||
pub fn with_scopes(mut self, scope: HashSet<String>) -> Self {
|
||||
self.scope = Some(scope);
|
||||
|
||||
Reference in New Issue
Block a user