replace custom serde types with serde_with in oauth2-types

This commit is contained in:
Quentin Gliech
2021-07-22 14:30:53 +02:00
parent 05f13f94f8
commit 51539019aa
7 changed files with 55 additions and 182 deletions

1
Cargo.lock generated
View File

@@ -1677,6 +1677,7 @@ dependencies = [
"parse-display",
"serde",
"serde_json",
"serde_with",
"url",
]

View File

@@ -12,8 +12,6 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::collections::HashSet;
use oauth2_types::oidc::Metadata;
use tide::{Body, Request, Response};
@@ -27,10 +25,10 @@ pub async fn get(req: Request<State>) -> tide::Result {
token_endpoint: state.token_endpoint(),
jwks_uri: state.jwks_uri(),
registration_endpoint: None,
scopes_supported: HashSet::default(),
response_types_supported: HashSet::default(),
response_modes_supported: HashSet::default(),
grant_types_supported: HashSet::default(),
scopes_supported: None,
response_types_supported: None,
response_modes_supported: None,
grant_types_supported: None,
};
let body = Body::from_json(&m)?;

View File

@@ -13,3 +13,4 @@ language-tags = { version = "0.3.2", features = ["serde"] }
url = { version = "2.2.2", features = ["serde"] }
parse-display = "0.5.1"
indoc = "1.0.3"
serde_with = "1.9.4"

View File

@@ -19,7 +19,6 @@
pub mod errors;
pub mod oidc;
pub mod requests;
mod types;
#[cfg(test)]
mod test_utils;

View File

