-
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathminimal_dpg_ui.py
More file actions
executable file
·381 lines (324 loc) · 13.6 KB
/
minimal_dpg_ui.py
File metadata and controls
executable file
·381 lines (324 loc) · 13.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
#!/usr/bin/env python3
"""
Minimal OneTrainer DPG UI that avoids problematic components
and focuses on core functionality.
"""
import os
import sys
import time
from typing import Dict, List, Callable, Optional, Union
try:
import dearpygui.dearpygui as dpg
except ImportError:
print("DearPyGui not found. Please install it with: pip install dearpygui")
sys.exit(1)
# Add the current directory to the path
current_dir = os.path.dirname(os.path.abspath(__file__))
if current_dir not in sys.path:
sys.path.insert(0, current_dir)
# Try to import OneTrainer modules
try:
from modules.util.enum.ModelType import ModelType, PeftType
from modules.util.enum.TrainingMethod import TrainingMethod
from modules.util.config.TrainConfig import TrainConfig
except ImportError as e:
print(f"Error importing OneTrainer modules: {e}")
print("Make sure you're in the OneTrainer directory.")
sys.exit(1)
class MinimalOneTrainerUI:
"""Minimal OneTrainer UI with DPG that focuses on essential functionality"""
def __init__(self):
"""Initialize the application"""
# Initialize DPG
dpg.create_context()
# Create viewport
dpg.create_viewport(
title="OneTrainer Minimal DPG UI",
width=1100,
height=800,
min_width=800,
min_height=600
)
# Set up state
self.train_config = TrainConfig.default_values()
self.current_tab = None
# Create main window
with dpg.window(tag="main_window", label="OneTrainer"):
# Menu bar
with dpg.menu_bar():
with dpg.menu(label="File"):
dpg.add_menu_item(label="Exit", callback=self.on_exit)
with dpg.menu(label="Help"):
dpg.add_menu_item(label="About", callback=self.show_about)
# Main layout
with dpg.group(horizontal=True):
# Left sidebar for tabs
with dpg.child_window(width=150, height=-1, tag="tabs_sidebar"):
# Tab buttons
self.add_tab_button("Model", self.show_model_tab)
self.add_tab_button("LoRA / LyCORIS", self.show_lora_tab)
self.add_tab_button("Training", self.show_training_tab)
self.add_tab_button("Concept", self.show_concept_tab)
self.add_tab_button("Sampling", self.show_sampling_tab)
self.add_tab_button("Cloud", self.show_cloud_tab)
# Right content area for tab content
self.tab_content = dpg.add_child_window(width=-1, height=-1, tag="tab_content")
# Set up tab content
self.create_model_tab()
self.create_lora_tab()
self.create_training_tab()
self.create_concept_tab()
self.create_sampling_tab()
self.create_cloud_tab()
# Hide all tabs initially
self.hide_all_tabs()
# Show model tab by default
self.show_model_tab()
# Setup viewport
dpg.setup_dearpygui()
dpg.set_primary_window("main_window", True)
def run(self):
"""Run the application"""
dpg.show_viewport()
# Main loop
while dpg.is_dearpygui_running():
dpg.render_dearpygui_frame()
time.sleep(0.01)
dpg.destroy_context()
def add_tab_button(self, label, callback):
"""Add a tab button to the sidebar"""
button_id = dpg.add_button(
label=label,
width=-1,
height=30,
callback=callback
)
return button_id
def hide_all_tabs(self):
"""Hide all tab content"""
for child in dpg.get_item_children(self.tab_content, slot=1):
dpg.hide_item(child)
def on_exit(self):
"""Handle exit menu item"""
dpg.stop_dearpygui()
def show_about(self):
"""Show about dialog"""
with dpg.window(label="About OneTrainer", modal=True,
width=400, height=200, pos=(200, 200)):
dpg.add_text("OneTrainer Minimal DPG UI")
dpg.add_text("A lightweight implementation of OneTrainer with DearPyGui")
dpg.add_separator()
dpg.add_button(label="Close", callback=lambda: dpg.delete_item(dpg.get_item_parent(dpg.get_item_parent())))
# Tab content creators
def create_model_tab(self):
"""Create model tab content"""
with dpg.group(parent=self.tab_content, tag="model_tab"):
dpg.add_text("Model Settings", color=(255, 255, 0))
dpg.add_separator()
# Model type
with dpg.group(horizontal=True):
dpg.add_text("Model Type:", indent=20)
model_types = [model_type.name for model_type in ModelType]
dpg.add_combo(
items=model_types,
default_value=self.train_config.model_type.name if hasattr(self.train_config, "model_type") else "SD",
width=200,
callback=self.on_model_type_change
)
# Training method
with dpg.group(horizontal=True):
dpg.add_text("Training Method:", indent=20)
training_methods = [method.name for method in TrainingMethod]
dpg.add_combo(
items=training_methods,
default_value=self.train_config.training_method.name if hasattr(self.train_config, "training_method") else "LORA",
width=200,
callback=self.on_training_method_change
)
# Base model
with dpg.group(horizontal=True):
dpg.add_text("Base Model:", indent=20)
dpg.add_input_text(width=300)
dpg.add_button(label="Browse", width=80)
def create_lora_tab(self):
"""Create LoRA tab content"""
with dpg.group(parent=self.tab_content, tag="lora_tab"):
dpg.add_text("LoRA / LyCORIS Settings", color=(255, 255, 0))
dpg.add_separator()
# PEFT type
with dpg.group(horizontal=True):
dpg.add_text("PEFT Type:", indent=20)
peft_types = [peft_type.name for peft_type in PeftType]
dpg.add_combo(
items=peft_types,
default_value=self.train_config.peft_type.name if hasattr(self.train_config, "peft_type") else "LORA",
width=200,
callback=self.on_peft_type_change
)
# LoRA rank
with dpg.group(horizontal=True):
dpg.add_text("LoRA Rank:", indent=20)
dpg.add_input_int(
default_value=getattr(self.train_config, "lora_rank", 32),
min_value=1,
max_value=128,
width=100
)
# LoRA alpha
with dpg.group(horizontal=True):
dpg.add_text("LoRA Alpha:", indent=20)
dpg.add_input_float(
default_value=getattr(self.train_config, "lora_alpha", 32.0),
min_value=1.0,
max_value=128.0,
width=100,
format="%.1f"
)
# Warning for LoKr
with dpg.group(tag="lokr_warning"):
with dpg.drawlist(width=500, height=80):
dpg.draw_rectangle(
(0, 0), (500, 80),
fill=(255, 230, 230, 200),
color=(255, 0, 0, 255),
thickness=2
)
with dpg.draw_layer():
dpg.draw_text(
(10, 10),
"⚠️ WARNING FOR HIDREAM MODELS ⚠️\n"
"LoKr is known to cause sampling issues with HiDream models.\n"
"It may lead to infinite sampling loops, freezing, or crashes.\n"
"For HiDream models, use standard LoRA, LoHa, or other PEFT types.",
color=(180, 0, 0, 255),
size=15
)
# Hide warning by default
dpg.hide_item("lokr_warning")
def create_training_tab(self):
"""Create training tab content"""
with dpg.group(parent=self.tab_content, tag="training_tab"):
dpg.add_text("Training Settings", color=(255, 255, 0))
dpg.add_separator()
# Basic settings
dpg.add_text("Basic Settings", color=(220, 220, 220))
# Batch size
with dpg.group(horizontal=True):
dpg.add_text("Batch Size:", indent=20)
dpg.add_input_int(
default_value=getattr(self.train_config, "batch_size", 1),
min_value=1,
max_value=64,
width=100
)
# Epochs
with dpg.group(horizontal=True):
dpg.add_text("Epochs:", indent=20)
dpg.add_input_int(
default_value=getattr(self.train_config, "epochs", 10),
min_value=1,
max_value=1000,
width=100
)
# Learning rate
with dpg.group(horizontal=True):
dpg.add_text("Learning Rate:", indent=20)
dpg.add_input_float(
default_value=getattr(self.train_config, "learning_rate", 1e-5),
min_value=1e-8,
max_value=1.0,
width=100,
format="%.6f"
)
def create_concept_tab(self):
"""Create concept tab content"""
with dpg.group(parent=self.tab_content, tag="concept_tab"):
dpg.add_text("Concept Settings", color=(255, 255, 0))
dpg.add_separator()
# Simple concept editor
dpg.add_text("This is a simplified concept tab. Full functionality available in upcoming updates.")
def create_sampling_tab(self):
"""Create sampling tab content"""
with dpg.group(parent=self.tab_content, tag="sampling_tab"):
dpg.add_text("Sampling Settings", color=(255, 255, 0))
dpg.add_separator()
# Simple sampling UI
dpg.add_text("This is a simplified sampling tab. Full functionality available in upcoming updates.")
def create_cloud_tab(self):
"""Create cloud tab content"""
with dpg.group(parent=self.tab_content, tag="cloud_tab"):
dpg.add_text("Cloud Training Settings", color=(255, 255, 0))
dpg.add_separator()
# Simple cloud UI
dpg.add_text("This is a simplified cloud tab. Full functionality available in upcoming updates.")
# Tab show methods
def show_model_tab(self):
"""Show model tab"""
self.hide_all_tabs()
dpg.show_item("model_tab")
self.current_tab = "model"
def show_lora_tab(self):
"""Show LoRA tab"""
self.hide_all_tabs()
dpg.show_item("lora_tab")
self.current_tab = "lora"
def show_training_tab(self):
"""Show training tab"""
self.hide_all_tabs()
dpg.show_item("training_tab")
self.current_tab = "training"
def show_concept_tab(self):
"""Show concept tab"""
self.hide_all_tabs()
dpg.show_item("concept_tab")
self.current_tab = "concept"
def show_sampling_tab(self):
"""Show sampling tab"""
self.hide_all_tabs()
dpg.show_item("sampling_tab")
self.current_tab = "sampling"
def show_cloud_tab(self):
"""Show cloud tab"""
self.hide_all_tabs()
dpg.show_item("cloud_tab")
self.current_tab = "cloud"
# Callbacks
def on_model_type_change(self, sender, value):
"""Handle model type change"""
try:
model_type = ModelType[value]
self.train_config.model_type = model_type
except (KeyError, AttributeError) as e:
print(f"Error setting model type: {e}")
def on_training_method_change(self, sender, value):
"""Handle training method change"""
try:
training_method = TrainingMethod[value]
self.train_config.training_method = training_method
except (KeyError, AttributeError) as e:
print(f"Error setting training method: {e}")
def on_peft_type_change(self, sender, value):
"""Handle PEFT type change"""
try:
peft_type = PeftType[value]
self.train_config.peft_type = peft_type
# Show warning for LoKr
if peft_type == PeftType.LOKR:
dpg.show_item("lokr_warning")
else:
dpg.hide_item("lokr_warning")
except (KeyError, AttributeError) as e:
print(f"Error setting PEFT type: {e}")
def main():
"""Main entry point"""
try:
app = MinimalOneTrainerUI()
app.run()
except Exception as e:
print(f"Error running application: {e}")
import traceback
traceback.print_exc()
return 1
return 0
if __name__ == "__main__":
sys.exit(main())