Skip to content

Commit 6e135a2

Browse files
authored
PTE+PTD support (#193)
- Add ModelConfiguration data class for individual PTE model configs - Add LoRA mode toggle in settings to enable multi-model selection - Add multi-model management UI with add/remove model flows - Add in-chat model switcher button and dialog for switching between loaded models - Support caching loaded LlmModule instances for instant model switching - Add shared data path (PTD) support for LoRA adapters - Maintain backward compatibility with legacy single-model mode
1 parent d624ad7 commit 6e135a2

8 files changed

Lines changed: 1296 additions & 66 deletions

File tree

llm/android/LlamaDemo/app/build.gradle.kts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -268,14 +268,14 @@ dependencies {
268268
if (useLocalAar == true) {
269269
implementation(files("libs/executorch.aar"))
270270
} else {
271-
implementation("org.pytorch:executorch-android:1.0.1")
271+
implementation("org.pytorch:executorch-android:1.1.0")
272272
// https://mvnrepository.com/artifact/org.pytorch/executorch-android-qnn
273273
// Uncomment this to enable QNN
274-
// implementation("org.pytorch:executorch-android-qnn:1.0.1")
274+
// implementation("org.pytorch:executorch-android-qnn:1.1.0")
275275

276276
// https://mvnrepository.com/artifact/org.pytorch/executorch-android-vulkan
277277
// uncomment to enable vulkan
278-
// implementation("org.pytorch:executorch-android-vulkan:1.0.1")
278+
// implementation("org.pytorch:executorch-android-vulkan:1.1.0")
279279
}
280280
implementation("com.google.android.material:material:1.12.0")
281281
implementation("androidx.activity:activity:1.9.0")
Lines changed: 65 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
package com.example.executorchllamademo
10+
11+
/**
12+
* Represents a single PTE model configuration.
13+
* Multiple ModelConfigurations can share the same PTD (data path) file for LoRA support.
14+
*/
15+
data class ModelConfiguration(
16+
val id: String = "",
17+
val modelFilePath: String = "",
18+
val tokenizerFilePath: String = "",
19+
val modelType: ModelType = ModelType.LLAMA_3,
20+
val backendType: BackendType = BackendType.XNNPACK,
21+
val temperature: Double = ModuleSettings.DEFAULT_TEMPERATURE,
22+
val displayName: String = "",
23+
val adapterFilePaths: List<String> = emptyList()
24+
) {
25+
companion object {
26+
fun create(
27+
modelFilePath: String,
28+
tokenizerFilePath: String,
29+
modelType: ModelType,
30+
backendType: BackendType,
31+
temperature: Double
32+
): ModelConfiguration {
33+
return ModelConfiguration(
34+
id = generateId(modelFilePath),
35+
modelFilePath = modelFilePath,
36+
tokenizerFilePath = tokenizerFilePath,
37+
modelType = modelType,
38+
backendType = backendType,
39+
temperature = temperature,
40+
displayName = extractDisplayName(modelFilePath)
41+
)
42+
}
43+
44+
private fun generateId(modelFilePath: String): String {
45+
return modelFilePath.hashCode().toString()
46+
}
47+
48+
private fun extractDisplayName(filePath: String): String {
49+
if (filePath.isEmpty()) return ""
50+
return filePath.substringAfterLast('/')
51+
}
52+
}
53+
54+
fun isValid(): Boolean {
55+
return modelFilePath.isNotEmpty() && tokenizerFilePath.isNotEmpty()
56+
}
57+
58+
fun withModelFilePath(path: String): ModelConfiguration {
59+
return copy(
60+
modelFilePath = path,
61+
id = generateId(path),
62+
displayName = extractDisplayName(path)
63+
)
64+
}
65+
}

llm/android/LlamaDemo/app/src/main/java/com/example/executorchllamademo/ModuleSettings.kt

Lines changed: 118 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,8 +10,10 @@ package com.example.executorchllamademo
1010

1111
/**
1212
* Holds module-specific settings for the current model/tokenizer configuration.
13+
* Supports both legacy single-model and multi-model configurations for LoRA support.
1314
*/
1415
data class ModuleSettings(
16+
// Legacy single-model fields (kept for backward compatibility)
1517
val modelFilePath: String = "",
1618
val tokenizerFilePath: String = "",
1719
val dataPath: String = "",
@@ -21,22 +23,135 @@ data class ModuleSettings(
2123
val modelType: ModelType = DEFAULT_MODEL,
2224
val backendType: BackendType = DEFAULT_BACKEND,
2325
val isClearChatHistory: Boolean = false,
24-
val isLoadModel: Boolean = false
26+
val isLoadModel: Boolean = false,
27+
28+
// LoRA mode toggle - when enabled, allows multiple model selection
29+
val isLoraMode: Boolean = false,
30+
31+
// Foundation PTD path - shared base weights for all LoRA models
32+
val foundationDataPath: String = "",
33+
34+
// Multi-model support fields (used when isLoraMode is true)
35+
val models: List<ModelConfiguration> = emptyList(),
36+
val activeModelId: String = "",
37+
val sharedDataPath: String = "",
38+
val foundationModelType: ModelType = ModelType.LLAMA_3
2539
) {
40+
/**
41+
* Gets the effective model type, considering multi-model configuration.
42+
*/
43+
fun getEffectiveModelType(): ModelType {
44+
val activeModel = getActiveModel()
45+
return activeModel?.modelType ?: modelType
46+
}
47+
2648
fun getFormattedSystemPrompt(): String {
27-
return PromptFormat.getSystemPromptTemplate(modelType)
49+
return PromptFormat.getSystemPromptTemplate(getEffectiveModelType())
2850
.replace(PromptFormat.SYSTEM_PLACEHOLDER, systemPrompt)
2951
}
3052

3153
fun getFormattedUserPrompt(prompt: String, thinkingMode: Boolean): String {
54+
val effectiveType = getEffectiveModelType()
3255
return userPrompt
3356
.replace(PromptFormat.USER_PLACEHOLDER, prompt)
3457
.replace(
3558
PromptFormat.THINKING_MODE_PLACEHOLDER,
36-
PromptFormat.getThinkingModeToken(modelType, thinkingMode)
59+
PromptFormat.getThinkingModeToken(effectiveType, thinkingMode)
3760
)
3861
}
3962

63+
/**
64+
* Gets the active model configuration if using multi-model mode.
65+
*/
66+
fun getActiveModel(): ModelConfiguration? {
67+
if (models.isEmpty() || activeModelId.isEmpty()) return null
68+
return models.find { it.id == activeModelId }
69+
}
70+
71+
/**
72+
* Gets a model configuration by ID.
73+
*/
74+
fun getModelById(modelId: String): ModelConfiguration? {
75+
return models.find { it.id == modelId }
76+
}
77+
78+
/**
79+
* Gets the effective foundation data path for LoRA mode.
80+
* Priority: foundationDataPath > sharedDataPath > dataPath
81+
*/
82+
fun getEffectiveDataPath(): String {
83+
return foundationDataPath.ifEmpty { sharedDataPath.ifEmpty { dataPath } }
84+
}
85+
86+
/**
87+
* Checks if there are multiple models configured.
88+
*/
89+
fun hasMultipleModels(): Boolean = models.size > 1
90+
91+
/**
92+
* Checks if any models are configured.
93+
*/
94+
fun hasModels(): Boolean = models.isNotEmpty()
95+
96+
/**
97+
* Adds a model to the list. If a model with the same ID exists, it's replaced.
98+
* If this is the first model, it becomes the active model.
99+
*/
100+
fun addModel(model: ModelConfiguration): ModuleSettings {
101+
val existingIndex = models.indexOfFirst { it.id == model.id }
102+
val newModels = if (existingIndex >= 0) {
103+
models.toMutableList().apply { this[existingIndex] = model }
104+
} else {
105+
models + model
106+
}
107+
val newActiveId = if (models.isEmpty()) model.id else activeModelId
108+
return copy(models = newModels, activeModelId = newActiveId)
109+
}
110+
111+
/**
112+
* Removes a model by ID. If the active model is removed, selects another.
113+
*/
114+
fun removeModel(modelId: String): ModuleSettings {
115+
val newModels = models.filter { it.id != modelId }
116+
val newActiveId = if (activeModelId == modelId) {
117+
newModels.firstOrNull()?.id ?: ""
118+
} else {
119+
activeModelId
120+
}
121+
return copy(models = newModels, activeModelId = newActiveId)
122+
}
123+
124+
/**
125+
* Sets the active model by ID.
126+
*/
127+
fun setActiveModel(modelId: String): ModuleSettings {
128+
return copy(activeModelId = modelId)
129+
}
130+
131+
/**
132+
* Migrates legacy single-model settings to multi-model format.
133+
*/
134+
fun migrateToMultiModel(): ModuleSettings {
135+
if (models.isNotEmpty()) return this
136+
137+
// Only migrate if there's a valid legacy model configuration
138+
if (modelFilePath.isEmpty() || tokenizerFilePath.isEmpty()) return this
139+
140+
val legacyModel = ModelConfiguration.create(
141+
modelFilePath = modelFilePath,
142+
tokenizerFilePath = tokenizerFilePath,
143+
modelType = modelType,
144+
backendType = backendType,
145+
temperature = temperature
146+
)
147+
148+
return copy(
149+
models = listOf(legacyModel),
150+
activeModelId = legacyModel.id,
151+
sharedDataPath = dataPath
152+
)
153+
}
154+
40155
companion object {
41156
const val DEFAULT_TEMPERATURE = 0.0
42157
val DEFAULT_MODEL = ModelType.LLAMA_3
Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
/*
2+
* Copyright (c) Meta Platforms, Inc. and affiliates.
3+
* All rights reserved.
4+
*
5+
* This source code is licensed under the BSD-style license found in the
6+
* LICENSE file in the root directory of this source tree.
7+
*/
8+
9+
package com.example.executorchllamademo.ui.components
10+
11+
import androidx.compose.foundation.clickable
12+
import androidx.compose.foundation.layout.Column
13+
import androidx.compose.foundation.layout.Row
14+
import androidx.compose.foundation.layout.Spacer
15+
import androidx.compose.foundation.layout.fillMaxWidth
16+
import androidx.compose.foundation.layout.padding
17+
import androidx.compose.foundation.layout.size
18+
import androidx.compose.foundation.layout.width
19+
import androidx.compose.material.icons.Icons
20+
import androidx.compose.material.icons.filled.Close
21+
import androidx.compose.material3.Icon
22+
import androidx.compose.material3.IconButton
23+
import androidx.compose.material3.RadioButton
24+
import androidx.compose.material3.RadioButtonDefaults
25+
import androidx.compose.material3.Text
26+
import androidx.compose.runtime.Composable
27+
import androidx.compose.ui.Alignment
28+
import androidx.compose.ui.Modifier
29+
import androidx.compose.ui.graphics.Color
30+
import androidx.compose.ui.text.font.FontWeight
31+
import androidx.compose.ui.text.style.TextOverflow
32+
import androidx.compose.ui.unit.dp
33+
import androidx.compose.ui.unit.sp
34+
import com.example.executorchllamademo.ModelConfiguration
35+
import com.example.executorchllamademo.ui.theme.LocalAppColors
36+
37+
/**
38+
* A composable that displays a single model configuration in a list.
39+
* Shows the model name, type, backend, and tokenizer, with a radio button
40+
* for selection and a remove button.
41+
*/
42+
@Composable
43+
fun ModelListItem(
44+
model: ModelConfiguration,
45+
isActive: Boolean,
46+
onSelect: () -> Unit,
47+
onRemove: () -> Unit,
48+
modifier: Modifier = Modifier
49+
) {
50+
val appColors = LocalAppColors.current
51+
52+
Row(
53+
modifier = modifier
54+
.fillMaxWidth()
55+
.clickable(onClick = onSelect)
56+
.padding(horizontal = 12.dp, vertical = 8.dp),
57+
verticalAlignment = Alignment.CenterVertically
58+
) {
59+
RadioButton(
60+
selected = isActive,
61+
onClick = onSelect,
62+
colors = RadioButtonDefaults.colors(
63+
selectedColor = appColors.settingsText,
64+
unselectedColor = appColors.settingsSecondaryText
65+
)
66+
)
67+
68+
Spacer(modifier = Modifier.width(8.dp))
69+
70+
Column(
71+
modifier = Modifier.weight(1f)
72+
) {
73+
// Model name (display name from file path)
74+
Text(
75+
text = model.displayName.ifEmpty { "Unknown Model" },
76+
color = appColors.settingsText,
77+
fontSize = 16.sp,
78+
fontWeight = FontWeight.Bold,
79+
maxLines = 1,
80+
overflow = TextOverflow.Ellipsis
81+
)
82+
83+
// Tokenizer name
84+
val tokenizerName = model.tokenizerFilePath.substringAfterLast('/').ifEmpty { "No tokenizer" }
85+
Text(
86+
text = tokenizerName,
87+
color = appColors.settingsSecondaryText.copy(alpha = 0.7f),
88+
fontSize = 11.sp,
89+
maxLines = 1,
90+
overflow = TextOverflow.Ellipsis
91+
)
92+
}
93+
94+
IconButton(
95+
onClick = onRemove,
96+
modifier = Modifier.size(40.dp)
97+
) {
98+
Icon(
99+
imageVector = Icons.Filled.Close,
100+
contentDescription = "Remove model",
101+
tint = Color(0xFFFF6666)
102+
)
103+
}
104+
}
105+
}

0 commit comments

Comments
 (0)