Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import kotlinx.serialization.json.JsonClassDiscriminator
import kotlinx.serialization.modules.SerializersModule
import kotlinx.serialization.modules.polymorphic
import kotlinx.serialization.modules.subclass
import com.google.ai.sample.network.MistralRequestCoordinator
import okhttp3.MediaType.Companion.toMediaType
import okhttp3.OkHttpClient
import okhttp3.Request
Expand Down Expand Up @@ -129,7 +130,15 @@ internal suspend fun callMistralApi(modelName: String, apiKey: String, chatHisto
.addHeader("Authorization", "Bearer $apiKey")
.build()

client.newCall(request).execute().use { response ->
val coordinated = MistralRequestCoordinator.execute(apiKeys = listOf(apiKey), maxAttempts = 4) { key ->
client.newCall(
request.newBuilder()
.header("Authorization", "Bearer $key")
.build()
).execute()
}

coordinated.response.use { response ->
val responseBody = response.body?.string()
if (!response.isSuccessful) {
Log.e("ScreenCaptureService", "Mistral API Error ($response.code): $responseBody")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ import com.google.ai.sample.feature.multimodal.ModelDownloadManager
import com.google.ai.sample.ModelOption
import com.google.ai.sample.GenerativeAiViewModelFactory
import com.google.ai.sample.InferenceBackend
import com.google.ai.sample.network.MistralRequestCoordinator
import com.google.ai.sample.feature.multimodal.dtos.toDto
import com.google.ai.sample.feature.multimodal.dtos.TempFilePathCollector
import kotlinx.coroutines.Dispatchers
Expand Down Expand Up @@ -70,8 +71,6 @@ import kotlinx.serialization.modules.subclass
import com.google.ai.sample.webrtc.WebRTCSender
import com.google.ai.sample.webrtc.SignalingClient
import org.webrtc.IceCandidate
import kotlin.math.max
import kotlin.math.roundToLong

class PhotoReasoningViewModel(
application: Application,
Expand Down Expand Up @@ -184,11 +183,14 @@ class PhotoReasoningViewModel(
// to avoid re-executing already-executed commands
private var incrementalCommandCount = 0

// Mistral rate limiting per API key (1.5 seconds between requests with same key)
private val mistralNextAllowedRequestAtMsByKey = mutableMapOf<String, Long>()
private var lastMistralTokenTimeMs = 0L
private var lastMistralTokenKey: String? = null
private val MISTRAL_MIN_INTERVAL_MS = 1500L
private data class QueuedMistralScreenshotRequest(
val bitmap: Bitmap,
val screenshotUri: String,
val screenInfo: String?
)
private val mistralAutoScreenshotQueueLock = Any()
private var mistralAutoScreenshotInFlight = false
private var queuedMistralScreenshotRequest: QueuedMistralScreenshotRequest? = null

// Accumulated full text during streaming for incremental command parsing
private var streamingAccumulatedText = StringBuilder()
Expand Down Expand Up @@ -1136,129 +1138,17 @@ class PhotoReasoningViewModel(

// Validate that we have at least one key before proceeding
require(availableKeys.isNotEmpty()) { "No valid Mistral API keys available after filtering" }

fun markKeyCooldown(key: String, referenceTimeMs: Long) {
val nextAllowedAt = referenceTimeMs + MISTRAL_MIN_INTERVAL_MS
val existing = mistralNextAllowedRequestAtMsByKey[key] ?: 0L
mistralNextAllowedRequestAtMsByKey[key] = max(existing, nextAllowedAt)
}

fun markKeyCooldown(key: String, referenceTimeMs: Long, extraDelayMs: Long) {
val normalizedExtraDelay = extraDelayMs.coerceAtLeast(0L)
val nextAllowedAt = referenceTimeMs + max(MISTRAL_MIN_INTERVAL_MS, normalizedExtraDelay)
val existing = mistralNextAllowedRequestAtMsByKey[key] ?: 0L
mistralNextAllowedRequestAtMsByKey[key] = max(existing, nextAllowedAt)
}

fun remainingWaitForKeyMs(key: String, nowMs: Long): Long {
val nextAllowedAt = mistralNextAllowedRequestAtMsByKey[key] ?: 0L
return (nextAllowedAt - nowMs).coerceAtLeast(0L)
}

fun parseRetryAfterMs(headerValue: String?): Long? {
if (headerValue.isNullOrBlank()) return null
val seconds = headerValue.trim().toDoubleOrNull() ?: return null
return (seconds * 1000.0).roundToLong().coerceAtLeast(0L)
}

fun parseRateLimitResetDelayMs(response: okhttp3.Response, nowMs: Long): Long? {
val resetHeader = response.header("x-ratelimit-reset") ?: return null
val resetEpochSeconds = resetHeader.trim().toLongOrNull() ?: return null
val resetMs = resetEpochSeconds * 1000L
return (resetMs - nowMs).coerceAtLeast(0L)
}

fun adaptiveRetryDelayMs(failureCount: Int): Long {
val cappedExponent = (failureCount - 1).coerceIn(0, 5)
return 1000L shl cappedExponent // 1s, 2s, 4s, 8s, 16s, 32s
}

fun isRetryableMistralFailure(code: Int): Boolean {
return code == 429 || code >= 500
}

var response: okhttp3.Response? = null
var selectedKeyForResponse: String? = null
var consecutiveFailures = 0
var blockedKeysThisRound = mutableSetOf<String>()

val maxAttempts = availableKeys.size * 4 + 8
while (response == null && consecutiveFailures < maxAttempts) {
if (stopExecutionFlag.get()) break

val now = System.currentTimeMillis()
val keyPool = availableKeys.filter { it !in blockedKeysThisRound }.ifEmpty {
blockedKeysThisRound.clear()
availableKeys
}

val keyWithLeastWait = keyPool.minByOrNull { remainingWaitForKeyMs(it, now) } ?: availableKeys.first()
val waitMs = remainingWaitForKeyMs(keyWithLeastWait, now)
if (waitMs > 0L) {
delay(waitMs)
}

val selectedKey = keyWithLeastWait
selectedKeyForResponse = selectedKey

try {
val attemptResponse = client.newCall(buildRequest(selectedKey)).execute()
val requestEndMs = System.currentTimeMillis()
val retryAfterMs = parseRetryAfterMs(attemptResponse.header("Retry-After"))
val resetDelayMs = parseRateLimitResetDelayMs(attemptResponse, requestEndMs)
val serverRequestedDelayMs = max(retryAfterMs ?: 0L, resetDelayMs ?: 0L)
markKeyCooldown(selectedKey, requestEndMs, serverRequestedDelayMs)

if (attemptResponse.isSuccessful) {
response = attemptResponse
break
}

val isRetryable = isRetryableMistralFailure(attemptResponse.code)
if (!isRetryable) {
val errBody = attemptResponse.body?.string()
attemptResponse.close()
throw IllegalStateException("Mistral Error ${attemptResponse.code}: $errBody")
}

attemptResponse.close()
blockedKeysThisRound.add(selectedKey)
consecutiveFailures++
val adaptiveDelay = adaptiveRetryDelayMs(consecutiveFailures)
markKeyCooldown(
selectedKey,
requestEndMs,
max(serverRequestedDelayMs, adaptiveDelay)
)
withContext(Dispatchers.Main) {
replaceAiMessageText(
"Mistral temporär nicht verfügbar (Versuch $consecutiveFailures/$maxAttempts). Warte auf Server-Rate-Limit und wiederhole...",
isPending = true
)
}
} catch (e: IOException) {
val requestEndMs = System.currentTimeMillis()
val adaptiveDelay = adaptiveRetryDelayMs(consecutiveFailures + 1)
markKeyCooldown(selectedKey, requestEndMs, adaptiveDelay)
blockedKeysThisRound.add(selectedKey)
consecutiveFailures++
if (consecutiveFailures >= maxAttempts) {
throw IOException("Mistral request failed after $maxAttempts attempts: ${e.message}", e)
}
withContext(Dispatchers.Main) {
replaceAiMessageText(
"Mistral Netzwerkfehler (Versuch $consecutiveFailures/$maxAttempts). Wiederhole...",
isPending = true
)
}
val coordinated = MistralRequestCoordinator.execute(
apiKeys = availableKeys,
maxAttempts = maxAttempts
) { selectedKey ->
if (stopExecutionFlag.get()) {
throw IOException("Mistral request aborted.")
}
client.newCall(buildRequest(selectedKey)).execute()
}

if (stopExecutionFlag.get()) {
throw IOException("Mistral request aborted.")
}

val finalResponse = response ?: throw IOException("Mistral request failed after $maxAttempts attempts.")
val finalResponse = coordinated.response

if (!finalResponse.isSuccessful) {
val errBody = finalResponse.body?.string()
Expand All @@ -1268,27 +1158,12 @@ class PhotoReasoningViewModel(

val body = finalResponse.body ?: throw IOException("Empty response body from Mistral")
val aiResponseText = openAiStreamParser.parse(body) { accText ->
selectedKeyForResponse?.let { key ->
lastMistralTokenKey = key
lastMistralTokenTimeMs = System.currentTimeMillis()
markKeyCooldown(key, lastMistralTokenTimeMs)
} ?: run {
Log.w(TAG, "selectedKeyForResponse is null during streaming callback")
}
withContext(Dispatchers.Main) {
replaceAiMessageText(accText, isPending = true)
processCommandsIncrementally(accText)
}
}
finalResponse.close()
selectedKeyForResponse?.let { key ->
val reference = if (lastMistralTokenKey == key && lastMistralTokenTimeMs > 0L) {
lastMistralTokenTimeMs
} else {
System.currentTimeMillis()
}
markKeyCooldown(key, reference)
}

withContext(Dispatchers.Main) {
_uiState.value = PhotoReasoningUiState.Success(aiResponseText)
Expand All @@ -1306,11 +1181,11 @@ class PhotoReasoningViewModel(
}
} finally {
withContext(Dispatchers.Main) {
releaseAndDrainMistralAutoScreenshotQueue()
refreshStopButtonState()
}
}
}
}

private fun reasonWithPuter(
userInput: String,
Expand Down Expand Up @@ -2404,16 +2279,22 @@ private fun processCommands(text: String) {
_commandExecutionStatus.value = status
}

// Create prompt with screen information if available
val genericAnalysisPrompt = createGenericScreenshotPrompt()

// Re-send the query with only the latest screenshot
reason(
userInput = genericAnalysisPrompt,
selectedImages = listOf(bitmap),
screenInfoForPrompt = screenInfo,
imageUrisForChat = listOf(screenshotUri.toString()) // Add this argument
)
val currentModel = GenerativeAiViewModelFactory.getCurrentModel()
if (currentModel.apiProvider == ApiProvider.MISTRAL) {
enqueueMistralAutoScreenshotRequest(
bitmap = bitmap,
screenshotUri = screenshotUri.toString(),
screenInfo = screenInfo
)
} else {
// Re-send the query with only the latest screenshot
reason(
userInput = createGenericScreenshotPrompt(),
selectedImages = listOf(bitmap),
screenInfoForPrompt = screenInfo,
imageUrisForChat = listOf(screenshotUri.toString())
)
}

PhotoReasoningScreenshotUiNotifier.showAddedToConversation(context)
} else {
Expand All @@ -2436,5 +2317,56 @@ private fun processCommands(text: String) {
}
}
}

private fun enqueueMistralAutoScreenshotRequest(
bitmap: Bitmap,
screenshotUri: String,
screenInfo: String?
) {
val request = QueuedMistralScreenshotRequest(
bitmap = bitmap,
screenshotUri = screenshotUri,
screenInfo = screenInfo
)
var shouldStartNow = false
synchronized(mistralAutoScreenshotQueueLock) {
if (mistralAutoScreenshotInFlight) {
queuedMistralScreenshotRequest = request
Log.d(TAG, "Mistral auto screenshot request queued (latest wins).")
} else {
mistralAutoScreenshotInFlight = true
shouldStartNow = true
}
}
if (shouldStartNow) {
dispatchMistralAutoScreenshotRequest(request)
}
}

private fun dispatchMistralAutoScreenshotRequest(request: QueuedMistralScreenshotRequest) {
reason(
userInput = createGenericScreenshotPrompt(),
selectedImages = listOf(request.bitmap),
screenInfoForPrompt = request.screenInfo,
imageUrisForChat = listOf(request.screenshotUri)
)
}

private fun releaseAndDrainMistralAutoScreenshotQueue() {
val nextRequest: QueuedMistralScreenshotRequest? = synchronized(mistralAutoScreenshotQueueLock) {
val queued = queuedMistralScreenshotRequest
if (queued == null) {
mistralAutoScreenshotInFlight = false
null
} else {
queuedMistralScreenshotRequest = null
queued
}
}
if (nextRequest != null) {
Log.d(TAG, "Draining queued Mistral auto screenshot request.")
dispatchMistralAutoScreenshotRequest(nextRequest)
}
}

}
Loading
Loading