diff --git a/app/src/main/kotlin/com/google/ai/sample/ScreenCaptureApiClients.kt b/app/src/main/kotlin/com/google/ai/sample/ScreenCaptureApiClients.kt index 00489d4..05a1e4a 100644 --- a/app/src/main/kotlin/com/google/ai/sample/ScreenCaptureApiClients.kt +++ b/app/src/main/kotlin/com/google/ai/sample/ScreenCaptureApiClients.kt @@ -15,6 +15,7 @@ import kotlinx.serialization.json.JsonClassDiscriminator import kotlinx.serialization.modules.SerializersModule import kotlinx.serialization.modules.polymorphic import kotlinx.serialization.modules.subclass +import com.google.ai.sample.network.MistralRequestCoordinator import okhttp3.MediaType.Companion.toMediaType import okhttp3.OkHttpClient import okhttp3.Request @@ -70,7 +71,13 @@ data class ServiceMistralResponseMessage( val content: String ) -internal suspend fun callMistralApi(modelName: String, apiKey: String, chatHistory: List, inputContent: Content): Pair { +internal suspend fun callMistralApi( + modelName: String, + apiKey: String, + chatHistory: List, + inputContent: Content, + availableApiKeys: List = listOf(apiKey) +): Pair { var responseText: String? = null var errorMessage: String? = null @@ -129,7 +136,16 @@ internal suspend fun callMistralApi(modelName: String, apiKey: String, chatHisto .addHeader("Authorization", "Bearer $apiKey") .build() - client.newCall(request).execute().use { response -> + val keysForCoordinator = availableApiKeys.filter { it.isNotBlank() }.distinct().ifEmpty { listOf(apiKey) } + val coordinated = MistralRequestCoordinator.execute(apiKeys = keysForCoordinator, maxAttempts = maxOf(4, keysForCoordinator.size * 3)) { key -> + client.newCall( + request.newBuilder() + .header("Authorization", "Bearer $key") + .build() + ).execute() + } + + coordinated.response.use { response -> val responseBody = response.body?.string() if (!response.isSuccessful) { Log.e("ScreenCaptureService", "Mistral API Error ($response.code): $responseBody") diff --git a/app/src/main/kotlin/com/google/ai/sample/ScreenCaptureService.kt b/app/src/main/kotlin/com/google/ai/sample/ScreenCaptureService.kt index 4551070..17145ec 100644 --- a/app/src/main/kotlin/com/google/ai/sample/ScreenCaptureService.kt +++ b/app/src/main/kotlin/com/google/ai/sample/ScreenCaptureService.kt @@ -297,7 +297,16 @@ class ScreenCaptureService : Service() { if (apiProvider == ApiProvider.VERCEL) { responseText = callVercelApi(applicationContext, modelName, apiKey, chatHistoryDtos, inputContentDto) } else if (apiProvider == ApiProvider.MISTRAL) { - val result = callMistralApi(modelName, apiKey, chatHistory, inputContent) + val availableMistralKeys = ApiKeyManager.getInstance(applicationContext) + .getApiKeys(ApiProvider.MISTRAL) + .filter { it.isNotBlank() } + val result = callMistralApi( + modelName = modelName, + apiKey = apiKey, + chatHistory = chatHistory, + inputContent = inputContent, + availableApiKeys = availableMistralKeys + ) responseText = result.first errorMessage = result.second } else if (apiProvider == ApiProvider.PUTER) { diff --git a/app/src/main/kotlin/com/google/ai/sample/feature/multimodal/PhotoReasoningViewModel.kt b/app/src/main/kotlin/com/google/ai/sample/feature/multimodal/PhotoReasoningViewModel.kt index 102e5a7..d8ae9ea 100644 --- a/app/src/main/kotlin/com/google/ai/sample/feature/multimodal/PhotoReasoningViewModel.kt +++ b/app/src/main/kotlin/com/google/ai/sample/feature/multimodal/PhotoReasoningViewModel.kt @@ -34,6 +34,7 @@ import com.google.ai.sample.feature.multimodal.ModelDownloadManager import com.google.ai.sample.ModelOption import com.google.ai.sample.GenerativeAiViewModelFactory import com.google.ai.sample.InferenceBackend +import com.google.ai.sample.network.MistralRequestCoordinator import com.google.ai.sample.feature.multimodal.dtos.toDto import com.google.ai.sample.feature.multimodal.dtos.TempFilePathCollector import kotlinx.coroutines.Dispatchers @@ -70,8 +71,6 @@ import kotlinx.serialization.modules.subclass import com.google.ai.sample.webrtc.WebRTCSender import com.google.ai.sample.webrtc.SignalingClient import org.webrtc.IceCandidate -import kotlin.math.max -import kotlin.math.roundToLong class PhotoReasoningViewModel( application: Application, @@ -184,11 +183,14 @@ class PhotoReasoningViewModel( // to avoid re-executing already-executed commands private var incrementalCommandCount = 0 - // Mistral rate limiting per API key (1.5 seconds between requests with same key) - private val mistralNextAllowedRequestAtMsByKey = mutableMapOf() - private var lastMistralTokenTimeMs = 0L - private var lastMistralTokenKey: String? = null - private val MISTRAL_MIN_INTERVAL_MS = 1500L + private data class QueuedMistralScreenshotRequest( + val bitmap: Bitmap, + val screenshotUri: String, + val screenInfo: String? + ) + private val mistralAutoScreenshotQueueLock = Any() + private var mistralAutoScreenshotInFlight = false + private var queuedMistralScreenshotRequest: QueuedMistralScreenshotRequest? = null // Accumulated full text during streaming for incremental command parsing private var streamingAccumulatedText = StringBuilder() @@ -1032,6 +1034,10 @@ class PhotoReasoningViewModel( screenInfoForPrompt: String? = null, imageUrisForChat: List? = null ) { + Log.d( + TAG, + "reasonWithMistral: start, images=${selectedImages.size}, screenInfo=${!screenInfoForPrompt.isNullOrBlank()}, chatSize=${_chatState.getAllMessages().size}" + ) _uiState.value = PhotoReasoningUiState.Loading _showStopNotificationFlow.value = true val context = appContext @@ -1060,6 +1066,7 @@ class PhotoReasoningViewModel( currentReasoningJob?.cancel() currentReasoningJob = viewModelScope.launch(Dispatchers.IO) { try { + Log.d(TAG, "reasonWithMistral: launched IO job") val currentModel = com.google.ai.sample.GenerativeAiViewModelFactory.getCurrentModel() val genSettings = com.google.ai.sample.util.GenerationSettingsPreferences.loadSettings(context, currentModel.modelName) @@ -1130,135 +1137,25 @@ class PhotoReasoningViewModel( val availableKeys = apiKeyManager.getApiKeys(ApiProvider.MISTRAL) .filter { it.isNotBlank() } .distinct() + Log.d(TAG, "reasonWithMistral: availableKeys=${availableKeys.size}") if (availableKeys.isEmpty()) { throw IOException("Mistral API key not found.") } // Validate that we have at least one key before proceeding require(availableKeys.isNotEmpty()) { "No valid Mistral API keys available after filtering" } - - fun markKeyCooldown(key: String, referenceTimeMs: Long) { - val nextAllowedAt = referenceTimeMs + MISTRAL_MIN_INTERVAL_MS - val existing = mistralNextAllowedRequestAtMsByKey[key] ?: 0L - mistralNextAllowedRequestAtMsByKey[key] = max(existing, nextAllowedAt) - } - - fun markKeyCooldown(key: String, referenceTimeMs: Long, extraDelayMs: Long) { - val normalizedExtraDelay = extraDelayMs.coerceAtLeast(0L) - val nextAllowedAt = referenceTimeMs + max(MISTRAL_MIN_INTERVAL_MS, normalizedExtraDelay) - val existing = mistralNextAllowedRequestAtMsByKey[key] ?: 0L - mistralNextAllowedRequestAtMsByKey[key] = max(existing, nextAllowedAt) - } - - fun remainingWaitForKeyMs(key: String, nowMs: Long): Long { - val nextAllowedAt = mistralNextAllowedRequestAtMsByKey[key] ?: 0L - return (nextAllowedAt - nowMs).coerceAtLeast(0L) - } - - fun parseRetryAfterMs(headerValue: String?): Long? { - if (headerValue.isNullOrBlank()) return null - val seconds = headerValue.trim().toDoubleOrNull() ?: return null - return (seconds * 1000.0).roundToLong().coerceAtLeast(0L) - } - - fun parseRateLimitResetDelayMs(response: okhttp3.Response, nowMs: Long): Long? { - val resetHeader = response.header("x-ratelimit-reset") ?: return null - val resetEpochSeconds = resetHeader.trim().toLongOrNull() ?: return null - val resetMs = resetEpochSeconds * 1000L - return (resetMs - nowMs).coerceAtLeast(0L) - } - - fun adaptiveRetryDelayMs(failureCount: Int): Long { - val cappedExponent = (failureCount - 1).coerceIn(0, 5) - return 1000L shl cappedExponent // 1s, 2s, 4s, 8s, 16s, 32s - } - - fun isRetryableMistralFailure(code: Int): Boolean { - return code == 429 || code >= 500 - } - - var response: okhttp3.Response? = null - var selectedKeyForResponse: String? = null - var consecutiveFailures = 0 - var blockedKeysThisRound = mutableSetOf() - val maxAttempts = availableKeys.size * 4 + 8 - while (response == null && consecutiveFailures < maxAttempts) { - if (stopExecutionFlag.get()) break - - val now = System.currentTimeMillis() - val keyPool = availableKeys.filter { it !in blockedKeysThisRound }.ifEmpty { - blockedKeysThisRound.clear() - availableKeys - } - - val keyWithLeastWait = keyPool.minByOrNull { remainingWaitForKeyMs(it, now) } ?: availableKeys.first() - val waitMs = remainingWaitForKeyMs(keyWithLeastWait, now) - if (waitMs > 0L) { - delay(waitMs) - } - - val selectedKey = keyWithLeastWait - selectedKeyForResponse = selectedKey - - try { - val attemptResponse = client.newCall(buildRequest(selectedKey)).execute() - val requestEndMs = System.currentTimeMillis() - val retryAfterMs = parseRetryAfterMs(attemptResponse.header("Retry-After")) - val resetDelayMs = parseRateLimitResetDelayMs(attemptResponse, requestEndMs) - val serverRequestedDelayMs = max(retryAfterMs ?: 0L, resetDelayMs ?: 0L) - markKeyCooldown(selectedKey, requestEndMs, serverRequestedDelayMs) - - if (attemptResponse.isSuccessful) { - response = attemptResponse - break - } - - val isRetryable = isRetryableMistralFailure(attemptResponse.code) - if (!isRetryable) { - val errBody = attemptResponse.body?.string() - attemptResponse.close() - throw IllegalStateException("Mistral Error ${attemptResponse.code}: $errBody") - } - - attemptResponse.close() - blockedKeysThisRound.add(selectedKey) - consecutiveFailures++ - val adaptiveDelay = adaptiveRetryDelayMs(consecutiveFailures) - markKeyCooldown( - selectedKey, - requestEndMs, - max(serverRequestedDelayMs, adaptiveDelay) - ) - withContext(Dispatchers.Main) { - replaceAiMessageText( - "Mistral temporär nicht verfügbar (Versuch $consecutiveFailures/$maxAttempts). Warte auf Server-Rate-Limit und wiederhole...", - isPending = true - ) - } - } catch (e: IOException) { - val requestEndMs = System.currentTimeMillis() - val adaptiveDelay = adaptiveRetryDelayMs(consecutiveFailures + 1) - markKeyCooldown(selectedKey, requestEndMs, adaptiveDelay) - blockedKeysThisRound.add(selectedKey) - consecutiveFailures++ - if (consecutiveFailures >= maxAttempts) { - throw IOException("Mistral request failed after $maxAttempts attempts: ${e.message}", e) - } - withContext(Dispatchers.Main) { - replaceAiMessageText( - "Mistral Netzwerkfehler (Versuch $consecutiveFailures/$maxAttempts). Wiederhole...", - isPending = true - ) - } + val coordinated = MistralRequestCoordinator.execute( + apiKeys = availableKeys, + maxAttempts = maxAttempts + ) { selectedKey -> + if (stopExecutionFlag.get()) { + throw IOException("Mistral request aborted.") } + client.newCall(buildRequest(selectedKey)).execute() } - - if (stopExecutionFlag.get()) { - throw IOException("Mistral request aborted.") - } - - val finalResponse = response ?: throw IOException("Mistral request failed after $maxAttempts attempts.") + val finalResponse = coordinated.response + Log.d(TAG, "reasonWithMistral: coordinated response code=${finalResponse.code}") if (!finalResponse.isSuccessful) { val errBody = finalResponse.body?.string() @@ -1268,27 +1165,13 @@ class PhotoReasoningViewModel( val body = finalResponse.body ?: throw IOException("Empty response body from Mistral") val aiResponseText = openAiStreamParser.parse(body) { accText -> - selectedKeyForResponse?.let { key -> - lastMistralTokenKey = key - lastMistralTokenTimeMs = System.currentTimeMillis() - markKeyCooldown(key, lastMistralTokenTimeMs) - } ?: run { - Log.w(TAG, "selectedKeyForResponse is null during streaming callback") - } withContext(Dispatchers.Main) { replaceAiMessageText(accText, isPending = true) processCommandsIncrementally(accText) } } + Log.d(TAG, "reasonWithMistral: stream parse finished, responseLength=${aiResponseText.length}") finalResponse.close() - selectedKeyForResponse?.let { key -> - val reference = if (lastMistralTokenKey == key && lastMistralTokenTimeMs > 0L) { - lastMistralTokenTimeMs - } else { - System.currentTimeMillis() - } - markKeyCooldown(key, reference) - } withContext(Dispatchers.Main) { _uiState.value = PhotoReasoningUiState.Success(aiResponseText) @@ -1306,11 +1189,13 @@ class PhotoReasoningViewModel( } } finally { withContext(Dispatchers.Main) { + Log.d(TAG, "reasonWithMistral: finally, draining queued auto-screenshot requests") + releaseAndDrainMistralAutoScreenshotQueue() refreshStopButtonState() } } } -} + } private fun reasonWithPuter( userInput: String, @@ -2202,7 +2087,6 @@ private fun processCommands(text: String) { } } } - private fun executeAccessibilityCommand(command: Command, shouldTrackCommand: Boolean) { ScreenOperatorAccessibilityService.executeCommand(command) if (shouldTrackCommand) { @@ -2404,16 +2288,22 @@ private fun processCommands(text: String) { _commandExecutionStatus.value = status } - // Create prompt with screen information if available - val genericAnalysisPrompt = createGenericScreenshotPrompt() - - // Re-send the query with only the latest screenshot - reason( - userInput = genericAnalysisPrompt, - selectedImages = listOf(bitmap), - screenInfoForPrompt = screenInfo, - imageUrisForChat = listOf(screenshotUri.toString()) // Add this argument - ) + val currentModel = GenerativeAiViewModelFactory.getCurrentModel() + if (currentModel.apiProvider == ApiProvider.MISTRAL) { + enqueueMistralAutoScreenshotRequest( + bitmap = bitmap, + screenshotUri = screenshotUri.toString(), + screenInfo = screenInfo + ) + } else { + // Re-send the query with only the latest screenshot + reason( + userInput = createGenericScreenshotPrompt(), + selectedImages = listOf(bitmap), + screenInfoForPrompt = screenInfo, + imageUrisForChat = listOf(screenshotUri.toString()) + ) + } PhotoReasoningScreenshotUiNotifier.showAddedToConversation(context) } else { @@ -2436,5 +2326,60 @@ private fun processCommands(text: String) { } } } + + private fun enqueueMistralAutoScreenshotRequest( + bitmap: Bitmap, + screenshotUri: String, + screenInfo: String? + ) { + val request = QueuedMistralScreenshotRequest( + bitmap = bitmap, + screenshotUri = screenshotUri, + screenInfo = screenInfo + ) + var shouldStartNow = false + synchronized(mistralAutoScreenshotQueueLock) { + if (mistralAutoScreenshotInFlight) { + queuedMistralScreenshotRequest = request + Log.d(TAG, "Mistral auto screenshot request queued (latest wins). uri=$screenshotUri") + } else { + mistralAutoScreenshotInFlight = true + Log.d(TAG, "Mistral auto screenshot request becomes in-flight. uri=$screenshotUri") + shouldStartNow = true + } + } + if (shouldStartNow) { + dispatchMistralAutoScreenshotRequest(request) + } + } + + private fun dispatchMistralAutoScreenshotRequest(request: QueuedMistralScreenshotRequest) { + Log.d(TAG, "Dispatching Mistral auto screenshot request. uri=${request.screenshotUri}") + reason( + userInput = createGenericScreenshotPrompt(), + selectedImages = listOf(request.bitmap), + screenInfoForPrompt = request.screenInfo, + imageUrisForChat = listOf(request.screenshotUri) + ) + } + + private fun releaseAndDrainMistralAutoScreenshotQueue() { + val nextRequest: QueuedMistralScreenshotRequest? = synchronized(mistralAutoScreenshotQueueLock) { + val queued = queuedMistralScreenshotRequest + if (queued == null) { + mistralAutoScreenshotInFlight = false + Log.d(TAG, "Mistral auto screenshot queue drained completely. inFlight=false") + null + } else { + queuedMistralScreenshotRequest = null + Log.d(TAG, "Mistral auto screenshot queue has pending request to drain.") + queued + } + } + if (nextRequest != null) { + Log.d(TAG, "Draining queued Mistral auto screenshot request.") + dispatchMistralAutoScreenshotRequest(nextRequest) + } + } } diff --git a/app/src/main/kotlin/com/google/ai/sample/network/MistralRequestCoordinator.kt b/app/src/main/kotlin/com/google/ai/sample/network/MistralRequestCoordinator.kt new file mode 100644 index 0000000..82bca65 --- /dev/null +++ b/app/src/main/kotlin/com/google/ai/sample/network/MistralRequestCoordinator.kt @@ -0,0 +1,156 @@ +package com.google.ai.sample.network + +import android.util.Log +import kotlinx.coroutines.delay +import kotlinx.coroutines.sync.Mutex +import kotlinx.coroutines.sync.withLock +import okhttp3.Response +import java.util.concurrent.atomic.AtomicLong +import kotlin.math.max +import kotlin.math.roundToLong + +internal data class MistralCoordinatedResponse( + val response: Response, + val apiKey: String +) + +internal object MistralRequestCoordinator { + private const val TAG = "MistralCoordinator" + private const val MIN_INTERVAL_MS = 1500L + private const val MAX_SERVER_DELAY_MS = 5_000L + private val cooldownMutex = Mutex() + private val nextAllowedRequestAtMsByKey = mutableMapOf() + private val requestId = AtomicLong(0L) + + private fun keyFingerprint(key: String): String { + if (key.length <= 8) return key + return "${key.take(4)}…${key.takeLast(4)}" + } + + private suspend fun markKeyCooldown( + key: String, + referenceTimeMs: Long, + extraDelayMs: Long = 0L + ) { + val nextAllowedAt = referenceTimeMs + max(MIN_INTERVAL_MS, extraDelayMs.coerceAtLeast(0L)) + cooldownMutex.withLock { + val existing = nextAllowedRequestAtMsByKey[key] ?: 0L + nextAllowedRequestAtMsByKey[key] = max(existing, nextAllowedAt) + } + } + + private suspend fun remainingWaitForKeyMs(key: String, nowMs: Long): Long { + return cooldownMutex.withLock { + val nextAllowedAt = nextAllowedRequestAtMsByKey[key] ?: 0L + (nextAllowedAt - nowMs).coerceAtLeast(0L) + } + } + + private fun parseRetryAfterMs(headerValue: String?): Long? { + if (headerValue.isNullOrBlank()) return null + val seconds = headerValue.trim().toDoubleOrNull() ?: return null + return (seconds * 1000.0).roundToLong().coerceAtLeast(0L) + } + + private fun parseRateLimitResetDelayMs(response: Response, nowMs: Long): Long? { + val resetHeader = response.header("x-ratelimit-reset") ?: return null + val raw = resetHeader.trim().toLongOrNull() ?: return null + val delayMs = when { + // likely unix epoch in milliseconds + raw >= 1_000_000_000_000L -> raw - nowMs + // likely unix epoch in seconds + raw >= 1_000_000_000L -> (raw * 1000L) - nowMs + // likely relative seconds + raw >= 0L -> raw * 1000L + else -> return null + } + return delayMs.coerceAtLeast(0L).coerceAtMost(MAX_SERVER_DELAY_MS) + } + + private fun adaptiveRetryDelayMs(failureCount: Int): Long { + val cappedExponent = (failureCount - 1).coerceIn(0, 5) + return 1000L shl cappedExponent + } + + private fun isRetryableFailure(code: Int): Boolean = code == 429 || code >= 500 + + suspend fun execute( + apiKeys: List, + maxAttempts: Int = apiKeys.size * 4 + 8, + request: suspend (apiKey: String) -> Response + ): MistralCoordinatedResponse { + require(apiKeys.isNotEmpty()) { "No Mistral API keys provided." } + val rid = requestId.incrementAndGet() + Log.d(TAG, "[$rid] execute start: keys=${apiKeys.size}, maxAttempts=$maxAttempts") + + var consecutiveFailures = 0 + var blockedKeysThisRound = mutableSetOf() + + while (consecutiveFailures < maxAttempts) { + val now = System.currentTimeMillis() + val keyPool = apiKeys.filter { it !in blockedKeysThisRound }.ifEmpty { + blockedKeysThisRound.clear() + apiKeys + } + + var selectedKey = apiKeys.first() + var waitMs = Long.MAX_VALUE + for (candidate in keyPool) { + val candidateWait = remainingWaitForKeyMs(candidate, now) + if (candidateWait < waitMs) { + waitMs = candidateWait + selectedKey = candidate + } + } + Log.d( + TAG, + "[$rid] attempt=${consecutiveFailures + 1}, selectedKey=${keyFingerprint(selectedKey)}, waitMs=$waitMs, blocked=${blockedKeysThisRound.size}" + ) + if (waitMs > 0L) { + delay(waitMs) + } + + try { + val response = request(selectedKey) + val requestEndMs = System.currentTimeMillis() + val retryAfterMs = parseRetryAfterMs(response.header("Retry-After")) + val resetDelayMs = parseRateLimitResetDelayMs(response, requestEndMs) + val serverRequestedDelayMs = max(retryAfterMs ?: 0L, resetDelayMs ?: 0L) + Log.d( + TAG, + "[$rid] response code=${response.code}, retryAfterMs=${retryAfterMs ?: -1}, resetDelayMs=${resetDelayMs ?: -1}, appliedDelayMs=$serverRequestedDelayMs" + ) + markKeyCooldown(selectedKey, requestEndMs, serverRequestedDelayMs) + + if (response.isSuccessful || !isRetryableFailure(response.code)) { + Log.d(TAG, "[$rid] returning response code=${response.code} with key=${keyFingerprint(selectedKey)}") + return MistralCoordinatedResponse(response = response, apiKey = selectedKey) + } + + response.close() + blockedKeysThisRound.add(selectedKey) + consecutiveFailures++ + val adaptiveDelay = adaptiveRetryDelayMs(consecutiveFailures) + Log.w( + TAG, + "[$rid] retryable failure code=${response.code}, consecutiveFailures=$consecutiveFailures, adaptiveDelay=$adaptiveDelay" + ) + markKeyCooldown(selectedKey, requestEndMs, max(serverRequestedDelayMs, adaptiveDelay)) + } catch (e: Exception) { + val requestEndMs = System.currentTimeMillis() + blockedKeysThisRound.add(selectedKey) + consecutiveFailures++ + Log.e( + TAG, + "[$rid] exception on key=${keyFingerprint(selectedKey)}, consecutiveFailures=$consecutiveFailures: ${e.message}", + e + ) + markKeyCooldown(selectedKey, requestEndMs, adaptiveRetryDelayMs(consecutiveFailures)) + if (consecutiveFailures >= maxAttempts) throw e + } + } + + Log.e(TAG, "[$rid] exhausted attempts ($maxAttempts) without success") + throw IllegalStateException("Mistral request failed after $maxAttempts attempts.") + } +}