diff --git a/libraries/pushproviders/unifiedpush/src/main/kotlin/io/element/android/libraries/pushproviders/unifiedpush/UnifiedPushGatewayResolver.kt b/libraries/pushproviders/unifiedpush/src/main/kotlin/io/element/android/libraries/pushproviders/unifiedpush/UnifiedPushGatewayResolver.kt index 26b900eb87..51f9703fe0 100644 --- a/libraries/pushproviders/unifiedpush/src/main/kotlin/io/element/android/libraries/pushproviders/unifiedpush/UnifiedPushGatewayResolver.kt +++ b/libraries/pushproviders/unifiedpush/src/main/kotlin/io/element/android/libraries/pushproviders/unifiedpush/UnifiedPushGatewayResolver.kt @@ -12,12 +12,21 @@ import io.element.android.libraries.core.coroutine.CoroutineDispatchers import io.element.android.libraries.core.data.tryOrNull import io.element.android.libraries.di.AppScope import kotlinx.coroutines.withContext +import retrofit2.HttpException import timber.log.Timber +import java.net.HttpURLConnection import java.net.URL import javax.inject.Inject +sealed interface UnifiedPushGatewayResolverResult { + data class Success(val gateway: String) : UnifiedPushGatewayResolverResult + data class Error(val gateway: String) : UnifiedPushGatewayResolverResult + data object NoMatrixGateway : UnifiedPushGatewayResolverResult + data object ErrorInvalidUrl : UnifiedPushGatewayResolverResult +} + interface UnifiedPushGatewayResolver { - suspend fun getGateway(endpoint: String): String + suspend fun getGateway(endpoint: String): UnifiedPushGatewayResolverResult } @ContributesBinding(AppScope::class) @@ -27,7 +36,7 @@ class DefaultUnifiedPushGatewayResolver @Inject constructor( ) : UnifiedPushGatewayResolver { private val logger = Timber.tag("DefaultUnifiedPushGatewayResolver") - override suspend fun getGateway(endpoint: String): String { + override suspend fun getGateway(endpoint: String): UnifiedPushGatewayResolverResult { val url = tryOrNull( onError = { logger.d(it, "Cannot parse endpoint as an URL") } ) { @@ -35,7 +44,7 @@ class DefaultUnifiedPushGatewayResolver @Inject constructor( } return if (url == null) { logger.d("Using default gateway") - UnifiedPushConfig.DEFAULT_PUSH_GATEWAY_HTTP_URL + UnifiedPushGatewayResolverResult.ErrorInvalidUrl } else { val port = if (url.port != -1) ":${url.port}" else "" val customBase = "${url.protocol}://${url.host}$port" @@ -47,14 +56,21 @@ class DefaultUnifiedPushGatewayResolver @Inject constructor( val discoveryResponse = api.discover() if (discoveryResponse.unifiedpush.gateway == "matrix") { logger.d("The endpoint seems to be a valid UnifiedPush gateway") + UnifiedPushGatewayResolverResult.Success(customUrl) } else { - logger.e("The endpoint does not seem to be a valid UnifiedPush gateway") + // The endpoint returned a 200 OK but didn't promote an actual matrix gateway, which means it doesn't have any + logger.w("The endpoint does not seem to be a valid UnifiedPush gateway, using fallback") + UnifiedPushGatewayResolverResult.NoMatrixGateway } } catch (throwable: Throwable) { - logger.e(throwable, "Error checking for UnifiedPush endpoint") + if ((throwable as? HttpException)?.code() == HttpURLConnection.HTTP_NOT_FOUND) { + logger.i("Checking for UnifiedPush endpoint yielded 404, using fallback") + UnifiedPushGatewayResolverResult.NoMatrixGateway + } else { + logger.e(throwable, "Error checking for UnifiedPush endpoint") + UnifiedPushGatewayResolverResult.Error(customUrl) + } } - // Always return the custom url. - customUrl } } } diff --git a/libraries/pushproviders/unifiedpush/src/main/kotlin/io/element/android/libraries/pushproviders/unifiedpush/VectorUnifiedPushMessagingReceiver.kt b/libraries/pushproviders/unifiedpush/src/main/kotlin/io/element/android/libraries/pushproviders/unifiedpush/VectorUnifiedPushMessagingReceiver.kt index 51e6729fda..eb40a7f1e0 100644 --- a/libraries/pushproviders/unifiedpush/src/main/kotlin/io/element/android/libraries/pushproviders/unifiedpush/VectorUnifiedPushMessagingReceiver.kt +++ b/libraries/pushproviders/unifiedpush/src/main/kotlin/io/element/android/libraries/pushproviders/unifiedpush/VectorUnifiedPushMessagingReceiver.kt @@ -64,6 +64,21 @@ class VectorUnifiedPushMessagingReceiver : MessagingReceiver() { Timber.tag(loggerTag.value).i("onNewEndpoint: $endpoint") coroutineScope.launch { val gateway = unifiedPushGatewayResolver.getGateway(endpoint) + .let { gatewayResult -> + when (gatewayResult) { + is UnifiedPushGatewayResolverResult.Error -> { + // Use previous gateway if any, or the provided one + unifiedPushStore.getPushGateway(instance) ?: gatewayResult.gateway + } + UnifiedPushGatewayResolverResult.ErrorInvalidUrl, + UnifiedPushGatewayResolverResult.NoMatrixGateway -> { + UnifiedPushConfig.DEFAULT_PUSH_GATEWAY_HTTP_URL + } + is UnifiedPushGatewayResolverResult.Success -> { + gatewayResult.gateway + } + } + } unifiedPushStore.storePushGateway(instance, gateway) val result = newGatewayHandler.handle(endpoint, gateway, instance) .onFailure { diff --git a/libraries/pushproviders/unifiedpush/src/test/kotlin/io/element/android/libraries/pushproviders/unifiedpush/DefaultUnifiedPushGatewayResolverTest.kt b/libraries/pushproviders/unifiedpush/src/test/kotlin/io/element/android/libraries/pushproviders/unifiedpush/DefaultUnifiedPushGatewayResolverTest.kt index b854f999f5..afe0bd2489 100644 --- a/libraries/pushproviders/unifiedpush/src/test/kotlin/io/element/android/libraries/pushproviders/unifiedpush/DefaultUnifiedPushGatewayResolverTest.kt +++ b/libraries/pushproviders/unifiedpush/src/test/kotlin/io/element/android/libraries/pushproviders/unifiedpush/DefaultUnifiedPushGatewayResolverTest.kt @@ -43,7 +43,7 @@ class DefaultUnifiedPushGatewayResolverTest { ) val result = sut.getGateway("https://custom.url") assertThat(unifiedPushApiFactory.baseUrlParameter).isEqualTo("https://custom.url") - assertThat(result).isEqualTo("https://custom.url/_matrix/push/v1/notify") + assertThat(result).isEqualTo(UnifiedPushGatewayResolverResult.Success("https://custom.url/_matrix/push/v1/notify")) } @Test @@ -56,7 +56,7 @@ class DefaultUnifiedPushGatewayResolverTest { ) val result = sut.getGateway("https://custom.url:123") assertThat(unifiedPushApiFactory.baseUrlParameter).isEqualTo("https://custom.url:123") - assertThat(result).isEqualTo("https://custom.url:123/_matrix/push/v1/notify") + assertThat(result).isEqualTo(UnifiedPushGatewayResolverResult.Success("https://custom.url:123/_matrix/push/v1/notify")) } @Test @@ -69,7 +69,7 @@ class DefaultUnifiedPushGatewayResolverTest { ) val result = sut.getGateway("https://custom.url:123/some/path") assertThat(unifiedPushApiFactory.baseUrlParameter).isEqualTo("https://custom.url:123") - assertThat(result).isEqualTo("https://custom.url:123/_matrix/push/v1/notify") + assertThat(result).isEqualTo(UnifiedPushGatewayResolverResult.Success("https://custom.url:123/_matrix/push/v1/notify")) } @Test @@ -82,7 +82,7 @@ class DefaultUnifiedPushGatewayResolverTest { ) val result = sut.getGateway("http://custom.url:123/some/path") assertThat(unifiedPushApiFactory.baseUrlParameter).isEqualTo("http://custom.url:123") - assertThat(result).isEqualTo("http://custom.url:123/_matrix/push/v1/notify") + assertThat(result).isEqualTo(UnifiedPushGatewayResolverResult.Success("http://custom.url:123/_matrix/push/v1/notify")) } @Test @@ -95,11 +95,11 @@ class DefaultUnifiedPushGatewayResolverTest { ) val result = sut.getGateway("http://custom.url") assertThat(unifiedPushApiFactory.baseUrlParameter).isEqualTo("http://custom.url") - assertThat(result).isEqualTo("http://custom.url/_matrix/push/v1/notify") + assertThat(result).isEqualTo(UnifiedPushGatewayResolverResult.Error("http://custom.url/_matrix/push/v1/notify")) } @Test - fun `when a custom url is invalid, the default url is returned`() = runTest { + fun `when a custom url is invalid, ErrorInvalidUrl is returned`() = runTest { val unifiedPushApiFactory = FakeUnifiedPushApiFactory( discoveryResponse = matrixDiscoveryResponse ) @@ -108,11 +108,11 @@ class DefaultUnifiedPushGatewayResolverTest { ) val result = sut.getGateway("invalid") assertThat(unifiedPushApiFactory.baseUrlParameter).isNull() - assertThat(result).isEqualTo(UnifiedPushConfig.DEFAULT_PUSH_GATEWAY_HTTP_URL) + assertThat(result).isEqualTo(UnifiedPushGatewayResolverResult.ErrorInvalidUrl) } @Test - fun `when a custom url provides a invalid matrix gateway, the custom url is still returned`() = runTest { + fun `when a custom url provides a invalid matrix gateway, NoMatrixGateway is returned`() = runTest { val unifiedPushApiFactory = FakeUnifiedPushApiFactory( discoveryResponse = invalidDiscoveryResponse ) @@ -121,7 +121,7 @@ class DefaultUnifiedPushGatewayResolverTest { ) val result = sut.getGateway("https://custom.url") assertThat(unifiedPushApiFactory.baseUrlParameter).isEqualTo("https://custom.url") - assertThat(result).isEqualTo("https://custom.url/_matrix/push/v1/notify") + assertThat(result).isEqualTo(UnifiedPushGatewayResolverResult.NoMatrixGateway) } private fun TestScope.createDefaultUnifiedPushGatewayResolver( diff --git a/libraries/pushproviders/unifiedpush/src/test/kotlin/io/element/android/libraries/pushproviders/unifiedpush/FakeUnifiedPushGatewayResolver.kt b/libraries/pushproviders/unifiedpush/src/test/kotlin/io/element/android/libraries/pushproviders/unifiedpush/FakeUnifiedPushGatewayResolver.kt index 661c636d43..5d0cef8991 100644 --- a/libraries/pushproviders/unifiedpush/src/test/kotlin/io/element/android/libraries/pushproviders/unifiedpush/FakeUnifiedPushGatewayResolver.kt +++ b/libraries/pushproviders/unifiedpush/src/test/kotlin/io/element/android/libraries/pushproviders/unifiedpush/FakeUnifiedPushGatewayResolver.kt @@ -10,9 +10,9 @@ package io.element.android.libraries.pushproviders.unifiedpush import io.element.android.tests.testutils.lambda.lambdaError class FakeUnifiedPushGatewayResolver( - private val getGatewayResult: (String) -> String = { lambdaError() }, + private val getGatewayResult: (String) -> UnifiedPushGatewayResolverResult = { lambdaError() }, ) : UnifiedPushGatewayResolver { - override suspend fun getGateway(endpoint: String): String { + override suspend fun getGateway(endpoint: String): UnifiedPushGatewayResolverResult { return getGatewayResult(endpoint) } } diff --git a/libraries/pushproviders/unifiedpush/src/test/kotlin/io/element/android/libraries/pushproviders/unifiedpush/VectorUnifiedPushMessagingReceiverTest.kt b/libraries/pushproviders/unifiedpush/src/test/kotlin/io/element/android/libraries/pushproviders/unifiedpush/VectorUnifiedPushMessagingReceiverTest.kt index be5f66e0f0..2293f97c3e 100644 --- a/libraries/pushproviders/unifiedpush/src/test/kotlin/io/element/android/libraries/pushproviders/unifiedpush/VectorUnifiedPushMessagingReceiverTest.kt +++ b/libraries/pushproviders/unifiedpush/src/test/kotlin/io/element/android/libraries/pushproviders/unifiedpush/VectorUnifiedPushMessagingReceiverTest.kt @@ -104,7 +104,7 @@ class VectorUnifiedPushMessagingReceiverTest { val vectorUnifiedPushMessagingReceiver = createVectorUnifiedPushMessagingReceiver( unifiedPushStore = unifiedPushStore, unifiedPushGatewayResolver = FakeUnifiedPushGatewayResolver( - getGatewayResult = { "aGateway" } + getGatewayResult = { UnifiedPushGatewayResolverResult.Success("aGateway") } ), endpointRegistrationHandler = endpointRegistrationHandler, unifiedPushNewGatewayHandler = unifiedPushNewGatewayHandler, @@ -144,7 +144,7 @@ class VectorUnifiedPushMessagingReceiverTest { val vectorUnifiedPushMessagingReceiver = createVectorUnifiedPushMessagingReceiver( unifiedPushStore = unifiedPushStore, unifiedPushGatewayResolver = FakeUnifiedPushGatewayResolver( - getGatewayResult = { "aGateway" } + getGatewayResult = { UnifiedPushGatewayResolverResult.Success("aGateway") } ), endpointRegistrationHandler = endpointRegistrationHandler, unifiedPushNewGatewayHandler = unifiedPushNewGatewayHandler,