Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 80 additions & 2 deletions src/common/ai_models.c
Original file line number Diff line number Diff line change
Expand Up @@ -1249,6 +1249,49 @@ static gboolean _extract_zip(const char *zippath,
return success;
}

// peek at the first archive entry to recover the model_id (the top
// level directory name in the zip layout)
static char *_zip_top_dir(const char *zippath)
{
struct archive *a = archive_read_new();
archive_read_support_format_zip(a);
if(archive_read_open_filename(a, zippath, 10240) != ARCHIVE_OK)
{
archive_read_free(a);
return NULL;
}
struct archive_entry *entry;
char *result = NULL;
if(archive_read_next_header(a, &entry) == ARCHIVE_OK)
{
const char *path = archive_entry_pathname(entry);
const char *slash = path ? strchr(path, '/') : NULL;
if(slash) result = g_strndup(path, slash - path);
else if(path) result = g_strdup(path);
}
archive_read_close(a);
archive_read_free(a);
return result;
}

// activate `model_id` only when nothing is active for its task
static void _activate_if_unset(dt_ai_registry_t *registry,
const char *model_id)
{
if(!registry || !model_id) return;
gchar *task = NULL;
g_mutex_lock(&registry->lock);
const dt_ai_model_t *m = _find_model_unlocked(registry, model_id);
if(m && m->task) task = g_strdup(m->task);
g_mutex_unlock(&registry->lock);
if(!task) return;
char *current = dt_ai_models_get_active_for_task(task);
if(!current || !current[0])
dt_ai_models_set_active_for_task(task, model_id);
g_free(current);
g_free(task);
}

// install a local .dtmodel file (zip archive) into the models directory.
// returns error message (caller must free) or NULL on success.
char *dt_ai_models_install_local(dt_ai_registry_t *registry,
Expand All @@ -1260,17 +1303,43 @@ char *dt_ai_models_install_local(dt_ai_registry_t *registry,
if(!g_file_test(filepath, G_FILE_TEST_IS_REGULAR))
return g_strdup_printf(_("file not found: %s"), filepath);

char *installed_id = _zip_top_dir(filepath);

if(!_extract_zip(filepath, registry->models_dir))
{
g_free(installed_id);
return g_strdup(_("failed to extract model archive"));
}

// rescan models directory to pick up newly installed model
dt_ai_models_refresh_status(registry);

_activate_if_unset(registry, installed_id);

dt_print(DT_DEBUG_AI, "[ai_models] model installed from: %s", filepath);

g_free(installed_id);
return NULL; // success
}

// best installed model for `task`: default preferred, else first found.
// caller must hold registry->lock
static const char *_pick_fallback_active_unlocked(dt_ai_registry_t *registry,
const char *task)
{
if(!registry || !task) return NULL;
const char *first_installed = NULL;
for(GList *l = registry->models; l; l = g_list_next(l))
{
const dt_ai_model_t *m = (const dt_ai_model_t *)l->data;
if(!m->task || strcmp(m->task, task) != 0) continue;
if(m->status != DT_AI_MODEL_DOWNLOADED) continue;
if(m->is_default) return m->id;
if(!first_installed) first_installed = m->id;
}
return first_installed;
}

#ifdef HAVE_AI_DOWNLOAD
// synchronous download - returns error message or NULL on success
char *dt_ai_models_download_sync(dt_ai_registry_t *registry,
Expand Down Expand Up @@ -1537,6 +1606,8 @@ char *dt_ai_models_download_sync(dt_ai_registry_t *registry,
}
g_mutex_unlock(&registry->lock);

_activate_if_unset(registry, model_id);

dt_print(DT_DEBUG_AI, "[ai_models] download complete: %s", model_id);

// final callback
Expand Down Expand Up @@ -1687,12 +1758,19 @@ gboolean dt_ai_models_delete(dt_ai_registry_t *registry, const char *model_id)
}
g_mutex_unlock(&registry->lock);

// clear active status if this was the active model for its task
// if deleted model was active, pick a fallback (default preferred)
if(task_copy)
{
char *active = dt_ai_models_get_active_for_task(task_copy);
if(active && strcmp(active, model_id) == 0)
dt_ai_models_set_active_for_task(task_copy, NULL);
{
g_mutex_lock(&registry->lock);
const char *fallback = _pick_fallback_active_unlocked(registry, task_copy);
char *fallback_copy = fallback ? g_strdup(fallback) : NULL;
g_mutex_unlock(&registry->lock);
dt_ai_models_set_active_for_task(task_copy, fallback_copy);
g_free(fallback_copy);
}
g_free(active);
g_free(task_copy);
}
Expand Down
166 changes: 100 additions & 66 deletions src/gui/preferences_ai.c
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,6 @@ typedef struct dt_prefs_ai_data_t
#ifdef HAVE_AI_DOWNLOAD
GtkWidget *download_selected_btn;
GtkWidget *download_default_btn;
GtkWidget *download_all_btn;
#endif
GtkWidget *install_btn;
GtkWidget *delete_selected_btn;
Expand Down Expand Up @@ -172,6 +171,27 @@ static const char *_status_to_string(dt_ai_model_status_t status)
}
}

