diff --git a/Cargo.lock b/Cargo.lock index 39d0eb277..8afd37e07 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1677,6 +1677,7 @@ dependencies = [ "parse-display", "serde", "serde_json", + "serde_with", "url", ] diff --git a/matrix-authentication-service/src/handlers/oauth2/discovery.rs b/matrix-authentication-service/src/handlers/oauth2/discovery.rs index c6e07e170..6ade8b7f3 100644 --- a/matrix-authentication-service/src/handlers/oauth2/discovery.rs +++ b/matrix-authentication-service/src/handlers/oauth2/discovery.rs @@ -12,8 +12,6 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::collections::HashSet; - use oauth2_types::oidc::Metadata; use tide::{Body, Request, Response}; @@ -27,10 +25,10 @@ pub async fn get(req: Request) -> tide::Result { token_endpoint: state.token_endpoint(), jwks_uri: state.jwks_uri(), registration_endpoint: None, - scopes_supported: HashSet::default(), - response_types_supported: HashSet::default(), - response_modes_supported: HashSet::default(), - grant_types_supported: HashSet::default(), + scopes_supported: None, + response_types_supported: None, + response_modes_supported: None, + grant_types_supported: None, }; let body = Body::from_json(&m)?; diff --git a/oauth2-types/Cargo.toml b/oauth2-types/Cargo.toml index fc076a33f..ec45e94a5 100644 --- a/oauth2-types/Cargo.toml +++ b/oauth2-types/Cargo.toml @@ -13,3 +13,4 @@ language-tags = { version = "0.3.2", features = ["serde"] } url = { version = "2.2.2", features = ["serde"] } parse-display = "0.5.1" indoc = "1.0.3" +serde_with = "1.9.4" diff --git a/oauth2-types/src/lib.rs b/oauth2-types/src/lib.rs index b9df81007..1f9c41e34 100644 --- a/oauth2-types/src/lib.rs +++ b/oauth2-types/src/lib.rs @@ -19,7 +19,6 @@ pub mod errors; pub mod oidc; pub mod requests; -mod types; #[cfg(test)] mod test_utils; diff --git a/oauth2-types/src/oidc.rs b/oauth2-types/src/oidc.rs index 77449a86c..480b45ba0 100644 --- a/oauth2-types/src/oidc.rs +++ b/oauth2-types/src/oidc.rs @@ -15,11 +15,13 @@ use std::collections::HashSet; use serde::Serialize; +use serde_with::skip_serializing_none; use url::Url; use crate::requests::{GrantType, ResponseMode, ResponseType}; // TODO: https://datatracker.ietf.org/doc/html/rfc8414#section-2 +#[skip_serializing_none] #[derive(Serialize)] pub struct Metadata { /// The authorization server's issuer identifier, which is a URL that uses @@ -27,40 +29,32 @@ pub struct Metadata { pub issuer: Url, /// URL of the authorization server's authorization endpoint. - #[serde(skip_serializing_if = "Option::is_none")] pub authorization_endpoint: Option, /// URL of the authorization server's token endpoint. - #[serde(skip_serializing_if = "Option::is_none")] pub token_endpoint: Option, /// URL of the authorization server's JWK Set document. - #[serde(skip_serializing_if = "Option::is_none")] pub jwks_uri: Option, /// URL of the authorization server's OAuth 2.0 Dynamic Client Registration /// endpoint. - #[serde(skip_serializing_if = "Option::is_none")] pub registration_endpoint: Option, /// JSON array containing a list of the OAuth 2.0 "scope" values that this /// authorization server supports. - #[serde(skip_serializing_if = "HashSet::is_empty")] - pub scopes_supported: HashSet, + pub scopes_supported: Option>, /// JSON array containing a list of the OAuth 2.0 "response_type" values /// that this authorization server supports. - #[serde(skip_serializing_if = "HashSet::is_empty")] - pub response_types_supported: HashSet, + pub response_types_supported: Option>, /// JSON array containing a list of the OAuth 2.0 "response_mode" values /// that this authorization server supports, as specified in "OAuth 2.0 /// Multiple Response Type Encoding Practices". - #[serde(skip_serializing_if = "HashSet::is_empty")] - pub response_modes_supported: HashSet, + pub response_modes_supported: Option>, /// JSON array containing a list of the OAuth 2.0 grant type values that /// this authorization server supports. - #[serde(skip_serializing_if = "HashSet::is_empty")] - pub grant_types_supported: HashSet, + pub grant_types_supported: Option>, } diff --git a/oauth2-types/src/requests.rs b/oauth2-types/src/requests.rs index 783d29653..325fc5e2a 100644 --- a/oauth2-types/src/requests.rs +++ b/oauth2-types/src/requests.rs @@ -12,15 +12,14 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::hash::Hash; +use std::{collections::HashSet, hash::Hash, time::Duration}; use language_tags::LanguageTag; use parse_display::{Display, FromStr}; use serde::{Deserialize, Serialize}; +use serde_with::{rust::StringWithSeparator, serde_as, DurationSeconds, SpaceSeparator}; use url::Url; -use crate::types::{Seconds, StringHashSet, StringVec}; - // ref: https://www.iana.org/assignments/oauth-parameters/oauth-parameters.xhtml #[derive(Hash, PartialEq, Eq, PartialOrd, Ord, Display, FromStr, Serialize)] @@ -60,21 +59,39 @@ pub enum Prompt { SelectAccount, } +#[serde_as] #[derive(Serialize, Deserialize)] pub struct AuthorizationRequest { - response_type: StringHashSet, + #[serde_as(as = "StringWithSeparator::")] + response_type: HashSet, + client_id: String, + redirect_uri: Option, - scope: StringHashSet, + + #[serde_as(as = "StringWithSeparator::")] + scope: HashSet, + state: Option, + response_mode: Option, + nonce: Option, + display: Option, - max_age: Option, - ui_locales: Option>, + + #[serde_as(as = "Option")] + max_age: Option, + + #[serde_as(as = "Option>")] + ui_locales: Option>, + id_token_hint: Option, + login_hint: Option, - acr_values: Option>, + + #[serde_as(as = "Option>")] + acr_values: Option>, } #[derive(Serialize, Deserialize)] @@ -95,10 +112,13 @@ pub struct AuthorizationCodeGrant { redirect_uri: Option, } +#[serde_as] #[derive(Serialize, Deserialize, Debug, PartialEq)] pub struct RefreshTokenGrant { refresh_token: String, - scope: Option>, + + #[serde_as(as = "Option>")] + scope: Option>, } #[derive(Serialize, Deserialize, Debug, Hash, PartialEq, Eq)] @@ -117,13 +137,20 @@ pub enum AccessTokenRequest { Unsupported, } +#[serde_as] #[derive(Serialize, Deserialize, Debug, PartialEq)] pub struct AccessTokenResponse { access_token: String, + token_type: TokenType, - expires_in: Option, + + #[serde_as(as = "Option")] + expires_in: Option, + refresh_token: Option, - scope: Option>, + + #[serde_as(as = "Option>")] + scope: Option>, } #[cfg(test)] @@ -140,14 +167,16 @@ mod tests { let expected = json!({ "grant_type": "refresh_token", "refresh_token": "abcd", - "scope": "openid profile", + "scope": "openid", }); let scope = { let mut s = HashSet::new(); + // TODO: insert multiple scopes and test it. It's a bit tricky to test since + // HashSet have no guarantees regarding the ordering of items, so right + // now the output is unstable. s.insert("openid".to_string()); - s.insert("profile".to_string()); - Some(s.into()) + Some(s) }; let req = AccessTokenRequest::RefreshToken(RefreshTokenGrant { diff --git a/oauth2-types/src/types.rs b/oauth2-types/src/types.rs deleted file mode 100644 index d741e367c..000000000 --- a/oauth2-types/src/types.rs +++ /dev/null @@ -1,149 +0,0 @@ -// Copyright 2021 The Matrix.org Foundation C.I.C. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -//! Utilitary types for serde - -use std::{ - collections::{hash_map::RandomState, HashSet}, - hash::Hash, - time::Duration, -}; - -use serde::{Deserialize, Serialize}; - -/// A `HashSet` that serializes to a space-separated string in alphanumerical -/// order -#[derive(Debug, PartialEq)] -pub struct StringHashSet(HashSet); - -impl From> for StringHashSet { - fn from(set: HashSet) -> Self { - Self(set) - } -} - -impl From> for HashSet { - fn from(set: StringHashSet) -> Self { - set.0 - } -} - -impl Serialize for StringHashSet -where - T: ToString + PartialOrd + Eq + Hash, -{ - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - let mut items: Vec<_> = self.0.iter().map(|i| i.to_string()).collect(); - items.sort(); - let s = items.join(" "); - serializer.serialize_str(&s) - } -} - -impl<'de, T> Deserialize<'de> for StringHashSet -where - T: std::str::FromStr + Eq + Hash, - ::Err: std::fmt::Display, -{ - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - let s: String = Deserialize::deserialize(deserializer)?; - let items: Result, _> = s.split_ascii_whitespace().map(T::from_str).collect(); - items.map(Into::into).map_err(serde::de::Error::custom) - } -} - -/// A Vec that serializes to a space-separated string -pub struct StringVec(Vec); - -impl From> for StringVec { - fn from(set: Vec) -> Self { - Self(set) - } -} - -impl From> for Vec { - fn from(v: StringVec) -> Self { - v.0 - } -} - -impl Serialize for StringVec -where - T: ToString, -{ - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - let items: Vec<_> = self.0.iter().map(|i| i.to_string()).collect(); - let s = items.join(" "); - serializer.serialize_str(&s) - } -} - -impl<'de, T> Deserialize<'de> for StringVec -where - T: std::str::FromStr + std::hash::Hash + Eq, - ::Err: std::fmt::Display, -{ - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - let s: String = Deserialize::deserialize(deserializer)?; - let items: Result, _> = s.split_ascii_whitespace().map(T::from_str).collect(); - items.map(Into::into).map_err(serde::de::Error::custom) - } -} - -/// A Duration that serializes to seconds -#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Default)] -pub struct Seconds(Duration); - -impl From for Seconds { - fn from(d: Duration) -> Self { - Self(d) - } -} - -impl From for Duration { - fn from(val: Seconds) -> Self { - val.0 - } -} - -impl Serialize for Seconds { - fn serialize(&self, serializer: S) -> Result - where - S: serde::Serializer, - { - self.0.as_secs().serialize(serializer) - } -} - -impl<'de> Deserialize<'de> for Seconds { - fn deserialize(deserializer: D) -> Result - where - D: serde::Deserializer<'de>, - { - let secs = u64::deserialize(deserializer)?; - Ok(Self(Duration::from_secs(secs))) - } -}