replace custom serde types with serde_with in oauth2-types
This commit is contained in:
1
Cargo.lock
generated
1
Cargo.lock
generated
@@ -1677,6 +1677,7 @@ dependencies = [
|
||||
"parse-display",
|
||||
"serde",
|
||||
"serde_json",
|
||||
"serde_with",
|
||||
"url",
|
||||
]
|
||||
|
||||
|
||||
@@ -12,8 +12,6 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::collections::HashSet;
|
||||
|
||||
use oauth2_types::oidc::Metadata;
|
||||
use tide::{Body, Request, Response};
|
||||
|
||||
@@ -27,10 +25,10 @@ pub async fn get(req: Request<State>) -> tide::Result {
|
||||
token_endpoint: state.token_endpoint(),
|
||||
jwks_uri: state.jwks_uri(),
|
||||
registration_endpoint: None,
|
||||
scopes_supported: HashSet::default(),
|
||||
response_types_supported: HashSet::default(),
|
||||
response_modes_supported: HashSet::default(),
|
||||
grant_types_supported: HashSet::default(),
|
||||
scopes_supported: None,
|
||||
response_types_supported: None,
|
||||
response_modes_supported: None,
|
||||
grant_types_supported: None,
|
||||
};
|
||||
|
||||
let body = Body::from_json(&m)?;
|
||||
|
||||
@@ -13,3 +13,4 @@ language-tags = { version = "0.3.2", features = ["serde"] }
|
||||
url = { version = "2.2.2", features = ["serde"] }
|
||||
parse-display = "0.5.1"
|
||||
indoc = "1.0.3"
|
||||
serde_with = "1.9.4"
|
||||
|
||||
@@ -19,7 +19,6 @@
|
||||
pub mod errors;
|
||||
pub mod oidc;
|
||||
pub mod requests;
|
||||
mod types;
|
||||
|
||||
#[cfg(test)]
|
||||
mod test_utils;
|
||||
|
||||
@@ -15,11 +15,13 @@
|
||||
use std::collections::HashSet;
|
||||
|
||||
use serde::Serialize;
|
||||
use serde_with::skip_serializing_none;
|
||||
use url::Url;
|
||||
|
||||
use crate::requests::{GrantType, ResponseMode, ResponseType};
|
||||
|
||||
// TODO: https://datatracker.ietf.org/doc/html/rfc8414#section-2
|
||||
#[skip_serializing_none]
|
||||
#[derive(Serialize)]
|
||||
pub struct Metadata {
|
||||
/// The authorization server's issuer identifier, which is a URL that uses
|
||||
@@ -27,40 +29,32 @@ pub struct Metadata {
|
||||
pub issuer: Url,
|
||||
|
||||
/// URL of the authorization server's authorization endpoint.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub authorization_endpoint: Option<Url>,
|
||||
|
||||
/// URL of the authorization server's token endpoint.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub token_endpoint: Option<Url>,
|
||||
|
||||
/// URL of the authorization server's JWK Set document.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub jwks_uri: Option<Url>,
|
||||
|
||||
/// URL of the authorization server's OAuth 2.0 Dynamic Client Registration
|
||||
/// endpoint.
|
||||
#[serde(skip_serializing_if = "Option::is_none")]
|
||||
pub registration_endpoint: Option<Url>,
|
||||
|
||||
/// JSON array containing a list of the OAuth 2.0 "scope" values that this
|
||||
/// authorization server supports.
|
||||
#[serde(skip_serializing_if = "HashSet::is_empty")]
|
||||
pub scopes_supported: HashSet<String>,
|
||||
pub scopes_supported: Option<HashSet<String>>,
|
||||
|
||||
/// JSON array containing a list of the OAuth 2.0 "response_type" values
|
||||
/// that this authorization server supports.
|
||||
#[serde(skip_serializing_if = "HashSet::is_empty")]
|
||||
pub response_types_supported: HashSet<ResponseType>,
|
||||
pub response_types_supported: Option<HashSet<ResponseType>>,
|
||||
|
||||
/// JSON array containing a list of the OAuth 2.0 "response_mode" values
|
||||
/// that this authorization server supports, as specified in "OAuth 2.0
|
||||
/// Multiple Response Type Encoding Practices".
|
||||
#[serde(skip_serializing_if = "HashSet::is_empty")]
|
||||
pub response_modes_supported: HashSet<ResponseMode>,
|
||||
pub response_modes_supported: Option<HashSet<ResponseMode>>,
|
||||
|
||||
/// JSON array containing a list of the OAuth 2.0 grant type values that
|
||||
/// this authorization server supports.
|
||||
#[serde(skip_serializing_if = "HashSet::is_empty")]
|
||||
pub grant_types_supported: HashSet<GrantType>,
|
||||
pub grant_types_supported: Option<HashSet<GrantType>>,
|
||||
}
|
||||
|
||||
@@ -12,15 +12,14 @@
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
use std::hash::Hash;
|
||||
use std::{collections::HashSet, hash::Hash, time::Duration};
|
||||
|
||||
use language_tags::LanguageTag;
|
||||
use parse_display::{Display, FromStr};
|
||||
use serde::{Deserialize, Serialize};
|
||||
use serde_with::{rust::StringWithSeparator, serde_as, DurationSeconds, SpaceSeparator};
|
||||
use url::Url;
|
||||
|
||||
use crate::types::{Seconds, StringHashSet, StringVec};
|
||||
|
||||
// ref: https://www.iana.org/assignments/oauth-parameters/oauth-parameters.xhtml
|
||||
|
||||
#[derive(Hash, PartialEq, Eq, PartialOrd, Ord, Display, FromStr, Serialize)]
|
||||
@@ -60,21 +59,39 @@ pub enum Prompt {
|
||||
SelectAccount,
|
||||
}
|
||||
|
||||
#[serde_as]
|
||||
#[derive(Serialize, Deserialize)]
|
||||
pub struct AuthorizationRequest {
|
||||
response_type: StringHashSet<ResponseType>,
|
||||
#[serde_as(as = "StringWithSeparator::<SpaceSeparator, ResponseType>")]
|
||||
response_type: HashSet<ResponseType>,
|
||||
|
||||
client_id: String,
|
||||
|
||||
redirect_uri: Option<Url>,
|
||||
scope: StringHashSet<String>,
|
||||
|
||||
#[serde_as(as = "StringWithSeparator::<SpaceSeparator, String>")]
|
||||
scope: HashSet<String>,
|
||||
|
||||
state: Option<String>,
|
||||
|
||||
response_mode: Option<ResponseMode>,
|
||||
|
||||
nonce: Option<String>,
|
||||
|
||||
display: Option<Display>,
|
||||
max_age: Option<Seconds>,
|
||||
ui_locales: Option<StringVec<LanguageTag>>,
|
||||
|
||||
#[serde_as(as = "Option<DurationSeconds>")]
|
||||
max_age: Option<Duration>,
|
||||
|
||||
#[serde_as(as = "Option<StringWithSeparator::<SpaceSeparator, LanguageTag>>")]
|
||||
ui_locales: Option<Vec<LanguageTag>>,
|
||||
|
||||
id_token_hint: Option<String>,
|
||||
|
||||
login_hint: Option<String>,
|
||||
acr_values: Option<StringHashSet<String>>,
|
||||
|
||||
#[serde_as(as = "Option<StringWithSeparator::<SpaceSeparator, String>>")]
|
||||
acr_values: Option<HashSet<String>>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize)]
|
||||
@@ -95,10 +112,13 @@ pub struct AuthorizationCodeGrant {
|
||||
redirect_uri: Option<Url>,
|
||||
}
|
||||
|
||||
#[serde_as]
|
||||
#[derive(Serialize, Deserialize, Debug, PartialEq)]
|
||||
pub struct RefreshTokenGrant {
|
||||
refresh_token: String,
|
||||
scope: Option<StringHashSet<String>>,
|
||||
|
||||
#[serde_as(as = "Option<StringWithSeparator::<SpaceSeparator, String>>")]
|
||||
scope: Option<HashSet<String>>,
|
||||
}
|
||||
|
||||
#[derive(Serialize, Deserialize, Debug, Hash, PartialEq, Eq)]
|
||||
@@ -117,13 +137,20 @@ pub enum AccessTokenRequest {
|
||||
Unsupported,
|
||||
}
|
||||
|
||||
#[serde_as]
|
||||
#[derive(Serialize, Deserialize, Debug, PartialEq)]
|
||||
pub struct AccessTokenResponse {
|
||||
access_token: String,
|
||||
|
||||
token_type: TokenType,
|
||||
expires_in: Option<Seconds>,
|
||||
|
||||
#[serde_as(as = "Option<DurationSeconds>")]
|
||||
expires_in: Option<Duration>,
|
||||
|
||||
refresh_token: Option<String>,
|
||||
scope: Option<StringHashSet<String>>,
|
||||
|
||||
#[serde_as(as = "Option<StringWithSeparator::<SpaceSeparator, String>>")]
|
||||
scope: Option<HashSet<String>>,
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
@@ -140,14 +167,16 @@ mod tests {
|
||||
let expected = json!({
|
||||
"grant_type": "refresh_token",
|
||||
"refresh_token": "abcd",
|
||||
"scope": "openid profile",
|
||||
"scope": "openid",
|
||||
});
|
||||
|
||||
let scope = {
|
||||
let mut s = HashSet::new();
|
||||
// TODO: insert multiple scopes and test it. It's a bit tricky to test since
|
||||
// HashSet have no guarantees regarding the ordering of items, so right
|
||||
// now the output is unstable.
|
||||
s.insert("openid".to_string());
|
||||
s.insert("profile".to_string());
|
||||
Some(s.into())
|
||||
Some(s)
|
||||
};
|
||||
|
||||
let req = AccessTokenRequest::RefreshToken(RefreshTokenGrant {
|
||||
|
||||
@@ -1,149 +0,0 @@
|
||||
// Copyright 2021 The Matrix.org Foundation C.I.C.
|
||||
//
|
||||
// Licensed under the Apache License, Version 2.0 (the "License");
|
||||
// you may not use this file except in compliance with the License.
|
||||
// You may obtain a copy of the License at
|
||||
//
|
||||
// http://www.apache.org/licenses/LICENSE-2.0
|
||||
//
|
||||
// Unless required by applicable law or agreed to in writing, software
|
||||
// distributed under the License is distributed on an "AS IS" BASIS,
|
||||
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
// See the License for the specific language governing permissions and
|
||||
// limitations under the License.
|
||||
|
||||
//! Utilitary types for serde
|
||||
|
||||
use std::{
|
||||
collections::{hash_map::RandomState, HashSet},
|
||||
hash::Hash,
|
||||
time::Duration,
|
||||
};
|
||||
|
||||
use serde::{Deserialize, Serialize};
|
||||
|
||||
/// A `HashSet` that serializes to a space-separated string in alphanumerical
|
||||
/// order
|
||||
#[derive(Debug, PartialEq)]
|
||||
pub struct StringHashSet<T: Eq + Hash>(HashSet<T>);
|
||||
|
||||
impl<T: Eq + Hash> From<HashSet<T>> for StringHashSet<T> {
|
||||
fn from(set: HashSet<T>) -> Self {
|
||||
Self(set)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T: Eq + Hash> From<StringHashSet<T>> for HashSet<T, RandomState> {
|
||||
fn from(set: StringHashSet<T>) -> Self {
|
||||
set.0
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Serialize for StringHashSet<T>
|
||||
where
|
||||
T: ToString + PartialOrd + Eq + Hash,
|
||||
{
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
let mut items: Vec<_> = self.0.iter().map(|i| i.to_string()).collect();
|
||||
items.sort();
|
||||
let s = items.join(" ");
|
||||
serializer.serialize_str(&s)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de, T> Deserialize<'de> for StringHashSet<T>
|
||||
where
|
||||
T: std::str::FromStr + Eq + Hash,
|
||||
<T as std::str::FromStr>::Err: std::fmt::Display,
|
||||
{
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
let s: String = Deserialize::deserialize(deserializer)?;
|
||||
let items: Result<HashSet<T>, _> = s.split_ascii_whitespace().map(T::from_str).collect();
|
||||
items.map(Into::into).map_err(serde::de::Error::custom)
|
||||
}
|
||||
}
|
||||
|
||||
/// A Vec that serializes to a space-separated string
|
||||
pub struct StringVec<T>(Vec<T>);
|
||||
|
||||
impl<T> From<Vec<T>> for StringVec<T> {
|
||||
fn from(set: Vec<T>) -> Self {
|
||||
Self(set)
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> From<StringVec<T>> for Vec<T> {
|
||||
fn from(v: StringVec<T>) -> Self {
|
||||
v.0
|
||||
}
|
||||
}
|
||||
|
||||
impl<T> Serialize for StringVec<T>
|
||||
where
|
||||
T: ToString,
|
||||
{
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
let items: Vec<_> = self.0.iter().map(|i| i.to_string()).collect();
|
||||
let s = items.join(" ");
|
||||
serializer.serialize_str(&s)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de, T> Deserialize<'de> for StringVec<T>
|
||||
where
|
||||
T: std::str::FromStr + std::hash::Hash + Eq,
|
||||
<T as std::str::FromStr>::Err: std::fmt::Display,
|
||||
{
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
let s: String = Deserialize::deserialize(deserializer)?;
|
||||
let items: Result<Vec<T>, _> = s.split_ascii_whitespace().map(T::from_str).collect();
|
||||
items.map(Into::into).map_err(serde::de::Error::custom)
|
||||
}
|
||||
}
|
||||
|
||||
/// A Duration that serializes to seconds
|
||||
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
|
||||
pub struct Seconds(Duration);
|
||||
|
||||
impl From<Duration> for Seconds {
|
||||
fn from(d: Duration) -> Self {
|
||||
Self(d)
|
||||
}
|
||||
}
|
||||
|
||||
impl From<Seconds> for Duration {
|
||||
fn from(val: Seconds) -> Self {
|
||||
val.0
|
||||
}
|
||||
}
|
||||
|
||||
impl Serialize for Seconds {
|
||||
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
|
||||
where
|
||||
S: serde::Serializer,
|
||||
{
|
||||
self.0.as_secs().serialize(serializer)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'de> Deserialize<'de> for Seconds {
|
||||
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
|
||||
where
|
||||
D: serde::Deserializer<'de>,
|
||||
{
|
||||
let secs = u64::deserialize(deserializer)?;
|
||||
Ok(Self(Duration::from_secs(secs)))
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user