#ifdef HAVE_AI_DOWNLOAD
// enable "download / update selected" only when at least one row is ticked
static void _update_download_selected_sensitivity(dt_prefs_ai_data_t *data)
{
if(!data->download_selected_btn) return;
gboolean any = FALSE;
GtkTreeIter iter;
gboolean valid
= gtk_tree_model_get_iter_first(GTK_TREE_MODEL(data->model_store), &iter);
while(valid)
{
gboolean sel = FALSE;
gtk_tree_model_get(GTK_TREE_MODEL(data->model_store), &iter,
COL_SELECTED, &sel, -1);
if(sel) { any = TRUE; break; }
valid = gtk_tree_model_iter_next(GTK_TREE_MODEL(data->model_store), &iter);
}
gtk_widget_set_sensitive(data->download_selected_btn, any);
}
#endif

static void _refresh_model_list(dt_prefs_ai_data_t *data)
{
if(!darktable.ai_registry)
Expand Down Expand Up @@ -246,6 +266,10 @@ static void _refresh_model_list(dt_prefs_ai_data_t *data)
// reset select-all toggle
if(data->select_all_toggle)
gtk_toggle_button_set_active(GTK_TOGGLE_BUTTON(data->select_all_toggle), FALSE);

#ifdef HAVE_AI_DOWNLOAD
_update_download_selected_sensitivity(data);
#endif
}

static void _update_controls_sensitivity(dt_prefs_ai_data_t *data, gboolean enabled)
Expand Down Expand Up @@ -626,6 +650,10 @@ static void _on_model_selection_toggled(GtkCellRendererToggle *cell,

// toggle the value
gtk_list_store_set(data->model_store, &iter, COL_SELECTED, !selected, -1);

#ifdef HAVE_AI_DOWNLOAD
_update_download_selected_sensitivity(data);
#endif
}

static void _on_enabled_toggled(GtkCellRendererToggle *cell,
Expand Down Expand Up @@ -684,6 +712,10 @@ static void _on_select_all_toggled(GtkToggleButton *toggle, gpointer user_data)
gtk_list_store_set(data->model_store, &iter, COL_SELECTED, select_all, -1);
valid = gtk_tree_model_iter_next(GTK_TREE_MODEL(data->model_store), &iter);
}

#ifdef HAVE_AI_DOWNLOAD
_update_download_selected_sensitivity(data);
#endif
}

static void _on_select_all_header_clicked(GtkWidget *button, gpointer user_data)
Expand Down Expand Up @@ -955,46 +987,20 @@ static void _on_download_default(GtkButton *button, gpointer user_data)
_refresh_model_list(data);
}

static void _on_download_all(GtkButton *button, gpointer user_data)
{
dt_prefs_ai_data_t *data = (dt_prefs_ai_data_t *)user_data;

// download all models that need downloading
const int count = dt_ai_models_get_count(darktable.ai_registry);
for(int i = 0; i < count; i++)
{
dt_ai_model_t *model = dt_ai_models_get_by_index(darktable.ai_registry, i);
if(!model)
continue;
gboolean need_download = (model->status == DT_AI_MODEL_NOT_DOWNLOADED
|| model->status == DT_AI_MODEL_UPDATE_AVAILABLE
|| model->status == DT_AI_MODEL_UPDATE_REQUIRED);
char *id = need_download ? g_strdup(model->id) : NULL;
dt_ai_model_free(model);
if(need_download)
{
if(!_download_model_with_dialog(data, id))
{
g_free(id);
break; // stop on error or cancel
}
g_free(id);
}
}
_refresh_model_list(data);
}
#endif // HAVE_AI_DOWNLOAD

static void _on_install_model(GtkButton *button, gpointer user_data)
{
dt_prefs_ai_data_t *data = (dt_prefs_ai_data_t *)user_data;

GtkFileChooserNative *filechooser = gtk_file_chooser_native_new(
_("install AI model"),
_("install AI models"),
GTK_WINDOW(data->parent_dialog),
GTK_FILE_CHOOSER_ACTION_OPEN,
_("_open"), _("_cancel"));

gtk_file_chooser_set_select_multiple(GTK_FILE_CHOOSER(filechooser), TRUE);

GtkFileFilter *filter = gtk_file_filter_new();
gtk_file_filter_set_name(filter, _("AI model packages (*.dtmodel)"));
gtk_file_filter_add_pattern(filter, "*.dtmodel");
Expand All @@ -1003,32 +1009,53 @@ static void _on_install_model(GtkButton *button, gpointer user_data)
gtk_file_chooser_set_filter(GTK_FILE_CHOOSER(filechooser), filter);

if(gtk_native_dialog_run(GTK_NATIVE_DIALOG(filechooser))
== GTK_RESPONSE_ACCEPT)
!= GTK_RESPONSE_ACCEPT)
{
char *filepath
= gtk_file_chooser_get_filename(GTK_FILE_CHOOSER(filechooser));
g_object_unref(filechooser);
return;
}

