Prevent users from using Element FOSS on homeservers that enforce the usage of Element Pro.

This commit is contained in:
Benoit Marty
2025-08-05 16:55:39 +02:00
parent 0b895f631d
commit 2dec34374e
24 changed files with 556 additions and 36 deletions

View File

@@ -36,6 +36,7 @@ dependencies {
implementation(projects.libraries.designsystem)
implementation(projects.libraries.matrixui)
implementation(projects.libraries.uiStrings)
implementation(projects.features.login.api)
implementation(libs.coil)

View File

@@ -33,8 +33,8 @@ import io.element.android.appnav.intent.ResolvedIntent
import io.element.android.appnav.root.RootNavStateFlowFactory
import io.element.android.appnav.root.RootPresenter
import io.element.android.appnav.root.RootView
import io.element.android.features.enterprise.api.EnterpriseService
import io.element.android.features.login.api.LoginParams
import io.element.android.features.login.api.accesscontrol.AccountProviderAccessControl
import io.element.android.features.rageshake.api.bugreport.BugReportEntryPoint
import io.element.android.features.signedout.api.SignedOutEntryPoint
import io.element.android.features.viewfolder.api.ViewFolderEntryPoint
@@ -64,7 +64,7 @@ class RootFlowNode @AssistedInject constructor(
@Assisted val buildContext: BuildContext,
@Assisted plugins: List<Plugin>,
private val authenticationService: MatrixAuthenticationService,
private val enterpriseService: EnterpriseService,
private val accountProviderAccessControl: AccountProviderAccessControl,
private val navStateFlowFactory: RootNavStateFlowFactory,
private val matrixSessionCache: MatrixSessionCache,
private val presenter: RootPresenter,
@@ -293,7 +293,7 @@ class RootFlowNode @AssistedInject constructor(
val latestSessionId = authenticationService.getLatestSessionId()
if (latestSessionId == null) {
// No session, open login
if (enterpriseService.isAllowedToConnectToHomeserver(params.accountProvider.ensureProtocol())) {
if (accountProviderAccessControl.isAllowedToConnectToAccountProvider(params.accountProvider.ensureProtocol())) {
switchToNotLoggedInFlow(params)
} else {
Timber.w("Login link ignored, we are not allowed to connect to the homeserver")

View File

@@ -0,0 +1,12 @@
/*
* Copyright 2025 New Vector Ltd.
*
* SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
* Please see LICENSE files in the repository root for full details.
*/
package io.element.android.features.login.api.accesscontrol
interface AccountProviderAccessControl {
suspend fun isAllowedToConnectToAccountProvider(accountProviderUrl: String): Boolean
}

View File

@@ -0,0 +1,61 @@
/*
* Copyright 2025 New Vector Ltd.
*
* SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
* Please see LICENSE files in the repository root for full details.
*/
package io.element.android.features.login.impl.accesscontrol
import com.squareup.anvil.annotations.ContributesBinding
import io.element.android.features.enterprise.api.EnterpriseService
import io.element.android.features.login.api.accesscontrol.AccountProviderAccessControl
import io.element.android.features.login.impl.changeserver.AccountProviderAccessException
import io.element.android.libraries.core.uri.ensureProtocol
import io.element.android.libraries.di.AppScope
import javax.inject.Inject
@ContributesBinding(AppScope::class)
class DefaultAccountProviderAccessControl @Inject constructor(
private val enterpriseService: EnterpriseService,
private val elementWellknownRetriever: ElementWellknownRetriever,
) : AccountProviderAccessControl {
override suspend fun isAllowedToConnectToAccountProvider(accountProviderUrl: String) = try {
assertIsAllowedToConnectToAccountProvider(
title = accountProviderUrl,
accountProviderUrl = accountProviderUrl,
)
true
} catch (_: AccountProviderAccessException) {
false
}
@Throws(AccountProviderAccessException::class)
suspend fun assertIsAllowedToConnectToAccountProvider(
title: String,
accountProviderUrl: String,
) {
if (enterpriseService.isEnterpriseBuild.not()) {
// Ensure that Element Pro is not required for this account provider
val wellKnown = elementWellknownRetriever.retrieve(
accountProviderUrl = accountProviderUrl.ensureProtocol(),
)
if (wellKnown?.enforceElementPro == true) {
throw AccountProviderAccessException.NeedElementProException(
unauthorisedAccountProviderTitle = title,
applicationId = ELEMENT_PRO_APPLICATION_ID,
)
}
}
if (enterpriseService.isAllowedToConnectToHomeserver(accountProviderUrl).not()) {
throw AccountProviderAccessException.UnauthorizedAccountProviderException(
unauthorisedAccountProviderTitle = title,
authorisedAccountProviderTitles = enterpriseService.defaultHomeserverList(),
)
}
}
companion object {
const val ELEMENT_PRO_APPLICATION_ID = "io.element.enterprise"
}
}

View File

@@ -0,0 +1,42 @@
/*
* Copyright 2025 New Vector Ltd.
*
* SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
* Please see LICENSE files in the repository root for full details.
*/
package io.element.android.features.login.impl.accesscontrol
import com.squareup.anvil.annotations.ContributesBinding
import io.element.android.features.login.impl.resolver.network.ElementWellKnown
import io.element.android.features.login.impl.resolver.network.WellknownAPI
import io.element.android.libraries.di.AppScope
import io.element.android.libraries.network.RetrofitFactory
import timber.log.Timber
import javax.inject.Inject
interface ElementWellknownRetriever {
suspend fun retrieve(accountProviderUrl: String): ElementWellKnown?
}
@ContributesBinding(AppScope::class)
class DefaultElementWellknownRetriever @Inject constructor(
private val retrofitFactory: RetrofitFactory,
) : ElementWellknownRetriever {
override suspend fun retrieve(accountProviderUrl: String): ElementWellKnown? {
val wellknownApi = try {
retrofitFactory.create(accountProviderUrl)
.create(WellknownAPI::class.java)
} catch (e: Exception) {
// If the base URL is not valid, we cannot retrieve the well-known data
Timber.e(e, "Failed to create Retrofit instance for $accountProviderUrl")
return null
}
return try {
wellknownApi.getElementWellKnown()
} catch (e: Exception) {
Timber.e(e, "Failed to retrieve Element well-known data for $accountProviderUrl")
null
}
}
}

View File

@@ -12,7 +12,7 @@ import androidx.compose.runtime.MutableState
import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.rememberCoroutineScope
import io.element.android.features.enterprise.api.EnterpriseService
import io.element.android.features.login.impl.accesscontrol.DefaultAccountProviderAccessControl
import io.element.android.features.login.impl.accountprovider.AccountProvider
import io.element.android.features.login.impl.accountprovider.AccountProviderDataSource
import io.element.android.features.login.impl.error.ChangeServerError
@@ -27,7 +27,7 @@ import javax.inject.Inject
class ChangeServerPresenter @Inject constructor(
private val authenticationService: MatrixAuthenticationService,
private val accountProviderDataSource: AccountProviderDataSource,
private val enterpriseService: EnterpriseService,
private val defaultAccountProviderAccessControl: DefaultAccountProviderAccessControl,
) : Presenter<ChangeServerState> {
@Composable
override fun present(): ChangeServerState {
@@ -55,12 +55,10 @@ class ChangeServerPresenter @Inject constructor(
changeServerAction: MutableState<AsyncData<Unit>>,
) = launch {
suspend {
if (enterpriseService.isAllowedToConnectToHomeserver(data.url).not()) {
throw UnauthorizedAccountProviderException(
unauthorisedAccountProviderTitle = data.title,
authorisedAccountProviderTitles = enterpriseService.defaultHomeserverList(),
)
}
defaultAccountProviderAccessControl.assertIsAllowedToConnectToAccountProvider(
title = data.title,
accountProviderUrl = data.url,
)
authenticationService.setHomeserver(data.url).map {
authenticationService.getHomeserverDetails().value!!
// Valid, remember user choice

View File

@@ -26,6 +26,14 @@ open class ChangeServerStateProvider : PreviewParameterProvider<ChangeServerStat
)
)
),
aChangeServerState(
changeServerAction = AsyncData.Failure(
ChangeServerError.NeedElementPro(
unauthorisedAccountProviderTitle = "example.com",
applicationId = "applicationId",
),
)
),
)
}

View File

@@ -12,13 +12,16 @@ import androidx.compose.runtime.LaunchedEffect
import androidx.compose.runtime.getValue
import androidx.compose.runtime.rememberUpdatedState
import androidx.compose.ui.Modifier
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.res.stringResource
import androidx.compose.ui.tooling.preview.PreviewParameter
import io.element.android.features.login.impl.R
import io.element.android.features.login.impl.dialogs.SlidingSyncNotSupportedDialog
import io.element.android.features.login.impl.error.ChangeServerError
import io.element.android.libraries.androidutils.system.openGooglePlay
import io.element.android.libraries.architecture.AsyncData
import io.element.android.libraries.designsystem.components.ProgressDialog
import io.element.android.libraries.designsystem.components.dialogs.ConfirmationDialog
import io.element.android.libraries.designsystem.components.dialogs.ErrorDialog
import io.element.android.libraries.designsystem.preview.ElementPreview
import io.element.android.libraries.designsystem.preview.PreviewsDayNight
@@ -31,6 +34,7 @@ fun ChangeServerView(
onSuccess: () -> Unit,
modifier: Modifier = Modifier,
) {
val context = LocalContext.current
val eventSink = state.eventSink
when (state.changeServerAction) {
is AsyncData.Failure -> {
@@ -56,6 +60,24 @@ fun ChangeServerView(
}
)
}
is ChangeServerError.NeedElementPro -> {
ConfirmationDialog(
modifier = modifier,
title = stringResource(R.string.screen_change_server_error_element_pro_required_title),
content = stringResource(
R.string.screen_change_server_error_element_pro_required_message,
error.unauthorisedAccountProviderTitle,
),
submitText = stringResource(R.string.screen_change_server_error_element_pro_required_action_android),
onSubmitClick = {
context.openGooglePlay(error.applicationId)
eventSink.invoke(ChangeServerEvents.ClearError)
},
onDismiss = {
eventSink.invoke(ChangeServerEvents.ClearError)
},
)
}
is ChangeServerError.UnauthorizedAccountProvider -> {
ErrorDialog(
modifier = modifier,

View File

@@ -7,7 +7,14 @@
package io.element.android.features.login.impl.changeserver
class UnauthorizedAccountProviderException(
val unauthorisedAccountProviderTitle: String,
val authorisedAccountProviderTitles: List<String>,
) : Exception()
sealed class AccountProviderAccessException : Exception() {
data class NeedElementProException(
val unauthorisedAccountProviderTitle: String,
val applicationId: String,
) : AccountProviderAccessException()
data class UnauthorizedAccountProviderException(
val unauthorisedAccountProviderTitle: String,
val authorisedAccountProviderTitles: List<String>,
) : AccountProviderAccessException()
}

View File

@@ -12,11 +12,11 @@ import androidx.compose.runtime.Composable
import androidx.compose.runtime.ReadOnlyComposable
import androidx.compose.ui.res.stringResource
import io.element.android.features.login.impl.R
import io.element.android.features.login.impl.changeserver.UnauthorizedAccountProviderException
import io.element.android.features.login.impl.changeserver.AccountProviderAccessException
import io.element.android.libraries.matrix.api.auth.AuthenticationException
import io.element.android.libraries.ui.strings.CommonStrings
sealed class ChangeServerError : Throwable() {
sealed class ChangeServerError : Exception() {
data class Error(
@StringRes val messageId: Int? = null,
val messageStr: String? = null,
@@ -26,6 +26,11 @@ sealed class ChangeServerError : Throwable() {
fun message(): String = messageStr ?: stringResource(messageId ?: CommonStrings.error_unknown)
}
data class NeedElementPro(
val unauthorisedAccountProviderTitle: String,
val applicationId: String,
) : ChangeServerError()
data class UnauthorizedAccountProvider(
val unauthorisedAccountProviderTitle: String,
val authorisedAccountProviderTitles: List<String>,
@@ -37,7 +42,11 @@ sealed class ChangeServerError : Throwable() {
fun from(error: Throwable): ChangeServerError = when (error) {
is AuthenticationException.SlidingSyncVersion -> SlidingSyncAlert
is AuthenticationException.Oidc -> Error(messageStr = error.message)
is UnauthorizedAccountProviderException -> UnauthorizedAccountProvider(
is AccountProviderAccessException.NeedElementProException -> NeedElementPro(
unauthorisedAccountProviderTitle = error.unauthorisedAccountProviderTitle,
applicationId = error.applicationId,
)
is AccountProviderAccessException.UnauthorizedAccountProviderException -> UnauthorizedAccountProvider(
unauthorisedAccountProviderTitle = error.unauthorisedAccountProviderTitle,
authorisedAccountProviderTitles = error.authorisedAccountProviderTitles,
)

View File

@@ -8,12 +8,15 @@
package io.element.android.features.login.impl.login
import androidx.compose.runtime.Composable
import androidx.compose.ui.platform.LocalContext
import androidx.compose.ui.res.stringResource
import io.element.android.features.login.impl.R
import io.element.android.features.login.impl.dialogs.SlidingSyncNotSupportedDialog
import io.element.android.features.login.impl.error.ChangeServerError
import io.element.android.features.login.impl.screens.createaccount.AccountCreationNotSupported
import io.element.android.libraries.androidutils.system.openGooglePlay
import io.element.android.libraries.architecture.AsyncData
import io.element.android.libraries.designsystem.components.dialogs.ConfirmationDialog
import io.element.android.libraries.designsystem.components.dialogs.ErrorDialog
import io.element.android.libraries.designsystem.theme.LocalBuildMeta
import io.element.android.libraries.matrix.api.auth.OidcDetails
@@ -28,6 +31,7 @@ fun LoginModeView(
onNeedLoginPassword: () -> Unit,
onCreateAccountContinue: (url: String) -> Unit
) {
val context = LocalContext.current
when (loginMode) {
is AsyncData.Failure -> {
when (val error = loginMode.error) {
@@ -48,6 +52,21 @@ fun LoginModeView(
onDismiss = onClearError,
)
}
is ChangeServerError.NeedElementPro -> {
ConfirmationDialog(
title = stringResource(R.string.screen_change_server_error_element_pro_required_title),
content = stringResource(
R.string.screen_change_server_error_element_pro_required_message,
error.unauthorisedAccountProviderTitle,
),
submitText = stringResource(R.string.screen_change_server_error_element_pro_required_action_android),
onSubmitClick = {
context.openGooglePlay(error.applicationId)
onClearError()
},
onDismiss = onClearError,
)
}
is ChangeServerError.UnauthorizedAccountProvider -> {
ErrorDialog(
content = stringResource(

View File

@@ -23,4 +23,7 @@ import kotlinx.serialization.Serializable
data class ElementWellKnown(
@SerialName("registration_helper_url")
val registrationHelperUrl: String? = null,
@SerialName("enforce_element_pro")
val enforceElementPro: Boolean? = null,
)

View File

@@ -21,6 +21,7 @@ import dagger.assisted.AssistedInject
import io.element.android.appconfig.OnBoardingConfig
import io.element.android.features.enterprise.api.EnterpriseService
import io.element.android.features.enterprise.api.canConnectToAnyHomeserver
import io.element.android.features.login.impl.accesscontrol.DefaultAccountProviderAccessControl
import io.element.android.features.login.impl.login.LoginHelper
import io.element.android.features.rageshake.api.RageshakeFeatureAvailability
import io.element.android.libraries.architecture.Presenter
@@ -34,6 +35,7 @@ class OnBoardingPresenter @AssistedInject constructor(
private val buildMeta: BuildMeta,
private val featureFlagService: FeatureFlagService,
private val enterpriseService: EnterpriseService,
private val defaultAccountProviderAccessControl: DefaultAccountProviderAccessControl,
private val rageshakeFeatureAvailability: RageshakeFeatureAvailability,
private val loginHelper: LoginHelper,
) : Presenter<OnBoardingState> {
@@ -63,7 +65,12 @@ class OnBoardingPresenter @AssistedInject constructor(
val linkAccountProvider by produceState<String?>(initialValue = null) {
// Account provider from the link, if allowed by the enterprise service
value = params.accountProvider?.takeIf {
enterpriseService.isAllowedToConnectToHomeserver(it)
try {
defaultAccountProviderAccessControl.assertIsAllowedToConnectToAccountProvider(it, it)
true
} catch (_: Exception) {
false
}
}
}
val defaultAccountProvider = remember(linkAccountProvider) {

View File

@@ -15,8 +15,7 @@ import androidx.compose.runtime.mutableStateOf
import androidx.compose.runtime.remember
import androidx.compose.runtime.rememberCoroutineScope
import androidx.compose.runtime.setValue
import io.element.android.features.enterprise.api.EnterpriseService
import io.element.android.features.login.impl.changeserver.UnauthorizedAccountProviderException
import io.element.android.features.login.impl.accesscontrol.DefaultAccountProviderAccessControl
import io.element.android.features.login.impl.qrcode.QrCodeLoginManager
import io.element.android.libraries.architecture.AsyncAction
import io.element.android.libraries.architecture.Presenter
@@ -38,7 +37,7 @@ class QrCodeScanPresenter @Inject constructor(
private val qrCodeLoginDataFactory: MatrixQrCodeLoginDataFactory,
private val qrCodeLoginManager: QrCodeLoginManager,
private val coroutineDispatchers: CoroutineDispatchers,
private val enterpriseService: EnterpriseService,
private val defaultAccountProviderAccessControl: DefaultAccountProviderAccessControl,
) : Presenter<QrCodeScanState> {
private var isScanning by mutableStateOf(true)
@@ -97,10 +96,10 @@ class QrCodeScanPresenter @Inject constructor(
Timber.e(it, "Error parsing QR code data")
}.getOrThrow()
val serverName = data.serverName()
if (serverName != null && enterpriseService.isAllowedToConnectToHomeserver(serverName).not()) {
throw UnauthorizedAccountProviderException(
unauthorisedAccountProviderTitle = serverName,
authorisedAccountProviderTitles = enterpriseService.defaultHomeserverList(),
if (serverName != null) {
defaultAccountProviderAccessControl.assertIsAllowedToConnectToAccountProvider(
title = serverName,
accountProviderUrl = serverName,
)
}
data

View File

@@ -8,7 +8,7 @@
package io.element.android.features.login.impl.screens.qrcode.scan
import androidx.compose.ui.tooling.preview.PreviewParameterProvider
import io.element.android.features.login.impl.changeserver.UnauthorizedAccountProviderException
import io.element.android.features.login.impl.changeserver.AccountProviderAccessException
import io.element.android.libraries.architecture.AsyncAction
import io.element.android.libraries.matrix.api.auth.qrlogin.MatrixQrCodeLoginData
import io.element.android.libraries.matrix.api.auth.qrlogin.QrLoginException
@@ -23,12 +23,21 @@ open class QrCodeScanStateProvider : PreviewParameterProvider<QrCodeScanState> {
aQrCodeScanState(
isScanning = false,
authenticationAction = AsyncAction.Failure(
UnauthorizedAccountProviderException(
AccountProviderAccessException.UnauthorizedAccountProviderException(
unauthorisedAccountProviderTitle = "example.com",
authorisedAccountProviderTitles = listOf("element.io", "element.org"),
)
)
),
aQrCodeScanState(
isScanning = false,
authenticationAction = AsyncAction.Failure(
AccountProviderAccessException.NeedElementProException(
unauthorisedAccountProviderTitle = "example.com",
applicationId = "applicationId"
)
)
),
// Add other state here
)
}

View File

@@ -35,7 +35,7 @@ import androidx.compose.ui.unit.dp
import io.element.android.compound.theme.ElementTheme
import io.element.android.compound.tokens.generated.CompoundIcons
import io.element.android.features.login.impl.R
import io.element.android.features.login.impl.changeserver.UnauthorizedAccountProviderException
import io.element.android.features.login.impl.changeserver.AccountProviderAccessException
import io.element.android.libraries.architecture.AsyncAction
import io.element.android.libraries.designsystem.atomic.pages.FlowStepPage
import io.element.android.libraries.designsystem.components.BigIcon
@@ -145,7 +145,10 @@ private fun ColumnScope.Buttons(
Spacer(modifier = Modifier.width(4.dp))
Text(
text = when (error) {
is UnauthorizedAccountProviderException -> {
is AccountProviderAccessException.NeedElementProException -> {
stringResource(R.string.screen_change_server_error_element_pro_required_title)
}
is AccountProviderAccessException.UnauthorizedAccountProviderException -> {
stringResource(
id = R.string.screen_change_server_error_unauthorized_homeserver_title,
error.unauthorisedAccountProviderTitle,
@@ -163,7 +166,13 @@ private fun ColumnScope.Buttons(
}
Text(
text = when (error) {
is UnauthorizedAccountProviderException -> {
is AccountProviderAccessException.NeedElementProException -> {
stringResource(
R.string.screen_change_server_error_element_pro_required_message,
error.unauthorisedAccountProviderTitle,
)
}
is AccountProviderAccessException.UnauthorizedAccountProviderException -> {
stringResource(
id = R.string.screen_change_server_error_unauthorized_homeserver_content,
error.authorisedAccountProviderTitles.joinToString(),

View File

@@ -13,6 +13,7 @@
<string name="screen_change_account_provider_other">"Other"</string>
<string name="screen_change_account_provider_subtitle">"Use a different account provider, such as your own private server or a work account."</string>
<string name="screen_change_account_provider_title">"Change account provider"</string>
<string name="screen_change_server_error_element_pro_required_action_android">"Google Play"</string>
<string name="screen_change_server_error_element_pro_required_message">"The Element Pro app is required on %1$s. Please download it from the store."</string>
<string name="screen_change_server_error_element_pro_required_title">"Element Pro required"</string>
<string name="screen_change_server_error_invalid_homeserver">"We couldn\'t reach this homeserver. Please check that you have entered the homeserver URL correctly. If the URL is correct, contact your homeserver administrator for further help."</string>

View File

@@ -0,0 +1,214 @@
/*
* Copyright 2025 New Vector Ltd.
*
* SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
* Please see LICENSE files in the repository root for full details.
*/
package io.element.android.features.login.impl.accesscontrol
import com.google.common.truth.Truth.assertThat
import io.element.android.features.enterprise.test.FakeEnterpriseService
import io.element.android.features.login.impl.changeserver.AccountProviderAccessException
import io.element.android.features.login.impl.resolver.network.ElementWellKnown
import io.element.android.libraries.matrix.test.AN_ACCOUNT_PROVIDER
import io.element.android.libraries.matrix.test.AN_ACCOUNT_PROVIDER_2
import io.element.android.libraries.matrix.test.AN_ACCOUNT_PROVIDER_URL
import kotlinx.coroutines.test.runTest
import org.junit.Assert.assertThrows
import org.junit.Test
class DefaultAccountProviderAccessControlTest {
@Test
fun `foss build should not allow using account provider that enforce enterprise build`() {
val accessControl = createDefaultAccountProviderAccessControl(
isEnterpriseBuild = false,
isAllowedToConnectToHomeserver = true,
elementWellKnown = ElementWellKnown(
enforceElementPro = true,
),
)
accessControl.expectNeedElementProException()
}
@Test
fun `foss build should not allow using account provider that enforce enterprise build taking precedence over authorization`() {
val accessControl = createDefaultAccountProviderAccessControl(
isEnterpriseBuild = false,
// false here.
isAllowedToConnectToHomeserver = false,
elementWellKnown = ElementWellKnown(
enforceElementPro = true,
),
)
accessControl.expectNeedElementProException()
}
@Test
fun `foss build should allow using account provider that does not enforce enterprise build`() = runTest {
val accessControl = createDefaultAccountProviderAccessControl(
isEnterpriseBuild = false,
isAllowedToConnectToHomeserver = true,
elementWellKnown = ElementWellKnown(
enforceElementPro = false,
),
)
accessControl.expectAllowed()
}
@Test
fun `foss build should allow using account provider twith missing key in wellknown`() = runTest {
val accessControl = createDefaultAccountProviderAccessControl(
isEnterpriseBuild = false,
isAllowedToConnectToHomeserver = true,
elementWellKnown = ElementWellKnown(
enforceElementPro = null,
),
)
accessControl.expectAllowed()
}
@Test
fun `foss build should allow using account provider twith missing wellknown`() = runTest {
val accessControl = createDefaultAccountProviderAccessControl(
isEnterpriseBuild = false,
isAllowedToConnectToHomeserver = true,
elementWellKnown = null,
)
accessControl.expectAllowed()
}
@Test
fun `foss build should not allow using account provider that do not enforce enterprise build but is not allowed`() {
val accessControl = createDefaultAccountProviderAccessControl(
isEnterpriseBuild = false,
isAllowedToConnectToHomeserver = false,
allowedAccountProviders = listOf(AN_ACCOUNT_PROVIDER_2),
elementWellKnown = ElementWellKnown(
enforceElementPro = false,
),
)
accessControl.expectUnauthorizedAccountProviderException()
}
@Test
fun `enterprise build should allow using account provider that enforce enterprise build`() = runTest {
val accessControl = createDefaultAccountProviderAccessControl(
isEnterpriseBuild = true,
isAllowedToConnectToHomeserver = true,
elementWellKnown = ElementWellKnown(
enforceElementPro = true,
),
)
accessControl.expectAllowed()
}
@Test
fun `enterprise build should allow using account provider that do not enforce enterprise build`() = runTest {
val accessControl = createDefaultAccountProviderAccessControl(
isEnterpriseBuild = true,
isAllowedToConnectToHomeserver = true,
elementWellKnown = ElementWellKnown(
enforceElementPro = false,
),
)
accessControl.expectAllowed()
}
@Test
fun `enterprise build should not allow using account provider that enforce enterprise build but is not allowed`() = runTest {
val accessControl = createDefaultAccountProviderAccessControl(
isEnterpriseBuild = true,
isAllowedToConnectToHomeserver = false,
allowedAccountProviders = listOf(AN_ACCOUNT_PROVIDER_2),
elementWellKnown = ElementWellKnown(
enforceElementPro = true,
),
)
accessControl.expectUnauthorizedAccountProviderException()
}
@Test
fun `enterprise build should not allow using account provider that do not enforce enterprise build but is not allowed`() = runTest {
val accessControl = createDefaultAccountProviderAccessControl(
isEnterpriseBuild = true,
isAllowedToConnectToHomeserver = false,
allowedAccountProviders = listOf(AN_ACCOUNT_PROVIDER_2),
elementWellKnown = ElementWellKnown(
enforceElementPro = false,
),
)
accessControl.expectUnauthorizedAccountProviderException()
}
private fun createDefaultAccountProviderAccessControl(
isEnterpriseBuild: Boolean = false,
isAllowedToConnectToHomeserver: Boolean = false,
allowedAccountProviders: List<String> = emptyList(),
elementWellKnown: ElementWellKnown? = null,
) = DefaultAccountProviderAccessControl(
enterpriseService = FakeEnterpriseService(
isEnterpriseBuild = isEnterpriseBuild,
isAllowedToConnectToHomeserverResult = { isAllowedToConnectToHomeserver },
defaultHomeserverListResult = { allowedAccountProviders },
),
elementWellknownRetriever = FakeElementWellknownRetriever(
retrieveResult = { elementWellKnown }
),
)
private fun DefaultAccountProviderAccessControl.expectNeedElementProException() {
val exception = assertThrows(AccountProviderAccessException.NeedElementProException::class.java) {
runTest {
assertIsAllowedToConnectToAccountProvider(
title = AN_ACCOUNT_PROVIDER,
accountProviderUrl = AN_ACCOUNT_PROVIDER_URL,
)
}
}
assertThat(exception.unauthorisedAccountProviderTitle).isEqualTo(AN_ACCOUNT_PROVIDER)
assertThat(exception.applicationId).isEqualTo("io.element.enterprise")
runTest {
assertThat(
isAllowedToConnectToAccountProvider(
accountProviderUrl = AN_ACCOUNT_PROVIDER_URL,
)
).isFalse()
}
}
private fun DefaultAccountProviderAccessControl.expectUnauthorizedAccountProviderException() {
val exception = assertThrows(AccountProviderAccessException.UnauthorizedAccountProviderException::class.java) {
runTest {
assertIsAllowedToConnectToAccountProvider(
title = AN_ACCOUNT_PROVIDER,
accountProviderUrl = AN_ACCOUNT_PROVIDER_URL,
)
}
}
assertThat(exception.unauthorisedAccountProviderTitle).isEqualTo(AN_ACCOUNT_PROVIDER)
assertThat(exception.authorisedAccountProviderTitles).containsExactly(AN_ACCOUNT_PROVIDER_2)
runTest {
assertThat(
isAllowedToConnectToAccountProvider(
accountProviderUrl = AN_ACCOUNT_PROVIDER_URL,
)
).isFalse()
}
}
private suspend fun DefaultAccountProviderAccessControl.expectAllowed() {
// If no exception is thrown, the test passes
assertIsAllowedToConnectToAccountProvider(
title = AN_ACCOUNT_PROVIDER,
accountProviderUrl = AN_ACCOUNT_PROVIDER_URL,
)
runTest {
assertThat(
isAllowedToConnectToAccountProvider(
accountProviderUrl = AN_ACCOUNT_PROVIDER_URL,
)
).isTrue()
}
}
}

View File

@@ -0,0 +1,19 @@
/*
* Copyright 2025 New Vector Ltd.
*
* SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial
* Please see LICENSE files in the repository root for full details.
*/
package io.element.android.features.login.impl.accesscontrol
import io.element.android.features.login.impl.resolver.network.ElementWellKnown
import io.element.android.tests.testutils.simulateLongTask
class FakeElementWellknownRetriever(
private val retrieveResult: (String) -> ElementWellKnown? = { null },
) : ElementWellknownRetriever {
override suspend fun retrieve(accountProviderUrl: String): ElementWellKnown? = simulateLongTask {
retrieveResult(accountProviderUrl)
}
}

View File

@@ -10,10 +10,15 @@ package io.element.android.features.login.impl.changeserver
import com.google.common.truth.Truth.assertThat
import io.element.android.features.enterprise.api.EnterpriseService
import io.element.android.features.enterprise.test.FakeEnterpriseService
import io.element.android.features.login.impl.accesscontrol.DefaultAccountProviderAccessControl
import io.element.android.features.login.impl.accesscontrol.ElementWellknownRetriever
import io.element.android.features.login.impl.accesscontrol.FakeElementWellknownRetriever
import io.element.android.features.login.impl.accountprovider.AccountProvider
import io.element.android.features.login.impl.accountprovider.AccountProviderDataSource
import io.element.android.features.login.impl.error.ChangeServerError
import io.element.android.features.login.impl.resolver.network.ElementWellKnown
import io.element.android.libraries.architecture.AsyncData
import io.element.android.libraries.core.uri.ensureProtocol
import io.element.android.libraries.matrix.test.A_HOMESERVER
import io.element.android.libraries.matrix.test.A_HOMESERVER_URL
import io.element.android.libraries.matrix.test.auth.FakeMatrixAuthenticationService
@@ -106,13 +111,48 @@ class ChangeServerPresenterTest {
}
}
@Test
fun `present - change server element pro required error`() = runTest {
val retrieveResult = lambdaRecorder<String, ElementWellKnown> {
ElementWellKnown(
enforceElementPro = true,
)
}
createPresenter(
elementWellknownRetriever = FakeElementWellknownRetriever(
retrieveResult = retrieveResult,
),
).test {
val initialState = awaitItem()
assertThat(initialState.changeServerAction).isEqualTo(AsyncData.Uninitialized)
val anAccountProvider = AccountProvider(url = A_HOMESERVER_URL)
initialState.eventSink.invoke(ChangeServerEvents.ChangeServer(anAccountProvider))
val loadingState = awaitItem()
assertThat(loadingState.changeServerAction).isInstanceOf(AsyncData.Loading::class.java)
val failureState = awaitItem()
assertThat(
(failureState.changeServerAction.errorOrNull() as ChangeServerError.NeedElementPro).unauthorisedAccountProviderTitle
).isEqualTo(anAccountProvider.title)
assertThat(
(failureState.changeServerAction.errorOrNull() as ChangeServerError.NeedElementPro).applicationId
).isEqualTo("io.element.enterprise")
retrieveResult.assertions()
.isCalledOnce()
.with(value(A_HOMESERVER_URL.ensureProtocol()))
}
}
private fun createPresenter(
authenticationService: FakeMatrixAuthenticationService = FakeMatrixAuthenticationService(),
accountProviderDataSource: AccountProviderDataSource = AccountProviderDataSource(FakeEnterpriseService()),
enterpriseService: EnterpriseService = FakeEnterpriseService(),
elementWellknownRetriever: ElementWellknownRetriever = FakeElementWellknownRetriever(),
) = ChangeServerPresenter(
authenticationService = authenticationService,
accountProviderDataSource = accountProviderDataSource,
enterpriseService = enterpriseService,
defaultAccountProviderAccessControl = DefaultAccountProviderAccessControl(
enterpriseService = enterpriseService,
elementWellknownRetriever = elementWellknownRetriever,
),
)
}

View File

@@ -12,6 +12,9 @@ import io.element.android.appconfig.OnBoardingConfig
import io.element.android.features.enterprise.api.EnterpriseService
import io.element.android.features.enterprise.test.FakeEnterpriseService
import io.element.android.features.login.impl.DefaultLoginUserStory
import io.element.android.features.login.impl.accesscontrol.DefaultAccountProviderAccessControl
import io.element.android.features.login.impl.accesscontrol.ElementWellknownRetriever
import io.element.android.features.login.impl.accesscontrol.FakeElementWellknownRetriever
import io.element.android.features.login.impl.login.LoginHelper
import io.element.android.features.login.impl.web.FakeWebClientUrlForAuthenticationRetriever
import io.element.android.features.login.impl.web.WebClientUrlForAuthenticationRetriever
@@ -235,6 +238,7 @@ private fun createPresenter(
buildMeta: BuildMeta = aBuildMeta(),
featureFlagService: FeatureFlagService = FakeFeatureFlagService(),
enterpriseService: EnterpriseService = FakeEnterpriseService(),
elementWellknownRetriever: ElementWellknownRetriever = FakeElementWellknownRetriever(),
rageshakeFeatureAvailability: () -> Boolean = { true },
loginHelper: LoginHelper = createLoginHelper(),
) = OnBoardingPresenter(
@@ -242,6 +246,10 @@ private fun createPresenter(
buildMeta = buildMeta,
featureFlagService = featureFlagService,
enterpriseService = enterpriseService,
defaultAccountProviderAccessControl = DefaultAccountProviderAccessControl(
enterpriseService = enterpriseService,
elementWellknownRetriever = elementWellknownRetriever,
),
rageshakeFeatureAvailability = rageshakeFeatureAvailability,
loginHelper = loginHelper,
)

View File

@@ -13,7 +13,10 @@ import app.cash.turbine.test
import com.google.common.truth.Truth.assertThat
import io.element.android.features.enterprise.api.EnterpriseService
import io.element.android.features.enterprise.test.FakeEnterpriseService
import io.element.android.features.login.impl.changeserver.UnauthorizedAccountProviderException
import io.element.android.features.login.impl.accesscontrol.DefaultAccountProviderAccessControl
import io.element.android.features.login.impl.accesscontrol.ElementWellknownRetriever
import io.element.android.features.login.impl.accesscontrol.FakeElementWellknownRetriever
import io.element.android.features.login.impl.changeserver.AccountProviderAccessException
import io.element.android.features.login.impl.qrcode.FakeQrCodeLoginManager
import io.element.android.libraries.core.coroutine.CoroutineDispatchers
import io.element.android.libraries.matrix.api.auth.qrlogin.QrCodeLoginStep
@@ -91,9 +94,15 @@ class QrCodeScanPresenterTest {
assertThat(awaitItem().isScanning).isFalse()
assertThat(awaitItem().authenticationAction.isLoading()).isTrue()
awaitItem().also { state ->
assertThat((state.authenticationAction.errorOrNull() as UnauthorizedAccountProviderException).unauthorisedAccountProviderTitle)
assertThat(
(state.authenticationAction
.errorOrNull() as AccountProviderAccessException.UnauthorizedAccountProviderException).unauthorisedAccountProviderTitle
)
.isEqualTo("example.com")
assertThat((state.authenticationAction.errorOrNull() as UnauthorizedAccountProviderException).authorisedAccountProviderTitles)
assertThat(
(state.authenticationAction
.errorOrNull() as AccountProviderAccessException.UnauthorizedAccountProviderException).authorisedAccountProviderTitles
)
.containsExactly("element.io")
}
}
@@ -153,10 +162,14 @@ class QrCodeScanPresenterTest {
coroutineDispatchers: CoroutineDispatchers = testCoroutineDispatchers(),
qrCodeLoginManager: FakeQrCodeLoginManager = FakeQrCodeLoginManager(),
enterpriseService: EnterpriseService = FakeEnterpriseService(),
elementWellknownRetriever: ElementWellknownRetriever = FakeElementWellknownRetriever(),
) = QrCodeScanPresenter(
qrCodeLoginDataFactory = qrCodeLoginDataFactory,
qrCodeLoginManager = qrCodeLoginManager,
coroutineDispatchers = coroutineDispatchers,
enterpriseService = enterpriseService,
defaultAccountProviderAccessControl = DefaultAccountProviderAccessControl(
enterpriseService = enterpriseService,
elementWellknownRetriever = elementWellknownRetriever,
),
)
}

View File

@@ -165,6 +165,7 @@ fun Context.startSharePlainTextIntent(
fun Context.openUrlInExternalApp(
url: String,
errorMessage: String = getString(R.string.error_no_compatible_app_found),
throwInCaseOfError: Boolean = false,
) {
val intent = Intent(Intent.ACTION_VIEW, url.toUri())
if (this !is Activity) {
@@ -173,10 +174,27 @@ fun Context.openUrlInExternalApp(
try {
startActivity(intent)
} catch (activityNotFoundException: ActivityNotFoundException) {
if (throwInCaseOfError) throw activityNotFoundException
toast(errorMessage)
}
}
/**
* Open Google Play on the provided application Id.
*/
fun Context.openGooglePlay(
appId: String,
) {
try {
openUrlInExternalApp(
url = "market://details?id=$appId",
throwInCaseOfError = true,
)
} catch (_: ActivityNotFoundException) {
openUrlInExternalApp("https://play.google.com/store/apps/details?id=$appId")
}
}
// Not in KTX anymore
fun Context.toast(resId: Int) {
Toast.makeText(this, resId, Toast.LENGTH_SHORT).show()

View File

@@ -69,6 +69,7 @@ const val A_REDACTION_REASON = "A redaction reason"
const val A_HOMESERVER_URL = "matrix.org"
const val A_HOMESERVER_URL_2 = "matrix-client.org"
const val AN_ACCOUNT_PROVIDER_URL = "https://account.provider.org"
const val AN_ACCOUNT_PROVIDER = "matrix.org"
const val AN_ACCOUNT_PROVIDER_2 = "element.io"
const val AN_ACCOUNT_PROVIDER_3 = "other.io"