Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,8 @@ public final class ClientSessionComponent {
private let coreCryptoProvider: any CoreCryptoProviderProtocol
private let completionHandlers: CompletionHandlers

private let faultyMLSRemovalKeysByDomain: [String: [String]]

public init(
selfUserID: UUID,
selfClientID: String,
Expand All @@ -81,7 +83,8 @@ public final class ClientSessionComponent {
mlsDecryptionService: any MLSDecryptionServiceInterface,
proteusService: any ProteusServiceInterface,
coreCryptoProvider: any CoreCryptoProviderProtocol,
completionHandlers: CompletionHandlers
completionHandlers: CompletionHandlers,
faultyMLSRemovalKeysByDomain: [String: [String]]
) {
self.selfUserID = selfUserID
self.selfClientID = selfClientID
Expand All @@ -99,6 +102,7 @@ public final class ClientSessionComponent {
self.isMLSEnabled = isMLSEnabled
self.coreCryptoProvider = coreCryptoProvider
self.completionHandlers = completionHandlers
self.faultyMLSRemovalKeysByDomain = faultyMLSRemovalKeysByDomain
}

public private(set) lazy var authenticationManager = AuthenticationManager(
Expand Down Expand Up @@ -793,6 +797,15 @@ public final class ClientSessionComponent {
userID: selfUserID
)

public lazy var repairFaultyRemovalKeysUsecase = RepairRemovalKeysUseCase(
faultyMLSRemovalKeysByDomain: faultyMLSRemovalKeysByDomain,
context: syncContext,
mlsService: mlsService,
conversationsAPI: conversationsAPI,
conversationLocalStore: conversationLocalStore,
initiateResetUseCase: initiateResetMLSConversationUseCase
)

public lazy var initiateResetMLSConversationUseCase = InitiateResetMLSConversationUseCase(
api: mlsAPI,
mlsService: mlsService,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@ public final class UserSessionComponent {
private let proteusService: any ProteusServiceInterface
private let coreCryptoProvider: any CoreCryptoProviderProtocol

private let faultyMLSRemovalKeysByDomain: [String: [String]]

public init(
currentBuildNumber: String,
selfUserID: UUID,
Expand All @@ -59,7 +61,8 @@ public final class UserSessionComponent {
mlsService: any MLSServiceInterface,
mlsDecryptionService: any MLSDecryptionServiceInterface,
proteusService: any ProteusServiceInterface,
coreCryptoProvider: any CoreCryptoProviderProtocol
coreCryptoProvider: any CoreCryptoProviderProtocol,
faultyMLSRemovalKeysByDomain: [String: [String]]
) {
self.currentBuildNumber = currentBuildNumber
self.selfUserID = selfUserID
Expand All @@ -77,6 +80,7 @@ public final class UserSessionComponent {
self.proteusService = proteusService
self.coreCryptoProvider = coreCryptoProvider
self.sharedContainerURL = sharedContainerURL
self.faultyMLSRemovalKeysByDomain = faultyMLSRemovalKeysByDomain
}

private let cookieStorage: any CookieStorageProtocol
Expand All @@ -103,7 +107,8 @@ public final class UserSessionComponent {
mlsDecryptionService: mlsDecryptionService,
proteusService: proteusService,
coreCryptoProvider: coreCryptoProvider,
completionHandlers: completionHandlers
completionHandlers: completionHandlers,
faultyMLSRemovalKeysByDomain: faultyMLSRemovalKeysByDomain
)
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import WireDataModel
import WireLogging
import WireNetwork

// sourcery: AutoMockable
public protocol InitiateResetMLSConversationUseCaseProtocol {
func invoke(groupID: WireDataModel.MLSGroupID, epoch: UInt64) async
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,16 @@ public final class ConversationLocalStore: ConversationLocalStoreProtocol {
}
}

public func fetchAllMLSConversations(domain: String?) async throws -> [ZMConversation] {
try await context.perform { [context] in
try ZMConversation.fetchConversationsWithMLSGroupStatus(
mlsGroupStatus: .ready,
domain: domain,
in: context
)
}
}

public func fetchConversation(
id: UUID,
domain: String?
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,21 @@ public protocol ConversationLocalStoreProtocol {
mlsGroupID: MLSGroupID
) async

/// Fetches all MLS conversations that are ready.
///
/// This method retrieves all conversations that have MLS group IDs and are in a ready state,
/// optionally filtered by domain.
///
/// - Parameter domain: The domain to filter conversations by. If `nil`, fetches conversations
/// from all domains.
///
/// - Returns: An array of `ZMConversation` objects that are MLS-ready. Returns an empty array
/// if no conversations match the criteria.
///
/// - Throws: An error if the fetch operation fails.

func fetchAllMLSConversations(domain: String?) async throws -> [ZMConversation]

/// Fetches a MLS conversation locally.
///
/// - parameters:
Expand Down
199 changes: 199 additions & 0 deletions WireDomain/Sources/WireDomain/UseCases/RepairRemovalKeysUseCase.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
//
// Wire
// Copyright (C) 2025 Wire Swiss GmbH
//
// This program is free software: you can redistribute it and/or modify
// it under the terms of the GNU General Public License as published by
// the Free Software Foundation, either version 3 of the License, or
// (at your option) any later version.
//
// This program is distributed in the hope that it will be useful,
// but WITHOUT ANY WARRANTY; without even the implied warranty of
// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
// GNU General Public License for more details.
//
// You should have received a copy of the GNU General Public License
// along with this program. If not, see http://www.gnu.org/licenses/.
//

import WireDataModel
import WireLogging
import WireNetwork

// sourcery: AutoMockable
/// Repairs conversations with faulty removal keys
public protocol RepairRemovalKeysUseCaseProtocol {
func invoke() async throws
}

public struct RepairRemovalKeysUseCase: RepairRemovalKeysUseCaseProtocol {

let faultyMLSRemovalKeysByDomain: [String: [String]]

private let context: NSManagedObjectContext
private let mlsService: MLSServiceInterface
private let conversationsAPI: ConversationsAPI
private let conversationLocalStore: ConversationLocalStoreProtocol
private let initiateResetUseCase: InitiateResetMLSConversationUseCaseProtocol

init(
faultyMLSRemovalKeysByDomain: [String: [String]],
context: NSManagedObjectContext,
mlsService: MLSServiceInterface,
conversationsAPI: ConversationsAPI,
conversationLocalStore: ConversationLocalStoreProtocol,
initiateResetUseCase: InitiateResetMLSConversationUseCaseProtocol
) {
self.faultyMLSRemovalKeysByDomain = faultyMLSRemovalKeysByDomain
self.context = context
self.mlsService = mlsService
self.conversationsAPI = conversationsAPI
self.conversationLocalStore = conversationLocalStore
self.initiateResetUseCase = initiateResetUseCase
}

public func invoke() async throws {
WireLogger.mls.info(
"initiating repair of faulty removal keys",
attributes: .safePublic
)

guard !faultyMLSRemovalKeysByDomain.isEmpty else {
WireLogger.mls.info(
"no faulty removal keys to repair, aborting",
attributes: .safePublic
)
return
}

// Process each domain
for (domain, faultyKeyHexStrings) in faultyMLSRemovalKeysByDomain {
try await processDomain(
domain: domain,
faultyKeyHexStrings: faultyKeyHexStrings
)
}
}

// MARK: - Private

private func processDomain(
domain: String,
faultyKeyHexStrings: [String]
) async throws {
WireLogger.mls.info(
"checking domain for \(faultyKeyHexStrings.count) faulty key(s)",
attributes: .safePublic
)

// Convert hex strings to Data
let faultyKeyDataList = faultyKeyHexStrings.compactMap(Data.init(hexString:))
guard faultyKeyDataList.count == faultyKeyHexStrings.count else {
WireLogger.mls.error(
"failed to decode some faulty removal key hex strings",
attributes: .safePublic
)
return
}

let allMLSConversations = try await conversationLocalStore.fetchAllMLSConversations(
domain: domain
)

// Find faulty conversations for this domain
let faultyConversations = await findFaultyConversations(
in: allMLSConversations,
faultyKeys: faultyKeyDataList
)

WireLogger.mls.info(
"detected \(faultyConversations.count)/\(allMLSConversations.count) affected conversations",
attributes: .safePublic
)

// Repair each faulty conversation in parallel
await withTaskGroup(of: Void.self) { group in
for (groupID, qualifiedID) in faultyConversations {
group.addTask {
await repairConversation(
groupID: groupID,
qualifiedID: qualifiedID
)
}
}
}
}

private func findFaultyConversations(
in conversations: [ZMConversation],
faultyKeys: [Data]
) async -> [(MLSGroupID, WireDataModel.QualifiedID)] {
var faultyConversations: [(MLSGroupID, WireDataModel.QualifiedID)] = []

for conversation in conversations {
let (groupID, qualifiedID) = await context.perform {
(conversation.mlsGroupID, conversation.qualifiedID)
}

guard let groupID, let qualifiedID else {
continue
}

let currentRemovalKey: Data
do {
currentRemovalKey = try await mlsService.externalSenderKey(groupID: groupID)
} catch {
WireLogger.mls.error(
"failed to get current removal key for a group, skipping: \(String(describing: error))",
attributes: .safePublic
)
continue
}

// Check if the current removal key matches any of the faulty keys
if faultyKeys.contains(currentRemovalKey) {
faultyConversations.append((
groupID,
qualifiedID
))
}
}

return faultyConversations
}

private func repairConversation(
groupID: MLSGroupID,
qualifiedID: WireDataModel.QualifiedID
) async {
let remoteConversation: WireNetwork.Conversation?
do {
remoteConversation = try await conversationsAPI.getConversations(
for: [qualifiedID.toAPIModel()]
).found.first
} catch {
WireLogger.mls.error(
"failed to get epoch for a group, skipping: \(String(describing: error))",
attributes: .safePublic, [.conversationId: qualifiedID.safeForLoggingDescription]
)
return
}

guard let remoteConversation else {
WireLogger.mls.error(
"remote conversation for a group not found, skipping",
attributes: .safePublic, [.conversationId: qualifiedID.safeForLoggingDescription]
)
return
}

WireLogger.mls.info(
"initiating reset for faulty conversation: \(qualifiedID)",
attributes: .safePublic, [.conversationId: qualifiedID.safeForLoggingDescription]
)

let epoch = UInt64(remoteConversation.epoch ?? 0)
await initiateResetUseCase.invoke(groupID: groupID, epoch: epoch)
}

}
Loading
Loading