Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion firebase-ai/app/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ android {

defaultConfig {
applicationId = "com.google.firebase.quickstart.ai"
minSdk = 23
minSdk = 26
targetSdk = 36
versionCode = 1
versionName = "1.0"
Expand Down Expand Up @@ -73,6 +73,7 @@ dependencies {
// Firebase
implementation(platform(libs.firebase.bom))
implementation(libs.firebase.ai)
implementation(libs.firebase.ai.ondevice)

// Image loading
implementation(libs.coil.compose)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import androidx.navigation.compose.NavHost
import androidx.navigation.compose.composable
import androidx.navigation.compose.rememberNavController
import com.google.firebase.quickstart.ai.feature.live.BidiViewModel
import com.google.firebase.quickstart.ai.feature.hybrid.HybridInferenceViewModel
import com.google.firebase.quickstart.ai.feature.media.imagen.ImagenViewModel
import com.google.firebase.quickstart.ai.feature.text.ChatViewModel
import com.google.firebase.quickstart.ai.feature.text.ServerPromptTemplateViewModel
Expand All @@ -36,6 +37,7 @@ import com.google.firebase.quickstart.ai.ui.ImagenScreen
import com.google.firebase.quickstart.ai.ui.ServerPromptScreen
import com.google.firebase.quickstart.ai.ui.StreamRealtimeScreen
import com.google.firebase.quickstart.ai.ui.StreamRealtimeVideoScreen
import com.google.firebase.quickstart.ai.ui.HybridInferenceScreen
import com.google.firebase.quickstart.ai.ui.SvgScreen
import com.google.firebase.quickstart.ai.ui.navigation.FIREBASE_AI_SAMPLES
import com.google.firebase.quickstart.ai.ui.navigation.MainMenuScreen
Expand Down Expand Up @@ -123,6 +125,10 @@ class MainActivity : ComponentActivity() {
StreamRealtimeVideoScreen(it)
}
}

ScreenType.HYBRID -> {
(vm as? HybridInferenceViewModel)?.let { HybridInferenceScreen(it) }
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
package com.google.firebase.quickstart.ai.feature.hybrid

import kotlinx.serialization.Serializable

@Serializable
data class Expense(
val name: String,
val price: Double,
val inferenceMode: String = ""
)
Original file line number Diff line number Diff line change
@@ -0,0 +1,153 @@
package com.google.firebase.quickstart.ai.feature.hybrid

import android.graphics.Bitmap
import android.util.Log
import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope
import com.google.firebase.Firebase
import com.google.firebase.ai.InferenceMode
import com.google.firebase.ai.InferenceSource
import com.google.firebase.ai.OnDeviceConfig
import com.google.firebase.ai.ai
import com.google.firebase.ai.ondevice.DownloadStatus
import com.google.firebase.ai.ondevice.FirebaseAIOnDevice
import com.google.firebase.ai.ondevice.OnDeviceModelStatus
import com.google.firebase.ai.type.GenerativeBackend
import com.google.firebase.ai.type.PublicPreviewAPI
import com.google.firebase.ai.type.content
import com.google.firebase.quickstart.ai.ui.HybridInferenceUiState
import kotlinx.coroutines.flow.MutableStateFlow
import kotlinx.coroutines.flow.StateFlow
import kotlinx.coroutines.flow.asStateFlow
import kotlinx.coroutines.flow.update
import kotlinx.coroutines.launch
import kotlinx.serialization.Serializable
import kotlinx.serialization.json.Json
import java.util.UUID

@Serializable
object HybridInferenceRoute

@OptIn(PublicPreviewAPI::class)
class HybridInferenceViewModel : ViewModel() {
private val _uiState = MutableStateFlow(
HybridInferenceUiState(
expenses = listOf(
Expense("Lunch", 15.50, "Example data"),
Expense("Coffee", 4.75, "Example data")
)
)
)
val uiState: StateFlow<HybridInferenceUiState> = _uiState.asStateFlow()

private val model = Firebase.ai(backend = GenerativeBackend.googleAI()).generativeModel(
modelName = "gemini-3.1-flash-lite-preview",
onDeviceConfig = OnDeviceConfig(mode = InferenceMode.PREFER_ON_DEVICE)
)

init {
checkAndDownloadModel()
}

private fun checkAndDownloadModel() {
viewModelScope.launch {
try {
val status = FirebaseAIOnDevice.checkStatus()
updateStatus(status)

if (status == OnDeviceModelStatus.DOWNLOADABLE) {
FirebaseAIOnDevice.download().collect { downloadStatus ->
when (downloadStatus) {
is DownloadStatus.DownloadStarted -> {
_uiState.update { it.copy(modelStatus = "Downloading model...") }
}

is DownloadStatus.DownloadInProgress -> {
val progress = downloadStatus.totalBytesDownloaded
_uiState.update { it.copy(modelStatus = "Downloading: $progress bytes downloaded") }
}

is DownloadStatus.DownloadCompleted -> {
_uiState.update { it.copy(modelStatus = "Model ready") }
}

is DownloadStatus.DownloadFailed -> {
_uiState.update {
it.copy(
modelStatus = "Download failed", errorMessage = "Model download failed"
)
}
}
}
}
}
} catch (e: Exception) {
_uiState.update { it.copy(modelStatus = "Error checking status", errorMessage = e.message) }
}
}
}

private fun updateStatus(status: OnDeviceModelStatus) {
val statusText = when (status) {
OnDeviceModelStatus.AVAILABLE -> "Model available"
OnDeviceModelStatus.DOWNLOADABLE -> "Model downloadable"
OnDeviceModelStatus.DOWNLOADING -> "Model downloading..."
OnDeviceModelStatus.UNAVAILABLE -> "On-device model unavailable"
else -> "Unknown"
}
_uiState.update { it.copy(modelStatus = statusText) }
}

fun scanReceipt(bitmap: Bitmap) {
viewModelScope.launch {
_uiState.update { it.copy(isScanning = true, errorMessage = null) }
try {
val prompt = content {
image(bitmap)
text(
"""
Extract the store name and the total price from this receipt.
Output only in JSON format containg 2 fields '{name,price}'.
Do not include any currency signs or backticks or any text around it.
Use dots for decimals.
Examples:
- {"name": "FakeStore", "price": "2.0"}
- {"name": "SomeMarket", "price": "3.5"}
""".trimIndent()
)
}

val response = model.generateContent(prompt)
val text = response.text
val inferenceMode = if (response.inferenceSource == InferenceSource.ON_DEVICE) {
"On-device"
} else {
"Cloud"
}
Log.d("HybridVM", "$inferenceMode response: $text")
if (text != null) {
parseAndAddExpense(text, inferenceMode)
} else {
_uiState.update { it.copy(errorMessage = "Could not extract data") }
}
} catch (e: Exception) {
_uiState.update { it.copy(errorMessage = "Error: ${e.message}") }
} finally {
_uiState.update { it.copy(isScanning = false) }
}
}
}

private fun parseAndAddExpense(text: String, inferenceMode: String) {
val json = text
// The on-device model sometimes outputs backticks, so we remove those
.replace("```json", "")
.replace("```", "")
try {
val newExpense = Json.decodeFromString<Expense>(json).copy(inferenceMode = inferenceMode)
_uiState.update { it.copy(expenses = it.expenses + newExpense) }
} catch (e: Exception) {
_uiState.update { it.copy(errorMessage = e.localizedMessage) }
}
}
}
Loading
Loading