diff --git a/crates/handlers/src/compat/login.rs b/crates/handlers/src/compat/login.rs index ff74bf8ef..75834f9b3 100644 --- a/crates/handlers/src/compat/login.rs +++ b/crates/handlers/src/compat/login.rs @@ -386,10 +386,13 @@ async fn user_password_login( username: String, password: String, ) -> Result<(CompatSession, User), RouteError> { + // Try getting the localpart out of the MXID + let username = homeserver.localpart(&username).unwrap_or(&username); + // Find the user let user = repo .user() - .find_by_username(&username) + .find_by_username(username) .await? .filter(mas_data_model::User::is_valid) .ok_or(RouteError::UserNotFound)?; @@ -539,23 +542,25 @@ mod tests { assert_eq!(body["errcode"], "M_UNRECOGNIZED"); } - /// Test that a user can login with a password using the Matrix - /// compatibility API. - #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] - async fn test_user_password_login(pool: PgPool) { - setup(); - let state = TestState::from_pool(pool).await.unwrap(); - - // Let's provision a user and add a password to it. This part is hard to test - // with just HTTP requests, so we'll use the repository directly. + async fn user_with_password(state: &TestState, username: &str, password: &str) { + let mut rng = state.rng(); let mut repo = state.repository().await.unwrap(); let user = repo .user() - .add(&mut state.rng(), &state.clock, "alice".to_owned()) + .add(&mut rng, &state.clock, username.to_owned()) + .await + .unwrap(); + let (version, hash) = state + .password_manager + .hash(&mut rng, Zeroizing::new(password.as_bytes().to_vec())) .await .unwrap(); + repo.user_password() + .add(&mut rng, &state.clock, &user, version, hash, None) + .await + .unwrap(); let mxid = state.homeserver_connection.mxid(&user.username); state .homeserver_connection @@ -563,28 +568,17 @@ mod tests { .await .unwrap(); - let (version, hashed_password) = state - .password_manager - .hash( - &mut state.rng(), - Zeroizing::new("password".to_owned().into_bytes()), - ) - .await - .unwrap(); - - repo.user_password() - .add( - &mut state.rng(), - &state.clock, - &user, - version, - hashed_password, - None, - ) - .await - .unwrap(); - repo.save().await.unwrap(); + } + + /// Test that a user can login with a password using the Matrix + /// compatibility API. + #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] + async fn test_user_password_login(pool: PgPool) { + setup(); + let state = TestState::from_pool(pool).await.unwrap(); + + user_with_password(&state, "alice", "password").await; // Now let's try to login with the password, without asking for a refresh token. let request = Request::post("/_matrix/client/v3/login").json(serde_json::json!({ @@ -662,6 +656,50 @@ mod tests { assert_eq!(body, old_body); } + /// Test that a user can login with a password using the Matrix + /// compatibility API, using a MXID as identifier + #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] + async fn test_user_password_login_mxid(pool: PgPool) { + setup(); + let state = TestState::from_pool(pool).await.unwrap(); + + user_with_password(&state, "alice", "password").await; + + // Login with a full MXID as identifier + let request = Request::post("/_matrix/client/v3/login").json(serde_json::json!({ + "type": "m.login.password", + "identifier": { + "type": "m.id.user", + "user": "@alice:example.com", + }, + "password": "password", + })); + + let response = state.request(request).await; + response.assert_status(StatusCode::OK); + let body: ResponseBody = response.json(); + assert!(!body.access_token.is_empty()); + assert_eq!(body.device_id.as_ref().unwrap().as_str().len(), 10); + assert_eq!(body.user_id, "@alice:example.com"); + assert_eq!(body.refresh_token, None); + assert_eq!(body.expires_in_ms, None); + + // With a MXID, but with the wrong server name + let request = Request::post("/_matrix/client/v3/login").json(serde_json::json!({ + "type": "m.login.password", + "identifier": { + "type": "m.id.user", + "user": "@alice:something.corp", + }, + "password": "password", + })); + + let response = state.request(request).await; + response.assert_status(StatusCode::FORBIDDEN); + let body: serde_json::Value = response.json(); + assert_eq!(body["errcode"], "M_FORBIDDEN"); + } + /// Test that password logins are rate limited. #[sqlx::test(migrator = "mas_storage_pg::MIGRATOR")] async fn test_password_login_rate_limit(pool: PgPool) {