Simplify push notification flow by using locally stored values for pending pushes (#6258)

* Create `PushRequest` in push history DB: this will be used to store requests for push notifications, either pending or completed ones.

* Rename `WorkManagerRequest` to `WorkManagerRequestBuilder`: make its `build` method return a list of `WorkManagerRequestWrapper`, which can be used to enqueue normal or unique workers.

* Rename `PerformDatabaseVacuumRequestBuilder` and adapt it to the new API.

* Adjust other components using `WorkManagerRequest`.

* Replace `SyncNotificationWorkManagerRequestBuilder` with `SyncPendingNotificationsRequestBuilder` and `FetchNotificationsWorker` with `FetchPendingNotificationsWorker`: this new pair of request builder and worker allow enqueuing requests for a session id and, once the worker runs, retrieve all the pending request data and use it to fetch the associated events. This simplifies quite a bit how this data had to be passed or grouped, since it's no longer necessary to do so

* Add new methods to `PushHistoryService` to modify the `PushDatabase`:

- insertOrUpdatePushRequest
- insertOrUpdatePushRequests
- getPendingPushRequests
- removeOldPushRequests

* Make `PushHandler` just handle incoming pushes: those will be inserted into the pending push request table in DB, then handled by the new worker. Once the process finished, a new `NotificationResultProcessor` will handle the results and what needs to be done with them (call ringing, displaying notifications, etc.)

* Add `requestType` optional parameter to `WorkManagerScheduler.cancel` so we can decide to only cancel some kinds of requests.

* Add migration to remove existing work manager requests for fetching notifications, since the previous worker class no longer exists.
This commit is contained in:
Jorge Martin Espinosa
2026-03-03 16:14:36 +01:00
committed by GitHub
parent a8c66381f2
commit 721add707c
48 changed files with 1715 additions and 1892 deletions

View File

@@ -23,6 +23,7 @@ import io.element.android.libraries.matrix.api.encryption.RecoveryState
import io.element.android.libraries.matrix.test.AN_EXCEPTION
import io.element.android.libraries.matrix.test.FakeMatrixClient
import io.element.android.libraries.matrix.test.encryption.FakeEncryptionService
import io.element.android.libraries.workmanager.api.WorkManagerRequestType
import io.element.android.libraries.workmanager.test.FakeWorkManagerScheduler
import io.element.android.tests.testutils.WarmUpRule
import io.element.android.tests.testutils.lambda.lambdaRecorder
@@ -149,7 +150,7 @@ class LogoutPresenterTest {
@Test
fun `present - logout then confirm`() = runTest {
val cancelWorkManagerJobsLambda = lambdaRecorder<SessionId, Unit> {}
val cancelWorkManagerJobsLambda = lambdaRecorder<SessionId, WorkManagerRequestType?, Unit> { _, _ -> }
val workManagerScheduler = FakeWorkManagerScheduler(cancelLambda = cancelWorkManagerJobsLambda)
val presenter = createLogoutPresenter(workManagerScheduler = workManagerScheduler)
moleculeFlow(RecompositionMode.Immediate) {
@@ -238,7 +239,7 @@ class LogoutPresenterTest {
internal fun createLogoutPresenter(
matrixClient: MatrixClient = FakeMatrixClient(),
encryptionService: EncryptionService = FakeEncryptionService(),
workManagerScheduler: FakeWorkManagerScheduler = FakeWorkManagerScheduler(cancelLambda = {}),
workManagerScheduler: FakeWorkManagerScheduler = FakeWorkManagerScheduler(cancelLambda = { _, _ -> }),
): LogoutPresenter = LogoutPresenter(
matrixClient = matrixClient,
encryptionService = encryptionService,

View File

@@ -31,6 +31,7 @@ dependencies {
implementation(projects.libraries.matrix.api)
implementation(projects.libraries.sessionStorage.api)
implementation(projects.libraries.uiStrings)
implementation(projects.libraries.workmanager.api)
testCommonDependencies(libs)
testImplementation(projects.libraries.matrix.test)

View File

@@ -0,0 +1,39 @@
/*
* Copyright (c) 2026 Element Creations Ltd.
*
* SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial.
* Please see LICENSE files in the repository root for full details.
*/
package io.element.android.features.migration.impl.migrations
import dev.zacsweers.metro.AppScope
import dev.zacsweers.metro.ContributesIntoSet
import io.element.android.libraries.matrix.api.core.SessionId
import io.element.android.libraries.sessionstorage.api.SessionStore
import io.element.android.libraries.workmanager.api.WorkManagerRequestType
import io.element.android.libraries.workmanager.api.WorkManagerScheduler
/**
* Remove existing fetch notifications work manager requests since their format has changed.
*/
@ContributesIntoSet(AppScope::class)
class AppMigration10(
private val sessionStore: SessionStore,
private val workManagerScheduler: WorkManagerScheduler,
) : AppMigration {
override val order: Int = 10
override suspend fun migrate(isFreshInstall: Boolean) {
if (isFreshInstall) return
val sessions = sessionStore.getAllSessions()
for (session in sessions) {
workManagerScheduler.cancel(
sessionId = SessionId(session.userId),
requestType = WorkManagerRequestType.NOTIFICATION_SYNC
)
}
}
}

View File

@@ -83,7 +83,7 @@ import io.element.android.libraries.matrix.impl.util.SessionPathsProvider
import io.element.android.libraries.matrix.impl.util.cancelAndDestroy
import io.element.android.libraries.matrix.impl.util.mxCallbackFlow
import io.element.android.libraries.matrix.impl.verification.RustSessionVerificationService
import io.element.android.libraries.matrix.impl.workmanager.PerformDatabaseVacuumWorkManagerRequest
import io.element.android.libraries.matrix.impl.workmanager.PerformDatabaseVacuumRequestBuilder
import io.element.android.libraries.sessionstorage.api.SessionStore
import io.element.android.libraries.workmanager.api.WorkManagerRequestType
import io.element.android.libraries.workmanager.api.WorkManagerScheduler
@@ -832,8 +832,8 @@ class RustMatrixClient(
if (workManagerScheduler.hasPendingWork(sessionId, WorkManagerRequestType.DB_VACUUM)) return
Timber.i("Scheduling periodic database vacuuming for session $sessionId")
val request = PerformDatabaseVacuumWorkManagerRequest(sessionId)
workManagerScheduler.submit(request)
val request = PerformDatabaseVacuumRequestBuilder(sessionId)
sessionCoroutineScope.launch { workManagerScheduler.submit(request) }
}
}

View File

@@ -10,18 +10,18 @@ package io.element.android.libraries.matrix.impl.workmanager
import androidx.work.Constraints
import androidx.work.Data
import androidx.work.PeriodicWorkRequest
import androidx.work.WorkRequest
import io.element.android.libraries.matrix.api.core.SessionId
import io.element.android.libraries.matrix.impl.workmanager.VacuumDatabaseWorker.Companion.SESSION_ID_PARAM
import io.element.android.libraries.workmanager.api.WorkManagerRequest
import io.element.android.libraries.workmanager.api.WorkManagerRequestBuilder
import io.element.android.libraries.workmanager.api.WorkManagerRequestType
import io.element.android.libraries.workmanager.api.WorkManagerRequestWrapper
import io.element.android.libraries.workmanager.api.workManagerTag
import java.util.concurrent.TimeUnit
class PerformDatabaseVacuumWorkManagerRequest(
class PerformDatabaseVacuumRequestBuilder(
private val sessionId: SessionId,
) : WorkManagerRequest {
override fun build(): Result<List<WorkRequest>> {
) : WorkManagerRequestBuilder {
override suspend fun build(): Result<List<WorkManagerRequestWrapper>> {
val data = Data.Builder().putString(SESSION_ID_PARAM, sessionId.value).build()
val workRequest = PeriodicWorkRequest.Builder(
workerClass = VacuumDatabaseWorker::class,
@@ -41,6 +41,6 @@ class PerformDatabaseVacuumWorkManagerRequest(
)
.build()
return Result.success(listOf(workRequest))
return Result.success(listOf(WorkManagerRequestWrapper(workRequest)))
}
}

View File

@@ -19,7 +19,7 @@ import io.element.android.libraries.network.useragent.SimpleUserAgentProvider
import io.element.android.libraries.sessionstorage.api.SessionStore
import io.element.android.libraries.sessionstorage.test.InMemorySessionStore
import io.element.android.libraries.sessionstorage.test.aSessionData
import io.element.android.libraries.workmanager.api.WorkManagerRequest
import io.element.android.libraries.workmanager.api.WorkManagerRequestBuilder
import io.element.android.libraries.workmanager.test.FakeWorkManagerScheduler
import io.element.android.services.analytics.test.FakeAnalyticsService
import io.element.android.services.toolbox.test.systemclock.FakeSystemClock
@@ -33,7 +33,7 @@ import java.io.File
class RustMatrixClientFactoryTest {
@Test
fun test() = runTest {
val scheduleVacuumLambda = lambdaRecorder<WorkManagerRequest, Unit> {}
val scheduleVacuumLambda = lambdaRecorder<WorkManagerRequestBuilder, Unit> {}
val workManagerScheduler = FakeWorkManagerScheduler(submitLambda = scheduleVacuumLambda)
val sut = createRustMatrixClientFactory(workManagerScheduler = workManagerScheduler)

View File

@@ -1,20 +0,0 @@
/*
* Copyright (c) 2025 Element Creations Ltd.
* Copyright 2025 New Vector Ltd.
*
* SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial.
* Please see LICENSE files in the repository root for full details.
*/
package io.element.android.libraries.push.api.push
import io.element.android.libraries.matrix.api.core.EventId
import io.element.android.libraries.matrix.api.core.RoomId
import io.element.android.libraries.matrix.api.core.SessionId
data class NotificationEventRequest(
val sessionId: SessionId,
val roomId: RoomId,
val eventId: EventId,
val providerInfo: String,
)

View File

@@ -1,13 +0,0 @@
/*
* Copyright (c) 2025 Element Creations Ltd.
* Copyright 2025 New Vector Ltd.
*
* SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial.
* Please see LICENSE files in the repository root for full details.
*/
package io.element.android.libraries.push.api.push
fun interface SyncOnNotifiableEvent {
suspend operator fun invoke(requests: List<NotificationEventRequest>)
}

View File

@@ -97,6 +97,7 @@ sqldelight {
databases {
create("PushDatabase") {
schemaOutputDirectory = File("src/main/sqldelight/databases")
verifyMigrations = true
}
}
}

View File

@@ -14,13 +14,16 @@ import android.os.PowerManager
import androidx.core.content.getSystemService
import dev.zacsweers.metro.AppScope
import dev.zacsweers.metro.ContributesBinding
import io.element.android.libraries.core.extensions.runCatchingExceptions
import io.element.android.libraries.di.annotations.ApplicationContext
import io.element.android.libraries.matrix.api.core.EventId
import io.element.android.libraries.matrix.api.core.RoomId
import io.element.android.libraries.matrix.api.core.SessionId
import io.element.android.libraries.push.impl.PushDatabase
import io.element.android.libraries.push.impl.db.PushHistory
import io.element.android.libraries.push.impl.db.PushRequest
import io.element.android.services.toolbox.api.systemclock.SystemClock
import kotlin.time.Instant
@ContributesBinding(AppScope::class)
class DefaultPushHistoryService(
@@ -31,7 +34,37 @@ class DefaultPushHistoryService(
private val powerManager = context.getSystemService<PowerManager>()
private val packageName = context.packageName
override fun onPushReceived(
override suspend fun insertOrUpdatePushRequest(pushRequest: PushRequest): Result<Unit> {
return runCatchingExceptions { pushDatabase.pushRequestQueries.insertPushRequest(pushRequest).await() }
}
override suspend fun insertOrUpdatePushRequests(pushRequests: List<PushRequest>): Result<Unit> {
return runCatchingExceptions {
pushDatabase.transaction {
for (request in pushRequests) {
pushDatabase.pushRequestQueries.insertPushRequest(request)
}
}
}
}
override suspend fun getPendingPushRequests(sessionId: SessionId, since: Instant?): Result<List<PushRequest>> {
return runCatchingExceptions {
pushDatabase.transactionWithResult {
val sinceTimeMillis = since?.toEpochMilliseconds() ?: 0
pushDatabase.pushRequestQueries.selectAllPendingForSession(sessionId.value, sinceTimeMillis).executeAsList()
}
}
}
override suspend fun removeOldPushRequests(sessionId: SessionId): Result<Unit> {
return runCatchingExceptions {
val keepAmount = 100L
pushDatabase.pushRequestQueries.removeOldest(keepAmount)
}
}
override fun onPushResult(
providerInfo: String,
eventId: EventId?,
roomId: RoomId?,

View File

@@ -11,13 +11,16 @@ package io.element.android.libraries.push.impl.history
import io.element.android.libraries.matrix.api.core.EventId
import io.element.android.libraries.matrix.api.core.RoomId
import io.element.android.libraries.matrix.api.core.SessionId
import io.element.android.libraries.push.impl.db.PushRequest
import io.element.android.libraries.push.impl.push.PushRequestStatus
import kotlin.time.Instant
interface PushHistoryService {
/**
* Create a new push history entry.
* Do not use directly, prefer using the extension functions.
*/
fun onPushReceived(
fun onPushResult(
providerInfo: String,
eventId: EventId?,
roomId: RoomId?,
@@ -26,12 +29,33 @@ interface PushHistoryService {
includeDeviceState: Boolean,
comment: String?,
)
/**
* Adds or replaces an existing [PushRequest] in the local database.
*/
suspend fun insertOrUpdatePushRequest(pushRequest: PushRequest): Result<Unit>
/**
* Replace a list of [PushRequest] in the database.
*/
suspend fun insertOrUpdatePushRequests(pushRequests: List<PushRequest>): Result<Unit>
/**
* Gets [PushRequestStatus.PENDING] push requests from the local database for a [SessionId].
* A [since] param can optionally be provided to only return those received after that date.
*/
suspend fun getPendingPushRequests(sessionId: SessionId, since: Instant?): Result<List<PushRequest>>
/**
* Removes the oldest push requests for a [SessionId].
*/
suspend fun removeOldPushRequests(sessionId: SessionId): Result<Unit>
}
fun PushHistoryService.onInvalidPushReceived(
providerInfo: String,
data: String,
) = onPushReceived(
) = onPushResult(
providerInfo = providerInfo,
eventId = null,
roomId = null,
@@ -46,7 +70,7 @@ fun PushHistoryService.onUnableToRetrieveSession(
eventId: EventId,
roomId: RoomId,
reason: String,
) = onPushReceived(
) = onPushResult(
providerInfo = providerInfo,
eventId = eventId,
roomId = roomId,
@@ -62,7 +86,7 @@ fun PushHistoryService.onUnableToResolveEvent(
roomId: RoomId,
sessionId: SessionId,
reason: String,
) = onPushReceived(
) = onPushResult(
providerInfo = providerInfo,
eventId = eventId,
roomId = roomId,
@@ -78,7 +102,7 @@ fun PushHistoryService.onSuccess(
roomId: RoomId,
sessionId: SessionId,
comment: String?,
) = onPushReceived(
) = onPushResult(
providerInfo = providerInfo,
eventId = eventId,
roomId = roomId,
@@ -95,7 +119,7 @@ fun PushHistoryService.onSuccess(
fun PushHistoryService.onDiagnosticPush(
providerInfo: String,
) = onPushReceived(
) = onPushResult(
providerInfo = providerInfo,
eventId = null,
roomId = null,

View File

@@ -50,8 +50,8 @@ import io.element.android.libraries.matrix.api.timeline.item.event.TextMessageTy
import io.element.android.libraries.matrix.api.timeline.item.event.VideoMessageType
import io.element.android.libraries.matrix.api.timeline.item.event.VoiceMessageType
import io.element.android.libraries.matrix.ui.messages.toPlainText
import io.element.android.libraries.push.api.push.NotificationEventRequest
import io.element.android.libraries.push.impl.R
import io.element.android.libraries.push.impl.db.PushRequest
import io.element.android.libraries.push.impl.notifications.model.InviteNotifiableEvent
import io.element.android.libraries.push.impl.notifications.model.NotifiableMessageEvent
import io.element.android.libraries.push.impl.notifications.model.ResolvedPushEvent
@@ -64,10 +64,10 @@ private val loggerTag = LoggerTag("DefaultNotifiableEventResolver", LoggerTag.No
/**
* Result of resolving a batch of push events.
* The outermost [Result] indicates whether the setup to resolve the events was successful.
* The results for each push notification will be a map of [NotificationEventRequest] to [Result] of [ResolvedPushEvent].
* The results for each push notification will be a map of [PushRequest] to [Result] of [ResolvedPushEvent].
* If the resolution of a specific event fails, the innermost [Result] will contain an exception.
*/
typealias ResolvePushEventsResult = Result<Map<NotificationEventRequest, Result<ResolvedPushEvent>>>
typealias ResolvePushEventsResult = Result<Map<PushRequest, Result<ResolvedPushEvent>>>
/**
* The notifiable event resolver is able to create a NotifiableEvent (view model for notifications) from an sdk Event.
@@ -78,7 +78,7 @@ typealias ResolvePushEventsResult = Result<Map<NotificationEventRequest, Result<
interface NotifiableEventResolver {
suspend fun resolveEvents(
sessionId: SessionId,
notificationEventRequests: List<NotificationEventRequest>
notificationEventRequests: List<PushRequest>
): ResolvePushEventsResult
}
@@ -96,15 +96,15 @@ class DefaultNotifiableEventResolver(
) : NotifiableEventResolver {
override suspend fun resolveEvents(
sessionId: SessionId,
notificationEventRequests: List<NotificationEventRequest>
notificationEventRequests: List<PushRequest>
): ResolvePushEventsResult {
Timber.d("Queueing notifications: $notificationEventRequests")
val client = matrixClientProvider.getOrRestore(sessionId).getOrElse {
return Result.failure(it)
}
val ids = notificationEventRequests.groupBy { it.roomId }
val ids = notificationEventRequests.groupBy { RoomId(it.roomId) }
.mapValues { (_, requests) ->
requests.map { it.eventId }
requests.map { EventId(it.eventId) }
}
// TODO this notificationData is not always valid at the moment, sometimes the Rust SDK can't fetch the matching event
@@ -125,7 +125,7 @@ class DefaultNotifiableEventResolver(
return Result.success(
notificationEventRequests.associate { request ->
val notificationDataResult = notificationDataMap[request.eventId]
val notificationDataResult = notificationDataMap[EventId(request.eventId)]
if (notificationDataResult == null) {
request to Result.failure(NotificationResolverException.UnknownError("No notification data for ${request.roomId} - ${request.eventId}"))
} else {

View File

@@ -1,125 +0,0 @@
/*
* Copyright (c) 2025 Element Creations Ltd.
* Copyright 2025 New Vector Ltd.
*
* SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial.
* Please see LICENSE files in the repository root for full details.
*/
package io.element.android.libraries.push.impl.notifications
import dev.zacsweers.metro.AppScope
import dev.zacsweers.metro.ContributesBinding
import dev.zacsweers.metro.SingleIn
import io.element.android.libraries.di.annotations.AppCoroutineScope
import io.element.android.libraries.featureflag.api.FeatureFlagService
import io.element.android.libraries.featureflag.api.FeatureFlags
import io.element.android.libraries.push.api.push.NotificationEventRequest
import io.element.android.libraries.push.impl.notifications.model.ResolvedPushEvent
import io.element.android.libraries.push.impl.workmanager.SyncNotificationWorkManagerRequest
import io.element.android.libraries.push.impl.workmanager.SyncNotificationsWorkerDataConverter
import io.element.android.libraries.workmanager.api.WorkManagerScheduler
import io.element.android.services.toolbox.api.sdk.BuildVersionSdkIntProvider
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.Job
import kotlinx.coroutines.SupervisorJob
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.delay
import kotlinx.coroutines.flow.MutableSharedFlow
import kotlinx.coroutines.flow.SharedFlow
import kotlinx.coroutines.launch
import timber.log.Timber
import kotlin.time.Duration.Companion.milliseconds
interface NotificationResolverQueue {
val results: SharedFlow<Pair<List<NotificationEventRequest>, Map<NotificationEventRequest, Result<ResolvedPushEvent>>>>
suspend fun enqueue(request: NotificationEventRequest)
}
/**
* This class is responsible for periodically batching notification requests and resolving them in a single call,
* so that we can avoid having to resolve each notification individually in the SDK.
*/
@OptIn(ExperimentalCoroutinesApi::class)
@SingleIn(AppScope::class)
@ContributesBinding(AppScope::class)
class DefaultNotificationResolverQueue(
private val notifiableEventResolver: NotifiableEventResolver,
@AppCoroutineScope
private val appCoroutineScope: CoroutineScope,
private val workManagerScheduler: WorkManagerScheduler,
private val featureFlagService: FeatureFlagService,
private val workerDataConverter: SyncNotificationsWorkerDataConverter,
private val buildVersionSdkIntProvider: BuildVersionSdkIntProvider,
) : NotificationResolverQueue {
companion object {
private const val BATCH_WINDOW_MS = 250L
}
private val requestQueue = Channel<NotificationEventRequest>(capacity = 100)
private var currentProcessingJob: Job? = null
/**
* A flow that emits pairs of a list of notification event requests and a map of the resolved events.
* The map contains the original request as the key and the resolved event as the value.
*/
override val results = MutableSharedFlow<Pair<List<NotificationEventRequest>, Map<NotificationEventRequest, Result<ResolvedPushEvent>>>>()
/**
* Enqueues a notification event request to be resolved.
* The request will be processed in batches, so it may not be resolved immediately.
*
* @param request The notification event request to enqueue.
*/
override suspend fun enqueue(request: NotificationEventRequest) {
// Cancel previous processing job if it exists, acting as a debounce operation
Timber.d("Cancelling job: $currentProcessingJob")
currentProcessingJob?.cancel()
// Enqueue the request and start a delayed processing job
requestQueue.send(request)
currentProcessingJob = processQueue()
Timber.d("Starting processing job for request: $request")
}
private fun processQueue() = appCoroutineScope.launch(SupervisorJob()) {
delay(BATCH_WINDOW_MS.milliseconds)
// If this job is still active (so this is the latest job), we launch a separate one that won't be cancelled when enqueueing new items
// to process the existing queued items.
appCoroutineScope.launch {
val groupedRequestsById = buildList {
while (!requestQueue.isEmpty) {
requestQueue.receiveCatching().getOrNull()?.let(::add)
}
}.groupBy { it.sessionId }
if (featureFlagService.isFeatureEnabled(FeatureFlags.SyncNotificationsWithWorkManager)) {
for ((sessionId, requests) in groupedRequestsById) {
workManagerScheduler.submit(
SyncNotificationWorkManagerRequest(
sessionId = sessionId,
notificationEventRequests = requests,
workerDataConverter = workerDataConverter,
buildVersionSdkIntProvider = buildVersionSdkIntProvider,
)
)
}
} else {
val sessionIds = groupedRequestsById.keys
for (sessionId in sessionIds) {
val requests = groupedRequestsById[sessionId].orEmpty()
Timber.d("Fetching notifications for $sessionId: $requests. Pending requests: ${!requestQueue.isEmpty}")
// Resolving the events in parallel should improve performance since each session id will query a different Client
launch {
// No need for a Mutex since the SDK already has one internally
val notifications = notifiableEventResolver.resolveEvents(sessionId, requests).getOrNull().orEmpty()
results.emit(requests to notifications)
}
}
}
}
}
}

View File

@@ -0,0 +1,238 @@
/*
* Copyright (c) 2026 Element Creations Ltd.
*
* SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial.
* Please see LICENSE files in the repository root for full details.
*/
package io.element.android.libraries.push.impl.notifications
import dev.zacsweers.metro.AppScope
import dev.zacsweers.metro.ContributesBinding
import dev.zacsweers.metro.SingleIn
import io.element.android.features.call.api.CallType
import io.element.android.features.call.api.ElementCallEntryPoint
import io.element.android.libraries.di.annotations.AppCoroutineScope
import io.element.android.libraries.featureflag.api.FeatureFlagService
import io.element.android.libraries.featureflag.api.FeatureFlags
import io.element.android.libraries.matrix.api.core.EventId
import io.element.android.libraries.matrix.api.core.RoomId
import io.element.android.libraries.matrix.api.core.SessionId
import io.element.android.libraries.matrix.api.exception.NotificationResolverException
import io.element.android.libraries.push.impl.db.PushRequest
import io.element.android.libraries.push.impl.history.PushHistoryService
import io.element.android.libraries.push.impl.history.onSuccess
import io.element.android.libraries.push.impl.history.onUnableToResolveEvent
import io.element.android.libraries.push.impl.notifications.channels.NotificationChannels
import io.element.android.libraries.push.impl.notifications.model.FallbackNotifiableEvent
import io.element.android.libraries.push.impl.notifications.model.NotifiableEvent
import io.element.android.libraries.push.impl.notifications.model.NotifiableRingingCallEvent
import io.element.android.libraries.push.impl.notifications.model.ResolvedPushEvent
import io.element.android.libraries.push.impl.push.MutableBatteryOptimizationStore
import io.element.android.libraries.push.impl.push.OnNotifiableEventReceived
import io.element.android.libraries.push.impl.push.OnRedactedEventReceived
import io.element.android.libraries.push.impl.push.SyncOnNotifiableEvent
import io.element.android.libraries.pushstore.api.UserPushStoreFactory
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.Job
import kotlinx.coroutines.flow.MutableSharedFlow
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.flow.launchIn
import kotlinx.coroutines.flow.onEach
import kotlinx.coroutines.launch
import timber.log.Timber
private const val TAG = "NotifResultProcessor"
interface NotificationResultProcessor {
suspend fun emit(results: Map<PushRequest, Result<ResolvedPushEvent>>)
fun start()
fun stop()
}
@ContributesBinding(AppScope::class)
@SingleIn(AppScope::class)
class DefaultNotificationResultProcessor(
private val pushHistoryService: PushHistoryService,
private val batteryOptimizationStore: MutableBatteryOptimizationStore,
private val fallbackNotificationFactory: FallbackNotificationFactory,
private val userPushStoreFactory: UserPushStoreFactory,
private val onRedactedEventReceived: OnRedactedEventReceived,
private val onNotifiableEventReceived: OnNotifiableEventReceived,
private val featureFlagService: FeatureFlagService,
private val syncOnNotifiableEvent: SyncOnNotifiableEvent,
private val elementCallEntryPoint: ElementCallEntryPoint,
private val notificationChannels: NotificationChannels,
@AppCoroutineScope private val coroutineScope: CoroutineScope,
) : NotificationResultProcessor {
private val resultFlow = MutableSharedFlow<Map<PushRequest, Result<ResolvedPushEvent>>>(extraBufferCapacity = Int.MAX_VALUE)
private var processJob: Job? = null
override suspend fun emit(results: Map<PushRequest, Result<ResolvedPushEvent>>) {
resultFlow.emit(results)
}
override fun start() {
if (processJob?.isActive == true) {
Timber.tag(TAG).w("Is already processing, not starting again")
return
}
processJob = resultFlow
.onEach(::processResults)
.launchIn(coroutineScope)
}
override fun stop() {
if (processJob?.isActive != true) {
Timber.tag(TAG).w("Is not processing, not stopping")
return
}
processJob?.cancel()
processJob = null
}
private suspend fun processResults(results: Map<PushRequest, Result<ResolvedPushEvent>>) {
// TODO what happens with items that weren't reported back?
for ((request, result) in results) {
result.fold(
onSuccess = {
if (it is ResolvedPushEvent.Event && it.notifiableEvent is FallbackNotifiableEvent) {
pushHistoryService.onUnableToResolveEvent(
providerInfo = request.providerInfo,
eventId = EventId(request.eventId),
roomId = RoomId(request.roomId),
sessionId = SessionId(request.sessionId),
reason = it.notifiableEvent.cause.orEmpty(),
)
} else {
pushHistoryService.onSuccess(
providerInfo = request.providerInfo,
eventId = EventId(request.eventId),
roomId = RoomId(request.roomId),
sessionId = SessionId(request.sessionId),
comment = "Push handled successfully",
)
}
},
onFailure = { exception ->
if (exception is NotificationResolverException.EventFilteredOut) {
pushHistoryService.onSuccess(
providerInfo = request.providerInfo,
eventId = EventId(request.eventId),
roomId = RoomId(request.roomId),
sessionId = SessionId(request.sessionId),
comment = "Push handled successfully but notification was filtered out",
)
} else if (exception is NotificationResolverException.EventRedacted) {
pushHistoryService.onSuccess(
providerInfo = request.providerInfo,
eventId = EventId(request.eventId),
roomId = RoomId(request.roomId),
sessionId = SessionId(request.sessionId),
comment = "Push handled successfully but event has been redacted",
)
} else {
val reason = when (exception) {
is NotificationResolverException.EventNotFound -> "Event not found"
else -> "Unknown error: ${exception.message}"
}
pushHistoryService.onUnableToResolveEvent(
providerInfo = request.providerInfo,
eventId = EventId(request.eventId),
roomId = RoomId(request.roomId),
sessionId = SessionId(request.sessionId),
reason = "$reason - Showing fallback notification",
)
batteryOptimizationStore.showBatteryOptimizationBanner()
}
}
)
}
val events = mutableListOf<NotifiableEvent>()
val redactions = mutableListOf<ResolvedPushEvent.Redaction>()
@Suppress("LoopWithTooManyJumpStatements")
for ((request, result) in results) {
val event = result.recover { exception ->
// If the event could not be resolved, we create a fallback notification
when (exception) {
is NotificationResolverException.EventFilteredOut -> {
// Do nothing, we don't want to show a notification for filtered out events
null
}
is NotificationResolverException.EventRedacted -> {
// Do nothing, we don't want to show a notification for redacted events
null
}
else -> {
Timber.tag(TAG).e(exception, "Failed to resolve push event")
ResolvedPushEvent.Event(
fallbackNotificationFactory.create(
sessionId = SessionId(request.sessionId),
roomId = RoomId(request.roomId),
eventId = EventId(request.eventId),
cause = exception.message,
)
)
}
}
}.getOrNull() ?: continue
val userPushStore = userPushStoreFactory.getOrCreate(event.sessionId)
val areNotificationsEnabled = userPushStore.getNotificationEnabledForDevice().first()
// If notifications are disabled for this session and device, we don't want to show the notification
// But if it's a ringing call, we want to show it anyway
val isRingingCall = (event as? ResolvedPushEvent.Event)?.notifiableEvent is NotifiableRingingCallEvent
if (!areNotificationsEnabled && !isRingingCall) continue
// We categorise each result into either a NotifiableEvent or a Redaction
when (event) {
is ResolvedPushEvent.Event -> {
events.add(event.notifiableEvent)
}
is ResolvedPushEvent.Redaction -> {
redactions.add(event)
}
}
}
// Process redactions of messages in background to not block operations with higher priority
if (redactions.isNotEmpty()) {
coroutineScope.launch { onRedactedEventReceived.onRedactedEventsReceived(redactions) }
}
// Find and process ringing call notifications separately
val (ringingCallEvents, nonRingingCallEvents) = events.partition { it is NotifiableRingingCallEvent }
for (ringingCallEvent in ringingCallEvents) {
Timber.tag(TAG).d("Ringing call event: $ringingCallEvent")
handleRingingCallEvent(ringingCallEvent as NotifiableRingingCallEvent)
}
// Finally, process other notifications (messages, invites, generic notifications, etc.)
if (nonRingingCallEvents.isNotEmpty()) {
onNotifiableEventReceived.onNotifiableEventsReceived(nonRingingCallEvents)
}
if (!featureFlagService.isFeatureEnabled(FeatureFlags.SyncNotificationsWithWorkManager)) {
syncOnNotifiableEvent(results.keys.toList())
}
}
private suspend fun handleRingingCallEvent(notifiableEvent: NotifiableRingingCallEvent) {
Timber.i("## handleInternal() : Incoming call.")
elementCallEntryPoint.handleIncomingCall(
callType = CallType.RoomCall(notifiableEvent.sessionId, notifiableEvent.roomId),
eventId = notifiableEvent.eventId,
senderId = notifiableEvent.senderId,
roomName = notifiableEvent.roomName,
senderName = notifiableEvent.senderDisambiguatedDisplayName,
avatarUrl = notifiableEvent.roomAvatarUrl,
timestamp = notifiableEvent.timestamp,
expirationTimestamp = notifiableEvent.expirationTimestamp,
notificationChannelId = notificationChannels.getChannelForIncomingCall(ring = true),
textContent = notifiableEvent.description,
)
}
}

View File

@@ -11,42 +11,28 @@ package io.element.android.libraries.push.impl.push
import dev.zacsweers.metro.AppScope
import dev.zacsweers.metro.ContributesBinding
import dev.zacsweers.metro.SingleIn
import io.element.android.features.call.api.CallType
import io.element.android.features.call.api.ElementCallEntryPoint
import io.element.android.libraries.core.log.logger.LoggerTag
import io.element.android.libraries.core.meta.BuildMeta
import io.element.android.libraries.di.annotations.AppCoroutineScope
import io.element.android.libraries.featureflag.api.FeatureFlagService
import io.element.android.libraries.featureflag.api.FeatureFlags
import io.element.android.libraries.matrix.api.exception.NotificationResolverException
import io.element.android.libraries.push.api.push.NotificationEventRequest
import io.element.android.libraries.push.api.push.SyncOnNotifiableEvent
import io.element.android.libraries.push.impl.db.PushRequest
import io.element.android.libraries.push.impl.history.PushHistoryService
import io.element.android.libraries.push.impl.history.onDiagnosticPush
import io.element.android.libraries.push.impl.history.onInvalidPushReceived
import io.element.android.libraries.push.impl.history.onSuccess
import io.element.android.libraries.push.impl.history.onUnableToResolveEvent
import io.element.android.libraries.push.impl.history.onUnableToRetrieveSession
import io.element.android.libraries.push.impl.notifications.FallbackNotificationFactory
import io.element.android.libraries.push.impl.notifications.NotificationResolverQueue
import io.element.android.libraries.push.impl.notifications.channels.NotificationChannels
import io.element.android.libraries.push.impl.notifications.model.FallbackNotifiableEvent
import io.element.android.libraries.push.impl.notifications.model.NotifiableEvent
import io.element.android.libraries.push.impl.notifications.model.NotifiableRingingCallEvent
import io.element.android.libraries.push.impl.notifications.model.ResolvedPushEvent
import io.element.android.libraries.push.impl.notifications.NotificationResultProcessor
import io.element.android.libraries.push.impl.test.DefaultTestPush
import io.element.android.libraries.push.impl.troubleshoot.DiagnosticPushHandler
import io.element.android.libraries.push.impl.workmanager.SyncPendingNotificationsRequestBuilder
import io.element.android.libraries.pushproviders.api.PushData
import io.element.android.libraries.pushproviders.api.PushHandler
import io.element.android.libraries.pushstore.api.UserPushStoreFactory
import io.element.android.libraries.pushstore.api.clientsecret.PushClientSecret
import io.element.android.libraries.workmanager.api.WorkManagerRequestType
import io.element.android.libraries.workmanager.api.WorkManagerScheduler
import io.element.android.services.analytics.api.AnalyticsLongRunningTransaction
import io.element.android.services.analytics.api.AnalyticsService
import kotlinx.coroutines.CoroutineScope
import io.element.android.services.toolbox.api.sdk.BuildVersionSdkIntProvider
import io.element.android.services.toolbox.api.systemclock.SystemClock
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.flow.launchIn
import kotlinx.coroutines.flow.map
import kotlinx.coroutines.launch
import timber.log.Timber
private val loggerTag = LoggerTag("PushHandler", LoggerTag.PushLoggerTag)
@@ -54,173 +40,20 @@ private val loggerTag = LoggerTag("PushHandler", LoggerTag.PushLoggerTag)
@SingleIn(AppScope::class)
@ContributesBinding(AppScope::class)
class DefaultPushHandler(
private val onNotifiableEventReceived: OnNotifiableEventReceived,
private val onRedactedEventReceived: OnRedactedEventReceived,
private val incrementPushDataStore: IncrementPushDataStore,
private val mutableBatteryOptimizationStore: MutableBatteryOptimizationStore,
private val userPushStoreFactory: UserPushStoreFactory,
private val pushClientSecret: PushClientSecret,
private val buildMeta: BuildMeta,
private val diagnosticPushHandler: DiagnosticPushHandler,
private val elementCallEntryPoint: ElementCallEntryPoint,
private val notificationChannels: NotificationChannels,
private val pushHistoryService: PushHistoryService,
private val resolverQueue: NotificationResolverQueue,
@AppCoroutineScope
private val appCoroutineScope: CoroutineScope,
private val fallbackNotificationFactory: FallbackNotificationFactory,
private val syncOnNotifiableEvent: SyncOnNotifiableEvent,
private val featureFlagService: FeatureFlagService,
private val userPushStoreFactory: UserPushStoreFactory,
private val analyticsService: AnalyticsService,
private val systemClock: SystemClock,
private val workManagerScheduler: WorkManagerScheduler,
private val buildVersionSdkIntProvider: BuildVersionSdkIntProvider,
resultProcessor: NotificationResultProcessor,
) : PushHandler {
init {
processPushEventResults()
}
/**
* Process the push notification event results emitted by the [resolverQueue].
*/
private fun processPushEventResults() {
resolverQueue.results
.map { (requests, resolvedEvents) ->
for (request in requests) {
// Log the result of the push notification event
val result = resolvedEvents[request]
if (result == null) {
pushHistoryService.onUnableToResolveEvent(
providerInfo = request.providerInfo,
eventId = request.eventId,
roomId = request.roomId,
sessionId = request.sessionId,
reason = "Push not handled: no result found for request",
)
} else {
result.fold(
onSuccess = {
if (it is ResolvedPushEvent.Event && it.notifiableEvent is FallbackNotifiableEvent) {
pushHistoryService.onUnableToResolveEvent(
providerInfo = request.providerInfo,
eventId = request.eventId,
roomId = request.roomId,
sessionId = request.sessionId,
reason = it.notifiableEvent.cause.orEmpty(),
)
} else {
pushHistoryService.onSuccess(
providerInfo = request.providerInfo,
eventId = request.eventId,
roomId = request.roomId,
sessionId = request.sessionId,
comment = "Push handled successfully",
)
}
},
onFailure = { exception ->
if (exception is NotificationResolverException.EventFilteredOut) {
pushHistoryService.onSuccess(
providerInfo = request.providerInfo,
eventId = request.eventId,
roomId = request.roomId,
sessionId = request.sessionId,
comment = "Push handled successfully but notification was filtered out",
)
} else if (exception is NotificationResolverException.EventRedacted) {
pushHistoryService.onSuccess(
providerInfo = request.providerInfo,
eventId = request.eventId,
roomId = request.roomId,
sessionId = request.sessionId,
comment = "Push handled successfully but event has been redacted",
)
} else {
val reason = when (exception) {
is NotificationResolverException.EventNotFound -> "Event not found"
else -> "Unknown error: ${exception.message}"
}
pushHistoryService.onUnableToResolveEvent(
providerInfo = request.providerInfo,
eventId = request.eventId,
roomId = request.roomId,
sessionId = request.sessionId,
reason = "$reason - Showing fallback notification",
)
mutableBatteryOptimizationStore.showBatteryOptimizationBanner()
}
}
)
}
}
val events = mutableListOf<NotifiableEvent>()
val redactions = mutableListOf<ResolvedPushEvent.Redaction>()
@Suppress("LoopWithTooManyJumpStatements")
for ((request, result) in resolvedEvents) {
val event = result.recover { exception ->
// If the event could not be resolved, we create a fallback notification
when (exception) {
is NotificationResolverException.EventFilteredOut -> {
// Do nothing, we don't want to show a notification for filtered out events
null
}
is NotificationResolverException.EventRedacted -> {
// Do nothing, we don't want to show a notification for redacted events
null
}
else -> {
Timber.tag(loggerTag.value).e(exception, "Failed to resolve push event")
ResolvedPushEvent.Event(
fallbackNotificationFactory.create(
sessionId = request.sessionId,
roomId = request.roomId,
eventId = request.eventId,
cause = exception.message,
)
)
}
}
}.getOrNull() ?: continue
val userPushStore = userPushStoreFactory.getOrCreate(event.sessionId)
val areNotificationsEnabled = userPushStore.getNotificationEnabledForDevice().first()
// If notifications are disabled for this session and device, we don't want to show the notification
// But if it's a ringing call, we want to show it anyway
val isRingingCall = (event as? ResolvedPushEvent.Event)?.notifiableEvent is NotifiableRingingCallEvent
if (!areNotificationsEnabled && !isRingingCall) continue
// We categorise each result into either a NotifiableEvent or a Redaction
when (event) {
is ResolvedPushEvent.Event -> {
events.add(event.notifiableEvent)
}
is ResolvedPushEvent.Redaction -> {
redactions.add(event)
}
}
}
// Process redactions of messages in background to not block operations with higher priority
if (redactions.isNotEmpty()) {
appCoroutineScope.launch { onRedactedEventReceived.onRedactedEventsReceived(redactions) }
}
// Find and process ringing call notifications separately
val (ringingCallEvents, nonRingingCallEvents) = events.partition { it is NotifiableRingingCallEvent }
for (ringingCallEvent in ringingCallEvents) {
Timber.tag(loggerTag.value).d("Ringing call event: $ringingCallEvent")
handleRingingCallEvent(ringingCallEvent as NotifiableRingingCallEvent)
}
// Finally, process other notifications (messages, invites, generic notifications, etc.)
if (nonRingingCallEvents.isNotEmpty()) {
onNotifiableEventReceived.onNotifiableEventsReceived(nonRingingCallEvents)
}
if (!featureFlagService.isFeatureEnabled(FeatureFlags.SyncNotificationsWithWorkManager)) {
syncOnNotifiableEvent(requests)
}
}
.launchIn(appCoroutineScope)
resultProcessor.start()
}
/**
@@ -233,9 +66,7 @@ class DefaultPushHandler(
// Start measuring how long it takes to display a notification from when the push is received
Timber.d("Calculating push-to-notification for event ${pushData.eventId}")
val parent = analyticsService.startLongRunningTransaction(AnalyticsLongRunningTransaction.PushToNotification(pushData.eventId.value))
if (featureFlagService.isFeatureEnabled(FeatureFlags.SyncNotificationsWithWorkManager)) {
analyticsService.startLongRunningTransaction(AnalyticsLongRunningTransaction.PushToWorkManager(pushData.eventId.value), parent)
}
analyticsService.startLongRunningTransaction(AnalyticsLongRunningTransaction.PushToWorkManager(pushData.eventId.value), parent)
Timber.tag(loggerTag.value).d("## handling pushData: ${pushData.roomId}/${pushData.eventId}")
if (buildMeta.lowPrivacyLoggingEnabled) {
@@ -282,34 +113,56 @@ class DefaultPushHandler(
return
}
appCoroutineScope.launch {
val notificationEventRequest = NotificationEventRequest(
sessionId = userId,
roomId = pushData.roomId,
eventId = pushData.eventId,
providerInfo = providerInfo,
val areNotificationsEnabled = userPushStoreFactory.getOrCreate(userId).getNotificationEnabledForDevice().first()
if (!areNotificationsEnabled) {
Timber.w("Push notification received when push notifications are disabled.")
return
}
val pushRequest = PushRequest(
pushDate = systemClock.epochMillis(),
providerInfo = providerInfo,
eventId = pushData.eventId.value,
roomId = pushData.roomId.value,
sessionId = userId.value,
status = PushRequestStatus.PENDING.value,
retries = 0L,
)
Timber.d("Queueing notification: $pushRequest")
pushHistoryService.insertOrUpdatePushRequest(pushRequest)
if (!workManagerScheduler.hasPendingWork(userId, WorkManagerRequestType.NOTIFICATION_SYNC)) {
Timber.d("No pending worker for push notifications found")
workManagerScheduler.submit(
SyncPendingNotificationsRequestBuilder(
sessionId = userId,
buildVersionSdkIntProvider = buildVersionSdkIntProvider,
)
)
Timber.d("Queueing notification: $notificationEventRequest")
resolverQueue.enqueue(notificationEventRequest)
}
} catch (e: Exception) {
Timber.tag(loggerTag.value).e(e, "## handleInternal() failed")
}
}
private suspend fun handleRingingCallEvent(notifiableEvent: NotifiableRingingCallEvent) {
Timber.i("## handleInternal() : Incoming call.")
elementCallEntryPoint.handleIncomingCall(
callType = CallType.RoomCall(notifiableEvent.sessionId, notifiableEvent.roomId),
eventId = notifiableEvent.eventId,
senderId = notifiableEvent.senderId,
roomName = notifiableEvent.roomName,
senderName = notifiableEvent.senderDisambiguatedDisplayName,
avatarUrl = notifiableEvent.roomAvatarUrl,
timestamp = notifiableEvent.timestamp,
expirationTimestamp = notifiableEvent.expirationTimestamp,
notificationChannelId = notificationChannels.getChannelForIncomingCall(ring = true),
textContent = notifiableEvent.description,
)
}
}
/**
* Represents the status of a [PushRequest].
*/
enum class PushRequestStatus(val value: Long) {
/**
* Either it was enqueued, and we never tried to fetch it, or it failed with a recoverable error.
*/
PENDING(0),
/**
* The event for the [PushRequest] was fetched successfully.
*/
SUCCESS(1),
/**
* Fetching the event for the [PushRequest] failed with an unrecoverable error, and it won't be retried.
*/
FAILED(2),
}

View File

@@ -14,8 +14,9 @@ import io.element.android.libraries.core.coroutine.CoroutineDispatchers
import io.element.android.libraries.featureflag.api.FeatureFlagService
import io.element.android.libraries.featureflag.api.FeatureFlags
import io.element.android.libraries.matrix.api.MatrixClientProvider
import io.element.android.libraries.push.api.push.NotificationEventRequest
import io.element.android.libraries.push.api.push.SyncOnNotifiableEvent
import io.element.android.libraries.matrix.api.core.RoomId
import io.element.android.libraries.matrix.api.core.SessionId
import io.element.android.libraries.push.impl.db.PushRequest
import io.element.android.services.appnavstate.api.AppForegroundStateService
import kotlinx.coroutines.delay
import kotlinx.coroutines.withContext
@@ -29,7 +30,7 @@ class DefaultSyncOnNotifiableEvent(
private val appForegroundStateService: AppForegroundStateService,
private val dispatchers: CoroutineDispatchers,
) : SyncOnNotifiableEvent {
override suspend operator fun invoke(requests: List<NotificationEventRequest>) = withContext(dispatchers.io) {
override suspend operator fun invoke(requests: List<PushRequest>) = withContext(dispatchers.io) {
if (!featureFlagService.isFeatureEnabled(FeatureFlags.SyncOnPush)) {
return@withContext
}
@@ -41,8 +42,8 @@ class DefaultSyncOnNotifiableEvent(
Timber.d("Starting opportunistic room list sync | In foreground: ${appForegroundStateService.isInForeground.value}")
for ((sessionId, events) in eventsBySession) {
val client = matrixClientProvider.getOrRestore(sessionId).getOrNull() ?: continue
val roomIds = events.map { it.roomId }.distinct()
val client = matrixClientProvider.getOrRestore(SessionId(sessionId)).getOrNull() ?: continue
val roomIds = events.map { RoomId(it.roomId) }.distinct()
client.roomListService.subscribeToVisibleRooms(roomIds)

View File

@@ -0,0 +1,14 @@
/*
* Copyright (c) 2026 Element Creations Ltd.
*
* SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial.
* Please see LICENSE files in the repository root for full details.
*/
package io.element.android.libraries.push.impl.push
import io.element.android.libraries.push.impl.db.PushRequest
fun interface SyncOnNotifiableEvent {
suspend operator fun invoke(requests: List<PushRequest>)
}

View File

@@ -1,191 +0,0 @@
/*
* Copyright (c) 2025 Element Creations Ltd.
* Copyright 2025 New Vector Ltd.
*
* SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial.
* Please see LICENSE files in the repository root for full details.
*/
package io.element.android.libraries.push.impl.workmanager
import android.content.Context
import androidx.work.CoroutineWorker
import androidx.work.WorkerParameters
import dev.zacsweers.metro.AppScope
import dev.zacsweers.metro.Assisted
import dev.zacsweers.metro.AssistedFactory
import dev.zacsweers.metro.AssistedInject
import dev.zacsweers.metro.ContributesIntoMap
import dev.zacsweers.metro.binding
import io.element.android.features.networkmonitor.api.NetworkMonitor
import io.element.android.features.networkmonitor.api.NetworkStatus
import io.element.android.libraries.core.extensions.runCatchingExceptions
import io.element.android.libraries.di.annotations.ApplicationContext
import io.element.android.libraries.matrix.api.auth.SessionRestorationException
import io.element.android.libraries.matrix.api.core.SessionId
import io.element.android.libraries.push.api.push.NotificationEventRequest
import io.element.android.libraries.push.api.push.SyncOnNotifiableEvent
import io.element.android.libraries.push.impl.notifications.NotifiableEventResolver
import io.element.android.libraries.push.impl.notifications.NotificationResolverQueue
import io.element.android.libraries.workmanager.api.WorkManagerScheduler
import io.element.android.libraries.workmanager.api.di.MetroWorkerFactory
import io.element.android.libraries.workmanager.api.di.WorkerKey
import io.element.android.services.analytics.api.AnalyticsLongRunningTransaction
import io.element.android.services.analytics.api.AnalyticsService
import io.element.android.services.analytics.api.finishLongRunningTransaction
import io.element.android.services.analytics.api.recordTransaction
import io.element.android.services.analyticsproviders.api.AnalyticsTransaction
import io.element.android.services.toolbox.api.sdk.BuildVersionSdkIntProvider
import kotlinx.coroutines.flow.MutableSharedFlow
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.withTimeoutOrNull
import timber.log.Timber
import kotlin.time.Duration.Companion.seconds
@AssistedInject
class FetchNotificationsWorker(
@Assisted params: WorkerParameters,
@ApplicationContext private val context: Context,
private val networkMonitor: NetworkMonitor,
private val eventResolver: NotifiableEventResolver,
private val queue: NotificationResolverQueue,
private val workManagerScheduler: WorkManagerScheduler,
private val syncOnNotifiableEvent: SyncOnNotifiableEvent,
private val workerDataConverter: SyncNotificationsWorkerDataConverter,
private val buildVersionSdkIntProvider: BuildVersionSdkIntProvider,
private val analyticsService: AnalyticsService,
) : CoroutineWorker(context, params) {
override suspend fun doWork(): Result {
Timber.d("FetchNotificationsWorker started")
val requests = workerDataConverter.deserialize(inputData) ?: return Result.failure()
val networkTimeoutSpans = requests.mapNotNull { request ->
val parent = analyticsService.getLongRunningTransaction(AnalyticsLongRunningTransaction.PushToWorkManager(request.eventId.value))
parent?.startChild("Waiting for network connectivity", "await_network")
}
// Wait for network to be available, but not more than 10 seconds
val hasNetwork = withTimeoutOrNull(10.seconds) {
networkMonitor.connectivity.first { it == NetworkStatus.Connected }
} != null
networkTimeoutSpans.finish()
// If there is a problem with the updated network values, report it and retry if needed
if (reportConnectivityError(requests = requests, hasNetwork = hasNetwork, isNetworkBlocked = networkMonitor.isNetworkBlocked())) {
return Result.retry()
}
val pendingAnalyticTransactions = requests.mapNotNull { request ->
analyticsService.finishLongRunningTransaction(AnalyticsLongRunningTransaction.PushToWorkManager(request.eventId.value))
val parent = analyticsService.getLongRunningTransaction(AnalyticsLongRunningTransaction.PushToNotification(request.eventId.value))
val transactionName = "WorkManager to event fetched"
parent?.startChild(transactionName)?.let { request.eventId to it }
}.toMap()
val failedSyncForSessions = mutableMapOf<SessionId, Throwable>()
val groupedRequests = requests.groupBy { it.sessionId }.toMutableMap()
for ((sessionId, notificationRequests) in groupedRequests) {
Timber.d("Processing notification requests for session $sessionId")
eventResolver.resolveEvents(sessionId, notificationRequests)
.fold(
onSuccess = { result ->
for ((_, transaction) in pendingAnalyticTransactions) {
transaction.finish()
}
// Update the resolved results in the queue
(queue.results as MutableSharedFlow).emit(requests to result)
},
onFailure = {
for ((_, transaction) in pendingAnalyticTransactions) {
transaction.attachError(it)
transaction.finish()
}
failedSyncForSessions[sessionId] = it
Timber.e(it, "Failed to resolve notification events for session $sessionId")
}
)
}
// If there were failures for whole sessions, we retry all their requests
if (failedSyncForSessions.isNotEmpty()) {
@Suppress("LoopWithTooManyJumpStatements")
for ((failedSessionId, exception) in failedSyncForSessions) {
if (exception.cause is SessionRestorationException) {
Timber.e(exception, "Session $failedSessionId could not be restored, not retrying notification fetching")
groupedRequests.remove(failedSessionId)
continue
}
val requestsToRetry = groupedRequests[failedSessionId] ?: continue
for (request in requestsToRetry) {
val failedTransaction = pendingAnalyticTransactions[request.eventId]
failedTransaction?.attachError(exception)
failedTransaction?.finish()
val eventId = request.eventId.value
val parent = analyticsService.getLongRunningTransaction(AnalyticsLongRunningTransaction.PushToNotification(eventId))
// Since we're retrying, start a new transaction
analyticsService.startLongRunningTransaction(AnalyticsLongRunningTransaction.PushToWorkManager(eventId), parent)
}
Timber.d("Re-scheduling ${requestsToRetry.size} failed notification requests for session $failedSessionId")
workManagerScheduler.submit(
SyncNotificationWorkManagerRequest(
sessionId = failedSessionId,
notificationEventRequests = requestsToRetry,
workerDataConverter = workerDataConverter,
buildVersionSdkIntProvider = buildVersionSdkIntProvider,
)
)
}
}
Timber.d("Notifications processed successfully")
analyticsService.recordTransaction("Opportunistic sync", "opportunistic_sync") {
performOpportunisticSyncIfNeeded(groupedRequests)
}
return Result.success()
}
private fun reportConnectivityError(requests: List<NotificationEventRequest>, hasNetwork: Boolean, isNetworkBlocked: Boolean): Boolean {
return if (!hasNetwork || isNetworkBlocked) {
for (request in requests) {
val eventId = request.eventId.value
analyticsService.finishLongRunningTransaction(AnalyticsLongRunningTransaction.PushToWorkManager(eventId)) {
it.putExtraData("has_network_connection", hasNetwork.toString())
it.putExtraData("is_network_blocked", isNetworkBlocked.toString())
}
val parent = analyticsService.getLongRunningTransaction(AnalyticsLongRunningTransaction.PushToNotification(eventId))
// Since we're retrying, start a new transaction
analyticsService.startLongRunningTransaction(AnalyticsLongRunningTransaction.PushToWorkManager(eventId), parent)
}
Timber.w("FetchNotificationsWorker will retry. Has network connectivity: $hasNetwork. Is network blocked: $isNetworkBlocked")
true
} else {
false
}
}
private suspend fun performOpportunisticSyncIfNeeded(
groupedRequests: Map<SessionId, List<NotificationEventRequest>>,
) {
for ((sessionId, notificationRequests) in groupedRequests) {
runCatchingExceptions {
syncOnNotifiableEvent(notificationRequests)
}.onFailure {
Timber.e(it, "Failed to sync on notifiable events for session $sessionId")
}
}
}
@ContributesIntoMap(AppScope::class, binding = binding<MetroWorkerFactory.WorkerInstanceFactory<*>>())
@WorkerKey(FetchNotificationsWorker::class)
@AssistedFactory
interface Factory : MetroWorkerFactory.WorkerInstanceFactory<FetchNotificationsWorker>
}
private fun <T : AnalyticsTransaction> Collection<T>.finish() = forEach { it.finish() }

View File

@@ -0,0 +1,239 @@
/*
* Copyright (c) 2025 Element Creations Ltd.
* Copyright 2025 New Vector Ltd.
*
* SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial.
* Please see LICENSE files in the repository root for full details.
*/
package io.element.android.libraries.push.impl.workmanager
import android.content.Context
import androidx.work.CoroutineWorker
import androidx.work.WorkerParameters
import dev.zacsweers.metro.AppScope
import dev.zacsweers.metro.Assisted
import dev.zacsweers.metro.AssistedFactory
import dev.zacsweers.metro.AssistedInject
import dev.zacsweers.metro.ContributesIntoMap
import dev.zacsweers.metro.binding
import io.element.android.features.networkmonitor.api.NetworkMonitor
import io.element.android.features.networkmonitor.api.NetworkStatus
import io.element.android.libraries.core.extensions.runCatchingExceptions
import io.element.android.libraries.di.annotations.ApplicationContext
import io.element.android.libraries.matrix.api.auth.SessionRestorationException
import io.element.android.libraries.matrix.api.core.SessionId
import io.element.android.libraries.matrix.api.exception.ClientException
import io.element.android.libraries.matrix.api.exception.isNetworkError
import io.element.android.libraries.push.impl.db.PushRequest
import io.element.android.libraries.push.impl.history.PushHistoryService
import io.element.android.libraries.push.impl.notifications.NotifiableEventResolver
import io.element.android.libraries.push.impl.notifications.NotificationResultProcessor
import io.element.android.libraries.push.impl.push.PushRequestStatus
import io.element.android.libraries.push.impl.push.SyncOnNotifiableEvent
import io.element.android.libraries.workmanager.api.di.MetroWorkerFactory
import io.element.android.libraries.workmanager.api.di.WorkerKey
import io.element.android.services.analytics.api.AnalyticsLongRunningTransaction
import io.element.android.services.analytics.api.AnalyticsService
import io.element.android.services.analytics.api.finishLongRunningTransaction
import io.element.android.services.analytics.api.recordTransaction
import io.element.android.services.analyticsproviders.api.AnalyticsTransaction
import io.element.android.services.toolbox.api.systemclock.SystemClock
import kotlinx.coroutines.flow.first
import kotlinx.coroutines.withTimeoutOrNull
import timber.log.Timber
import kotlin.time.Duration.Companion.days
import kotlin.time.Duration.Companion.seconds
import kotlin.time.Instant
@AssistedInject
class FetchPendingNotificationsWorker(
@Assisted private val params: WorkerParameters,
@ApplicationContext private val context: Context,
private val pushHistoryService: PushHistoryService,
private val networkMonitor: NetworkMonitor,
private val eventResolver: NotifiableEventResolver,
private val syncOnNotifiableEvent: SyncOnNotifiableEvent,
private val resultProcessor: NotificationResultProcessor,
private val analyticsService: AnalyticsService,
private val systemClock: SystemClock,
) : CoroutineWorker(context, params) {
override suspend fun doWork(): Result {
Timber.d("FetchNotificationsWorker started")
// RunCatching for test in debug mode
val sessionId = runCatchingExceptions {
inputData.getString(SyncPendingNotificationsRequestBuilder.SESSION_ID)?.let(::SessionId)
}.getOrNull() ?: return Result.failure()
// Fetch pending requests in the last 24 hours
val fetchSince = Instant.fromEpochMilliseconds(systemClock.epochMillis()).minus(1.days)
val requests = pushHistoryService.getPendingPushRequests(sessionId, fetchSince).getOrNull() ?: return Result.failure()
pushHistoryService.removeOldPushRequests(sessionId).onFailure {
Timber.e(it, "Could not remove outdated push requests")
}
if (requests.isEmpty()) {
Timber.d("No pending notifications to fetch, returning early")
return Result.success()
}
checkNetworkConnection(requests)?.let { failure -> return failure }
Timber.d("Fetching ${requests.size} push requests")
val pendingAnalyticTransactions = requests.mapNotNull { request ->
analyticsService.finishLongRunningTransaction(AnalyticsLongRunningTransaction.PushToWorkManager(request.eventId))
val parent = analyticsService.getLongRunningTransaction(AnalyticsLongRunningTransaction.PushToNotification(request.eventId))
val transactionName = "WorkManager to event fetched"
parent?.startChild(transactionName)?.let { request.eventId to it }
}.toMap()
Timber.d("Processing notification requests for session $sessionId")
val results = eventResolver.resolveEvents(sessionId, requests)
.fold(
onSuccess = { results ->
for ((_, transaction) in pendingAnalyticTransactions) {
transaction.finish()
}
// Update the resolved results in the queue
resultProcessor.emit(results)
results
},
onFailure = {
// This is a failure at the fetch notification setup, not a failure for a single fetch notification operation
return handleSetupError(sessionId, requests, pendingAnalyticTransactions, it)
}
)
val updatedRequests = mutableListOf<PushRequest>()
for (request in requests) {
val result = results[request] ?: continue
result.fold(
onSuccess = { updatedRequests.add(request.copy(status = PushRequestStatus.SUCCESS.value)) },
onFailure = { exception ->
if (exception is ClientException && exception.isNetworkError()) {
// Reset to pending so we can retry it later
updatedRequests.add(request.copy(status = PushRequestStatus.PENDING.value))
} else {
updatedRequests.add(request.copy(status = PushRequestStatus.FAILED.value))
}
}
)
}
Timber.d("Notifications processed successfully")
pushHistoryService.insertOrUpdatePushRequests(updatedRequests)
analyticsService.recordTransaction("Opportunistic sync", "opportunistic_sync") {
performOpportunisticSyncIfNeeded(mapOf(sessionId to requests))
}
return if (updatedRequests.any { it.status == PushRequestStatus.PENDING.value }) Result.retry() else Result.success()
}
private suspend fun performOpportunisticSyncIfNeeded(
groupedRequests: Map<SessionId, List<PushRequest>>,
) {
for ((sessionId, notificationRequests) in groupedRequests) {
runCatchingExceptions {
syncOnNotifiableEvent(notificationRequests)
}.onFailure {
Timber.e(it, "Failed to sync on notifiable events for session $sessionId")
}
}
}
private suspend fun checkNetworkConnection(requests: List<PushRequest>): Result? {
val networkTimeoutSpans = requests.mapNotNull { request ->
val parent = analyticsService.getLongRunningTransaction(AnalyticsLongRunningTransaction.PushToWorkManager(request.eventId))
parent?.startChild("Waiting for network connectivity", "await_network")
}
// Wait for network to be available, but not more than 10 seconds
val hasNetwork = withTimeoutOrNull(10.seconds) {
networkMonitor.connectivity.first { it == NetworkStatus.Connected }
} != null
networkTimeoutSpans.finish()
// If there is a problem with the updated network values, report it and retry if needed
if (reportConnectivityError(requests = requests, hasNetwork = hasNetwork, isNetworkBlocked = networkMonitor.isNetworkBlocked())) {
pushHistoryService.insertOrUpdatePushRequests(requests.map { request ->
request.copy(retries = request.retries + 1)
})
return Result.retry()
}
return null
}
private fun reportConnectivityError(requests: List<PushRequest>, hasNetwork: Boolean, isNetworkBlocked: Boolean): Boolean {
return if (!hasNetwork || isNetworkBlocked) {
for (request in requests) {
val eventId = request.eventId
analyticsService.finishLongRunningTransaction(AnalyticsLongRunningTransaction.PushToWorkManager(eventId)) {
it.putExtraData("has_network_connection", hasNetwork.toString())
it.putExtraData("is_network_blocked", isNetworkBlocked.toString())
}
val parent = analyticsService.getLongRunningTransaction(AnalyticsLongRunningTransaction.PushToNotification(eventId))
// Since we're retrying, start a new transaction
analyticsService.startLongRunningTransaction(AnalyticsLongRunningTransaction.PushToWorkManager(eventId), parent)
}
Timber.w("FetchNotificationsWorker will retry. Has network connectivity: $hasNetwork. Is network blocked: $isNetworkBlocked")
true
} else {
false
}
}
private suspend fun handleSetupError(
sessionId: SessionId,
requests: List<PushRequest>,
pendingAnalyticTransactions: Map<String, AnalyticsTransaction>,
throwable: Throwable,
): Result {
for ((_, transaction) in pendingAnalyticTransactions) {
transaction.attachError(throwable)
transaction.finish()
}
// If there were failures on the setup step and they weren't recoverable, update the requests and fail
if (throwable.cause is SessionRestorationException) {
Timber.e(throwable, "Session $sessionId could not be restored, not retrying notification fetching")
pushHistoryService.insertOrUpdatePushRequests(requests.map { request ->
request.copy(status = PushRequestStatus.FAILED.value)
})
return Result.failure()
}
// If the failure is recoverable, retry
for (request in requests) {
val failedTransaction = pendingAnalyticTransactions[request.eventId]
failedTransaction?.attachError(throwable)
failedTransaction?.finish()
val eventId = request.eventId
val parent = analyticsService.getLongRunningTransaction(AnalyticsLongRunningTransaction.PushToNotification(eventId))
// Since we're retrying, start a new transaction
analyticsService.startLongRunningTransaction(AnalyticsLongRunningTransaction.PushToWorkManager(eventId), parent)
}
Timber.d("Re-scheduling ${requests.size} failed notification requests for session $sessionId")
pushHistoryService.insertOrUpdatePushRequests(requests.map { request ->
request.copy(retries = request.retries + 1)
})
return Result.retry()
}
@ContributesIntoMap(AppScope::class, binding = binding<MetroWorkerFactory.WorkerInstanceFactory<*>>())
@WorkerKey(FetchPendingNotificationsWorker::class)
@AssistedFactory
interface Factory : MetroWorkerFactory.WorkerInstanceFactory<FetchPendingNotificationsWorker>
}
private fun <T : AnalyticsTransaction> Collection<T>.finish() = forEach { it.finish() }

View File

@@ -1,68 +0,0 @@
/*
* Copyright (c) 2025 Element Creations Ltd.
* Copyright 2025 New Vector Ltd.
*
* SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial.
* Please see LICENSE files in the repository root for full details.
*/
package io.element.android.libraries.push.impl.workmanager
import android.os.Build
import androidx.work.OneTimeWorkRequestBuilder
import androidx.work.OutOfQuotaPolicy
import androidx.work.WorkRequest
import io.element.android.libraries.matrix.api.core.SessionId
import io.element.android.libraries.push.api.push.NotificationEventRequest
import io.element.android.libraries.workmanager.api.WorkManagerRequest
import io.element.android.libraries.workmanager.api.WorkManagerRequestType
import io.element.android.libraries.workmanager.api.workManagerTag
import io.element.android.services.toolbox.api.sdk.BuildVersionSdkIntProvider
import kotlinx.serialization.SerialName
import kotlinx.serialization.Serializable
import timber.log.Timber
import java.security.InvalidParameterException
class SyncNotificationWorkManagerRequest(
private val sessionId: SessionId,
private val notificationEventRequests: List<NotificationEventRequest>,
private val workerDataConverter: SyncNotificationsWorkerDataConverter,
private val buildVersionSdkIntProvider: BuildVersionSdkIntProvider,
) : WorkManagerRequest {
override fun build(): Result<List<WorkRequest>> {
if (notificationEventRequests.isEmpty()) {
return Result.failure(InvalidParameterException("notificationEventRequests cannot be empty"))
}
Timber.d("Scheduling ${notificationEventRequests.size} notification requests with WorkManager for $sessionId")
return workerDataConverter.serialize(notificationEventRequests).map { dataList ->
dataList.map { data ->
OneTimeWorkRequestBuilder<FetchNotificationsWorker>()
.setInputData(data)
.apply {
// Expedited workers aren't needed on Android 12 or lower:
// They force displaying a foreground sync notification for no good reason, since they sync almost immediately anyway
// See https://developer.android.com/develop/background-work/background-tasks/persistent/getting-started/define-work#backwards-compat
if (buildVersionSdkIntProvider.isAtLeast(Build.VERSION_CODES.TIRAMISU)) {
setExpedited(OutOfQuotaPolicy.RUN_AS_NON_EXPEDITED_WORK_REQUEST)
}
}
.setTraceTag(workManagerTag(sessionId, WorkManagerRequestType.NOTIFICATION_SYNC))
// TODO investigate using this instead of the resolver queue
// .setInputMerger()
.build()
}
}
}
@Serializable
data class Data(
@SerialName("session_id")
val sessionId: String,
@SerialName("room_id")
val roomId: String,
@SerialName("event_id")
val eventId: String,
@SerialName("provider_info")
val providerInfo: String,
)
}

View File

@@ -1,129 +0,0 @@
/*
* Copyright (c) 2025 Element Creations Ltd.
* Copyright 2025 New Vector Ltd.
*
* SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial.
* Please see LICENSE files in the repository root for full details.
*/
package io.element.android.libraries.push.impl.workmanager
import androidx.work.Data
import androidx.work.workDataOf
import dev.zacsweers.metro.Inject
import io.element.android.libraries.androidutils.json.JsonProvider
import io.element.android.libraries.core.extensions.mapCatchingExceptions
import io.element.android.libraries.core.extensions.runCatchingExceptions
import io.element.android.libraries.matrix.api.core.EventId
import io.element.android.libraries.matrix.api.core.RoomId
import io.element.android.libraries.matrix.api.core.SessionId
import io.element.android.libraries.push.api.push.NotificationEventRequest
import timber.log.Timber
@Inject
class SyncNotificationsWorkerDataConverter(
private val json: JsonProvider,
) {
fun serialize(notificationEventRequests: List<NotificationEventRequest>): Result<List<Data>> {
// First try to serialize all requests at once. In the vast majority of cases this will work.
return serializeRequests(notificationEventRequests)
.map { listOf(it) }
.recoverCatching { t ->
if (t is DataForWorkManagerIsTooBig) {
// Perform serialization on sublists, workDataOf have failed because of size limit
Timber.w(t, "Failed to serialize ${notificationEventRequests.size} notification requests, split the requests per room.")
// Group the requests per rooms
val requestsSortedPerRoom = notificationEventRequests.groupBy { it.roomId }.values
// Build a list of sublist with size at most CHUNK_SIZE, and with all rooms kept together
buildList {
val currentChunk = mutableListOf<NotificationEventRequest>()
for (requests in requestsSortedPerRoom) {
if (currentChunk.size + requests.size <= CHUNK_SIZE) {
// Can add the whole room requests to the current chunk
currentChunk.addAll(requests)
} else {
// Add the current chunk
add(currentChunk.toList())
// Start a new chunk with the current room requests
currentChunk.clear()
// If a room has more requests than CHUNK_SIZE, we need to split them
requests.chunked(CHUNK_SIZE) { chunk ->
if (chunk.size == CHUNK_SIZE) {
add(chunk.toList())
} else {
currentChunk.addAll(chunk)
}
}
}
}
// Add any remaining requests
add(currentChunk.toList())
}
.filter { it.isNotEmpty() }
.also {
Timber.d("Split notification requests into ${it.size} chunks for WorkManager serialization")
it.forEach { requests ->
Timber.d(" - Chunk with ${requests.size} requests")
}
}
.mapNotNull { serializeRequests(it).getOrNull() }
} else {
throw t
}
}
}
private fun serializeRequests(notificationEventRequests: List<NotificationEventRequest>): Result<Data> {
return runCatchingExceptions { json().encodeToString(notificationEventRequests.map { it.toData() }) }
.onFailure {
Timber.e(it, "Failed to serialize notification requests")
}
.mapCatchingExceptions { str ->
// Note: workDataOf can fail if the data is too large
try {
workDataOf(REQUESTS_KEY to str)
} catch (_: IllegalStateException) {
throw DataForWorkManagerIsTooBig()
}
}
}
fun deserialize(data: Data): List<NotificationEventRequest>? {
val rawRequestsJson = data.getString(REQUESTS_KEY) ?: return null
return runCatchingExceptions {
json().decodeFromString<List<SyncNotificationWorkManagerRequest.Data>>(rawRequestsJson).map { it.toRequest() }
}.fold(
onSuccess = {
Timber.d("Deserialized ${it.size} requests")
it
},
onFailure = {
Timber.e(it, "Failed to deserialize notification requests")
null
}
)
}
companion object {
private const val REQUESTS_KEY = "requests"
internal const val CHUNK_SIZE = 20
}
}
private fun NotificationEventRequest.toData(): SyncNotificationWorkManagerRequest.Data {
return SyncNotificationWorkManagerRequest.Data(
sessionId = sessionId.value,
roomId = roomId.value,
eventId = eventId.value,
providerInfo = providerInfo,
)
}
private fun SyncNotificationWorkManagerRequest.Data.toRequest(): NotificationEventRequest {
return NotificationEventRequest(
sessionId = SessionId(sessionId),
roomId = RoomId(roomId),
eventId = EventId(eventId),
providerInfo = providerInfo,
)
}

View File

@@ -0,0 +1,51 @@
/*
* Copyright (c) 2025 Element Creations Ltd.
* Copyright 2025 New Vector Ltd.
*
* SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial.
* Please see LICENSE files in the repository root for full details.
*/
package io.element.android.libraries.push.impl.workmanager
import android.os.Build
import androidx.work.ExistingWorkPolicy
import androidx.work.OneTimeWorkRequestBuilder
import androidx.work.OutOfQuotaPolicy
import androidx.work.workDataOf
import io.element.android.libraries.matrix.api.core.SessionId
import io.element.android.libraries.workmanager.api.WorkManagerRequestBuilder
import io.element.android.libraries.workmanager.api.WorkManagerRequestType
import io.element.android.libraries.workmanager.api.WorkManagerRequestWrapper
import io.element.android.libraries.workmanager.api.WorkManagerWorkerType
import io.element.android.libraries.workmanager.api.workManagerTag
import io.element.android.services.toolbox.api.sdk.BuildVersionSdkIntProvider
class SyncPendingNotificationsRequestBuilder(
private val sessionId: SessionId,
private val buildVersionSdkIntProvider: BuildVersionSdkIntProvider,
) : WorkManagerRequestBuilder {
companion object {
const val SESSION_ID = "session_id"
}
override suspend fun build(): Result<List<WorkManagerRequestWrapper>> {
val type = WorkManagerWorkerType.Unique(
name = workManagerTag(sessionId = sessionId, requestType = WorkManagerRequestType.NOTIFICATION_SYNC),
policy = ExistingWorkPolicy.APPEND_OR_REPLACE,
)
val request = OneTimeWorkRequestBuilder<FetchPendingNotificationsWorker>()
.setInputData(workDataOf(SESSION_ID to sessionId.value))
.apply {
// Expedited workers aren't needed on Android 12 or lower:
// They force displaying a foreground sync notification for no good reason, since they sync almost immediately anyway
// See https://developer.android.com/develop/background-work/background-tasks/persistent/getting-started/define-work#backwards-compat
if (buildVersionSdkIntProvider.isAtLeast(Build.VERSION_CODES.TIRAMISU)) {
setExpedited(OutOfQuotaPolicy.RUN_AS_NON_EXPEDITED_WORK_REQUEST)
}
}
.setTraceTag(workManagerTag(sessionId, WorkManagerRequestType.NOTIFICATION_SYNC))
.build()
return Result.success(listOf(WorkManagerRequestWrapper(request, type)))
}
}

Binary file not shown.

Binary file not shown.

View File

@@ -17,6 +17,5 @@ INSERT INTO PushHistory VALUES ?;
removeAll:
DELETE FROM PushHistory;
-- add query to keep only the last x entries
removeOldest:
DELETE FROM PushHistory WHERE rowid NOT IN (SELECT rowid FROM PushHistory ORDER BY pushDate DESC LIMIT ?);

View File

@@ -0,0 +1,24 @@
CREATE TABLE PushRequest (
pushDate INTEGER NOT NULL,
providerInfo TEXT NOT NULL,
eventId TEXT NOT NULL,
roomId TEXT NOT NULL,
sessionId TEXT NOT NULL,
status INTEGER NOT NULL DEFAULT 0,
retries INTEGER NOT NULL DEFAULT 0,
PRIMARY KEY(sessionId, eventId)
);
CREATE INDEX PushRequestSessionAndStatus ON PushRequest (sessionId, status);
selectAllPendingForSession:
SELECT * FROM PushRequest WHERE status = 0 AND sessionId = ? AND pushDate > ? ORDER BY pushDate ASC;
insertPushRequest:
INSERT OR REPLACE INTO PushRequest VALUES ?;
removeAll:
DELETE FROM PushRequest;
removeOldest:
DELETE FROM PushRequest WHERE rowid NOT IN (SELECT rowid FROM PushRequest ORDER BY pushDate DESC LIMIT ?);

View File

@@ -0,0 +1,14 @@
-- Migrate DB from version 1
CREATE TABLE PushRequest (
pushDate INTEGER NOT NULL,
providerInfo TEXT NOT NULL,
eventId TEXT NOT NULL,
roomId TEXT NOT NULL,
sessionId TEXT NOT NULL,
status INTEGER NOT NULL DEFAULT 0,
retries INTEGER NOT NULL DEFAULT 0,
PRIMARY KEY(sessionId, eventId)
);
CREATE INDEX PushRequestSessionAndStatus ON PushRequest (sessionId, status);

View File

@@ -11,7 +11,9 @@ package io.element.android.libraries.push.impl.history
import io.element.android.libraries.matrix.api.core.EventId
import io.element.android.libraries.matrix.api.core.RoomId
import io.element.android.libraries.matrix.api.core.SessionId
import io.element.android.libraries.push.impl.db.PushRequest
import io.element.android.tests.testutils.lambda.lambdaError
import kotlin.time.Instant
class FakePushHistoryService(
private val onPushReceivedResult: (
@@ -22,9 +24,13 @@ class FakePushHistoryService(
Boolean,
Boolean,
String?
) -> Unit = { _, _, _, _, _, _, _ -> lambdaError() }
) -> Unit = { _, _, _, _, _, _, _ -> lambdaError() },
private val enqueuePushRequest: (PushRequest) -> Result<Unit> = { lambdaError() },
private val replacePushRequests: (List<PushRequest>) -> Result<Unit> = { lambdaError() },
private val getPendingPushRequests: (SessionId, Instant?) -> Result<List<PushRequest>> = { _, _ -> lambdaError() },
private val removeOldPushRequests: (SessionId) -> Result<Unit> = { lambdaError() },
) : PushHistoryService {
override fun onPushReceived(
override fun onPushResult(
providerInfo: String,
eventId: EventId?,
roomId: RoomId?,
@@ -43,4 +49,20 @@ class FakePushHistoryService(
comment
)
}
override suspend fun insertOrUpdatePushRequest(pushRequest: PushRequest): Result<Unit> {
return enqueuePushRequest.invoke(pushRequest)
}
override suspend fun insertOrUpdatePushRequests(pushRequests: List<PushRequest>): Result<Unit> {
return replacePushRequests.invoke(pushRequests)
}
override suspend fun getPendingPushRequests(sessionId: SessionId, since: Instant?): Result<List<PushRequest>> {
return getPendingPushRequests.invoke(sessionId, since)
}
override suspend fun removeOldPushRequests(sessionId: SessionId): Result<Unit> {
return removeOldPushRequests.invoke(sessionId)
}
}

View File

@@ -47,9 +47,10 @@ import io.element.android.libraries.matrix.test.FakeMatrixClientProvider
import io.element.android.libraries.matrix.test.notification.FakeNotificationService
import io.element.android.libraries.matrix.test.notification.aNotificationData
import io.element.android.libraries.matrix.test.permalink.FakePermalinkParser
import io.element.android.libraries.push.api.push.NotificationEventRequest
import io.element.android.libraries.push.impl.db.PushRequest
import io.element.android.libraries.push.impl.notifications.fake.FakeNotificationMediaRepo
import io.element.android.libraries.push.impl.notifications.fixtures.aNotifiableMessageEvent
import io.element.android.libraries.push.impl.notifications.fixtures.aPushRequest
import io.element.android.libraries.push.impl.notifications.model.FallbackNotifiableEvent
import io.element.android.libraries.push.impl.notifications.model.InviteNotifiableEvent
import io.element.android.libraries.push.impl.notifications.model.NotifiableMessageEvent
@@ -71,7 +72,7 @@ class DefaultNotifiableEventResolverTest {
@Test
fun `resolve event no session`() = runTest {
val sut = createDefaultNotifiableEventResolver(notificationService = null)
val result = sut.resolveEvents(A_SESSION_ID, listOf(NotificationEventRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, "firebase")))
val result = sut.resolveEvents(A_SESSION_ID, listOf(aPushRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, "firebase")))
assertThat(result.isFailure).isTrue()
}
@@ -80,7 +81,7 @@ class DefaultNotifiableEventResolverTest {
val sut = createDefaultNotifiableEventResolver(
notificationResult = Result.failure(AN_EXCEPTION)
)
val request = NotificationEventRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, "firebase")
val request = aPushRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, "firebase")
val result = sut.resolveEvents(A_SESSION_ID, listOf(request))
assertThat(result.isFailure).isTrue()
}
@@ -90,7 +91,7 @@ class DefaultNotifiableEventResolverTest {
val sut = createDefaultNotifiableEventResolver(
notificationResult = Result.success(mapOf(AN_EVENT_ID to Result.failure(AN_EXCEPTION)))
)
val request = NotificationEventRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, "firebase")
val request = aPushRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, "firebase")
val result = sut.resolveEvents(A_SESSION_ID, listOf(request))
assertThat(result.getEvent(request)?.isFailure).isTrue()
}
@@ -109,7 +110,7 @@ class DefaultNotifiableEventResolverTest {
)
)
)
val request = NotificationEventRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, "firebase")
val request = aPushRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, "firebase")
val result = sut.resolveEvents(A_SESSION_ID, listOf(request))
val expectedResult = ResolvedPushEvent.Event(
aNotifiableMessageEvent(body = "Hello world")
@@ -133,7 +134,7 @@ class DefaultNotifiableEventResolverTest {
)
)
)
val request = NotificationEventRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, "firebase")
val request = aPushRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, "firebase")
val result = sut.resolveEvents(A_SESSION_ID, listOf(request))
val expectedResult = ResolvedPushEvent.Event(
aNotifiableMessageEvent(body = "Hello world", hasMentionOrReply = true)
@@ -161,7 +162,7 @@ class DefaultNotifiableEventResolverTest {
)
)
)
val request = NotificationEventRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, "firebase")
val request = aPushRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, "firebase")
val result = sut.resolveEvents(A_SESSION_ID, listOf(request))
val expectedResult = ResolvedPushEvent.Event(
aNotifiableMessageEvent(body = "Hello world")
@@ -189,7 +190,7 @@ class DefaultNotifiableEventResolverTest {
)
)
)
val request = NotificationEventRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, "firebase")
val request = aPushRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, "firebase")
val result = sut.resolveEvents(A_SESSION_ID, listOf(request))
val expectedResult = ResolvedPushEvent.Event(
aNotifiableMessageEvent(body = "Hello world")
@@ -211,7 +212,7 @@ class DefaultNotifiableEventResolverTest {
)
)
)
val request = NotificationEventRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, "firebase")
val request = aPushRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, "firebase")
val result = sut.resolveEvents(A_SESSION_ID, listOf(request))
val expectedResult = ResolvedPushEvent.Event(
aNotifiableMessageEvent(body = "Audio")
@@ -233,7 +234,7 @@ class DefaultNotifiableEventResolverTest {
)
)
)
val request = NotificationEventRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, "firebase")
val request = aPushRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, "firebase")
val result = sut.resolveEvents(A_SESSION_ID, listOf(request))
val expectedResult = ResolvedPushEvent.Event(
aNotifiableMessageEvent(body = "Video")
@@ -255,7 +256,7 @@ class DefaultNotifiableEventResolverTest {
)
)
)
val request = NotificationEventRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, "firebase")
val request = aPushRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, "firebase")
val result = sut.resolveEvents(A_SESSION_ID, listOf(request))
val expectedResult = ResolvedPushEvent.Event(
aNotifiableMessageEvent(body = "Voice message")
@@ -277,7 +278,7 @@ class DefaultNotifiableEventResolverTest {
)
)
)
val request = NotificationEventRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, "firebase")
val request = aPushRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, "firebase")
val result = sut.resolveEvents(A_SESSION_ID, listOf(request))
val expectedResult = ResolvedPushEvent.Event(
aNotifiableMessageEvent(body = "Image")
@@ -299,7 +300,7 @@ class DefaultNotifiableEventResolverTest {
)
)
)
val request = NotificationEventRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, "firebase")
val request = aPushRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, "firebase")
val result = sut.resolveEvents(A_SESSION_ID, listOf(request))
val expectedResult = ResolvedPushEvent.Event(
aNotifiableMessageEvent(body = "Sticker")
@@ -321,7 +322,7 @@ class DefaultNotifiableEventResolverTest {
)
)
)
val request = NotificationEventRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, "firebase")
val request = aPushRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, "firebase")
val result = sut.resolveEvents(A_SESSION_ID, listOf(request))
val expectedResult = ResolvedPushEvent.Event(
aNotifiableMessageEvent(body = "File")
@@ -343,7 +344,7 @@ class DefaultNotifiableEventResolverTest {
)
)
)
val request = NotificationEventRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, "firebase")
val request = aPushRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, "firebase")
val result = sut.resolveEvents(A_SESSION_ID, listOf(request))
val expectedResult = ResolvedPushEvent.Event(
aNotifiableMessageEvent(body = "Location")
@@ -365,7 +366,7 @@ class DefaultNotifiableEventResolverTest {
)
)
)
val request = NotificationEventRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, "firebase")
val request = aPushRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, "firebase")
val result = sut.resolveEvents(A_SESSION_ID, listOf(request))
val expectedResult = ResolvedPushEvent.Event(
aNotifiableMessageEvent(body = "Notice")
@@ -387,7 +388,7 @@ class DefaultNotifiableEventResolverTest {
)
)
)
val request = NotificationEventRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, "firebase")
val request = aPushRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, "firebase")
val result = sut.resolveEvents(A_SESSION_ID, listOf(request))
val expectedResult = ResolvedPushEvent.Event(
aNotifiableMessageEvent(body = "* Bob is happy")
@@ -409,7 +410,7 @@ class DefaultNotifiableEventResolverTest {
)
)
)
val request = NotificationEventRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, "firebase")
val request = aPushRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, "firebase")
val result = sut.resolveEvents(A_SESSION_ID, listOf(request))
val expectedResult = ResolvedPushEvent.Event(
aNotifiableMessageEvent(body = "Poll: A question")
@@ -432,7 +433,7 @@ class DefaultNotifiableEventResolverTest {
)
)
)
val request = NotificationEventRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, "firebase")
val request = aPushRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, "firebase")
val result = sut.resolveEvents(A_SESSION_ID, listOf(request))
assertThat(result.getEvent(request)?.getOrNull()).isNull()
}
@@ -451,7 +452,7 @@ class DefaultNotifiableEventResolverTest {
)
)
)
val request = NotificationEventRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, "firebase")
val request = aPushRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, "firebase")
val result = sut.resolveEvents(A_SESSION_ID, listOf(request))
val expectedResult = ResolvedPushEvent.Event(
InviteNotifiableEvent(
@@ -490,7 +491,7 @@ class DefaultNotifiableEventResolverTest {
)
)
)
val request = NotificationEventRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, "firebase")
val request = aPushRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, "firebase")
val result = sut.resolveEvents(A_SESSION_ID, listOf(request))
val expectedResult = ResolvedPushEvent.Event(
InviteNotifiableEvent(
@@ -527,7 +528,7 @@ class DefaultNotifiableEventResolverTest {
)
)
)
val request = NotificationEventRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, "firebase")
val request = aPushRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, "firebase")
val result = sut.resolveEvents(A_SESSION_ID, listOf(request))
val expectedResult = ResolvedPushEvent.Event(
InviteNotifiableEvent(
@@ -565,7 +566,7 @@ class DefaultNotifiableEventResolverTest {
)
)
)
val request = NotificationEventRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, "firebase")
val request = aPushRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, "firebase")
val result = sut.resolveEvents(A_SESSION_ID, listOf(request))
val expectedResult = ResolvedPushEvent.Event(
InviteNotifiableEvent(
@@ -605,7 +606,7 @@ class DefaultNotifiableEventResolverTest {
)
)
)
val request = NotificationEventRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, "firebase")
val request = aPushRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, "firebase")
val result = sut.resolveEvents(A_SESSION_ID, listOf(request))
val expectedResult = ResolvedPushEvent.Event(
InviteNotifiableEvent(
@@ -642,7 +643,7 @@ class DefaultNotifiableEventResolverTest {
)
)
)
val request = NotificationEventRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, "firebase")
val request = aPushRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, "firebase")
val result = sut.resolveEvents(A_SESSION_ID, listOf(request))
assertThat(result.getEvent(request)?.getOrNull()).isNull()
}
@@ -654,7 +655,7 @@ class DefaultNotifiableEventResolverTest {
mapOf(AN_EVENT_ID to Result.success(aNotificationData(content = NotificationContent.MessageLike.RoomEncrypted)))
)
)
val request = NotificationEventRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, "firebase")
val request = aPushRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, "firebase")
val result = sut.resolveEvents(A_SESSION_ID, listOf(request))
val expectedResult = ResolvedPushEvent.Event(
FallbackNotifiableEvent(
@@ -680,7 +681,7 @@ class DefaultNotifiableEventResolverTest {
mapOf(AN_EVENT_ID to Result.failure(NotificationResolverException.EventNotFound))
)
)
val request = NotificationEventRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, "firebase")
val request = aPushRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, "firebase")
val result = sut.resolveEvents(A_SESSION_ID, listOf(request))
assertThat(result.getEvent(request)).isEqualTo(Result.failure<ResolvedPushEvent?>(NotificationResolverException.EventNotFound))
}
@@ -698,7 +699,7 @@ class DefaultNotifiableEventResolverTest {
)
)
)
val request = NotificationEventRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, "firebase")
val request = aPushRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, "firebase")
val result = sut.resolveEvents(A_SESSION_ID, listOf(request))
val expectedResult = ResolvedPushEvent.Event(
NotifiableMessageEvent(
@@ -766,7 +767,7 @@ class DefaultNotifiableEventResolverTest {
)
)
callNotificationEventResolver.resolveEventLambda = { _, _, _ -> Result.success(expectedResult.notifiableEvent) }
val request = NotificationEventRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, "firebase")
val request = aPushRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, "firebase")
val result = sut.resolveEvents(A_SESSION_ID, listOf(request))
assertThat(result.getEvent(request)).isEqualTo(Result.success(expectedResult))
}
@@ -791,7 +792,7 @@ class DefaultNotifiableEventResolverTest {
redactedEventId = AN_EVENT_ID_2,
reason = A_REDACTION_REASON,
)
val request = NotificationEventRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, "firebase")
val request = aPushRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, "firebase")
val result = sut.resolveEvents(A_SESSION_ID, listOf(request))
assertThat(result.getEvent(request)).isEqualTo(Result.success(expectedResult))
}
@@ -810,7 +811,7 @@ class DefaultNotifiableEventResolverTest {
)
)
)
val request = NotificationEventRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, "firebase")
val request = aPushRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, "firebase")
val result = sut.resolveEvents(A_SESSION_ID, listOf(request))
assertThat(result.getEvent(request)?.getOrNull()).isNull()
}
@@ -857,13 +858,13 @@ class DefaultNotifiableEventResolverTest {
mapOf(AN_EVENT_ID to Result.success(aNotificationData(content = content)))
)
)
val request = NotificationEventRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, "firebase")
val request = aPushRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, "firebase")
val result = sut.resolveEvents(A_SESSION_ID, listOf(request))
assertThat(result.getEvent(request)?.getOrNull()).isNull()
}
private fun Result<Map<NotificationEventRequest, Result<ResolvedPushEvent>>>.getEvent(
request: NotificationEventRequest
private fun Result<Map<PushRequest, Result<ResolvedPushEvent>>>.getEvent(
request: PushRequest
): Result<ResolvedPushEvent>? {
return getOrNull()?.get(request)
}

View File

@@ -0,0 +1,310 @@
/*
* Copyright (c) 2026 Element Creations Ltd.
*
* SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial.
* Please see LICENSE files in the repository root for full details.
*/
package io.element.android.libraries.push.impl.notifications
import com.google.common.truth.Truth.assertThat
import io.element.android.features.call.api.CallType
import io.element.android.features.call.test.FakeElementCallEntryPoint
import io.element.android.libraries.featureflag.test.FakeFeatureFlagService
import io.element.android.libraries.matrix.api.core.EventId
import io.element.android.libraries.matrix.api.core.RoomId
import io.element.android.libraries.matrix.api.core.SessionId
import io.element.android.libraries.matrix.api.core.UserId
import io.element.android.libraries.matrix.api.exception.NotificationResolverException
import io.element.android.libraries.matrix.api.timeline.item.event.EventType
import io.element.android.libraries.matrix.test.AN_EVENT_ID
import io.element.android.libraries.matrix.test.AN_EVENT_ID_2
import io.element.android.libraries.matrix.test.A_ROOM_ID
import io.element.android.libraries.matrix.test.A_SESSION_ID
import io.element.android.libraries.matrix.test.A_USER_ID
import io.element.android.libraries.push.impl.db.PushRequest
import io.element.android.libraries.push.impl.history.FakePushHistoryService
import io.element.android.libraries.push.impl.notifications.channels.FakeNotificationChannels
import io.element.android.libraries.push.impl.notifications.fixtures.aFallbackNotifiableEvent
import io.element.android.libraries.push.impl.notifications.fixtures.aNotifiableCallEvent
import io.element.android.libraries.push.impl.notifications.fixtures.aNotifiableMessageEvent
import io.element.android.libraries.push.impl.notifications.fixtures.aPushRequest
import io.element.android.libraries.push.impl.notifications.model.NotifiableEvent
import io.element.android.libraries.push.impl.notifications.model.ResolvedPushEvent
import io.element.android.libraries.push.impl.push.FakeMutableBatteryOptimizationStore
import io.element.android.libraries.push.impl.push.FakeOnNotifiableEventReceived
import io.element.android.libraries.push.impl.push.FakeOnRedactedEventReceived
import io.element.android.libraries.push.impl.push.SyncOnNotifiableEvent
import io.element.android.libraries.pushstore.test.userpushstore.FakeUserPushStoreFactory
import io.element.android.services.toolbox.test.systemclock.FakeSystemClock
import io.element.android.tests.testutils.lambda.any
import io.element.android.tests.testutils.lambda.lambdaRecorder
import io.element.android.tests.testutils.lambda.value
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.test.TestScope
import kotlinx.coroutines.test.advanceTimeBy
import kotlinx.coroutines.test.runCurrent
import kotlinx.coroutines.test.runTest
import org.junit.Test
import kotlin.time.Duration.Companion.milliseconds
@OptIn(ExperimentalCoroutinesApi::class)
class DefaultNotificationResultProcessorTest {
@Test
fun `when not able to resolve the event, the banner to disable battery optimization will be displayed`() {
`test notification resolver failure`(
notificationResolveResult = { requests: List<PushRequest> ->
Result.success(
requests.associateWith { Result.failure(NotificationResolverException.UnknownError("Unable to resolve event")) }
)
},
shouldSetOptimizationBatteryBanner = true,
)
}
private fun `test notification resolver failure`(
notificationResolveResult: (List<PushRequest>) -> Result<Map<PushRequest, Result<ResolvedPushEvent>>>,
shouldSetOptimizationBatteryBanner: Boolean,
) {
runTest {
val notifiableEventResult =
lambdaRecorder<SessionId, List<PushRequest>, Result<Map<PushRequest, Result<ResolvedPushEvent>>>> { _, requests ->
notificationResolveResult(requests)
}
val onPushReceivedResult = lambdaRecorder<String, EventId?, RoomId?, SessionId?, Boolean, Boolean, String?, Unit> { _, _, _, _, _, _, _ -> }
val pushHistoryService = FakePushHistoryService(
onPushReceivedResult = onPushReceivedResult,
)
val showBatteryOptimizationBannerResult = lambdaRecorder<Unit> {}
val processor = createDefaultNotificationResultProcessor(
mutableBatteryOptimizationStore = FakeMutableBatteryOptimizationStore(
showBatteryOptimizationBannerResult = showBatteryOptimizationBannerResult,
),
pushHistoryService = pushHistoryService,
)
runningProcessor(processor) {
emit(mapOf(aPushRequest() to Result.failure(IllegalStateException("boom"))))
}
notifiableEventResult.assertions()
.isNeverCalled()
onPushReceivedResult.assertions()
.isCalledOnce()
.with(any(), value(AN_EVENT_ID), value(A_ROOM_ID), value(A_USER_ID), value(false), value(true), any())
showBatteryOptimizationBannerResult.assertions().let {
if (shouldSetOptimizationBatteryBanner) {
it.isCalledOnce()
} else {
it.isNeverCalled()
}
}
}
}
@Test
fun `when ringing call PushData is received, the incoming call will be handled`() = runTest {
val handleIncomingCallLambda = lambdaRecorder<
CallType.RoomCall,
EventId,
UserId,
String?,
String?,
String?,
String,
String?,
Unit,
> { _, _, _, _, _, _, _, _ -> }
val elementCallEntryPoint = FakeElementCallEntryPoint(handleIncomingCallResult = handleIncomingCallLambda)
val onNotifiableEventsReceived = lambdaRecorder<List<NotifiableEvent>, Unit> {}
val onPushReceivedResult = lambdaRecorder<String, EventId?, RoomId?, SessionId?, Boolean, Boolean, String?, Unit> { _, _, _, _, _, _, _ -> }
val pushHistoryService = FakePushHistoryService(
onPushReceivedResult = onPushReceivedResult,
)
val processor = createDefaultNotificationResultProcessor(
elementCallEntryPoint = elementCallEntryPoint,
onNotifiableEventsReceived = onNotifiableEventsReceived,
pushHistoryService = pushHistoryService,
)
runningProcessor(processor) {
emit(mapOf(aPushRequest() to Result.success(ResolvedPushEvent.Event(aNotifiableCallEvent()))))
}
advanceTimeBy(300.milliseconds)
handleIncomingCallLambda.assertions().isCalledOnce()
onNotifiableEventsReceived.assertions().isNeverCalled()
onPushReceivedResult.assertions().isCalledOnce()
}
@Test
fun `when notify call PushData is received, the incoming call will be treated as a normal notification`() = runTest {
val onNotifiableEventsReceived = lambdaRecorder<List<NotifiableEvent>, Unit> {}
val handleIncomingCallLambda = lambdaRecorder<
CallType.RoomCall,
EventId,
UserId,
String?,
String?,
String?,
String,
String?,
Unit,
> { _, _, _, _, _, _, _, _ -> }
val elementCallEntryPoint = FakeElementCallEntryPoint(handleIncomingCallResult = handleIncomingCallLambda)
val onPushReceivedResult = lambdaRecorder<String, EventId?, RoomId?, SessionId?, Boolean, Boolean, String?, Unit> { _, _, _, _, _, _, _ -> }
val pushHistoryService = FakePushHistoryService(
onPushReceivedResult = onPushReceivedResult,
)
val processor = createDefaultNotificationResultProcessor(
elementCallEntryPoint = elementCallEntryPoint,
onNotifiableEventsReceived = onNotifiableEventsReceived,
pushHistoryService = pushHistoryService,
)
runningProcessor(processor) {
processor.emit(mapOf(aPushRequest() to Result.success(ResolvedPushEvent.Event(aNotifiableMessageEvent(type = EventType.RTC_NOTIFICATION)))))
}
advanceTimeBy(300.milliseconds)
handleIncomingCallLambda.assertions().isNeverCalled()
onNotifiableEventsReceived.assertions().isCalledOnce()
onPushReceivedResult.assertions().isCalledOnce()
}
@Test
fun `when notify call PushData is received, the incoming call will be treated as a normal notification even if notification are disabled`() = runTest {
val onNotifiableEventsReceived = lambdaRecorder<List<NotifiableEvent>, Unit> {}
val handleIncomingCallLambda = lambdaRecorder<
CallType.RoomCall,
EventId,
UserId,
String?,
String?,
String?,
String,
String?,
Unit,
> { _, _, _, _, _, _, _, _ -> }
val elementCallEntryPoint = FakeElementCallEntryPoint(handleIncomingCallResult = handleIncomingCallLambda)
val onPushReceivedResult = lambdaRecorder<String, EventId?, RoomId?, SessionId?, Boolean, Boolean, String?, Unit> { _, _, _, _, _, _, _ -> }
val pushHistoryService = FakePushHistoryService(
onPushReceivedResult = onPushReceivedResult,
)
val processor = createDefaultNotificationResultProcessor(
elementCallEntryPoint = elementCallEntryPoint,
onNotifiableEventsReceived = onNotifiableEventsReceived,
pushHistoryService = pushHistoryService,
)
runningProcessor(processor) {
processor.emit(mapOf(aPushRequest() to Result.success(ResolvedPushEvent.Event(aNotifiableCallEvent()))))
}
advanceTimeBy(300.milliseconds)
handleIncomingCallLambda.assertions().isCalledOnce()
onNotifiableEventsReceived.assertions().isNeverCalled()
onPushReceivedResult.assertions().isCalledOnce()
}
@Test
fun `when a redaction is received, the onRedactedEventReceived is informed`() = runTest {
val aRedaction = ResolvedPushEvent.Redaction(
sessionId = A_SESSION_ID,
roomId = A_ROOM_ID,
redactedEventId = AN_EVENT_ID_2,
reason = null
)
val onRedactedEventReceived = lambdaRecorder<List<ResolvedPushEvent.Redaction>, Unit> { }
val onPushReceivedResult = lambdaRecorder<String, EventId?, RoomId?, SessionId?, Boolean, Boolean, String?, Unit> { _, _, _, _, _, _, _ -> }
val pushHistoryService = FakePushHistoryService(
onPushReceivedResult = onPushReceivedResult,
)
val processor = createDefaultNotificationResultProcessor(
onRedactedEventReceived = onRedactedEventReceived,
pushHistoryService = pushHistoryService,
)
runningProcessor(processor) {
emit(mapOf(aPushRequest() to Result.success(aRedaction)))
}
advanceTimeBy(300.milliseconds)
onRedactedEventReceived.assertions().isCalledOnce()
.with(value(listOf(aRedaction)))
onPushReceivedResult.assertions()
.isCalledOnce()
}
@Test
fun `when receiving a fallback event, we notify the push history service about it not being resolved`() = runTest {
val aNotifiableFallbackEvent = aFallbackNotifiableEvent()
val onNotifiableEventsReceived = lambdaRecorder<List<NotifiableEvent>, Unit> {}
var receivedFallbackEvent = false
val onPushReceivedResult =
lambdaRecorder<String, EventId?, RoomId?, SessionId?, Boolean, Boolean, String?, Unit> { _, _, _, _, isResolved, _, comment ->
receivedFallbackEvent = !isResolved && comment == "Unable to resolve event: ${aNotifiableFallbackEvent.cause}"
}
val pushHistoryService = FakePushHistoryService(
onPushReceivedResult = onPushReceivedResult,
)
val processor = createDefaultNotificationResultProcessor(
onNotifiableEventsReceived = onNotifiableEventsReceived,
pushHistoryService = pushHistoryService,
)
runningProcessor(processor) {
emit(mapOf(aPushRequest() to Result.success(ResolvedPushEvent.Event(aNotifiableFallbackEvent))))
}
advanceTimeBy(300.milliseconds)
onNotifiableEventsReceived.assertions().isCalledOnce()
assertThat(receivedFallbackEvent).isTrue()
}
private suspend fun TestScope.runningProcessor(processor: NotificationResultProcessor, block: suspend NotificationResultProcessor.() -> Unit) {
processor.start()
runCurrent()
block(processor)
runCurrent()
processor.stop()
}
private fun TestScope.createDefaultNotificationResultProcessor(
systemClock: FakeSystemClock = FakeSystemClock(),
pushHistoryService: FakePushHistoryService = FakePushHistoryService(),
mutableBatteryOptimizationStore: FakeMutableBatteryOptimizationStore = FakeMutableBatteryOptimizationStore(),
fallbackNotificationFactory: FallbackNotificationFactory = FallbackNotificationFactory(systemClock),
userPushStoreFactory: FakeUserPushStoreFactory = FakeUserPushStoreFactory(),
onRedactedEventReceived: (List<ResolvedPushEvent.Redaction>) -> Unit = {},
onNotifiableEventsReceived: (List<NotifiableEvent>) -> Unit = {},
featureFlagService: FakeFeatureFlagService = FakeFeatureFlagService(),
syncOnNotifiableEvent: SyncOnNotifiableEvent = {},
elementCallEntryPoint: FakeElementCallEntryPoint = FakeElementCallEntryPoint(),
notificationChannels: FakeNotificationChannels = FakeNotificationChannels(),
coroutineScope: CoroutineScope = backgroundScope,
) = DefaultNotificationResultProcessor(
pushHistoryService = pushHistoryService,
batteryOptimizationStore = mutableBatteryOptimizationStore,
fallbackNotificationFactory = fallbackNotificationFactory,
userPushStoreFactory = userPushStoreFactory,
onRedactedEventReceived = FakeOnRedactedEventReceived(onRedactedEventReceived),
onNotifiableEventReceived = FakeOnNotifiableEventReceived(onNotifiableEventsReceived),
featureFlagService = featureFlagService,
syncOnNotifiableEvent = syncOnNotifiableEvent,
elementCallEntryPoint = elementCallEntryPoint,
notificationChannels = notificationChannels,
coroutineScope = coroutineScope,
)
}

View File

@@ -9,18 +9,18 @@
package io.element.android.libraries.push.impl.notifications
import io.element.android.libraries.matrix.api.core.SessionId
import io.element.android.libraries.push.api.push.NotificationEventRequest
import io.element.android.libraries.push.impl.db.PushRequest
import io.element.android.libraries.push.impl.notifications.model.ResolvedPushEvent
import io.element.android.tests.testutils.lambda.lambdaError
class FakeNotifiableEventResolver(
private val resolveEventsResult: (SessionId, List<NotificationEventRequest>) -> Result<Map<NotificationEventRequest, Result<ResolvedPushEvent>>> =
private val resolveEventsResult: (SessionId, List<PushRequest>) -> Result<Map<PushRequest, Result<ResolvedPushEvent>>> =
{ _, _ -> lambdaError() }
) : NotifiableEventResolver {
override suspend fun resolveEvents(
sessionId: SessionId,
notificationEventRequests: List<NotificationEventRequest>
): Result<Map<NotificationEventRequest, Result<ResolvedPushEvent>>> {
notificationEventRequests: List<PushRequest>
): Result<Map<PushRequest, Result<ResolvedPushEvent>>> {
return resolveEventsResult(sessionId, notificationEventRequests)
}
}

View File

@@ -0,0 +1,30 @@
/*
* Copyright (c) 2026 Element Creations Ltd.
*
* SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial.
* Please see LICENSE files in the repository root for full details.
*/
package io.element.android.libraries.push.impl.notifications
import io.element.android.libraries.push.impl.db.PushRequest
import io.element.android.libraries.push.impl.notifications.model.ResolvedPushEvent
import io.element.android.tests.testutils.lambda.lambdaError
class FakeNotificationResultProcessor(
private val emit: (Map<PushRequest, Result<ResolvedPushEvent>>) -> Unit = { lambdaError() },
private val start: () -> Unit = { lambdaError() },
private val stop: () -> Unit = { lambdaError() },
) : NotificationResultProcessor {
override suspend fun emit(results: Map<PushRequest, Result<ResolvedPushEvent>>) {
return emit.invoke(results)
}
override fun start() {
start.invoke()
}
override fun stop() {
stop.invoke()
}
}

View File

@@ -1,6 +1,5 @@
/*
* Copyright (c) 2025 Element Creations Ltd.
* Copyright 2025 New Vector Ltd.
* Copyright (c) 2026 Element Creations Ltd.
*
* SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial.
* Please see LICENSE files in the repository root for full details.
@@ -14,16 +13,22 @@ import io.element.android.libraries.matrix.api.core.SessionId
import io.element.android.libraries.matrix.test.AN_EVENT_ID
import io.element.android.libraries.matrix.test.A_ROOM_ID
import io.element.android.libraries.matrix.test.A_SESSION_ID
import io.element.android.libraries.push.api.push.NotificationEventRequest
import io.element.android.libraries.push.impl.db.PushRequest
import io.element.android.libraries.push.impl.push.PushRequestStatus
fun aNotificationEventRequest(
fun aPushRequest(
sessionId: SessionId = A_SESSION_ID,
roomId: RoomId = A_ROOM_ID,
eventId: EventId = AN_EVENT_ID,
providerInfo: String = "providerInfo",
) = NotificationEventRequest(
sessionId = sessionId,
roomId = roomId,
eventId = eventId,
providerInfo: String = "firebase",
status: PushRequestStatus = PushRequestStatus.PENDING,
retries: Int = 0,
) = PushRequest(
pushDate = System.currentTimeMillis(),
providerInfo = providerInfo,
eventId = eventId.value,
roomId = roomId.value,
sessionId = sessionId.value,
status = status.value,
retries = retries.toLong(),
)

View File

@@ -11,65 +11,38 @@
package io.element.android.libraries.push.impl.push
import app.cash.turbine.test
import com.google.common.truth.Truth.assertThat
import io.element.android.features.call.api.CallType
import io.element.android.features.call.test.FakeElementCallEntryPoint
import io.element.android.libraries.androidutils.json.DefaultJsonProvider
import io.element.android.libraries.core.meta.BuildMeta
import io.element.android.libraries.featureflag.api.FeatureFlags
import io.element.android.libraries.featureflag.test.FakeFeatureFlagService
import io.element.android.libraries.matrix.api.core.EventId
import io.element.android.libraries.matrix.api.core.RoomId
import io.element.android.libraries.matrix.api.core.SessionId
import io.element.android.libraries.matrix.api.core.UserId
import io.element.android.libraries.matrix.api.exception.NotificationResolverException
import io.element.android.libraries.matrix.api.notification.RtcNotificationType
import io.element.android.libraries.matrix.api.timeline.item.event.EventType
import io.element.android.libraries.matrix.test.AN_EVENT_ID
import io.element.android.libraries.matrix.test.AN_EVENT_ID_2
import io.element.android.libraries.matrix.test.A_ROOM_ID
import io.element.android.libraries.matrix.test.A_SECRET
import io.element.android.libraries.matrix.test.A_SESSION_ID
import io.element.android.libraries.matrix.test.A_USER_ID
import io.element.android.libraries.matrix.test.core.aBuildMeta
import io.element.android.libraries.push.api.push.NotificationEventRequest
import io.element.android.libraries.push.api.push.SyncOnNotifiableEvent
import io.element.android.libraries.push.impl.db.PushRequest
import io.element.android.libraries.push.impl.history.FakePushHistoryService
import io.element.android.libraries.push.impl.history.PushHistoryService
import io.element.android.libraries.push.impl.notifications.DefaultNotificationResolverQueue
import io.element.android.libraries.push.impl.notifications.FakeNotifiableEventResolver
import io.element.android.libraries.push.impl.notifications.FallbackNotificationFactory
import io.element.android.libraries.push.impl.notifications.channels.FakeNotificationChannels
import io.element.android.libraries.push.impl.notifications.fixtures.aFallbackNotifiableEvent
import io.element.android.libraries.push.impl.notifications.fixtures.aNotifiableCallEvent
import io.element.android.libraries.push.impl.notifications.fixtures.aNotifiableMessageEvent
import io.element.android.libraries.push.impl.notifications.model.NotifiableEvent
import io.element.android.libraries.push.impl.notifications.model.ResolvedPushEvent
import io.element.android.libraries.push.impl.notifications.FakeNotificationResultProcessor
import io.element.android.libraries.push.impl.test.DefaultTestPush
import io.element.android.libraries.push.impl.troubleshoot.DiagnosticPushHandler
import io.element.android.libraries.push.impl.workmanager.SyncNotificationsWorkerDataConverter
import io.element.android.libraries.pushproviders.api.PushData
import io.element.android.libraries.pushstore.api.UserPushStore
import io.element.android.libraries.pushstore.api.clientsecret.PushClientSecret
import io.element.android.libraries.pushstore.test.userpushstore.FakeUserPushStore
import io.element.android.libraries.pushstore.test.userpushstore.FakeUserPushStoreFactory
import io.element.android.libraries.pushstore.test.userpushstore.clientsecret.FakePushClientSecret
import io.element.android.libraries.workmanager.api.WorkManagerRequest
import io.element.android.libraries.workmanager.api.WorkManagerRequestBuilder
import io.element.android.libraries.workmanager.test.FakeWorkManagerScheduler
import io.element.android.services.analytics.test.FakeAnalyticsService
import io.element.android.services.toolbox.test.sdk.FakeBuildVersionSdkIntProvider
import io.element.android.services.toolbox.test.systemclock.FakeSystemClock
import io.element.android.tests.testutils.lambda.any
import io.element.android.tests.testutils.lambda.lambdaError
import io.element.android.tests.testutils.lambda.lambdaRecorder
import io.element.android.tests.testutils.lambda.matching
import io.element.android.tests.testutils.lambda.value
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.test.TestScope
import kotlinx.coroutines.test.advanceTimeBy
import kotlinx.coroutines.test.runTest
import org.junit.Test
import java.time.Instant
import kotlin.time.Duration.Companion.milliseconds
private const val A_PUSHER_INFO = "info"
@@ -96,84 +69,36 @@ class DefaultPushHandlerTest {
}
@Test
fun `when classical PushData is received, the notification drawer is informed`() = runTest {
val aNotifiableMessageEvent = aNotifiableMessageEvent()
val notifiableEventResult =
lambdaRecorder<SessionId, List<NotificationEventRequest>, Result<Map<NotificationEventRequest, Result<ResolvedPushEvent>>>> { _, _ ->
val request = NotificationEventRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, A_PUSHER_INFO)
Result.success(mapOf(request to Result.success(ResolvedPushEvent.Event(aNotifiableMessageEvent))))
}
val onNotifiableEventsReceived = lambdaRecorder<List<NotifiableEvent>, Unit> {}
fun `when classical PushData is received, the work is scheduled`() = runTest {
val incrementPushCounterResult = lambdaRecorder<Unit> {}
val onPushReceivedResult = lambdaRecorder<String, EventId?, RoomId?, SessionId?, Boolean, Boolean, String?, Unit> { _, _, _, _, _, _, _ -> }
val pushHistoryService = FakePushHistoryService(
onPushReceivedResult = onPushReceivedResult,
)
val aPushData = PushData(
eventId = AN_EVENT_ID,
roomId = A_ROOM_ID,
unread = 0,
clientSecret = A_SECRET,
)
val enqueuePushRequestResult = lambdaRecorder<PushRequest, Result<Unit>> { Result.success(Unit) }
val pushHistoryService = FakePushHistoryService(
enqueuePushRequest = enqueuePushRequestResult,
)
val submitWorkLambda = lambdaRecorder<WorkManagerRequestBuilder, Unit> {}
val workManagerScheduler = FakeWorkManagerScheduler(submitLambda = submitWorkLambda)
val defaultPushHandler = createDefaultPushHandler(
onNotifiableEventsReceived = onNotifiableEventsReceived,
notifiableEventsResult = notifiableEventResult,
pushClientSecret = FakePushClientSecret(
getUserIdFromSecretResult = { A_USER_ID }
),
incrementPushCounterResult = incrementPushCounterResult,
workManagerScheduler = workManagerScheduler,
pushHistoryService = pushHistoryService,
)
defaultPushHandler.handle(aPushData, A_PUSHER_INFO)
advanceTimeBy(300.milliseconds)
incrementPushCounterResult.assertions()
submitWorkLambda.assertions()
.isCalledOnce()
notifiableEventResult.assertions()
.isCalledOnce()
.with(value(A_USER_ID), any())
onNotifiableEventsReceived.assertions()
.isCalledOnce()
.with(value(listOf(aNotifiableMessageEvent)))
onPushReceivedResult.assertions()
.isCalledOnce()
}
@Test
fun `when classical PushData is received and the workmanager flag is enabled, the work is scheduled`() = runTest {
val aNotifiableMessageEvent = aNotifiableMessageEvent()
val notifiableEventResult =
lambdaRecorder<SessionId, List<NotificationEventRequest>, Result<Map<NotificationEventRequest, Result<ResolvedPushEvent>>>> { _, _ ->
val request = NotificationEventRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, A_PUSHER_INFO)
Result.success(mapOf(request to Result.success(ResolvedPushEvent.Event(aNotifiableMessageEvent))))
}
val incrementPushCounterResult = lambdaRecorder<Unit> {}
val aPushData = PushData(
eventId = AN_EVENT_ID,
roomId = A_ROOM_ID,
unread = 0,
clientSecret = A_SECRET,
)
val featureFlagService = FakeFeatureFlagService(mapOf(FeatureFlags.SyncNotificationsWithWorkManager.key to true))
val submitWorkLambda = lambdaRecorder<WorkManagerRequest, Unit> {}
val workManagerScheduler = FakeWorkManagerScheduler(submitLambda = submitWorkLambda)
val defaultPushHandler = createDefaultPushHandler(
notifiableEventsResult = notifiableEventResult,
pushClientSecret = FakePushClientSecret(
getUserIdFromSecretResult = { A_USER_ID }
),
incrementPushCounterResult = incrementPushCounterResult,
featureFlagService = featureFlagService,
workManagerScheduler = workManagerScheduler,
)
defaultPushHandler.handle(aPushData, A_PUSHER_INFO)
advanceTimeBy(300.milliseconds)
submitWorkLambda.assertions().isCalledOnce()
incrementPushCounterResult.assertions()
.isCalledOnce()
@@ -182,13 +107,6 @@ class DefaultPushHandlerTest {
@Test
fun `when classical PushData is received, but notifications are disabled, nothing happen`() =
runTest {
val aNotifiableMessageEvent = aNotifiableMessageEvent()
val notifiableEventResult =
lambdaRecorder<SessionId, List<NotificationEventRequest>, Result<Map<NotificationEventRequest, Result<ResolvedPushEvent>>>> { _, _ ->
val request = NotificationEventRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, A_PUSHER_INFO)
Result.success(mapOf(request to Result.success(ResolvedPushEvent.Event(aNotifiableMessageEvent))))
}
val onNotifiableEventsReceived = lambdaRecorder<List<NotifiableEvent>, Unit> {}
val incrementPushCounterResult = lambdaRecorder<Unit> {}
val aPushData = PushData(
eventId = AN_EVENT_ID,
@@ -197,12 +115,15 @@ class DefaultPushHandlerTest {
clientSecret = A_SECRET,
)
val onPushReceivedResult = lambdaRecorder<String, EventId?, RoomId?, SessionId?, Boolean, Boolean, String?, Unit> { _, _, _, _, _, _, _ -> }
val enqueuePushRequestResult = lambdaRecorder<PushRequest, Result<Unit>> { Result.success(Unit) }
val pushHistoryService = FakePushHistoryService(
onPushReceivedResult = onPushReceivedResult,
enqueuePushRequest = enqueuePushRequestResult,
)
val submitWorkLambda = lambdaRecorder<WorkManagerRequestBuilder, Unit> {}
val workManagerScheduler = FakeWorkManagerScheduler(submitLambda = submitWorkLambda)
val defaultPushHandler = createDefaultPushHandler(
onNotifiableEventsReceived = onNotifiableEventsReceived,
notifiableEventsResult = notifiableEventResult,
pushClientSecret = FakePushClientSecret(
getUserIdFromSecretResult = { A_USER_ID }
),
@@ -211,31 +132,24 @@ class DefaultPushHandlerTest {
},
incrementPushCounterResult = incrementPushCounterResult,
pushHistoryService = pushHistoryService,
workManagerScheduler = workManagerScheduler,
)
defaultPushHandler.handle(aPushData, A_PUSHER_INFO)
advanceTimeBy(300.milliseconds)
submitWorkLambda.assertions()
.isNeverCalled()
enqueuePushRequestResult.assertions()
.isNeverCalled()
incrementPushCounterResult.assertions()
.isCalledOnce()
notifiableEventResult.assertions()
.isCalledOnce()
onNotifiableEventsReceived.assertions()
.isNeverCalled()
onPushReceivedResult.assertions()
.isCalledOnce()
.isNeverCalled()
}
@Test
fun `when PushData is received, but client secret is not known, nothing happen`() =
runTest {
val aNotifiableMessageEvent = aNotifiableMessageEvent()
val notifiableEventResult =
lambdaRecorder<SessionId, List<NotificationEventRequest>, Result<Map<NotificationEventRequest, Result<ResolvedPushEvent>>>> { _, _ ->
val request = NotificationEventRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, A_PUSHER_INFO)
Result.success(mapOf(request to Result.success(ResolvedPushEvent.Event(aNotifiableMessageEvent))))
}
val onNotifiableEventsReceived = lambdaRecorder<List<NotifiableEvent>, Unit> {}
fun `when PushData is received, but client secret is not known, nothing happen`() = runTest {
val incrementPushCounterResult = lambdaRecorder<Unit> {}
val aPushData = PushData(
eventId = AN_EVENT_ID,
@@ -247,477 +161,85 @@ class DefaultPushHandlerTest {
val pushHistoryService = FakePushHistoryService(
onPushReceivedResult = onPushReceivedResult,
)
val submitWorkLambda = lambdaRecorder<WorkManagerRequestBuilder, Unit> {}
val workManagerScheduler = FakeWorkManagerScheduler(submitLambda = submitWorkLambda)
val defaultPushHandler = createDefaultPushHandler(
onNotifiableEventsReceived = onNotifiableEventsReceived,
notifiableEventsResult = notifiableEventResult,
pushClientSecret = FakePushClientSecret(
getUserIdFromSecretResult = { null }
),
incrementPushCounterResult = incrementPushCounterResult,
pushHistoryService = pushHistoryService,
workManagerScheduler = workManagerScheduler,
)
defaultPushHandler.handle(aPushData, A_PUSHER_INFO)
submitWorkLambda.assertions()
.isNeverCalled()
incrementPushCounterResult.assertions()
.isCalledOnce()
notifiableEventResult.assertions()
.isNeverCalled()
onNotifiableEventsReceived.assertions()
.isNeverCalled()
onPushReceivedResult.assertions()
.isCalledOnce()
}
@Test
fun `when classical PushData is received, but a failure occurs (session not found), nothing happen`() {
`test notification resolver failure`(
notificationResolveResult = { _ ->
Result.failure(NotificationResolverException.UnknownError("Unable to restore session"))
},
shouldSetOptimizationBatteryBanner = false,
fun `when diagnostic PushData is received, the diagnostic push handler is informed`() = runTest {
val aPushData = PushData(
eventId = DefaultTestPush.TEST_EVENT_ID,
roomId = A_ROOM_ID,
unread = 0,
clientSecret = A_SECRET,
)
}
@Test
fun `when classical PushData is received, but not able to resolve the event, the banner to disable battery optimization will be displayed`() {
`test notification resolver failure`(
notificationResolveResult = { requests: List<NotificationEventRequest> ->
Result.success(
requests.associateWith { Result.failure(NotificationResolverException.UnknownError("Unable to resolve event")) }
)
},
shouldSetOptimizationBatteryBanner = true,
val diagnosticPushHandler = DiagnosticPushHandler()
val onPushReceivedResult = lambdaRecorder<String, EventId?, RoomId?, SessionId?, Boolean, Boolean, String?, Unit> { _, _, _, _, _, _, _ -> }
val pushHistoryService = FakePushHistoryService(
onPushReceivedResult = onPushReceivedResult,
)
}
private fun `test notification resolver failure`(
notificationResolveResult: (List<NotificationEventRequest>) -> Result<Map<NotificationEventRequest, Result<ResolvedPushEvent>>>,
shouldSetOptimizationBatteryBanner: Boolean,
) {
runTest {
val notifiableEventResult =
lambdaRecorder<SessionId, List<NotificationEventRequest>, Result<Map<NotificationEventRequest, Result<ResolvedPushEvent>>>> { _, requests ->
notificationResolveResult(requests)
}
val onNotifiableEventsReceived = lambdaRecorder<List<NotifiableEvent>, Unit> {}
val incrementPushCounterResult = lambdaRecorder<Unit> {}
val aPushData = PushData(
eventId = AN_EVENT_ID,
roomId = A_ROOM_ID,
unread = 0,
clientSecret = A_SECRET,
)
val onPushReceivedResult = lambdaRecorder<String, EventId?, RoomId?, SessionId?, Boolean, Boolean, String?, Unit> { _, _, _, _, _, _, _ -> }
val pushHistoryService = FakePushHistoryService(
onPushReceivedResult = onPushReceivedResult,
)
val showBatteryOptimizationBannerResult = lambdaRecorder<Unit> {}
val defaultPushHandler = createDefaultPushHandler(
onNotifiableEventsReceived = onNotifiableEventsReceived,
notifiableEventsResult = notifiableEventResult,
buildMeta = aBuildMeta(
// Also test `lowPrivacyLoggingEnabled = false` here
lowPrivacyLoggingEnabled = false
),
pushClientSecret = FakePushClientSecret(
getUserIdFromSecretResult = { A_USER_ID }
),
incrementPushCounterResult = incrementPushCounterResult,
mutableBatteryOptimizationStore = FakeMutableBatteryOptimizationStore(
showBatteryOptimizationBannerResult = showBatteryOptimizationBannerResult,
),
pushHistoryService = pushHistoryService,
)
val defaultPushHandler = createDefaultPushHandler(
diagnosticPushHandler = diagnosticPushHandler,
incrementPushCounterResult = { },
pushHistoryService = pushHistoryService,
)
diagnosticPushHandler.state.test {
defaultPushHandler.handle(aPushData, A_PUSHER_INFO)
advanceTimeBy(300.milliseconds)
incrementPushCounterResult.assertions()
.isCalledOnce()
notifiableEventResult.assertions()
.isCalledOnce()
.with(value(A_USER_ID), any())
onPushReceivedResult.assertions()
.isCalledOnce()
.with(any(), value(AN_EVENT_ID), value(A_ROOM_ID), value(A_USER_ID), value(false), value(true), any())
showBatteryOptimizationBannerResult.assertions().let {
if (shouldSetOptimizationBatteryBanner) {
it.isCalledOnce()
} else {
it.isNeverCalled()
}
}
awaitItem()
}
}
@Test
fun `when ringing call PushData is received, the incoming call will be handled`() = runTest {
val aPushData = PushData(
eventId = AN_EVENT_ID,
roomId = A_ROOM_ID,
unread = 0,
clientSecret = A_SECRET,
)
val handleIncomingCallLambda = lambdaRecorder<
CallType.RoomCall,
EventId,
UserId,
String?,
String?,
String?,
String,
String?,
Unit,
> { _, _, _, _, _, _, _, _ -> }
val elementCallEntryPoint = FakeElementCallEntryPoint(handleIncomingCallResult = handleIncomingCallLambda)
val onNotifiableEventsReceived = lambdaRecorder<List<NotifiableEvent>, Unit> {}
val onPushReceivedResult = lambdaRecorder<String, EventId?, RoomId?, SessionId?, Boolean, Boolean, String?, Unit> { _, _, _, _, _, _, _ -> }
val pushHistoryService = FakePushHistoryService(
onPushReceivedResult = onPushReceivedResult,
)
val defaultPushHandler = createDefaultPushHandler(
elementCallEntryPoint = elementCallEntryPoint,
notifiableEventsResult = { _, _ ->
val request = NotificationEventRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, A_PUSHER_INFO)
Result.success(
mapOf(
request to Result.success(
ResolvedPushEvent.Event(
aNotifiableCallEvent(rtcNotificationType = RtcNotificationType.RING, timestamp = Instant.now().toEpochMilli())
)
)
)
)
},
incrementPushCounterResult = {},
pushClientSecret = FakePushClientSecret(
getUserIdFromSecretResult = { A_USER_ID }
),
onNotifiableEventsReceived = onNotifiableEventsReceived,
pushHistoryService = pushHistoryService,
)
defaultPushHandler.handle(aPushData, A_PUSHER_INFO)
advanceTimeBy(300.milliseconds)
handleIncomingCallLambda.assertions().isCalledOnce()
onNotifiableEventsReceived.assertions().isNeverCalled()
onPushReceivedResult.assertions().isCalledOnce()
}
@Test
fun `when notify call PushData is received, the incoming call will be treated as a normal notification`() = runTest {
val aPushData = PushData(
eventId = AN_EVENT_ID,
roomId = A_ROOM_ID,
unread = 0,
clientSecret = A_SECRET,
)
val onNotifiableEventsReceived = lambdaRecorder<List<NotifiableEvent>, Unit> {}
val handleIncomingCallLambda = lambdaRecorder<
CallType.RoomCall,
EventId,
UserId,
String?,
String?,
String?,
String,
String?,
Unit,
> { _, _, _, _, _, _, _, _ -> }
val elementCallEntryPoint = FakeElementCallEntryPoint(handleIncomingCallResult = handleIncomingCallLambda)
val onPushReceivedResult = lambdaRecorder<String, EventId?, RoomId?, SessionId?, Boolean, Boolean, String?, Unit> { _, _, _, _, _, _, _ -> }
val pushHistoryService = FakePushHistoryService(
onPushReceivedResult = onPushReceivedResult,
)
val defaultPushHandler = createDefaultPushHandler(
elementCallEntryPoint = elementCallEntryPoint,
onNotifiableEventsReceived = onNotifiableEventsReceived,
notifiableEventsResult = { _, _ ->
val request = NotificationEventRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, A_PUSHER_INFO)
Result.success(mapOf(request to Result.success(ResolvedPushEvent.Event(aNotifiableMessageEvent(type = EventType.RTC_NOTIFICATION)))))
},
incrementPushCounterResult = {},
pushClientSecret = FakePushClientSecret(
getUserIdFromSecretResult = { A_USER_ID }
),
pushHistoryService = pushHistoryService,
)
defaultPushHandler.handle(aPushData, A_PUSHER_INFO)
advanceTimeBy(300.milliseconds)
handleIncomingCallLambda.assertions().isNeverCalled()
onNotifiableEventsReceived.assertions().isCalledOnce()
onPushReceivedResult.assertions().isCalledOnce()
}
@Test
fun `when notify call PushData is received, the incoming call will be treated as a normal notification even if notification are disabled`() = runTest {
val aPushData = PushData(
eventId = AN_EVENT_ID,
roomId = A_ROOM_ID,
unread = 0,
clientSecret = A_SECRET,
)
val onNotifiableEventsReceived = lambdaRecorder<List<NotifiableEvent>, Unit> {}
val handleIncomingCallLambda = lambdaRecorder<
CallType.RoomCall,
EventId,
UserId,
String?,
String?,
String?,
String,
String?,
Unit,
> { _, _, _, _, _, _, _, _ -> }
val elementCallEntryPoint = FakeElementCallEntryPoint(handleIncomingCallResult = handleIncomingCallLambda)
val onPushReceivedResult = lambdaRecorder<String, EventId?, RoomId?, SessionId?, Boolean, Boolean, String?, Unit> { _, _, _, _, _, _, _ -> }
val pushHistoryService = FakePushHistoryService(
onPushReceivedResult = onPushReceivedResult,
)
val defaultPushHandler = createDefaultPushHandler(
elementCallEntryPoint = elementCallEntryPoint,
onNotifiableEventsReceived = onNotifiableEventsReceived,
notifiableEventsResult = { _, _ ->
val request = NotificationEventRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, A_PUSHER_INFO)
Result.success(mapOf(request to Result.success(ResolvedPushEvent.Event(aNotifiableCallEvent()))))
},
incrementPushCounterResult = {},
userPushStore = FakeUserPushStore().apply {
setNotificationEnabledForDevice(false)
},
pushClientSecret = FakePushClientSecret(
getUserIdFromSecretResult = { A_USER_ID }
),
pushHistoryService = pushHistoryService,
)
defaultPushHandler.handle(aPushData, A_PUSHER_INFO)
advanceTimeBy(300.milliseconds)
handleIncomingCallLambda.assertions().isCalledOnce()
onNotifiableEventsReceived.assertions().isNeverCalled()
onPushReceivedResult.assertions().isCalledOnce()
}
@Test
fun `when a redaction is received, the onRedactedEventReceived is informed`() = runTest {
val aPushData = PushData(
eventId = AN_EVENT_ID,
roomId = A_ROOM_ID,
unread = 0,
clientSecret = A_SECRET,
)
val aRedaction = ResolvedPushEvent.Redaction(
sessionId = A_SESSION_ID,
roomId = A_ROOM_ID,
redactedEventId = AN_EVENT_ID_2,
reason = null
)
val onRedactedEventReceived = lambdaRecorder<List<ResolvedPushEvent.Redaction>, Unit> { }
val incrementPushCounterResult = lambdaRecorder<Unit> {}
val onPushReceivedResult = lambdaRecorder<String, EventId?, RoomId?, SessionId?, Boolean, Boolean, String?, Unit> { _, _, _, _, _, _, _ -> }
val pushHistoryService = FakePushHistoryService(
onPushReceivedResult = onPushReceivedResult,
)
val defaultPushHandler = createDefaultPushHandler(
onRedactedEventsReceived = onRedactedEventReceived,
incrementPushCounterResult = incrementPushCounterResult,
notifiableEventsResult = { _, _ ->
val request = NotificationEventRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, A_PUSHER_INFO)
Result.success(mapOf(request to Result.success(aRedaction)))
},
pushClientSecret = FakePushClientSecret(
getUserIdFromSecretResult = { A_USER_ID }
),
pushHistoryService = pushHistoryService,
)
defaultPushHandler.handle(aPushData, A_PUSHER_INFO)
advanceTimeBy(300.milliseconds)
incrementPushCounterResult.assertions()
.isCalledOnce()
onRedactedEventReceived.assertions().isCalledOnce()
.with(value(listOf(aRedaction)))
onPushReceivedResult.assertions()
.isCalledOnce()
}
@Test
fun `when diagnostic PushData is received, the diagnostic push handler is informed`() =
runTest {
val aPushData = PushData(
eventId = DefaultTestPush.TEST_EVENT_ID,
roomId = A_ROOM_ID,
unread = 0,
clientSecret = A_SECRET,
)
val diagnosticPushHandler = DiagnosticPushHandler()
val onPushReceivedResult = lambdaRecorder<String, EventId?, RoomId?, SessionId?, Boolean, Boolean, String?, Unit> { _, _, _, _, _, _, _ -> }
val pushHistoryService = FakePushHistoryService(
onPushReceivedResult = onPushReceivedResult,
)
val defaultPushHandler = createDefaultPushHandler(
diagnosticPushHandler = diagnosticPushHandler,
incrementPushCounterResult = { },
pushHistoryService = pushHistoryService,
)
diagnosticPushHandler.state.test {
defaultPushHandler.handle(aPushData, A_PUSHER_INFO)
awaitItem()
}
onPushReceivedResult.assertions()
.isCalledOnce()
}
@Test
fun `when receiving several push notifications at the same time, those are batched before being processed`() = runTest {
val aNotifiableMessageEvent = aNotifiableMessageEvent()
val notifiableEventResult =
lambdaRecorder<SessionId, List<NotificationEventRequest>, Result<Map<NotificationEventRequest, Result<ResolvedPushEvent>>>> { _, _ ->
val request = NotificationEventRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, A_PUSHER_INFO)
Result.success(mapOf(request to Result.success(ResolvedPushEvent.Event(aNotifiableMessageEvent))))
}
val onNotifiableEventsReceived = lambdaRecorder<List<NotifiableEvent>, Unit> {}
val incrementPushCounterResult = lambdaRecorder<Unit> {}
val onPushReceivedResult = lambdaRecorder<String, EventId?, RoomId?, SessionId?, Boolean, Boolean, String?, Unit> { _, _, _, _, _, _, _ -> }
val pushHistoryService = FakePushHistoryService(
onPushReceivedResult = onPushReceivedResult,
)
val aPushData = PushData(
eventId = AN_EVENT_ID,
roomId = A_ROOM_ID,
unread = 0,
clientSecret = A_SECRET,
)
val anotherPushData = PushData(
eventId = AN_EVENT_ID_2,
roomId = A_ROOM_ID,
unread = 0,
clientSecret = A_SECRET,
)
val defaultPushHandler = createDefaultPushHandler(
onNotifiableEventsReceived = onNotifiableEventsReceived,
notifiableEventsResult = notifiableEventResult,
pushClientSecret = FakePushClientSecret(
getUserIdFromSecretResult = { A_USER_ID }
),
incrementPushCounterResult = incrementPushCounterResult,
pushHistoryService = pushHistoryService,
)
defaultPushHandler.handle(aPushData, A_PUSHER_INFO)
defaultPushHandler.handle(anotherPushData, A_PUSHER_INFO)
advanceTimeBy(300.milliseconds)
incrementPushCounterResult.assertions()
.isCalledExactly(2)
notifiableEventResult.assertions()
.isCalledOnce()
.with(value(A_USER_ID), matching<List<NotificationEventRequest>> { requests ->
requests.size == 2 && requests.first().eventId == AN_EVENT_ID && requests.last().eventId == AN_EVENT_ID_2
})
onNotifiableEventsReceived.assertions()
.isCalledOnce()
onPushReceivedResult.assertions()
.isCalledExactly(2)
}
@Test
fun `when receiving a fallback event, we notify the push history service about it not being resolved`() = runTest {
val aNotifiableFallbackEvent = aFallbackNotifiableEvent()
val notifiableEventResult =
lambdaRecorder<SessionId, List<NotificationEventRequest>, Result<Map<NotificationEventRequest, Result<ResolvedPushEvent>>>> { _, _ ->
val request = NotificationEventRequest(A_SESSION_ID, A_ROOM_ID, AN_EVENT_ID, A_PUSHER_INFO)
Result.success(mapOf(request to Result.success(ResolvedPushEvent.Event(aNotifiableFallbackEvent))))
}
val onNotifiableEventsReceived = lambdaRecorder<List<NotifiableEvent>, Unit> {}
val incrementPushCounterResult = lambdaRecorder<Unit> {}
var receivedFallbackEvent = false
val onPushReceivedResult =
lambdaRecorder<String, EventId?, RoomId?, SessionId?, Boolean, Boolean, String?, Unit> { _, _, _, _, isResolved, _, comment ->
receivedFallbackEvent = !isResolved && comment == "Unable to resolve event: ${aNotifiableFallbackEvent.cause}"
}
val pushHistoryService = FakePushHistoryService(
onPushReceivedResult = onPushReceivedResult,
)
val aPushData = PushData(
eventId = AN_EVENT_ID,
roomId = A_ROOM_ID,
unread = 0,
clientSecret = A_SECRET,
)
val defaultPushHandler = createDefaultPushHandler(
onNotifiableEventsReceived = onNotifiableEventsReceived,
notifiableEventsResult = notifiableEventResult,
pushClientSecret = FakePushClientSecret(
getUserIdFromSecretResult = { A_USER_ID }
),
incrementPushCounterResult = incrementPushCounterResult,
pushHistoryService = pushHistoryService,
)
defaultPushHandler.handle(aPushData, A_PUSHER_INFO)
advanceTimeBy(300.milliseconds)
onNotifiableEventsReceived.assertions().isCalledOnce()
assertThat(receivedFallbackEvent).isTrue()
}
private fun TestScope.createDefaultPushHandler(
onNotifiableEventsReceived: (List<NotifiableEvent>) -> Unit = { lambdaError() },
onRedactedEventsReceived: (List<ResolvedPushEvent.Redaction>) -> Unit = { lambdaError() },
notifiableEventsResult: (SessionId, List<NotificationEventRequest>) -> Result<Map<NotificationEventRequest, Result<ResolvedPushEvent>>> =
{ _, _ -> lambdaError() },
private fun createDefaultPushHandler(
incrementPushCounterResult: () -> Unit = { lambdaError() },
mutableBatteryOptimizationStore: MutableBatteryOptimizationStore = FakeMutableBatteryOptimizationStore(),
userPushStore: UserPushStore = FakeUserPushStore(),
userPushStore: FakeUserPushStore = FakeUserPushStore(),
pushClientSecret: PushClientSecret = FakePushClientSecret(),
buildMeta: BuildMeta = aBuildMeta(),
diagnosticPushHandler: DiagnosticPushHandler = DiagnosticPushHandler(),
elementCallEntryPoint: FakeElementCallEntryPoint = FakeElementCallEntryPoint(),
notificationChannels: FakeNotificationChannels = FakeNotificationChannels(),
pushHistoryService: PushHistoryService = FakePushHistoryService(),
syncOnNotifiableEvent: SyncOnNotifiableEvent = SyncOnNotifiableEvent {},
featureFlagService: FakeFeatureFlagService = FakeFeatureFlagService(initialState = mapOf(FeatureFlags.SyncNotificationsWithWorkManager.key to false)),
workManagerScheduler: FakeWorkManagerScheduler = FakeWorkManagerScheduler(),
analyticsService: FakeAnalyticsService = FakeAnalyticsService(),
systemClock: FakeSystemClock = FakeSystemClock(),
buildVersionSdkIntProvider: FakeBuildVersionSdkIntProvider = FakeBuildVersionSdkIntProvider(33),
resultProcessor: FakeNotificationResultProcessor = FakeNotificationResultProcessor(
emit = { Result.success(Unit) },
start = {},
stop = {},
),
): DefaultPushHandler {
return DefaultPushHandler(
onNotifiableEventReceived = FakeOnNotifiableEventReceived(onNotifiableEventsReceived),
onRedactedEventReceived = FakeOnRedactedEventReceived(onRedactedEventsReceived),
incrementPushDataStore = object : IncrementPushDataStore {
override suspend fun incrementPushCounter() {
incrementPushCounterResult()
}
},
mutableBatteryOptimizationStore = mutableBatteryOptimizationStore,
userPushStoreFactory = FakeUserPushStoreFactory { userPushStore },
pushClientSecret = pushClientSecret,
buildMeta = buildMeta,
diagnosticPushHandler = diagnosticPushHandler,
elementCallEntryPoint = elementCallEntryPoint,
notificationChannels = notificationChannels,
pushHistoryService = pushHistoryService,
// We don't use a fake here so we can perform tests that are a bit more end to end
resolverQueue = DefaultNotificationResolverQueue(
notifiableEventResolver = FakeNotifiableEventResolver(notifiableEventsResult),
appCoroutineScope = backgroundScope,
workManagerScheduler = workManagerScheduler,
featureFlagService = featureFlagService,
workerDataConverter = SyncNotificationsWorkerDataConverter(DefaultJsonProvider()),
buildVersionSdkIntProvider = FakeBuildVersionSdkIntProvider(33),
),
appCoroutineScope = backgroundScope,
fallbackNotificationFactory = FallbackNotificationFactory(
clock = FakeSystemClock(),
),
syncOnNotifiableEvent = syncOnNotifiableEvent,
featureFlagService = featureFlagService,
analyticsService = analyticsService,
systemClock = systemClock,
workManagerScheduler = workManagerScheduler,
buildVersionSdkIntProvider = buildVersionSdkIntProvider,
resultProcessor = resultProcessor,
)
}
}

View File

@@ -20,8 +20,7 @@ import io.element.android.libraries.matrix.test.FakeMatrixClientProvider
import io.element.android.libraries.matrix.test.room.FakeBaseRoom
import io.element.android.libraries.matrix.test.room.FakeJoinedRoom
import io.element.android.libraries.matrix.test.sync.FakeSyncService
import io.element.android.libraries.push.api.push.SyncOnNotifiableEvent
import io.element.android.libraries.push.impl.notifications.fixtures.aNotificationEventRequest
import io.element.android.libraries.push.impl.notifications.fixtures.aPushRequest
import io.element.android.services.appnavstate.test.FakeAppForegroundStateService
import io.element.android.tests.testutils.lambda.assert
import io.element.android.tests.testutils.lambda.lambdaRecorder
@@ -53,7 +52,7 @@ class SyncOnNotifiableEventTest {
givenGetRoomResult(A_ROOM_ID, room)
}
private val notificationRequest = aNotificationEventRequest()
private val notificationRequest = aPushRequest()
@Test
fun `when feature flag is disabled, nothing happens`() = runTest {

View File

@@ -1,210 +0,0 @@
/*
* Copyright (c) 2025 Element Creations Ltd.
* Copyright 2025 New Vector Ltd.
*
* SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial.
* Please see LICENSE files in the repository root for full details.
*/
package io.element.android.libraries.push.impl.workmanager
import androidx.test.ext.junit.runners.AndroidJUnit4
import androidx.test.platform.app.InstrumentationRegistry
import androidx.work.Data
import androidx.work.ListenableWorker
import androidx.work.WorkerParameters
import androidx.work.impl.utils.taskexecutor.WorkManagerTaskExecutor
import androidx.work.workDataOf
import com.google.common.truth.Truth.assertThat
import com.google.common.util.concurrent.ListenableFuture
import io.element.android.features.networkmonitor.api.NetworkStatus
import io.element.android.features.networkmonitor.test.FakeNetworkMonitor
import io.element.android.libraries.androidutils.json.DefaultJsonProvider
import io.element.android.libraries.push.api.push.SyncOnNotifiableEvent
import io.element.android.libraries.push.impl.notifications.FakeNotifiableEventResolver
import io.element.android.libraries.push.impl.notifications.NotificationResolverQueue
import io.element.android.libraries.push.impl.notifications.fixtures.aNotifiableMessageEvent
import io.element.android.libraries.push.impl.notifications.model.ResolvedPushEvent
import io.element.android.libraries.push.test.notifications.FakeNotificationResolverQueue
import io.element.android.libraries.workmanager.api.WorkManagerRequest
import io.element.android.libraries.workmanager.api.di.MetroWorkerFactory
import io.element.android.libraries.workmanager.test.FakeWorkManagerScheduler
import io.element.android.services.analytics.test.FakeAnalyticsService
import io.element.android.services.toolbox.test.sdk.FakeBuildVersionSdkIntProvider
import io.element.android.tests.testutils.lambda.lambdaRecorder
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.test.TestScope
import kotlinx.coroutines.test.advanceTimeBy
import kotlinx.coroutines.test.runTest
import org.junit.Test
import org.junit.runner.RunWith
import java.util.UUID
import java.util.concurrent.Executor
import java.util.concurrent.Executors
import java.util.concurrent.TimeUnit
import kotlin.time.Duration.Companion.seconds
@OptIn(ExperimentalCoroutinesApi::class)
@RunWith(AndroidJUnit4::class)
class FetchNotificationWorkerTest {
@Test
fun `test - success`() = runTest {
var synced = false
val syncOnNotifiableEventLambda = SyncOnNotifiableEvent { synced = true }
val queue = FakeNotificationResolverQueue(
processingLambda = { Result.success(ResolvedPushEvent.Event(aNotifiableMessageEvent())) }
)
val worker = createWorker(
input = """
[
{
"session_id": "@alice:matrix.org",
"room_id": "!roomid:matrix.org",
"event_id": "$1436ebk:matrix.org",
"provider_info": "some_info"
}
]
""".trimIndent(),
queue = queue,
syncOnNotifiableEvent = syncOnNotifiableEventLambda,
)
val result = worker.doWork()
// The process finished successfully
assertThat(result).isEqualTo(ListenableWorker.Result.success())
// A result was emitted
assertThat(queue.results.replayCache).isNotEmpty()
// An opportunistic sync was triggered
assertThat(synced).isTrue()
}
@Test
fun `test - invalid input fails the work`() = runTest {
val worker = createWorker(
input = """
[
{
"session_id": "!alice:matrix.org",
"room_id": "!roomid:matrix.org",
"event_id": "$1436ebk:matrix.org",
"provider_info": "some_info"
}
]
""".trimIndent(),
)
val result = worker.doWork()
// The process failed
assertThat(result).isEqualTo(ListenableWorker.Result.failure())
}
@Test
fun `test - no network connectivity fails the work`() = runTest {
val networkMonitor = FakeNetworkMonitor(initialStatus = NetworkStatus.Disconnected)
val worker = createWorker(
input = """
[
{
"session_id": "@alice:matrix.org",
"room_id": "!roomid:matrix.org",
"event_id": "$1436ebk:matrix.org",
"provider_info": "some_info"
}
]
""".trimIndent(),
networkMonitor = networkMonitor,
)
val result = worker.doWork()
advanceTimeBy(10.seconds)
// The process failed due to a timeout in getting the network connectivity, a retry is scheduled
assertThat(result).isEqualTo(ListenableWorker.Result.retry())
}
@Test
fun `test - failing to resolve events re-schedules the work`() = runTest {
val submitWorkerLambda = lambdaRecorder<WorkManagerRequest, Unit> {}
val scheduler = FakeWorkManagerScheduler(submitLambda = submitWorkerLambda)
val resolver = FakeNotifiableEventResolver(
resolveEventsResult = { _, _ -> Result.failure(Exception("Failed to resolve events")) }
)
val worker = createWorker(
input = """
[
{
"session_id": "@alice:matrix.org",
"room_id": "!roomid:matrix.org",
"event_id": "$1436ebk:matrix.org",
"provider_info": "some_info"
}
]
""".trimIndent(),
eventResolver = resolver,
workManagerScheduler = scheduler,
)
val result = worker.doWork()
// The process was considered successful, but a retry was scheduled due to the failure to resolve events
assertThat(result).isEqualTo(ListenableWorker.Result.success())
submitWorkerLambda.assertions().isCalledOnce()
}
private fun TestScope.createWorker(
input: String,
networkMonitor: FakeNetworkMonitor = FakeNetworkMonitor(),
eventResolver: FakeNotifiableEventResolver = FakeNotifiableEventResolver(resolveEventsResult = { _, _ -> Result.success(emptyMap()) }),
queue: NotificationResolverQueue = FakeNotificationResolverQueue(
processingLambda = { Result.success(ResolvedPushEvent.Event(aNotifiableMessageEvent())) }
),
workManagerScheduler: FakeWorkManagerScheduler = FakeWorkManagerScheduler(),
syncOnNotifiableEvent: SyncOnNotifiableEvent = SyncOnNotifiableEvent {},
analyticsService: FakeAnalyticsService = FakeAnalyticsService(),
) = FetchNotificationsWorker(
params = createWorkerParams(workDataOf("requests" to input)),
context = InstrumentationRegistry.getInstrumentation().context,
networkMonitor = networkMonitor,
eventResolver = eventResolver,
queue = queue,
workManagerScheduler = workManagerScheduler,
syncOnNotifiableEvent = syncOnNotifiableEvent,
workerDataConverter = SyncNotificationsWorkerDataConverter(DefaultJsonProvider()),
buildVersionSdkIntProvider = FakeBuildVersionSdkIntProvider(33),
analyticsService = analyticsService,
)
private fun TestScope.createWorkerParams(
inputData: Data = Data.EMPTY,
): WorkerParameters = WorkerParameters(
UUID.randomUUID(),
inputData,
emptySet(),
WorkerParameters.RuntimeExtras(),
0,
0,
Executors.newSingleThreadExecutor(),
backgroundScope.coroutineContext,
WorkManagerTaskExecutor(Executors.newSingleThreadExecutor()),
MetroWorkerFactory(emptyMap()),
{ context, id, data -> FakeListenableFuture() },
{ context, id, foregroundInfo -> FakeListenableFuture() },
)
}
class FakeListenableFuture<T> : ListenableFuture<T> {
override fun addListener(listener: Runnable, executor: Executor) = Unit
override fun cancel(mayInterruptIfRunning: Boolean): Boolean = true
override fun get(): T? = null
override fun get(timeout: Long, unit: TimeUnit?): T? = null
override fun isCancelled(): Boolean = false
override fun isDone(): Boolean = false
}

View File

@@ -0,0 +1,278 @@
/*
* Copyright (c) 2025 Element Creations Ltd.
* Copyright 2025 New Vector Ltd.
*
* SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial.
* Please see LICENSE files in the repository root for full details.
*/
package io.element.android.libraries.push.impl.workmanager
import androidx.test.ext.junit.runners.AndroidJUnit4
import androidx.test.platform.app.InstrumentationRegistry
import androidx.work.Data
import androidx.work.ListenableWorker
import androidx.work.WorkerParameters
import androidx.work.impl.utils.taskexecutor.WorkManagerTaskExecutor
import androidx.work.workDataOf
import com.google.common.truth.Truth.assertThat
import com.google.common.util.concurrent.ListenableFuture
import io.element.android.features.networkmonitor.api.NetworkStatus
import io.element.android.features.networkmonitor.test.FakeNetworkMonitor
import io.element.android.libraries.matrix.api.core.SessionId
import io.element.android.libraries.matrix.api.exception.ClientException
import io.element.android.libraries.push.impl.db.PushRequest
import io.element.android.libraries.push.impl.history.FakePushHistoryService
import io.element.android.libraries.push.impl.notifications.FakeNotifiableEventResolver
import io.element.android.libraries.push.impl.notifications.FakeNotificationResultProcessor
import io.element.android.libraries.push.impl.notifications.fixtures.aPushRequest
import io.element.android.libraries.push.impl.notifications.model.ResolvedPushEvent
import io.element.android.libraries.push.impl.push.SyncOnNotifiableEvent
import io.element.android.libraries.workmanager.api.WorkManagerRequestBuilder
import io.element.android.libraries.workmanager.api.di.MetroWorkerFactory
import io.element.android.services.analytics.test.FakeAnalyticsService
import io.element.android.services.toolbox.test.systemclock.FakeSystemClock
import io.element.android.tests.testutils.lambda.lambdaRecorder
import kotlinx.coroutines.ExperimentalCoroutinesApi
import kotlinx.coroutines.test.TestScope
import kotlinx.coroutines.test.advanceTimeBy
import kotlinx.coroutines.test.runTest
import org.junit.Test
import org.junit.runner.RunWith
import java.util.UUID
import java.util.concurrent.Executor
import java.util.concurrent.Executors
import java.util.concurrent.TimeUnit
import kotlin.time.Duration.Companion.seconds
import kotlin.time.Instant
@OptIn(ExperimentalCoroutinesApi::class)
@RunWith(AndroidJUnit4::class)
class FetchPendingNotificationWorkerTest {
@Test
fun `test - success`() = runTest {
var synced = false
val syncOnNotifiableEventLambda = SyncOnNotifiableEvent { synced = true }
val emitResultLambda = lambdaRecorder<Map<PushRequest, Result<ResolvedPushEvent>>, Unit> {}
val processor = FakeNotificationResultProcessor(emit = emitResultLambda)
val getPendingResultsLambda = lambdaRecorder<SessionId, Instant?, Result<List<PushRequest>>> { _, _ -> Result.success(listOf(aPushRequest())) }
val replacePushRequestsLambda = lambdaRecorder<List<PushRequest>, Result<Unit>> { Result.success(Unit) }
val removeOldPushRequestsLambda = lambdaRecorder<SessionId, Result<Unit>> { Result.success(Unit) }
val pushHistoryService = FakePushHistoryService(
getPendingPushRequests = getPendingResultsLambda,
replacePushRequests = replacePushRequestsLambda,
removeOldPushRequests = removeOldPushRequestsLambda,
)
val worker = createWorker(
input = "@alice:matrix.org",
pushHistoryService = pushHistoryService,
resultProcessor = processor,
syncOnNotifiableEvent = syncOnNotifiableEventLambda,
)
val result = worker.doWork()
// The expected data is fetched and replaced from the service
getPendingResultsLambda.assertions().isCalledOnce()
replacePushRequestsLambda.assertions().isCalledOnce()
removeOldPushRequestsLambda.assertions().isCalledOnce()
// The process finished successfully
assertThat(result).isEqualTo(ListenableWorker.Result.success())
// A result was emitted
emitResultLambda.assertions().isCalledOnce()
// An opportunistic sync was triggered
assertThat(synced).isTrue()
}
@Test
fun `test - invalid input fails the work`() = runTest {
val worker = createWorker(input = "!alice:matrix.org")
val result = worker.doWork()
// The process failed
assertThat(result).isEqualTo(ListenableWorker.Result.failure())
}
@Test
fun `test - no network connectivity fails the work`() = runTest {
val networkMonitor = FakeNetworkMonitor(initialStatus = NetworkStatus.Disconnected)
val emitResultLambda = lambdaRecorder<Map<PushRequest, Result<ResolvedPushEvent>>, Unit> {}
val processor = FakeNotificationResultProcessor(emit = emitResultLambda)
val pushHistoryService = FakePushHistoryService(
getPendingPushRequests = { _, _ -> Result.success(listOf(aPushRequest())) },
replacePushRequests = { Result.success(Unit) },
removeOldPushRequests = { Result.success(Unit) },
)
val worker = createWorker(
input = "@alice:matrix.org",
networkMonitor = networkMonitor,
resultProcessor = processor,
pushHistoryService = pushHistoryService,
)
val result = worker.doWork()
advanceTimeBy(10.seconds)
// The process failed due to a timeout in getting the network connectivity, a retry is scheduled
assertThat(result).isEqualTo(ListenableWorker.Result.retry())
}
@Test
fun `test - failing to setup retries the work`() = runTest {
val submitWorkerLambda = lambdaRecorder<WorkManagerRequestBuilder, Unit> {}
val emitResultLambda = lambdaRecorder<Map<PushRequest, Result<ResolvedPushEvent>>, Unit> {}
val processor = FakeNotificationResultProcessor(emit = emitResultLambda)
val pushHistoryService = FakePushHistoryService(
getPendingPushRequests = { _, _ -> Result.success(listOf(aPushRequest())) },
replacePushRequests = { Result.success(Unit) },
removeOldPushRequests = { Result.success(Unit) },
)
val resolver = FakeNotifiableEventResolver(
resolveEventsResult = { _, _ -> Result.failure(Exception("Failed to resolve events")) }
)
val worker = createWorker(
input = "@alice:matrix.org",
eventResolver = resolver,
resultProcessor = processor,
pushHistoryService = pushHistoryService,
)
val result = worker.doWork()
assertThat(result).isEqualTo(ListenableWorker.Result.retry())
// Never called since we don't need to re-submit
submitWorkerLambda.assertions().isNeverCalled()
}
@Test
fun `test - failing to resolve events with recoverable error retries the work`() {
val pushRequest = aPushRequest()
runTest {
val submitWorkerLambda = lambdaRecorder<WorkManagerRequestBuilder, Unit> {}
val emitResultLambda = lambdaRecorder<Map<PushRequest, Result<ResolvedPushEvent>>, Unit> {}
val processor = FakeNotificationResultProcessor(emit = emitResultLambda)
val pushHistoryService = FakePushHistoryService(
getPendingPushRequests = { _, _ -> Result.success(listOf(pushRequest)) },
replacePushRequests = { Result.success(Unit) },
removeOldPushRequests = { Result.success(Unit) },
)
val resolver = FakeNotifiableEventResolver(
resolveEventsResult = { _, _ ->
Result.success(mapOf(pushRequest to Result.failure(ClientException.Generic("error sending request for url", null))))
}
)
val worker = createWorker(
input = "@alice:matrix.org",
eventResolver = resolver,
resultProcessor = processor,
pushHistoryService = pushHistoryService,
)
val result = worker.doWork()
assertThat(result).isEqualTo(ListenableWorker.Result.retry())
// Never called since we don't need to re-submit
submitWorkerLambda.assertions().isNeverCalled()
// We do save the updated events to the push DB
emitResultLambda.assertions().isCalledOnce()
}
}
@Test
fun `test - failing to resolve events with unrecoverable error saves the new state and ends as success`() {
val pushRequest = aPushRequest()
runTest {
val submitWorkerLambda = lambdaRecorder<WorkManagerRequestBuilder, Unit> {}
val emitResultLambda = lambdaRecorder<Map<PushRequest, Result<ResolvedPushEvent>>, Unit> {}
val processor = FakeNotificationResultProcessor(emit = emitResultLambda)
val pushHistoryService = FakePushHistoryService(
getPendingPushRequests = { _, _ -> Result.success(listOf(pushRequest)) },
replacePushRequests = { Result.success(Unit) },
removeOldPushRequests = { Result.success(Unit) },
)
val resolver = FakeNotifiableEventResolver(
resolveEventsResult = { _, _ ->
Result.success(mapOf(pushRequest to Result.failure(IllegalStateException("Unrecoverable"))))
}
)
val worker = createWorker(
input = "@alice:matrix.org",
eventResolver = resolver,
resultProcessor = processor,
pushHistoryService = pushHistoryService,
)
val result = worker.doWork()
assertThat(result).isEqualTo(ListenableWorker.Result.success())
// Never called since we don't need to re-submit
submitWorkerLambda.assertions().isNeverCalled()
// We do save the updated events to the push DB
emitResultLambda.assertions().isCalledOnce()
}
}
private fun TestScope.createWorker(
input: String,
networkMonitor: FakeNetworkMonitor = FakeNetworkMonitor(),
eventResolver: FakeNotifiableEventResolver = FakeNotifiableEventResolver(resolveEventsResult = { _, _ -> Result.success(emptyMap()) }),
syncOnNotifiableEvent: SyncOnNotifiableEvent = SyncOnNotifiableEvent {},
analyticsService: FakeAnalyticsService = FakeAnalyticsService(),
pushHistoryService: FakePushHistoryService = FakePushHistoryService(),
resultProcessor: FakeNotificationResultProcessor = FakeNotificationResultProcessor(),
systemClock: FakeSystemClock = FakeSystemClock(),
) = FetchPendingNotificationsWorker(
params = createWorkerParams(workDataOf("session_id" to input)),
context = InstrumentationRegistry.getInstrumentation().context,
networkMonitor = networkMonitor,
eventResolver = eventResolver,
syncOnNotifiableEvent = syncOnNotifiableEvent,
analyticsService = analyticsService,
pushHistoryService = pushHistoryService,
resultProcessor = resultProcessor,
systemClock = systemClock,
)
private fun TestScope.createWorkerParams(
inputData: Data = Data.EMPTY,
): WorkerParameters = WorkerParameters(
UUID.randomUUID(),
inputData,
emptySet(),
WorkerParameters.RuntimeExtras(),
0,
0,
Executors.newSingleThreadExecutor(),
backgroundScope.coroutineContext,
WorkManagerTaskExecutor(Executors.newSingleThreadExecutor()),
MetroWorkerFactory(emptyMap()),
{ context, id, data -> FakeListenableFuture() },
{ context, id, foregroundInfo -> FakeListenableFuture() },
)
}
class FakeListenableFuture<T> : ListenableFuture<T> {
override fun addListener(listener: Runnable, executor: Executor) = Unit
override fun cancel(mayInterruptIfRunning: Boolean): Boolean = true
override fun get(): T? = null
override fun get(timeout: Long, unit: TimeUnit?): T? = null
override fun isCancelled(): Boolean = false
override fun isDone(): Boolean = false
}

View File

@@ -1,98 +0,0 @@
/*
* Copyright (c) 2025 Element Creations Ltd.
* Copyright 2025 New Vector Ltd.
*
* SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial.
* Please see LICENSE files in the repository root for full details.
*/
package io.element.android.libraries.push.impl.workmanager
import androidx.work.OneTimeWorkRequest
import androidx.work.hasKeyWithValueOfType
import com.google.common.truth.Truth.assertThat
import io.element.android.libraries.androidutils.json.DefaultJsonProvider
import io.element.android.libraries.matrix.api.core.SessionId
import io.element.android.libraries.matrix.test.A_SESSION_ID
import io.element.android.libraries.push.api.push.NotificationEventRequest
import io.element.android.libraries.push.impl.notifications.fixtures.aNotificationEventRequest
import io.element.android.libraries.workmanager.api.WorkManagerRequestType
import io.element.android.libraries.workmanager.api.workManagerTag
import io.element.android.services.toolbox.test.sdk.FakeBuildVersionSdkIntProvider
import kotlinx.coroutines.test.runTest
import org.junit.Test
import kotlin.collections.first
class SyncNotificationWorkManagerRequestTest {
@Test
fun `build - success API 33`() = runTest {
val request = createSyncNotificationWorkManagerRequest(
sessionId = A_SESSION_ID,
notificationEventRequests = listOf(aNotificationEventRequest()),
sdkVersion = 33,
)
val result = request.build()
assertThat(result.isSuccess).isTrue()
result.getOrNull()!!.first().run {
assertThat(this).isInstanceOf(OneTimeWorkRequest::class.java)
assertThat(workSpec.input.hasKeyWithValueOfType<String>("requests")).isTrue()
// True in API 33+
assertThat(workSpec.expedited).isTrue()
assertThat(workSpec.traceTag).isEqualTo(workManagerTag(A_SESSION_ID, WorkManagerRequestType.NOTIFICATION_SYNC))
}
}
@Test
fun `build - success API 32 and lower`() = runTest {
val request = createSyncNotificationWorkManagerRequest(
sessionId = A_SESSION_ID,
notificationEventRequests = listOf(aNotificationEventRequest()),
sdkVersion = 32,
)
val result = request.build()
assertThat(result.isSuccess).isTrue()
result.getOrNull()!!.first().run {
assertThat(this).isInstanceOf(OneTimeWorkRequest::class.java)
assertThat(workSpec.input.hasKeyWithValueOfType<String>("requests")).isTrue()
// False before API 33
assertThat(workSpec.expedited).isFalse()
assertThat(workSpec.traceTag).isEqualTo(workManagerTag(A_SESSION_ID, WorkManagerRequestType.NOTIFICATION_SYNC))
}
}
@Test
fun `build - empty list of requests fails`() = runTest {
val request = createSyncNotificationWorkManagerRequest(
sessionId = A_SESSION_ID,
notificationEventRequests = emptyList()
)
val result = request.build()
assertThat(result.isFailure).isTrue()
}
@Test
fun `build - invalid serialization`() = runTest {
val request = createSyncNotificationWorkManagerRequest(
sessionId = A_SESSION_ID,
notificationEventRequests = listOf(aNotificationEventRequest()),
workerDataConverter = SyncNotificationsWorkerDataConverter({ error("error during serialization") })
)
val result = request.build()
assertThat(result.isFailure).isTrue()
}
}
private fun createSyncNotificationWorkManagerRequest(
sessionId: SessionId,
notificationEventRequests: List<NotificationEventRequest>,
workerDataConverter: SyncNotificationsWorkerDataConverter = SyncNotificationsWorkerDataConverter(DefaultJsonProvider()),
sdkVersion: Int = 33,
) = SyncNotificationWorkManagerRequest(
sessionId = sessionId,
notificationEventRequests = notificationEventRequests,
workerDataConverter = workerDataConverter,
buildVersionSdkIntProvider = FakeBuildVersionSdkIntProvider(sdkVersion),
)

View File

@@ -0,0 +1,74 @@
/*
* Copyright (c) 2025 Element Creations Ltd.
* Copyright 2025 New Vector Ltd.
*
* SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial.
* Please see LICENSE files in the repository root for full details.
*/
package io.element.android.libraries.push.impl.workmanager
import androidx.work.OneTimeWorkRequest
import androidx.work.hasKeyWithValueOfType
import com.google.common.truth.Truth.assertThat
import io.element.android.libraries.matrix.api.core.SessionId
import io.element.android.libraries.matrix.test.A_SESSION_ID
import io.element.android.libraries.workmanager.api.WorkManagerRequestType
import io.element.android.libraries.workmanager.api.WorkManagerWorkerType
import io.element.android.libraries.workmanager.api.workManagerTag
import io.element.android.services.toolbox.test.sdk.FakeBuildVersionSdkIntProvider
import kotlinx.coroutines.test.runTest
import org.junit.Test
class SyncPendingNotificationsRequestBuilderTest {
@Test
fun `build - success API 33`() = runTest {
val request = createSyncPendingNotificationsRequestBuilder(
sessionId = A_SESSION_ID,
sdkVersion = 33,
)
val results = request.build()
assertThat(results.isSuccess).isTrue()
results.getOrNull()!!.first().let { result ->
assertThat(result.type).isInstanceOf(WorkManagerWorkerType.Unique::class.java)
result.request.run {
assertThat(this).isInstanceOf(OneTimeWorkRequest::class.java)
assertThat(workSpec.input.hasKeyWithValueOfType<String>(SyncPendingNotificationsRequestBuilder.SESSION_ID)).isTrue()
// True in API 33+
assertThat(workSpec.expedited).isTrue()
assertThat(workSpec.traceTag).isEqualTo(workManagerTag(A_SESSION_ID, WorkManagerRequestType.NOTIFICATION_SYNC))
}
}
}
@Test
fun `build - success API 32 and lower`() = runTest {
val request = createSyncPendingNotificationsRequestBuilder(
sessionId = A_SESSION_ID,
sdkVersion = 32,
)
val results = request.build()
assertThat(results.isSuccess).isTrue()
results.getOrNull()!!.first().let { result ->
assertThat(result.type).isInstanceOf(WorkManagerWorkerType.Unique::class.java)
result.request.run {
assertThat(this).isInstanceOf(OneTimeWorkRequest::class.java)
assertThat(workSpec.input.hasKeyWithValueOfType<String>(SyncPendingNotificationsRequestBuilder.SESSION_ID)).isTrue()
// False before API 33
assertThat(workSpec.expedited).isFalse()
assertThat(workSpec.traceTag).isEqualTo(workManagerTag(A_SESSION_ID, WorkManagerRequestType.NOTIFICATION_SYNC))
}
}
}
}
private fun createSyncPendingNotificationsRequestBuilder(
sessionId: SessionId,
sdkVersion: Int = 33,
) = SyncPendingNotificationsRequestBuilder(
sessionId = sessionId,
buildVersionSdkIntProvider = FakeBuildVersionSdkIntProvider(sdkVersion),
)

View File

@@ -1,141 +0,0 @@
/*
* Copyright (c) 2025 Element Creations Ltd.
* Copyright 2025 New Vector Ltd.
*
* SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial.
* Please see LICENSE files in the repository root for full details.
*/
package io.element.android.libraries.push.impl.workmanager
import com.google.common.truth.Truth.assertThat
import io.element.android.libraries.androidutils.json.DefaultJsonProvider
import io.element.android.libraries.matrix.api.core.EventId
import io.element.android.libraries.matrix.test.AN_EVENT_ID
import io.element.android.libraries.matrix.test.AN_EVENT_ID_2
import io.element.android.libraries.matrix.test.A_ROOM_ID
import io.element.android.libraries.matrix.test.A_ROOM_ID_2
import io.element.android.libraries.matrix.test.A_ROOM_ID_3
import io.element.android.libraries.matrix.test.A_SESSION_ID
import io.element.android.libraries.matrix.test.A_SESSION_ID_2
import io.element.android.libraries.push.api.push.NotificationEventRequest
import org.junit.Test
class WorkerDataConverterTest {
@Test
fun `ensure identity when serializing - deserializing an empty list`() {
testIdentity(emptyList())
}
@Test
fun `ensure identity when serializing - deserializing a list`() {
testIdentity(
listOf(
NotificationEventRequest(
sessionId = A_SESSION_ID,
roomId = A_ROOM_ID,
eventId = AN_EVENT_ID,
providerInfo = "info1",
),
NotificationEventRequest(
sessionId = A_SESSION_ID_2,
roomId = A_ROOM_ID_2,
eventId = AN_EVENT_ID_2,
providerInfo = "info2",
),
)
)
}
@Test
fun `serializing lots of data leads to several work data generated - one room - 100 events should be split in 5 chunks`() {
val data = List(100) {
NotificationEventRequest(
sessionId = A_SESSION_ID,
roomId = A_ROOM_ID,
eventId = EventId(AN_EVENT_ID.value + it),
providerInfo = "info$it",
)
}
val sut = SyncNotificationsWorkerDataConverter(DefaultJsonProvider())
val serialized = sut.serialize(data)
assertThat(serialized.getOrNull()?.size).isGreaterThan(1)
assertThat(serialized.getOrNull()?.size).isEqualTo(100 / SyncNotificationsWorkerDataConverter.CHUNK_SIZE)
// All the items are present
val deserialized = serialized.getOrNull()?.flatMap { sut.deserialize(it)!! }
assertThat(deserialized).containsExactlyElementsIn(data)
}
@Test
fun `serializing lots of data leads to several work data generated - one room - 101 events should be split in 6 chunks`() {
val data = List(101) {
NotificationEventRequest(
sessionId = A_SESSION_ID,
roomId = A_ROOM_ID,
eventId = EventId(AN_EVENT_ID.value + it),
providerInfo = "info$it",
)
}
val sut = SyncNotificationsWorkerDataConverter(DefaultJsonProvider())
val serialized = sut.serialize(data)
assertThat(serialized.getOrNull()?.size).isGreaterThan(1)
assertThat(serialized.getOrNull()?.size).isEqualTo(100 / SyncNotificationsWorkerDataConverter.CHUNK_SIZE + 1)
// All the items are present
val deserialized = serialized.getOrNull()?.flatMap { sut.deserialize(it)!! }
assertThat(deserialized).containsExactlyElementsIn(data)
}
@Test
fun `serializing lots of data leads to several work data generated - 3 rooms - 25 events should be split in 2 chunks and room not mixed`() {
val data1 = List(15) {
NotificationEventRequest(
sessionId = A_SESSION_ID,
roomId = A_ROOM_ID,
eventId = EventId(AN_EVENT_ID.value + it),
providerInfo = "info".repeat(100) + it,
)
}
val data2 = List(3) {
NotificationEventRequest(
sessionId = A_SESSION_ID,
roomId = A_ROOM_ID_2,
eventId = EventId(AN_EVENT_ID.value + it),
providerInfo = "info".repeat(100) + it,
)
}
val data3 = List(7) {
NotificationEventRequest(
sessionId = A_SESSION_ID,
roomId = A_ROOM_ID_3,
eventId = EventId(AN_EVENT_ID.value + it),
providerInfo = "info".repeat(100) + it,
)
}
val data = (data1 + data2 + data3).shuffled()
val sut = SyncNotificationsWorkerDataConverter(DefaultJsonProvider())
val serialized = sut.serialize(data)
assertThat(serialized.getOrNull()?.size).isEqualTo(2)
// All the items are present
val deserialized = serialized.getOrNull()?.flatMap { sut.deserialize(it)!! }
assertThat(deserialized).containsExactlyElementsIn(data)
// Rooms are not mixed between the chunks
val setsOfRooms = serialized.getOrNull()!!
.map { workData -> sut.deserialize(workData)!! }
.map {
it.map { request -> request.roomId }.toSet()
}
// Ensure that all sets are distinct
assertThat(setsOfRooms.size).isEqualTo(2)
// 3 roomId are present
assertThat(setsOfRooms.flatten().toSet()).containsExactly(A_ROOM_ID, A_ROOM_ID_2, A_ROOM_ID_3)
// No intersection between sets
assertThat(setsOfRooms[0].intersect(setsOfRooms[1])).isEmpty()
}
private fun testIdentity(data: List<NotificationEventRequest>) {
val sut = SyncNotificationsWorkerDataConverter(DefaultJsonProvider())
val serialized = sut.serialize(data).getOrThrow()
val result = sut.deserialize(serialized.first())
assertThat(result).isEqualTo(data)
}
}

View File

@@ -1,24 +0,0 @@
/*
* Copyright (c) 2025 Element Creations Ltd.
* Copyright 2025 New Vector Ltd.
*
* SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial.
* Please see LICENSE files in the repository root for full details.
*/
package io.element.android.libraries.push.test.notifications
import io.element.android.libraries.push.api.push.NotificationEventRequest
import io.element.android.libraries.push.impl.notifications.NotificationResolverQueue
import io.element.android.libraries.push.impl.notifications.model.ResolvedPushEvent
import kotlinx.coroutines.flow.MutableSharedFlow
class FakeNotificationResolverQueue(
private val processingLambda: suspend (NotificationEventRequest) -> Result<ResolvedPushEvent>,
) : NotificationResolverQueue {
override val results = MutableSharedFlow<Pair<List<NotificationEventRequest>, Map<NotificationEventRequest, Result<ResolvedPushEvent>>>>(replay = 1)
override suspend fun enqueue(request: NotificationEventRequest) {
results.emit(listOf(request) to mapOf(request to processingLambda(request)))
}
}

View File

@@ -1,4 +1,4 @@
-- This file is not striclty necessary, since the first
-- This file is not strictly necessary, since the first
-- version of the DB is 1, so we will never migrate from 0
CREATE TABLE SessionData (

View File

@@ -1,15 +0,0 @@
/*
* Copyright (c) 2025 Element Creations Ltd.
* Copyright 2025 New Vector Ltd.
*
* SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial.
* Please see LICENSE files in the repository root for full details.
*/
package io.element.android.libraries.workmanager.api
import androidx.work.WorkRequest
interface WorkManagerRequest {
fun build(): Result<List<WorkRequest>>
}

View File

@@ -0,0 +1,45 @@
/*
* Copyright (c) 2025 Element Creations Ltd.
* Copyright 2025 New Vector Ltd.
*
* SPDX-License-Identifier: AGPL-3.0-only OR LicenseRef-Element-Commercial.
* Please see LICENSE files in the repository root for full details.
*/
package io.element.android.libraries.workmanager.api
import androidx.work.ExistingWorkPolicy
import androidx.work.WorkRequest
/**
* A base class that can be customized to [build] work requests to schedule in `WorkManager`.
*/
interface WorkManagerRequestBuilder {
/**
* Builds a work request wrapper using the provided data.
*/
suspend fun build(): Result<List<WorkManagerRequestWrapper>>
}
/**
* A wrapper that allows us to avoid using Android APIs directly when scheduling workers.
*/
data class WorkManagerRequestWrapper(
val request: WorkRequest,
val type: WorkManagerWorkerType = WorkManagerWorkerType.Default,
)
/**
* The type of worker to use when scheduling the task.
*/
sealed interface WorkManagerWorkerType {
/**
* This allows a single worker instance with the [name] id to run at the same time. Its [policy] can be customized.
*/
data class Unique(val name: String, val policy: ExistingWorkPolicy) : WorkManagerWorkerType
/**
* The default worker type, with no custom rules.
*/
data object Default : WorkManagerWorkerType
}

View File

@@ -11,9 +11,21 @@ package io.element.android.libraries.workmanager.api
import io.element.android.libraries.matrix.api.core.SessionId
interface WorkManagerScheduler {
fun submit(workManagerRequest: WorkManagerRequest)
/**
* Submits a new work request built from [workManagerRequestBuilder] to run in `WorkManager`.
*/
suspend fun submit(workManagerRequestBuilder: WorkManagerRequestBuilder)
/**
* Checks if there are any pending requests scheduled for the provided [sessionId] and [requestType].
*/
fun hasPendingWork(sessionId: SessionId, requestType: WorkManagerRequestType): Boolean
fun cancel(sessionId: SessionId)
/**
* Cancel pending work requests for the session [SessionId].
* If [requestType] is provided, it will only cancel requests for that type, otherwise it will cancel all requests.
*/
fun cancel(sessionId: SessionId, requestType: WorkManagerRequestType? = null)
}
fun workManagerTag(sessionId: SessionId, requestType: WorkManagerRequestType): String {

View File

@@ -8,6 +8,7 @@
package io.element.android.libraries.workmanager.impl
import androidx.work.OneTimeWorkRequest
import androidx.work.WorkInfo
import androidx.work.WorkManager
import dev.zacsweers.metro.AppScope
@@ -16,9 +17,10 @@ import dev.zacsweers.metro.SingleIn
import io.element.android.libraries.matrix.api.core.SessionId
import io.element.android.libraries.sessionstorage.api.observer.SessionListener
import io.element.android.libraries.sessionstorage.api.observer.SessionObserver
import io.element.android.libraries.workmanager.api.WorkManagerRequest
import io.element.android.libraries.workmanager.api.WorkManagerRequestBuilder
import io.element.android.libraries.workmanager.api.WorkManagerRequestType
import io.element.android.libraries.workmanager.api.WorkManagerScheduler
import io.element.android.libraries.workmanager.api.WorkManagerWorkerType
import io.element.android.libraries.workmanager.api.workManagerTag
import timber.log.Timber
@@ -41,13 +43,22 @@ class DefaultWorkManagerScheduler(
})
}
override fun submit(workManagerRequest: WorkManagerRequest) {
workManagerRequest.build().fold(
onSuccess = { workRequests ->
workManager.enqueue(workRequests)
override suspend fun submit(workManagerRequestBuilder: WorkManagerRequestBuilder) {
workManagerRequestBuilder.build().fold(
onSuccess = { wrappers ->
for (wrapper in wrappers) {
when (wrapper.type) {
WorkManagerWorkerType.Default -> workManager.enqueue(wrapper.request)
is WorkManagerWorkerType.Unique -> {
val type = wrapper.type as WorkManagerWorkerType.Unique
val requests = wrapper.request as OneTimeWorkRequest
workManager.enqueueUniqueWork(type.name, type.policy, requests)
}
}
}
},
onFailure = {
Timber.e(it, "Failed to build WorkManager request $workManagerRequest")
Timber.e(it, "Failed to build WorkManager request $workManagerRequestBuilder")
}
)
}
@@ -64,10 +75,15 @@ class DefaultWorkManagerScheduler(
}
}
override fun cancel(sessionId: SessionId) {
override fun cancel(sessionId: SessionId, requestType: WorkManagerRequestType?) {
Timber.d("Cancelling work for sessionId: $sessionId")
for (requestType in WorkManagerRequestType.entries) {
if (requestType != null) {
workManager.cancelAllWorkByTag(workManagerTag(sessionId, requestType))
} else {
for (requestType in WorkManagerRequestType.entries) {
workManager.cancelAllWorkByTag(workManagerTag(sessionId, requestType))
}
}
}
}

View File

@@ -7,12 +7,17 @@
package io.element.android.libraries.workmanager.impl
import android.content.Context
import androidx.work.OneTimeWorkRequest
import androidx.work.WorkManager
import androidx.work.WorkRequest
import androidx.work.Worker
import androidx.work.WorkerParameters
import io.element.android.libraries.matrix.api.core.SessionId
import io.element.android.libraries.sessionstorage.test.observer.FakeSessionObserver
import io.element.android.libraries.workmanager.api.WorkManagerRequest
import io.element.android.libraries.workmanager.api.WorkManagerRequestBuilder
import io.element.android.libraries.workmanager.api.WorkManagerRequestType
import io.element.android.libraries.workmanager.api.WorkManagerRequestWrapper
import io.element.android.libraries.workmanager.api.workManagerTag
import io.mockk.every
import io.mockk.mockk
@@ -55,7 +60,7 @@ class DefaultWorkManagerSchedulerTest {
sessionObserver = FakeSessionObserver(),
)
scheduler.submit(FakeWorkManagerRequest())
scheduler.submit(FakeWorkManagerRequestBuilder())
verify { workManager.enqueue(any<List<WorkRequest>>()) }
}
@@ -69,7 +74,7 @@ class DefaultWorkManagerSchedulerTest {
sessionObserver = FakeSessionObserver(),
)
scheduler.submit(FakeWorkManagerRequest(result = Result.failure(IllegalStateException("Test error"))))
scheduler.submit(FakeWorkManagerRequestBuilder(result = Result.failure(IllegalStateException("Test error"))))
verify(exactly = 0) { workManager.enqueue(any<List<WorkRequest>>()) }
}
@@ -88,7 +93,7 @@ class DefaultWorkManagerSchedulerTest {
val mockSessionA = mockk<WorkRequest> {
every { tags } returns setOf(tagToRemove)
}
scheduler.submit(FakeWorkManagerRequest(result = Result.success(listOf(mockSessionA))))
scheduler.submit(FakeWorkManagerRequestBuilder(result = Result.success(listOf(WorkManagerRequestWrapper(mockSessionA)))))
scheduler.cancel(sessionId)
@@ -96,10 +101,16 @@ class DefaultWorkManagerSchedulerTest {
}
}
private class FakeWorkManagerRequest(
private val result: Result<List<WorkRequest>> = Result.success(listOf()),
) : WorkManagerRequest {
override fun build(): Result<List<WorkRequest>> {
private val workRequest = OneTimeWorkRequest.Builder(FakeWorker::class.java).build()
private class FakeWorkManagerRequestBuilder(
private val result: Result<List<WorkManagerRequestWrapper>> = Result.success(listOf(WorkManagerRequestWrapper(workRequest))),
) : WorkManagerRequestBuilder {
override suspend fun build(): Result<List<WorkManagerRequestWrapper>> {
return result
}
}
internal class FakeWorker(context: Context, params: WorkerParameters) : Worker(context, params) {
override fun doWork(): Result = Result.success()
}

View File

@@ -9,25 +9,25 @@
package io.element.android.libraries.workmanager.test
import io.element.android.libraries.matrix.api.core.SessionId
import io.element.android.libraries.workmanager.api.WorkManagerRequest
import io.element.android.libraries.workmanager.api.WorkManagerRequestBuilder
import io.element.android.libraries.workmanager.api.WorkManagerRequestType
import io.element.android.libraries.workmanager.api.WorkManagerScheduler
import io.element.android.tests.testutils.lambda.lambdaError
class FakeWorkManagerScheduler(
private val submitLambda: (WorkManagerRequest) -> Unit = { lambdaError() },
private val submitLambda: (WorkManagerRequestBuilder) -> Unit = { lambdaError() },
private val hasPendingWorkLambda: (SessionId, WorkManagerRequestType) -> Boolean = { _, _ -> false },
private val cancelLambda: (SessionId) -> Unit = { lambdaError() },
private val cancelLambda: (SessionId, WorkManagerRequestType?) -> Unit = { _, _ -> lambdaError() },
) : WorkManagerScheduler {
override fun submit(workManagerRequest: WorkManagerRequest) {
submitLambda(workManagerRequest)
override suspend fun submit(workManagerRequestBuilder: WorkManagerRequestBuilder) {
submitLambda(workManagerRequestBuilder)
}
override fun hasPendingWork(sessionId: SessionId, requestType: WorkManagerRequestType): Boolean {
return hasPendingWorkLambda(sessionId, requestType)
}
override fun cancel(sessionId: SessionId) {
cancelLambda(sessionId)
override fun cancel(sessionId: SessionId, requestType: WorkManagerRequestType?) {
cancelLambda(sessionId, requestType)
}
}