@@ -15,11 +15,13 @@
use std::collections::HashSet;
use serde::Serialize;
use serde_with::skip_serializing_none;
use url::Url;
use crate::requests::{GrantType, ResponseMode, ResponseType};
// TODO: https://datatracker.ietf.org/doc/html/rfc8414#section-2
#[skip_serializing_none]
#[derive(Serialize)]
pub struct Metadata {
/// The authorization server's issuer identifier, which is a URL that uses
@@ -27,40 +29,32 @@ pub struct Metadata {
pub issuer: Url,
/// URL of the authorization server's authorization endpoint.
#[serde(skip_serializing_if = "Option::is_none")]
pub authorization_endpoint: Option<Url>,
/// URL of the authorization server's token endpoint.
#[serde(skip_serializing_if = "Option::is_none")]
pub token_endpoint: Option<Url>,
/// URL of the authorization server's JWK Set document.
#[serde(skip_serializing_if = "Option::is_none")]
pub jwks_uri: Option<Url>,
/// URL of the authorization server's OAuth 2.0 Dynamic Client Registration
/// endpoint.
#[serde(skip_serializing_if = "Option::is_none")]
pub registration_endpoint: Option<Url>,
/// JSON array containing a list of the OAuth 2.0 "scope" values that this
/// authorization server supports.
#[serde(skip_serializing_if = "HashSet::is_empty")]
pub scopes_supported: HashSet<String>,
pub scopes_supported: Option<HashSet<String>>,
/// JSON array containing a list of the OAuth 2.0 "response_type" values
/// that this authorization server supports.
#[serde(skip_serializing_if = "HashSet::is_empty")]
pub response_types_supported: HashSet<ResponseType>,
pub response_types_supported: Option<HashSet<ResponseType>>,
/// JSON array containing a list of the OAuth 2.0 "response_mode" values
/// that this authorization server supports, as specified in "OAuth 2.0
/// Multiple Response Type Encoding Practices".
#[serde(skip_serializing_if = "HashSet::is_empty")]
pub response_modes_supported: HashSet<ResponseMode>,
pub response_modes_supported: Option<HashSet<ResponseMode>>,
/// JSON array containing a list of the OAuth 2.0 grant type values that
/// this authorization server supports.
#[serde(skip_serializing_if = "HashSet::is_empty")]
pub grant_types_supported: HashSet<GrantType>,
pub grant_types_supported: Option<HashSet<GrantType>>,
}

View File

@@ -12,15 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
use std::hash::Hash;
use std::{collections::HashSet, hash::Hash, time::Duration};
use language_tags::LanguageTag;
use parse_display::{Display, FromStr};
use serde::{Deserialize, Serialize};
use serde_with::{rust::StringWithSeparator, serde_as, DurationSeconds, SpaceSeparator};
use url::Url;
use crate::types::{Seconds, StringHashSet, StringVec};
// ref: https://www.iana.org/assignments/oauth-parameters/oauth-parameters.xhtml
#[derive(Hash, PartialEq, Eq, PartialOrd, Ord, Display, FromStr, Serialize)]
@@ -60,21 +59,39 @@ pub enum Prompt {
SelectAccount,
}
#[serde_as]
#[derive(Serialize, Deserialize)]
pub struct AuthorizationRequest {
response_type: StringHashSet<ResponseType>,
#[serde_as(as = "StringWithSeparator::<SpaceSeparator, ResponseType>")]
response_type: HashSet<ResponseType>,
client_id: String,
redirect_uri: Option<Url>,
scope: StringHashSet<String>,
#[serde_as(as = "StringWithSeparator::<SpaceSeparator, String>")]
scope: HashSet<String>,
state: Option<String>,
response_mode: Option<ResponseMode>,
nonce: Option<String>,
display: Option<Display>,
max_age: Option<Seconds>,
ui_locales: Option<StringVec<LanguageTag>>,
#[serde_as(as = "Option<DurationSeconds>")]
max_age: Option<Duration>,
#[serde_as(as = "Option<StringWithSeparator::<SpaceSeparator, LanguageTag>>")]
ui_locales: Option<Vec<LanguageTag>>,
id_token_hint: Option<String>,
login_hint: Option<String>,
acr_values: Option<StringHashSet<String>>,
#[serde_as(as = "Option<StringWithSeparator::<SpaceSeparator, String>>")]
acr_values: Option<HashSet<String>>,
}
#[derive(Serialize, Deserialize)]
@@ -95,10 +112,13 @@ pub struct AuthorizationCodeGrant {
redirect_uri: Option<Url>,
}
#[serde_as]
#[derive(Serialize, Deserialize, Debug, PartialEq)]
pub struct RefreshTokenGrant {
refresh_token: String,
scope: Option<StringHashSet<String>>,
#[serde_as(as = "Option<StringWithSeparator::<SpaceSeparator, String>>")]
scope: Option<HashSet<String>>,
}
#[derive(Serialize, Deserialize, Debug, Hash, PartialEq, Eq)]
@@ -117,13 +137,20 @@ pub enum AccessTokenRequest {
Unsupported,
}
#[serde_as]
#[derive(Serialize, Deserialize, Debug, PartialEq)]
pub struct AccessTokenResponse {
access_token: String,
token_type: TokenType,
expires_in: Option<Seconds>,
#[serde_as(as = "Option<DurationSeconds>")]
expires_in: Option<Duration>,
refresh_token: Option<String>,
scope: Option<StringHashSet<String>>,
#[serde_as(as = "Option<StringWithSeparator::<SpaceSeparator, String>>")]
scope: Option<HashSet<String>>,
}
#[cfg(test)]
@@ -140,14 +167,16 @@ mod tests {
let expected = json!({
"grant_type": "refresh_token",
"refresh_token": "abcd",
"scope": "openid profile",
"scope": "openid",
});
let scope = {
let mut s = HashSet::new();
// TODO: insert multiple scopes and test it. It's a bit tricky to test since
// HashSet have no guarantees regarding the ordering of items, so right
// now the output is unstable.
s.insert("openid".to_string());
s.insert("profile".to_string());
Some(s.into())
Some(s)
};
let req = AccessTokenRequest::RefreshToken(RefreshTokenGrant {

View File

@@ -1,149 +0,0 @@
// Copyright 2021 The Matrix.org Foundation C.I.C.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//! Utilitary types for serde
use std::{
collections::{hash_map::RandomState, HashSet},
hash::Hash,
time::Duration,
};
use serde::{Deserialize, Serialize};
/// A `HashSet` that serializes to a space-separated string in alphanumerical
/// order
#[derive(Debug, PartialEq)]
pub struct StringHashSet<T: Eq + Hash>(HashSet<T>);
impl<T: Eq + Hash> From<HashSet<T>> for StringHashSet<T> {
fn from(set: HashSet<T>) -> Self {
Self(set)
}
}
impl<T: Eq + Hash> From<StringHashSet<T>> for HashSet<T, RandomState> {
fn from(set: StringHashSet<T>) -> Self {
set.0
}
}
impl<T> Serialize for StringHashSet<T>
where
T: ToString + PartialOrd + Eq + Hash,
{
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let mut items: Vec<_> = self.0.iter().map(|i| i.to_string()).collect();
items.sort();
let s = items.join(" ");
serializer.serialize_str(&s)
}
}
impl<'de, T> Deserialize<'de> for StringHashSet<T>
where
T: std::str::FromStr + Eq + Hash,
<T as std::str::FromStr>::Err: std::fmt::Display,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let s: String = Deserialize::deserialize(deserializer)?;
let items: Result<HashSet<T>, _> = s.split_ascii_whitespace().map(T::from_str).collect();
items.map(Into::into).map_err(serde::de::Error::custom)
}
}
/// A Vec that serializes to a space-separated string
pub struct StringVec<T>(Vec<T>);
impl<T> From<Vec<T>> for StringVec<T> {
fn from(set: Vec<T>) -> Self {
Self(set)
}
}
impl<T> From<StringVec<T>> for Vec<T> {
fn from(v: StringVec<T>) -> Self {
v.0
}
}
impl<T> Serialize for StringVec<T>
where
T: ToString,
{
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
let items: Vec<_> = self.0.iter().map(|i| i.to_string()).collect();
let s = items.join(" ");
serializer.serialize_str(&s)
}
}
impl<'de, T> Deserialize<'de> for StringVec<T>
where
T: std::str::FromStr + std::hash::Hash + Eq,
<T as std::str::FromStr>::Err: std::fmt::Display,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let s: String = Deserialize::deserialize(deserializer)?;
let items: Result<Vec<T>, _> = s.split_ascii_whitespace().map(T::from_str).collect();
items.map(Into::into).map_err(serde::de::Error::custom)
}
}
/// A Duration that serializes to seconds
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
pub struct Seconds(Duration);
impl From<Duration> for Seconds {
fn from(d: Duration) -> Self {
Self(d)
}
}
impl From<Seconds> for Duration {
fn from(val: Seconds) -> Self {
val.0
}
}
impl Serialize for Seconds {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
self.0.as_secs().serialize(serializer)
}
}
impl<'de> Deserialize<'de> for Seconds {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let secs = u64::deserialize(deserializer)?;
Ok(Self(Duration::from_secs(secs)))
}
}