GSList *files = gtk_file_chooser_get_filenames(GTK_FILE_CHOOSER(filechooser));
g_object_unref(filechooser);

int ok = 0;
GString *errors = g_string_new(NULL);
for(GSList *l = files; l; l = l->next)
{
const char *filepath = (const char *)l->data;
char *error = dt_ai_models_install_local(darktable.ai_registry, filepath);
if(error)
{
GtkWidget *err_dialog = gtk_message_dialog_new(
GTK_WINDOW(data->parent_dialog),
GTK_DIALOG_MODAL,
GTK_MESSAGE_ERROR,
GTK_BUTTONS_OK,
"%s", error);
gtk_dialog_run(GTK_DIALOG(err_dialog));
gtk_widget_destroy(err_dialog);
gchar *base = g_path_get_basename(filepath);
g_string_append_printf(errors, "%s: %s\n", base, error);
g_free(base);
g_free(error);
}
else
{
DT_CONTROL_SIGNAL_RAISE(DT_SIGNAL_AI_MODELS_CHANGED);
_refresh_model_list(data);
ok++;
}
g_free(filepath);
}
g_object_unref(filechooser);
g_slist_free_full(files, g_free);

if(ok)
{
DT_CONTROL_SIGNAL_RAISE(DT_SIGNAL_AI_MODELS_CHANGED);
_refresh_model_list(data);
}

if(errors->len)
{
GtkWidget *err_dialog = gtk_message_dialog_new(
GTK_WINDOW(data->parent_dialog),
GTK_DIALOG_MODAL,
GTK_MESSAGE_ERROR,
GTK_BUTTONS_OK,
"%s", errors->str);
gtk_dialog_run(GTK_DIALOG(err_dialog));
gtk_widget_destroy(err_dialog);
}
g_string_free(errors, TRUE);
}

static void _on_delete_selected(GtkButton *button, gpointer user_data)
Expand Down Expand Up @@ -1846,19 +1873,9 @@ void init_tab_ai(GtkWidget *dialog, GtkWidget *stack)
gtk_grid_attach(GTK_GRID(models_grid), button_box, 0, row++, 1, 1);

#ifdef HAVE_AI_DOWNLOAD
// download selected button
data->download_selected_btn = gtk_button_new_with_label(_("download selected"));
gtk_widget_set_tooltip_text(data->download_selected_btn,
_("download or update the selected models"));
g_signal_connect(
data->download_selected_btn,
"clicked",
G_CALLBACK(_on_download_selected),
data);
dt_gui_box_add(button_box, data->download_selected_btn);

// download default button
data->download_default_btn = gtk_button_new_with_label(_("download default"));
// download / update default button
data->download_default_btn
= gtk_button_new_with_label(_("download / update default"));
gtk_widget_set_tooltip_text(data->download_default_btn,
_("download or update all default models"));
g_signal_connect(
Expand All @@ -1868,21 +1885,38 @@ void init_tab_ai(GtkWidget *dialog, GtkWidget *stack)
data);
dt_gui_box_add(button_box, data->download_default_btn);

// download all button
data->download_all_btn = gtk_button_new_with_label(_("download all"));
gtk_widget_set_tooltip_text(data->download_all_btn,
_("download or update all available models"));
g_signal_connect(data->download_all_btn, "clicked", G_CALLBACK(_on_download_all), data);
dt_gui_box_add(button_box, data->download_all_btn);
// download / update selected button
data->download_selected_btn
= gtk_button_new_with_label(_("download / update selected"));
gtk_widget_set_tooltip_text(data->download_selected_btn,
_("download or update the selected models"));
g_signal_connect(
data->download_selected_btn,
"clicked",
G_CALLBACK(_on_download_selected),
data);
dt_gui_box_add(button_box, data->download_selected_btn);

// gap before import
GtkWidget *sep1 = gtk_separator_new(GTK_ORIENTATION_VERTICAL);
gtk_widget_set_margin_start(sep1, DT_PIXEL_APPLY_DPI(8));
gtk_widget_set_margin_end(sep1, DT_PIXEL_APPLY_DPI(8));
dt_gui_box_add(button_box, sep1);
#endif // HAVE_AI_DOWNLOAD

// install model button
data->install_btn = gtk_button_new_with_label(_("install model"));
// import from file button
data->install_btn = gtk_button_new_with_label(_("import from file…"));
gtk_widget_set_tooltip_text(data->install_btn,
_("install a model from a local .dtmodel file"));
g_signal_connect(data->install_btn, "clicked", G_CALLBACK(_on_install_model), data);
dt_gui_box_add(button_box, data->install_btn);

// gap before delete
GtkWidget *sep2 = gtk_separator_new(GTK_ORIENTATION_VERTICAL);
gtk_widget_set_margin_start(sep2, DT_PIXEL_APPLY_DPI(8));
gtk_widget_set_margin_end(sep2, DT_PIXEL_APPLY_DPI(8));
dt_gui_box_add(button_box, sep2);

// delete selected button
data->delete_selected_btn = gtk_button_new_with_label(_("delete selected"));
gtk_widget_set_tooltip_text(data->delete_selected_btn,
Expand Down
Loading