Ensure that when no Matrix gateway exists, the default one is used.

This commit is contained in:
Benoit Marty
2024-12-30 13:18:37 +01:00
parent b9225b4a6f
commit 86a44b9035
5 changed files with 51 additions and 20 deletions

View File

@@ -12,12 +12,21 @@ import io.element.android.libraries.core.coroutine.CoroutineDispatchers
import io.element.android.libraries.core.data.tryOrNull
import io.element.android.libraries.di.AppScope
import kotlinx.coroutines.withContext
import retrofit2.HttpException
import timber.log.Timber
import java.net.HttpURLConnection
import java.net.URL
import javax.inject.Inject
sealed interface UnifiedPushGatewayResolverResult {
data class Success(val gateway: String) : UnifiedPushGatewayResolverResult
data class Error(val gateway: String) : UnifiedPushGatewayResolverResult
data object NoMatrixGateway : UnifiedPushGatewayResolverResult
data object ErrorInvalidUrl : UnifiedPushGatewayResolverResult
}
interface UnifiedPushGatewayResolver {
suspend fun getGateway(endpoint: String): String
suspend fun getGateway(endpoint: String): UnifiedPushGatewayResolverResult
}
@ContributesBinding(AppScope::class)
@@ -27,7 +36,7 @@ class DefaultUnifiedPushGatewayResolver @Inject constructor(
) : UnifiedPushGatewayResolver {
private val logger = Timber.tag("DefaultUnifiedPushGatewayResolver")
override suspend fun getGateway(endpoint: String): String {
override suspend fun getGateway(endpoint: String): UnifiedPushGatewayResolverResult {
val url = tryOrNull(
onError = { logger.d(it, "Cannot parse endpoint as an URL") }
) {
@@ -35,7 +44,7 @@ class DefaultUnifiedPushGatewayResolver @Inject constructor(
}
return if (url == null) {
logger.d("Using default gateway")
UnifiedPushConfig.DEFAULT_PUSH_GATEWAY_HTTP_URL
UnifiedPushGatewayResolverResult.ErrorInvalidUrl
} else {
val port = if (url.port != -1) ":${url.port}" else ""
val customBase = "${url.protocol}://${url.host}$port"
@@ -47,14 +56,21 @@ class DefaultUnifiedPushGatewayResolver @Inject constructor(
val discoveryResponse = api.discover()
if (discoveryResponse.unifiedpush.gateway == "matrix") {
logger.d("The endpoint seems to be a valid UnifiedPush gateway")
UnifiedPushGatewayResolverResult.Success(customUrl)
} else {
logger.e("The endpoint does not seem to be a valid UnifiedPush gateway")
// The endpoint returned a 200 OK but didn't promote an actual matrix gateway, which means it doesn't have any
logger.w("The endpoint does not seem to be a valid UnifiedPush gateway, using fallback")
UnifiedPushGatewayResolverResult.NoMatrixGateway
}
} catch (throwable: Throwable) {
logger.e(throwable, "Error checking for UnifiedPush endpoint")
if ((throwable as? HttpException)?.code() == HttpURLConnection.HTTP_NOT_FOUND) {
logger.i("Checking for UnifiedPush endpoint yielded 404, using fallback")
UnifiedPushGatewayResolverResult.NoMatrixGateway
} else {
logger.e(throwable, "Error checking for UnifiedPush endpoint")
UnifiedPushGatewayResolverResult.Error(customUrl)
}
}
// Always return the custom url.
customUrl
}
}
}

View File

@@ -64,6 +64,21 @@ class VectorUnifiedPushMessagingReceiver : MessagingReceiver() {
Timber.tag(loggerTag.value).i("onNewEndpoint: $endpoint")
coroutineScope.launch {
val gateway = unifiedPushGatewayResolver.getGateway(endpoint)
.let { gatewayResult ->
when (gatewayResult) {
is UnifiedPushGatewayResolverResult.Error -> {
// Use previous gateway if any, or the provided one
unifiedPushStore.getPushGateway(instance) ?: gatewayResult.gateway
}
UnifiedPushGatewayResolverResult.ErrorInvalidUrl,
UnifiedPushGatewayResolverResult.NoMatrixGateway -> {
UnifiedPushConfig.DEFAULT_PUSH_GATEWAY_HTTP_URL
}
is UnifiedPushGatewayResolverResult.Success -> {
gatewayResult.gateway
}
}
}
unifiedPushStore.storePushGateway(instance, gateway)
val result = newGatewayHandler.handle(endpoint, gateway, instance)
.onFailure {

View File

@@ -43,7 +43,7 @@ class DefaultUnifiedPushGatewayResolverTest {
)
val result = sut.getGateway("https://custom.url")
assertThat(unifiedPushApiFactory.baseUrlParameter).isEqualTo("https://custom.url")
assertThat(result).isEqualTo("https://custom.url/_matrix/push/v1/notify")
assertThat(result).isEqualTo(UnifiedPushGatewayResolverResult.Success("https://custom.url/_matrix/push/v1/notify"))
}
@Test
@@ -56,7 +56,7 @@ class DefaultUnifiedPushGatewayResolverTest {
)
val result = sut.getGateway("https://custom.url:123")
assertThat(unifiedPushApiFactory.baseUrlParameter).isEqualTo("https://custom.url:123")
assertThat(result).isEqualTo("https://custom.url:123/_matrix/push/v1/notify")
assertThat(result).isEqualTo(UnifiedPushGatewayResolverResult.Success("https://custom.url:123/_matrix/push/v1/notify"))
}
@Test
@@ -69,7 +69,7 @@ class DefaultUnifiedPushGatewayResolverTest {
)
val result = sut.getGateway("https://custom.url:123/some/path")
assertThat(unifiedPushApiFactory.baseUrlParameter).isEqualTo("https://custom.url:123")
assertThat(result).isEqualTo("https://custom.url:123/_matrix/push/v1/notify")
assertThat(result).isEqualTo(UnifiedPushGatewayResolverResult.Success("https://custom.url:123/_matrix/push/v1/notify"))
}
@Test
@@ -82,7 +82,7 @@ class DefaultUnifiedPushGatewayResolverTest {
)
val result = sut.getGateway("http://custom.url:123/some/path")
assertThat(unifiedPushApiFactory.baseUrlParameter).isEqualTo("http://custom.url:123")
assertThat(result).isEqualTo("http://custom.url:123/_matrix/push/v1/notify")
assertThat(result).isEqualTo(UnifiedPushGatewayResolverResult.Success("http://custom.url:123/_matrix/push/v1/notify"))
}
@Test
@@ -95,11 +95,11 @@ class DefaultUnifiedPushGatewayResolverTest {
)
val result = sut.getGateway("http://custom.url")
assertThat(unifiedPushApiFactory.baseUrlParameter).isEqualTo("http://custom.url")
assertThat(result).isEqualTo("http://custom.url/_matrix/push/v1/notify")
assertThat(result).isEqualTo(UnifiedPushGatewayResolverResult.Error("http://custom.url/_matrix/push/v1/notify"))
}
@Test
fun `when a custom url is invalid, the default url is returned`() = runTest {
fun `when a custom url is invalid, ErrorInvalidUrl is returned`() = runTest {
val unifiedPushApiFactory = FakeUnifiedPushApiFactory(
discoveryResponse = matrixDiscoveryResponse
)
@@ -108,11 +108,11 @@ class DefaultUnifiedPushGatewayResolverTest {
)
val result = sut.getGateway("invalid")
assertThat(unifiedPushApiFactory.baseUrlParameter).isNull()
assertThat(result).isEqualTo(UnifiedPushConfig.DEFAULT_PUSH_GATEWAY_HTTP_URL)
assertThat(result).isEqualTo(UnifiedPushGatewayResolverResult.ErrorInvalidUrl)
}
@Test
fun `when a custom url provides a invalid matrix gateway, the custom url is still returned`() = runTest {
fun `when a custom url provides a invalid matrix gateway, NoMatrixGateway is returned`() = runTest {
val unifiedPushApiFactory = FakeUnifiedPushApiFactory(
discoveryResponse = invalidDiscoveryResponse
)
@@ -121,7 +121,7 @@ class DefaultUnifiedPushGatewayResolverTest {
)
val result = sut.getGateway("https://custom.url")
assertThat(unifiedPushApiFactory.baseUrlParameter).isEqualTo("https://custom.url")
assertThat(result).isEqualTo("https://custom.url/_matrix/push/v1/notify")
assertThat(result).isEqualTo(UnifiedPushGatewayResolverResult.NoMatrixGateway)
}
private fun TestScope.createDefaultUnifiedPushGatewayResolver(

View File

@@ -10,9 +10,9 @@ package io.element.android.libraries.pushproviders.unifiedpush
import io.element.android.tests.testutils.lambda.lambdaError
class FakeUnifiedPushGatewayResolver(
private val getGatewayResult: (String) -> String = { lambdaError() },
private val getGatewayResult: (String) -> UnifiedPushGatewayResolverResult = { lambdaError() },
) : UnifiedPushGatewayResolver {
override suspend fun getGateway(endpoint: String): String {
override suspend fun getGateway(endpoint: String): UnifiedPushGatewayResolverResult {
return getGatewayResult(endpoint)
}
}

View File

@@ -104,7 +104,7 @@ class VectorUnifiedPushMessagingReceiverTest {
val vectorUnifiedPushMessagingReceiver = createVectorUnifiedPushMessagingReceiver(
unifiedPushStore = unifiedPushStore,
unifiedPushGatewayResolver = FakeUnifiedPushGatewayResolver(
getGatewayResult = { "aGateway" }
getGatewayResult = { UnifiedPushGatewayResolverResult.Success("aGateway") }
),
endpointRegistrationHandler = endpointRegistrationHandler,
unifiedPushNewGatewayHandler = unifiedPushNewGatewayHandler,
@@ -144,7 +144,7 @@ class VectorUnifiedPushMessagingReceiverTest {
val vectorUnifiedPushMessagingReceiver = createVectorUnifiedPushMessagingReceiver(
unifiedPushStore = unifiedPushStore,
unifiedPushGatewayResolver = FakeUnifiedPushGatewayResolver(
getGatewayResult = { "aGateway" }
getGatewayResult = { UnifiedPushGatewayResolverResult.Success("aGateway") }
),
endpointRegistrationHandler = endpointRegistrationHandler,
unifiedPushNewGatewayHandler = unifiedPushNewGatewayHandler,