diff --git a/firebase-ai/app/build.gradle.kts b/firebase-ai/app/build.gradle.kts index 2a5a4efc4..6b47e5086 100644 --- a/firebase-ai/app/build.gradle.kts +++ b/firebase-ai/app/build.gradle.kts @@ -12,7 +12,7 @@ android { defaultConfig { applicationId = "com.google.firebase.quickstart.ai" - minSdk = 23 + minSdk = 26 targetSdk = 36 versionCode = 1 versionName = "1.0" @@ -73,6 +73,7 @@ dependencies { // Firebase implementation(platform(libs.firebase.bom)) implementation(libs.firebase.ai) + implementation(libs.firebase.ai.ondevice) // Image loading implementation(libs.coil.compose) diff --git a/firebase-ai/app/src/main/java/com/google/firebase/quickstart/ai/MainActivity.kt b/firebase-ai/app/src/main/java/com/google/firebase/quickstart/ai/MainActivity.kt index 06ee42f8d..51ed6ce39 100644 --- a/firebase-ai/app/src/main/java/com/google/firebase/quickstart/ai/MainActivity.kt +++ b/firebase-ai/app/src/main/java/com/google/firebase/quickstart/ai/MainActivity.kt @@ -27,6 +27,7 @@ import androidx.navigation.compose.NavHost import androidx.navigation.compose.composable import androidx.navigation.compose.rememberNavController import com.google.firebase.quickstart.ai.feature.live.BidiViewModel +import com.google.firebase.quickstart.ai.feature.hybrid.HybridInferenceViewModel import com.google.firebase.quickstart.ai.feature.media.imagen.ImagenViewModel import com.google.firebase.quickstart.ai.feature.text.ChatViewModel import com.google.firebase.quickstart.ai.feature.text.ServerPromptTemplateViewModel @@ -36,6 +37,7 @@ import com.google.firebase.quickstart.ai.ui.ImagenScreen import com.google.firebase.quickstart.ai.ui.ServerPromptScreen import com.google.firebase.quickstart.ai.ui.StreamRealtimeScreen import com.google.firebase.quickstart.ai.ui.StreamRealtimeVideoScreen +import com.google.firebase.quickstart.ai.ui.HybridInferenceScreen import com.google.firebase.quickstart.ai.ui.SvgScreen import com.google.firebase.quickstart.ai.ui.navigation.FIREBASE_AI_SAMPLES import com.google.firebase.quickstart.ai.ui.navigation.MainMenuScreen @@ -123,6 +125,10 @@ class MainActivity : ComponentActivity() { StreamRealtimeVideoScreen(it) } } + + ScreenType.HYBRID -> { + (vm as? HybridInferenceViewModel)?.let { HybridInferenceScreen(it) } + } } } } diff --git a/firebase-ai/app/src/main/java/com/google/firebase/quickstart/ai/feature/hybrid/Expense.kt b/firebase-ai/app/src/main/java/com/google/firebase/quickstart/ai/feature/hybrid/Expense.kt new file mode 100644 index 000000000..60bf5cf32 --- /dev/null +++ b/firebase-ai/app/src/main/java/com/google/firebase/quickstart/ai/feature/hybrid/Expense.kt @@ -0,0 +1,10 @@ +package com.google.firebase.quickstart.ai.feature.hybrid + +import kotlinx.serialization.Serializable + +@Serializable +data class Expense( + val name: String, + val price: Double, + val inferenceMode: String = "" +) diff --git a/firebase-ai/app/src/main/java/com/google/firebase/quickstart/ai/feature/hybrid/HybridInferenceViewModel.kt b/firebase-ai/app/src/main/java/com/google/firebase/quickstart/ai/feature/hybrid/HybridInferenceViewModel.kt new file mode 100644 index 000000000..d0644026c --- /dev/null +++ b/firebase-ai/app/src/main/java/com/google/firebase/quickstart/ai/feature/hybrid/HybridInferenceViewModel.kt @@ -0,0 +1,153 @@ +package com.google.firebase.quickstart.ai.feature.hybrid + +import android.graphics.Bitmap +import android.util.Log +import androidx.lifecycle.ViewModel +import androidx.lifecycle.viewModelScope +import com.google.firebase.Firebase +import com.google.firebase.ai.InferenceMode +import com.google.firebase.ai.InferenceSource +import com.google.firebase.ai.OnDeviceConfig +import com.google.firebase.ai.ai +import com.google.firebase.ai.ondevice.DownloadStatus +import com.google.firebase.ai.ondevice.FirebaseAIOnDevice +import com.google.firebase.ai.ondevice.OnDeviceModelStatus +import com.google.firebase.ai.type.GenerativeBackend +import com.google.firebase.ai.type.PublicPreviewAPI +import com.google.firebase.ai.type.content +import com.google.firebase.quickstart.ai.ui.HybridInferenceUiState +import kotlinx.coroutines.flow.MutableStateFlow +import kotlinx.coroutines.flow.StateFlow +import kotlinx.coroutines.flow.asStateFlow +import kotlinx.coroutines.flow.update +import kotlinx.coroutines.launch +import kotlinx.serialization.Serializable +import kotlinx.serialization.json.Json +import java.util.UUID + +@Serializable +object HybridInferenceRoute + +@OptIn(PublicPreviewAPI::class) +class HybridInferenceViewModel : ViewModel() { + private val _uiState = MutableStateFlow( + HybridInferenceUiState( + expenses = listOf( + Expense("Lunch", 15.50, "Example data"), + Expense("Coffee", 4.75, "Example data") + ) + ) + ) + val uiState: StateFlow = _uiState.asStateFlow() + + private val model = Firebase.ai(backend = GenerativeBackend.googleAI()).generativeModel( + modelName = "gemini-3.1-flash-lite-preview", + onDeviceConfig = OnDeviceConfig(mode = InferenceMode.PREFER_ON_DEVICE) + ) + + init { + checkAndDownloadModel() + } + + private fun checkAndDownloadModel() { + viewModelScope.launch { + try { + val status = FirebaseAIOnDevice.checkStatus() + updateStatus(status) + + if (status == OnDeviceModelStatus.DOWNLOADABLE) { + FirebaseAIOnDevice.download().collect { downloadStatus -> + when (downloadStatus) { + is DownloadStatus.DownloadStarted -> { + _uiState.update { it.copy(modelStatus = "Downloading model...") } + } + + is DownloadStatus.DownloadInProgress -> { + val progress = downloadStatus.totalBytesDownloaded + _uiState.update { it.copy(modelStatus = "Downloading: $progress bytes downloaded") } + } + + is DownloadStatus.DownloadCompleted -> { + _uiState.update { it.copy(modelStatus = "Model ready") } + } + + is DownloadStatus.DownloadFailed -> { + _uiState.update { + it.copy( + modelStatus = "Download failed", errorMessage = "Model download failed" + ) + } + } + } + } + } + } catch (e: Exception) { + _uiState.update { it.copy(modelStatus = "Error checking status", errorMessage = e.message) } + } + } + } + + private fun updateStatus(status: OnDeviceModelStatus) { + val statusText = when (status) { + OnDeviceModelStatus.AVAILABLE -> "Model available" + OnDeviceModelStatus.DOWNLOADABLE -> "Model downloadable" + OnDeviceModelStatus.DOWNLOADING -> "Model downloading..." + OnDeviceModelStatus.UNAVAILABLE -> "On-device model unavailable" + else -> "Unknown" + } + _uiState.update { it.copy(modelStatus = statusText) } + } + + fun scanReceipt(bitmap: Bitmap) { + viewModelScope.launch { + _uiState.update { it.copy(isScanning = true, errorMessage = null) } + try { + val prompt = content { + image(bitmap) + text( + """ + Extract the store name and the total price from this receipt. + Output only in JSON format containg 2 fields '{name,price}'. + Do not include any currency signs or backticks or any text around it. + Use dots for decimals. + Examples: + - {"name": "FakeStore", "price": "2.0"} + - {"name": "SomeMarket", "price": "3.5"} + """.trimIndent() + ) + } + + val response = model.generateContent(prompt) + val text = response.text + val inferenceMode = if (response.inferenceSource == InferenceSource.ON_DEVICE) { + "On-device" + } else { + "Cloud" + } + Log.d("HybridVM", "$inferenceMode response: $text") + if (text != null) { + parseAndAddExpense(text, inferenceMode) + } else { + _uiState.update { it.copy(errorMessage = "Could not extract data") } + } + } catch (e: Exception) { + _uiState.update { it.copy(errorMessage = "Error: ${e.message}") } + } finally { + _uiState.update { it.copy(isScanning = false) } + } + } + } + + private fun parseAndAddExpense(text: String, inferenceMode: String) { + val json = text + // The on-device model sometimes outputs backticks, so we remove those + .replace("```json", "") + .replace("```", "") + try { + val newExpense = Json.decodeFromString(json).copy(inferenceMode = inferenceMode) + _uiState.update { it.copy(expenses = it.expenses + newExpense) } + } catch (e: Exception) { + _uiState.update { it.copy(errorMessage = e.localizedMessage) } + } + } +} diff --git a/firebase-ai/app/src/main/java/com/google/firebase/quickstart/ai/ui/HybridInferenceScreen.kt b/firebase-ai/app/src/main/java/com/google/firebase/quickstart/ai/ui/HybridInferenceScreen.kt new file mode 100644 index 000000000..dd04de643 --- /dev/null +++ b/firebase-ai/app/src/main/java/com/google/firebase/quickstart/ai/ui/HybridInferenceScreen.kt @@ -0,0 +1,206 @@ +package com.google.firebase.quickstart.ai.ui + +import android.Manifest +import android.content.pm.PackageManager +import androidx.activity.compose.rememberLauncherForActivityResult +import androidx.activity.result.contract.ActivityResultContracts +import androidx.core.content.ContextCompat +import androidx.compose.foundation.layout.Arrangement +import androidx.compose.foundation.layout.Box +import androidx.compose.foundation.layout.Column +import androidx.compose.foundation.layout.Row +import androidx.compose.foundation.layout.Spacer +import androidx.compose.foundation.layout.fillMaxSize +import androidx.compose.foundation.layout.fillMaxWidth +import androidx.compose.foundation.layout.height +import androidx.compose.foundation.layout.padding +import androidx.compose.foundation.layout.size +import androidx.compose.foundation.lazy.LazyColumn +import androidx.compose.foundation.lazy.items +import androidx.compose.material.icons.Icons +import androidx.compose.material.icons.filled.Add +import androidx.compose.material.icons.filled.CameraAlt +import androidx.compose.material.icons.filled.ReceiptLong +import androidx.compose.material3.Card +import androidx.compose.material3.CardDefaults +import androidx.compose.material3.CircularProgressIndicator +import androidx.compose.material3.FloatingActionButton +import androidx.compose.material3.HorizontalDivider +import androidx.compose.material3.Icon +import androidx.compose.material3.MaterialTheme +import androidx.compose.material3.Scaffold +import androidx.compose.material3.Text +import androidx.compose.runtime.Composable +import androidx.compose.runtime.collectAsState +import androidx.compose.runtime.getValue +import androidx.compose.ui.Alignment +import androidx.compose.ui.Modifier +import androidx.compose.ui.graphics.Color +import androidx.compose.ui.platform.LocalContext +import androidx.compose.ui.text.font.FontWeight +import androidx.compose.ui.unit.dp +import androidx.compose.ui.unit.sp +import androidx.lifecycle.viewmodel.compose.viewModel +import com.google.firebase.quickstart.ai.feature.hybrid.HybridInferenceViewModel + +@Composable +fun HybridInferenceScreen( + viewModel: HybridInferenceViewModel = viewModel() +) { + val uiState by viewModel.uiState.collectAsState() + val context = LocalContext.current + + val cameraLauncher = rememberLauncherForActivityResult( + contract = ActivityResultContracts.TakePicturePreview(), + onResult = { bitmap -> + bitmap?.let { viewModel.scanReceipt(it) } + } + ) + + val permissionLauncher = rememberLauncherForActivityResult( + ActivityResultContracts.RequestPermission() + ) { isGranted -> + if (isGranted) { + cameraLauncher.launch(null) + } + } + + Scaffold( + floatingActionButton = { + FloatingActionButton( + onClick = { + val permissionCheckResult = + ContextCompat.checkSelfPermission(context, Manifest.permission.CAMERA) + if (permissionCheckResult == PackageManager.PERMISSION_GRANTED) { + cameraLauncher.launch(null) + } else { + permissionLauncher.launch(Manifest.permission.CAMERA) + } + }, + containerColor = MaterialTheme.colorScheme.primary, + contentColor = MaterialTheme.colorScheme.onPrimary + ) { + if (uiState.isScanning) { + CircularProgressIndicator( + modifier = Modifier.size(24.dp), + color = MaterialTheme.colorScheme.onPrimary, + strokeWidth = 2.dp + ) + } else { + Icon(Icons.Default.CameraAlt, contentDescription = "Scan Receipt") + } + } + } + ) { padding -> + Column( + modifier = Modifier + .fillMaxSize() + .padding(padding) + .padding(16.dp) + ) { + // Model Status Card + Card( + modifier = Modifier.fillMaxWidth(), + colors = CardDefaults.cardColors( + containerColor = MaterialTheme.colorScheme.secondaryContainer + ) + ) { + Row( + modifier = Modifier.padding(12.dp), + verticalAlignment = Alignment.CenterVertically + ) { + Icon( + Icons.Default.ReceiptLong, + contentDescription = null, + tint = MaterialTheme.colorScheme.onSecondaryContainer + ) + Spacer(modifier = Modifier.size(12.dp)) + Column { + Text( + "Hybrid AI Status", + style = MaterialTheme.typography.labelMedium, + color = MaterialTheme.colorScheme.onSecondaryContainer + ) + Text( + uiState.modelStatus, + style = MaterialTheme.typography.bodyMedium, + fontWeight = FontWeight.Bold, + color = MaterialTheme.colorScheme.onSecondaryContainer + ) + } + } + } + + Spacer(modifier = Modifier.height(16.dp)) + + Text( + "Expenses", + style = MaterialTheme.typography.headlineSmall, + fontWeight = FontWeight.Bold + ) + + Spacer(modifier = Modifier.height(8.dp)) + + if (uiState.expenses.isEmpty()) { + Box(modifier = Modifier.fillMaxSize(), contentAlignment = Alignment.Center) { + Text("No expenses yet. Scan a receipt to add one.", color = Color.Gray) + } + } else { + LazyColumn( + modifier = Modifier.weight(1f), + verticalArrangement = Arrangement.spacedBy(8.dp) + ) { + items(uiState.expenses) { expense -> + ExpenseItem(expense.name, expense.price, expense.inferenceMode) + } + } + } + + if (uiState.errorMessage != null) { + Spacer(modifier = Modifier.height(8.dp)) + Text( + text = uiState.errorMessage!!, + color = MaterialTheme.colorScheme.error, + style = MaterialTheme.typography.bodySmall + ) + } + } + } +} + +@Composable +fun ExpenseItem(name: String, price: Double, inferenceMode: String) { + Card( + modifier = Modifier.fillMaxWidth(), + elevation = CardDefaults.cardElevation(defaultElevation = 2.dp) + ) { + Row( + modifier = Modifier + .padding(16.dp) + .fillMaxWidth(), + horizontalArrangement = Arrangement.SpaceBetween, + verticalAlignment = Alignment.CenterVertically + ) { + Column { + Text( + name, + style = MaterialTheme.typography.bodyLarge, + fontWeight = FontWeight.Medium + ) + if (inferenceMode.isNotEmpty()) { + Text( + inferenceMode, + style = MaterialTheme.typography.labelSmall, + color = Color.Gray + ) + } + } + Text( + "$${String.format("%.2f", price)}", + style = MaterialTheme.typography.bodyLarge, + fontWeight = FontWeight.Bold, + color = MaterialTheme.colorScheme.primary + ) + } + } +} diff --git a/firebase-ai/app/src/main/java/com/google/firebase/quickstart/ai/ui/HybridInferenceUiState.kt b/firebase-ai/app/src/main/java/com/google/firebase/quickstart/ai/ui/HybridInferenceUiState.kt new file mode 100644 index 000000000..eb61e6e2a --- /dev/null +++ b/firebase-ai/app/src/main/java/com/google/firebase/quickstart/ai/ui/HybridInferenceUiState.kt @@ -0,0 +1,10 @@ +package com.google.firebase.quickstart.ai.ui + +import com.google.firebase.quickstart.ai.feature.hybrid.Expense + +data class HybridInferenceUiState( + val expenses: List = emptyList(), + val isScanning: Boolean = false, + val modelStatus: String = "Checking model status...", + val errorMessage: String? = null +) diff --git a/firebase-ai/app/src/main/java/com/google/firebase/quickstart/ai/ui/navigation/FirebaseAISamples.kt b/firebase-ai/app/src/main/java/com/google/firebase/quickstart/ai/ui/navigation/FirebaseAISamples.kt index 8bb2cb153..a72196be2 100644 --- a/firebase-ai/app/src/main/java/com/google/firebase/quickstart/ai/ui/navigation/FirebaseAISamples.kt +++ b/firebase-ai/app/src/main/java/com/google/firebase/quickstart/ai/ui/navigation/FirebaseAISamples.kt @@ -2,6 +2,8 @@ package com.google.firebase.quickstart.ai.ui.navigation import com.google.firebase.quickstart.ai.feature.live.StreamAudioViewModel import com.google.firebase.quickstart.ai.feature.live.StreamVideoViewModel +import com.google.firebase.quickstart.ai.feature.hybrid.HybridInferenceRoute +import com.google.firebase.quickstart.ai.feature.hybrid.HybridInferenceViewModel import com.google.firebase.quickstart.ai.feature.live.StreamRealtimeAudioRoute import com.google.firebase.quickstart.ai.feature.live.StreamRealtimeVideoRoute import com.google.firebase.quickstart.ai.feature.media.imagen.ImagenGenerationRoute @@ -239,5 +241,13 @@ val FIREBASE_AI_SAMPLES = listOf( screenType = ScreenType.SVG, viewModelClass = SvgViewModel::class, categories = listOf(Category.IMAGE, Category.TEXT) + ), + Sample( + title = "Hybrid Receipt Scanner", + description = "Use hybrid inference to scan receipts and extract expense data on-device whenever possible.", + route = HybridInferenceRoute, + screenType = ScreenType.HYBRID, + viewModelClass = HybridInferenceViewModel::class, + categories = listOf(Category.TEXT, Category.IMAGE, Category.HYBRID) ) ) diff --git a/firebase-ai/app/src/main/java/com/google/firebase/quickstart/ai/ui/navigation/Sample.kt b/firebase-ai/app/src/main/java/com/google/firebase/quickstart/ai/ui/navigation/Sample.kt index 76bb0c934..a51b56315 100644 --- a/firebase-ai/app/src/main/java/com/google/firebase/quickstart/ai/ui/navigation/Sample.kt +++ b/firebase-ai/app/src/main/java/com/google/firebase/quickstart/ai/ui/navigation/Sample.kt @@ -12,7 +12,8 @@ enum class Category( AUDIO("Audio"), DOCUMENT("Document"), FUNCTION_CALLING("Function calling"), - LIVE_API("Live API Streaming") + LIVE_API("Live API Streaming"), + HYBRID("Hybrid inference") } enum class ScreenType { @@ -21,7 +22,8 @@ enum class ScreenType { SVG, SERVER_PROMPT, BIDI, - BIDI_VIDEO + BIDI_VIDEO, + HYBRID } data class Sample( diff --git a/gradle/libs.versions.toml b/gradle/libs.versions.toml index 0afc49131..5fc9de341 100644 --- a/gradle/libs.versions.toml +++ b/gradle/libs.versions.toml @@ -7,7 +7,7 @@ composeBom = "2025.12.00" composeNavigation = "2.9.6" coreKtx = "1.17.0" espressoCore = "3.7.0" -firebaseBom = "34.7.0" +firebaseBom = "34.11.0" googleServices = "4.4.4" firebaseCrashlytics = "3.0.6" firebasePerf = "2.0.2" @@ -52,6 +52,7 @@ coil-network-okhttp = { module = "io.coil-kt.coil3:coil-network-okhttp", version coil-svg = { module = "io.coil-kt.coil3:coil-svg", version.ref = "coil3Compose" } compose-navigation = { group = "androidx.navigation", name = "navigation-compose", version.ref = "composeNavigation"} firebase-ai = { module = "com.google.firebase:firebase-ai" } +firebase-ai-ondevice = { module = "com.google.firebase:firebase-ai-ondevice", version = "16.0.0-beta01" } firebase-bom = { module = "com.google.firebase:firebase-bom", version.ref = "firebaseBom" } junit = { group = "junit", name = "junit", version.ref = "junit" } kotlinx-serialization-core = { module = "org.jetbrains.kotlinx:kotlinx-serialization-core", version.ref = "kotlinxSerializationCore" }