From cbc5bb80dd4ee60830ac1080f035137a09caf38f Mon Sep 17 00:00:00 2001 From: GitInno <86991526+gitnnolabs@users.noreply.github.com> Date: Thu, 2 Jul 2026 07:24:41 -0300 Subject: [PATCH] [scielo-tools-18] feature: Add the app IA to use in the future to XML markup --- config/settings/base.py | 41 ++- config/urls.py | 1 + core/wagtail_hooks.py | 7 - ia/__init__.py | 1 + ia/apps.py | 8 + ia/db_router.py | 33 +++ ia/exceptions.py | 10 + ia/forms.py | 15 ++ ia/migrations/0001_initial.py | 269 +++++++++++++++++++ ia/migrations/__init__.py | 1 + ia/models.py | 186 +++++++++++++ ia/payload.py | 221 +++++++++++++++ ia/prompts/__init__.py | 1 + ia/prompts/back.py | 26 ++ ia/prompts/text.py | 113 ++++++++ ia/prompts/vision.py | 62 +++++ ia/providers/__init__.py | 8 + ia/providers/gemini.py | 62 +++++ ia/providers/local.py | 103 +++++++ ia/providers/ollama.py | 134 +++++++++ ia/references.py | 38 +++ ia/service.py | 68 +++++ ia/tasks.py | 59 ++++ ia/templates/wagtailadmin/icons/ia-brain.svg | 11 + ia/tests/__init__.py | 1 + ia/tests/test_db_router.py | 59 ++++ ia/tests/test_forms.py | 57 ++++ ia/tests/test_models.py | 54 ++++ ia/tests/test_references.py | 49 ++++ ia/tests/test_service.py | 86 ++++++ ia/tests/test_urls.py | 102 +++++++ ia/urls.py | 10 + ia/utils/__init__.py | 1 + ia/utils/blocks.py | 63 +++++ ia/utils/json.py | 80 ++++++ ia/utils/normalizers.py | 51 ++++ ia/utils/text.py | 41 +++ ia/utils/vision.py | 91 +++++++ ia/wagtail_hooks.py | 263 ++++++++++++++++++ "instru\303\247\303\265es.txt" | 34 +++ pytest.ini | 3 + requirements/base.txt | 4 +- requirements/local.txt | 3 +- setup.cfg | 17 +- 44 files changed, 2518 insertions(+), 29 deletions(-) create mode 100644 ia/__init__.py create mode 100644 ia/apps.py create mode 100644 ia/db_router.py create mode 100644 ia/exceptions.py create mode 100644 ia/forms.py create mode 100644 ia/migrations/0001_initial.py create mode 100644 ia/migrations/__init__.py create mode 100644 ia/models.py create mode 100644 ia/payload.py create mode 100644 ia/prompts/__init__.py create mode 100644 ia/prompts/back.py create mode 100644 ia/prompts/text.py create mode 100644 ia/prompts/vision.py create mode 100644 ia/providers/__init__.py create mode 100644 ia/providers/gemini.py create mode 100644 ia/providers/local.py create mode 100644 ia/providers/ollama.py create mode 100644 ia/references.py create mode 100644 ia/service.py create mode 100644 ia/tasks.py create mode 100644 ia/templates/wagtailadmin/icons/ia-brain.svg create mode 100644 ia/tests/__init__.py create mode 100644 ia/tests/test_db_router.py create mode 100644 ia/tests/test_forms.py create mode 100644 ia/tests/test_models.py create mode 100644 ia/tests/test_references.py create mode 100644 ia/tests/test_service.py create mode 100644 ia/tests/test_urls.py create mode 100644 ia/urls.py create mode 100644 ia/utils/__init__.py create mode 100644 ia/utils/blocks.py create mode 100644 ia/utils/json.py create mode 100644 ia/utils/normalizers.py create mode 100644 ia/utils/text.py create mode 100644 ia/utils/vision.py create mode 100644 ia/wagtail_hooks.py create mode 100644 "instru\303\247\303\265es.txt" diff --git a/config/settings/base.py b/config/settings/base.py index 89469f0..60de446 100644 --- a/config/settings/base.py +++ b/config/settings/base.py @@ -23,6 +23,7 @@ ROOT_DIR = Path(__file__).resolve(strict=True).parent.parent.parent # core/ APPS_DIR = ROOT_DIR / "core" +LLAMA_MODEL_DIR = ROOT_DIR / "ia/download" env = environ.Env() READ_DOT_ENV_FILE = env.bool("DJANGO_READ_DOT_ENV_FILE", default=False) @@ -40,7 +41,7 @@ "core.home", "wagtail.contrib.forms", "wagtail.contrib.redirects", - 'wagtail.contrib.settings', + "wagtail.contrib.settings", "wagtail_modeladmin", "wagtail.embeds", "wagtail.sites", @@ -63,7 +64,7 @@ "django.contrib.sessions", "django.contrib.messages", "django.contrib.staticfiles", - "django_celery_results" + "django_celery_results", ] THIRD_PARTY_APPS = [ @@ -77,6 +78,7 @@ "core", "core_settings", "xml_manager", + "ia", ] INSTALLED_APPS = DJANGO_APPS + THIRD_PARTY_APPS + LOCAL_APPS + WAGTAIL @@ -85,7 +87,7 @@ "django.contrib.sessions.middleware.SessionMiddleware", "django.middleware.common.CommonMiddleware", "django.middleware.csrf.CsrfViewMiddleware", - 'django.middleware.locale.LocaleMiddleware', + "django.middleware.locale.LocaleMiddleware", "django.contrib.auth.middleware.AuthenticationMiddleware", "django.contrib.messages.middleware.MessageMiddleware", "django.middleware.clickjacking.XFrameOptionsMiddleware", @@ -106,7 +108,7 @@ "django.template.context_processors.request", "django.contrib.auth.context_processors.auth", "django.contrib.messages.context_processors.messages", - 'wagtail.contrib.settings.context_processors.settings', + "wagtail.contrib.settings.context_processors.settings", ], }, }, @@ -145,13 +147,13 @@ LANGUAGE_CODE = "en" LANGUAGES = [ - ('pt-br', 'Português (Brasil)'), - ('es', 'Español'), - ('en', 'English'), + ("pt-br", "Português (Brasil)"), + ("es", "Español"), + ("en", "English"), ] LOCALE_PATHS = [ - os.path.join(BASE_DIR, 'locale'), + os.path.join(BASE_DIR, "locale"), ] TIME_ZONE = "UTC" @@ -239,10 +241,22 @@ # This can be omitted to allow all files, but note that this may present a security risk # if untrusted users are allowed to upload files - # see https://docs.wagtail.org/en/stable/advanced_topics/deploying.html#user-uploaded-files -WAGTAILDOCS_EXTENSIONS = ['csv', 'docx', 'json', 'key', 'odt', 'pdf', 'pptx', 'rtf', 'txt', 'xlsx', 'zip'] +WAGTAILDOCS_EXTENSIONS = [ + "csv", + "docx", + "json", + "key", + "odt", + "pdf", + "pptx", + "rtf", + "txt", + "xlsx", + "zip", +] # https://docs.djangoproject.com/en/dev/ref/settings/#auth-user-model -AUTH_USER_MODEL = 'users.CustomUser' +AUTH_USER_MODEL = "users.CustomUser" # Celery # ------------------------------------------------------------------------------ @@ -269,13 +283,15 @@ CELERY_BEAT_SCHEDULER = "django_celery_beat.schedulers:DatabaseScheduler" # http://docs.celeryproject.org/en/latest/userguide/configuration.html DJANGO_CELERY_BEAT_TZ_AWARE = False -#CELERY PROMETHEUS DASHBOARD +# CELERY PROMETHEUS DASHBOARD # https://docs.celeryq.dev/en/stable/userguide/configuration.html#worker-send-task-events CELERY_WORKER_SEND_TASK_EVENTS = True # https://docs.celeryq.dev/en/stable/userguide/configuration.html#std-setting-task_send_sent_event CELERY_SEND_TASK_SENT_EVENT = True CELERYD_SEND_EVENTS = True -CE_BUCKETS=1,2.5,5,10,30,60,300,600,900,1800 +CE_BUCKETS = 1, 2.5, 5, 10, 30, 60, 300, 600, 900, 1800 + +LLAMA_ENABLED = env.bool("LLAMA_ENABLED", default=True) # Celery Results # ------------------------------------------------------------------------------ @@ -285,3 +301,4 @@ CELERY_RESULT_EXTENDED = True DATA_UPLOAD_MAX_NUMBER_FIELDS = 10000 +SILENCED_SYSTEM_CHECKS = ["treebeard.E001"] diff --git a/config/urls.py b/config/urls.py index 88372e2..e17e835 100644 --- a/config/urls.py +++ b/config/urls.py @@ -9,6 +9,7 @@ urlpatterns = [ path("django-admin/", admin.site.urls), + path("admin/ia/", include("ia.urls")), path("admin/", include(wagtailadmin_urls)), path("documents/", include(wagtaildocs_urls)), path("i18n/", include("django.conf.urls.i18n")), diff --git a/core/wagtail_hooks.py b/core/wagtail_hooks.py index 8bb4b16..00794ed 100644 --- a/core/wagtail_hooks.py +++ b/core/wagtail_hooks.py @@ -21,13 +21,6 @@ def ensure_image_title(sender, instance, **kwargs): pre_save.connect(ensure_image_title, sender=get_image_model()) -@hooks.register("construct_main_menu") -def keep_only_sps_validation_menu(request, menu_items): - menu_items[:] = [ - item for item in menu_items if item.name == "sps_package_validation" - ] - - @hooks.register("construct_help_menu") def replace_help_menu_items(request, help_menu_items): help_menu_items[:] = [ diff --git a/ia/__init__.py b/ia/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/ia/__init__.py @@ -0,0 +1 @@ + diff --git a/ia/apps.py b/ia/apps.py new file mode 100644 index 0000000..6a66426 --- /dev/null +++ b/ia/apps.py @@ -0,0 +1,8 @@ +from django.apps import AppConfig +from django.utils.translation import gettext_lazy as _ + + +class IAConfig(AppConfig): + default_auto_field = "django.db.models.BigAutoField" + name = "ia" + verbose_name = _("IA Model") diff --git a/ia/db_router.py b/ia/db_router.py new file mode 100644 index 0000000..8f68dbd --- /dev/null +++ b/ia/db_router.py @@ -0,0 +1,33 @@ +class IADatabaseRouter: + app_label = "ia" + db_alias = "ia_db" + + def _ia_db_configured(self): + from django.conf import settings + + return self.db_alias in settings.DATABASES + + def db_for_read(self, model, **hints): + if model._meta.app_label == self.app_label and self._ia_db_configured(): + return self.db_alias + return None + + def db_for_write(self, model, **hints): + if model._meta.app_label == self.app_label and self._ia_db_configured(): + return self.db_alias + return None + + def allow_relation(self, obj1, obj2, **hints): + if ( + obj1._meta.app_label == self.app_label + or obj2._meta.app_label == self.app_label + ): + return True + return None + + def allow_migrate(self, db, app_label, model_name=None, **hints): + if app_label != self.app_label: + return None + if self._ia_db_configured(): + return db == self.db_alias + return db == "default" diff --git a/ia/exceptions.py b/ia/exceptions.py new file mode 100644 index 0000000..5c1b12d --- /dev/null +++ b/ia/exceptions.py @@ -0,0 +1,10 @@ +class LlamaDisabledError(Exception): + pass + + +class LlamaModelNotFoundError(FileNotFoundError): + pass + + +class LlamaNotInstalledError(ImportError): + pass diff --git a/ia/forms.py b/ia/forms.py new file mode 100644 index 0000000..250a82d --- /dev/null +++ b/ia/forms.py @@ -0,0 +1,15 @@ +from wagtail.admin.forms import WagtailAdminModelForm + + +class IAAdminModelForm(WagtailAdminModelForm): + def save_all(self, user): + model_with_creator = super().save(commit=False) + + if self.instance.pk is not None: + model_with_creator.updated_by = user + else: + model_with_creator.creator = user + + self.save() + + return model_with_creator diff --git a/ia/migrations/0001_initial.py b/ia/migrations/0001_initial.py new file mode 100644 index 0000000..74b95b6 --- /dev/null +++ b/ia/migrations/0001_initial.py @@ -0,0 +1,269 @@ +# Generated by Django 6.0.5 on 2026-07-02 09:00 + +import django.db.models.deletion +from django.conf import settings +from django.db import migrations, models + + +class Migration(migrations.Migration): + initial = True + + dependencies = [ + migrations.swappable_dependency(settings.AUTH_USER_MODEL), + ] + + operations = [ + migrations.CreateModel( + name="GeminiModel", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ( + "created", + models.DateTimeField( + auto_now_add=True, verbose_name="Creation date" + ), + ), + ( + "updated", + models.DateTimeField( + auto_now=True, verbose_name="Last update date" + ), + ), + ("api_key", models.CharField(max_length=255, verbose_name="API Key")), + ( + "is_active", + models.BooleanField(default=False, verbose_name="Active"), + ), + ( + "creator", + models.ForeignKey( + editable=False, + null=True, + on_delete=django.db.models.deletion.SET_NULL, + related_name="%(class)s_creator", + to=settings.AUTH_USER_MODEL, + verbose_name="Creator", + ), + ), + ( + "updated_by", + models.ForeignKey( + blank=True, + editable=False, + null=True, + on_delete=django.db.models.deletion.SET_NULL, + related_name="%(class)s_last_mod_user", + to=settings.AUTH_USER_MODEL, + verbose_name="Updater", + ), + ), + ], + options={ + "verbose_name": "Gemini model", + "verbose_name_plural": "Gemini models", + }, + ), + migrations.CreateModel( + name="HuggingFaceModel", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ( + "created", + models.DateTimeField( + auto_now_add=True, verbose_name="Creation date" + ), + ), + ( + "updated", + models.DateTimeField( + auto_now=True, verbose_name="Last update date" + ), + ), + ( + "name_model", + models.CharField( + help_text="e.g. bartowski/Llama-3.2-3B-Instruct-GGUF", + max_length=255, + verbose_name="Model name", + ), + ), + ( + "name_file", + models.CharField( + help_text="e.g. Llama-3.2-3B-Instruct-Q4_K_M.gguf", + max_length=255, + verbose_name="Model file", + ), + ), + ( + "hf_token", + models.CharField( + blank=True, max_length=255, verbose_name="HuggingFace token" + ), + ), + ( + "download_status", + models.IntegerField( + blank=True, + choices=[ + (1, "No model"), + (2, "Downloading"), + (3, "Downloaded"), + (4, "Download error"), + ], + default=1, + verbose_name="Download status", + ), + ), + ( + "is_active", + models.BooleanField(default=False, verbose_name="Active"), + ), + ( + "creator", + models.ForeignKey( + editable=False, + null=True, + on_delete=django.db.models.deletion.SET_NULL, + related_name="%(class)s_creator", + to=settings.AUTH_USER_MODEL, + verbose_name="Creator", + ), + ), + ( + "updated_by", + models.ForeignKey( + blank=True, + editable=False, + null=True, + on_delete=django.db.models.deletion.SET_NULL, + related_name="%(class)s_last_mod_user", + to=settings.AUTH_USER_MODEL, + verbose_name="Updater", + ), + ), + ], + options={ + "verbose_name": "HuggingFace model", + "verbose_name_plural": "HuggingFace models", + }, + ), + migrations.CreateModel( + name="OllamaModel", + fields=[ + ( + "id", + models.BigAutoField( + auto_created=True, + primary_key=True, + serialize=False, + verbose_name="ID", + ), + ), + ( + "created", + models.DateTimeField( + auto_now_add=True, verbose_name="Creation date" + ), + ), + ( + "updated", + models.DateTimeField( + auto_now=True, verbose_name="Last update date" + ), + ), + ( + "url", + models.URLField( + help_text="e.g. http://host:11434", verbose_name="Ollama URL" + ), + ), + ( + "model", + models.CharField( + blank=True, + help_text="Select after fetching", + max_length=255, + verbose_name="Model", + ), + ), + ( + "is_vision", + models.BooleanField( + default=False, + help_text="Enable to use vision pipeline (page images). When disabled, text pipeline is used.", + verbose_name="Vision model", + ), + ), + ( + "docx_extractor", + models.CharField( + choices=[ + ("zipfile", "Zipfile (built-in)"), + ("docling", "Docling (OCR-capable, requires pytorch)"), + ], + default="zipfile", + help_text="Method to extract plain text from DOCX for LLM prompts.", + max_length=16, + verbose_name="DOCX text extractor", + ), + ), + ( + "context_limit", + models.PositiveIntegerField( + blank=True, + help_text="Max input tokens for this model. Leave empty for auto-detect via API.", + null=True, + verbose_name="Context tokens", + ), + ), + ( + "is_active", + models.BooleanField(default=False, verbose_name="Active"), + ), + ( + "creator", + models.ForeignKey( + editable=False, + null=True, + on_delete=django.db.models.deletion.SET_NULL, + related_name="%(class)s_creator", + to=settings.AUTH_USER_MODEL, + verbose_name="Creator", + ), + ), + ( + "updated_by", + models.ForeignKey( + blank=True, + editable=False, + null=True, + on_delete=django.db.models.deletion.SET_NULL, + related_name="%(class)s_last_mod_user", + to=settings.AUTH_USER_MODEL, + verbose_name="Updater", + ), + ), + ], + options={ + "verbose_name": "Ollama model", + "verbose_name_plural": "Ollama models", + }, + ), + ] diff --git a/ia/migrations/__init__.py b/ia/migrations/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/ia/migrations/__init__.py @@ -0,0 +1 @@ + diff --git a/ia/models.py b/ia/models.py new file mode 100644 index 0000000..98ac21b --- /dev/null +++ b/ia/models.py @@ -0,0 +1,186 @@ +from django import forms +from django.conf import settings +from django.core.exceptions import ValidationError +from django.db import models +from django.utils.translation import gettext_lazy as _ +from wagtail.admin.panels import FieldPanel + +from ia.forms import IAAdminModelForm + + +class IAMetadataModel(models.Model): + created = models.DateTimeField(verbose_name=_("Creation date"), auto_now_add=True) + updated = models.DateTimeField(verbose_name=_("Last update date"), auto_now=True) + creator = models.ForeignKey( + settings.AUTH_USER_MODEL, + verbose_name=_("Creator"), + related_name="%(class)s_creator", + editable=False, + on_delete=models.SET_NULL, + null=True, + ) + updated_by = models.ForeignKey( + settings.AUTH_USER_MODEL, + verbose_name=_("Updater"), + related_name="%(class)s_last_mod_user", + editable=False, + null=True, + blank=True, + on_delete=models.SET_NULL, + ) + + class Meta: + abstract = True + + +class MaskedPasswordWidget(forms.PasswordInput): + def __init__(self, attrs=None): + super().__init__(attrs=attrs, render_value=True) + + +class DownloadStatus(models.IntegerChoices): + NO_MODEL = 1, _("No model") + DOWNLOADING = 2, _("Downloading") + DOWNLOADED = 3, _("Downloaded") + ERROR = 4, _("Download error") + + +class HuggingFaceModel(IAMetadataModel): + name_model = models.CharField( + _("Model name"), + max_length=255, + help_text="e.g. bartowski/Llama-3.2-3B-Instruct-GGUF", + ) + name_file = models.CharField( + _("Model file"), + max_length=255, + help_text="e.g. Llama-3.2-3B-Instruct-Q4_K_M.gguf", + ) + hf_token = models.CharField(_("HuggingFace token"), max_length=255, blank=True) + download_status = models.IntegerField( + _("Download status"), + choices=DownloadStatus.choices, + default=DownloadStatus.NO_MODEL, + blank=True, + ) + is_active = models.BooleanField(_("Active"), default=False) + + panels = [ + FieldPanel("name_model"), + FieldPanel("name_file"), + FieldPanel("hf_token", widget=MaskedPasswordWidget()), + FieldPanel( + "download_status", + widget=forms.Select( + choices=DownloadStatus.choices, attrs={"disabled": True} + ), + ), + FieldPanel("is_active"), + ] + base_form_class = IAAdminModelForm + + class Meta: + verbose_name = _("HuggingFace model") + verbose_name_plural = _("HuggingFace models") + + def __str__(self): + return self.name_model or self.name_file or "HuggingFace" + + def clean(self): + if not self.name_model: + raise ValidationError({"name_model": _("Model name is required.")}) + if not self.name_file: + raise ValidationError({"name_file": _("Model file is required.")}) + + def save(self, *args, **kwargs): + if self.is_active: + HuggingFaceModel.objects.exclude(pk=self.pk).update(is_active=False) + super().save(*args, **kwargs) + + +class OllamaModel(IAMetadataModel): + url = models.URLField(_("Ollama URL"), help_text="e.g. http://host:11434") + model = models.CharField( + _("Model"), max_length=255, blank=True, help_text="Select after fetching" + ) + is_vision = models.BooleanField( + _("Vision model"), + default=False, + help_text="Enable to use vision pipeline (page images). When disabled, text pipeline is used.", + ) + + class DocxExtractor(models.TextChoices): + ZIPFILE = "zipfile", _("Zipfile (built-in)") + DOCLING = "docling", _("Docling (OCR-capable, requires pytorch)") + + docx_extractor = models.CharField( + _("DOCX text extractor"), + max_length=16, + choices=DocxExtractor.choices, + default=DocxExtractor.ZIPFILE, + help_text="Method to extract plain text from DOCX for LLM prompts.", + ) + context_limit = models.PositiveIntegerField( + _("Context tokens"), + blank=True, + null=True, + help_text="Max input tokens for this model. Leave empty for auto-detect via API.", + ) + is_active = models.BooleanField(_("Active"), default=False) + + panels = [ + FieldPanel("url"), + FieldPanel( + "model", widget=forms.Select(attrs={"data-ai": "ollama-model-select"}) + ), + FieldPanel("is_vision"), + FieldPanel("docx_extractor"), + FieldPanel("context_limit"), + FieldPanel("is_active"), + ] + base_form_class = IAAdminModelForm + + class Meta: + verbose_name = _("Ollama model") + verbose_name_plural = _("Ollama models") + + def __str__(self): + return f"{self.model or '?'} @ {self.url or '?'}" if self.model else "Ollama" + + def clean(self): + if not self.url: + raise ValidationError({"url": _("Ollama URL is required.")}) + if not self.model: + raise ValidationError({"model": _("Please select a model.")}) + + def save(self, *args, **kwargs): + if self.is_active: + OllamaModel.objects.exclude(pk=self.pk).update(is_active=False) + super().save(*args, **kwargs) + + +class GeminiModel(IAMetadataModel): + api_key = models.CharField(_("API Key"), max_length=255) + is_active = models.BooleanField(_("Active"), default=False) + + panels = [ + FieldPanel("api_key", widget=MaskedPasswordWidget()), + FieldPanel("is_active"), + ] + base_form_class = IAAdminModelForm + + class Meta: + verbose_name = _("Gemini model") + verbose_name_plural = _("Gemini models") + + def __str__(self): + return "Gemini" + + def clean(self): + if not self.api_key: + raise ValidationError({"api_key": _("API Key is required.")}) + + def save(self, *args, **kwargs): + if self.is_active: + GeminiModel.objects.exclude(pk=self.pk).update(is_active=False) + super().save(*args, **kwargs) diff --git a/ia/payload.py b/ia/payload.py new file mode 100644 index 0000000..90f1a93 --- /dev/null +++ b/ia/payload.py @@ -0,0 +1,221 @@ +import logging +import re + +from ia.utils.normalizers import ( + DOI_RE, + ORCID_RE, + stz_affiliation_id, + stz_country_code, + stz_date, + stz_first_number, + stz_language, + stz_norm, + stz_text, +) + +logger = logging.getLogger(__name__) + + +def payload_counts(payload): + return { + "doi": 1 if payload.get("doi") else 0, + "titles": len(payload.get("titles") or []), + "authors": len(payload.get("authors") or []), + "affiliations": len(payload.get("affiliations") or []), + "dates": len(payload.get("dates") or []), + "abstracts": len(payload.get("abstracts") or []), + "keywords": len(payload.get("keywords") or []), + } + + +def is_invalid_title(text): + clean = stz_norm(text) + heading_words = { + "abstract", + "resumo", + "resumen", + "resumén", + "sumário", + "sumario", + "keywords", + "palavras-chave", + "palabras clave", + "introduction", + "introdução", + "introducción", + "methodology", + "metodologia", + "conclusion", + "conclusão", + "conclusión", + "references", + "referências", + "referencias", + "bibliography", + "bibliografia", + } + if clean in heading_words: + return True + return ( + len(clean) > 320 + or bool(DOI_RE.search(clean)) + or bool(ORCID_RE.search(clean)) + or "@" in clean + or len( + re.findall( + r"\b(universidad|university|instituto|department|facultad|doctor|graduad|research|analysis|study|approach)", + clean, + ) + ) + > 2 + ) + + +def normalize_payload(payload, article): + if not isinstance(payload, dict): + return {}, ["LLM response was not a JSON object."] + + warnings = [ + str(item).strip() for item in payload.get("warnings") or [] if str(item).strip() + ] + normalized = { + "doi": stz_text(payload.get("doi")), + "titles": [], + "authors": [], + "affiliations": [], + "dates": [], + "abstracts": [], + "keywords": [], + } + + seen_titles = set() + for item in payload.get("titles") or []: + if isinstance(item, str): + text = stz_text(item) + language = stz_language("", article.language) + kind = "main" + else: + text = stz_text(item.get("text")) + language = stz_language(item.get("language"), article.language) + kind = (item.get("kind") or "translated").lower() + if not text or is_invalid_title(text): + if text: + warnings.append(f"Title discarded: {text[:120]}") + continue + key = stz_norm(text) + if key in seen_titles: + continue + seen_titles.add(key) + normalized["titles"].append({"text": text, "language": language, "kind": kind}) + + for index, item in enumerate(payload.get("authors") or [], 1): + if isinstance(item, str): + continue + given_names = stz_text(item.get("given_names")) + surname = stz_text(item.get("surname")) + display = stz_text(item.get("display")) or " ".join( + part for part in [given_names, surname] if part + ) + if not display: + continue + if not given_names and not surname: + parts = display.rsplit(" ", 1) + surname = parts[-1] if len(parts) > 1 else display + given_names = parts[0] if len(parts) > 1 else "" + normalized["authors"].append( + { + "given_names": given_names, + "surname": surname, + "display": display, + "orcid": stz_text(item.get("orcid")), + "affiliations": stz_affiliation_id(item.get("affiliations")) + or [str(index)], + "symbol": stz_text(item.get("symbol")), + } + ) + + for index, item in enumerate(payload.get("affiliations") or [], 1): + if isinstance(item, str): + continue + text = stz_text(item.get("text")) + if not text: + continue + normalized["affiliations"].append( + { + "id": stz_first_number(item.get("id")) or str(index), + "symbol": stz_text(item.get("symbol")), + "text": text, + "orgname": stz_text(item.get("orgname")), + "orgdiv1": stz_text(item.get("orgdiv1")), + "orgdiv2": stz_text(item.get("orgdiv2")), + "city": stz_text(item.get("city")), + "state": stz_text(item.get("state")), + "country": stz_text(item.get("country")), + "country_code": stz_country_code(item.get("country_code")), + } + ) + + for item in payload.get("dates") or []: + if isinstance(item, str): + continue + date_type = (item.get("type") or "other").lower() + normalized_date = stz_date(item.get("date") or item.get("raw")) + raw = stz_text(item.get("raw")) or stz_text(item.get("date")) + if date_type not in {"received", "accepted", "published", "ahp", "other"}: + date_type = "other" + if normalized_date or raw: + normalized["dates"].append( + {"type": date_type, "date": normalized_date, "raw": raw} + ) + + for item in payload.get("abstracts") or []: + if isinstance(item, str): + continue + text = stz_text(item.get("text")) + if not text: + continue + normalized["abstracts"].append( + { + "title": stz_text(item.get("title")), + "text": text, + "language": stz_language(item.get("language"), article.language), + } + ) + + for item in payload.get("keywords") or []: + if isinstance(item, str): + continue + terms = [stz_text(term) for term in item.get("terms") or []] + terms = [term for term in terms if term] + if not terms: + continue + normalized["keywords"].append( + { + "title": stz_text(item.get("title")), + "terms": terms, + "language": stz_language(item.get("language"), article.language), + } + ) + + if normalized["titles"] and not any( + title["kind"] == "main" for title in normalized["titles"] + ): + normalized["titles"][0]["kind"] = "main" + normalized["titles"].sort(key=lambda title: 0 if title["kind"] == "main" else 1) + + return normalized, warnings + + +def has_useful_payload(payload): + return any( + payload.get(key) + for key in ( + "doi", + "titles", + "authors", + "affiliations", + "dates", + "abstracts", + "keywords", + ) + ) diff --git a/ia/prompts/__init__.py b/ia/prompts/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/ia/prompts/__init__.py @@ -0,0 +1 @@ + diff --git a/ia/prompts/back.py b/ia/prompts/back.py new file mode 100644 index 0000000..3d95c21 --- /dev/null +++ b/ia/prompts/back.py @@ -0,0 +1,26 @@ +MESSAGES = [ + { + "role": "system", + "content": "You distinguish citation components and respond in JSON.", + }, + { + "role": "user", + "content": "Bachman S et al. 2011. Supporting Red List threat assessments. ZooKeys 150:117-126. DOI: https://doi.org/10.3897/zookeys.150.2109", + }, +] + +RESPONSE_FORMAT = { + "type": "json_object", + "schema": { + "type": "object", + "properties": { + "reftype": {"type": "string"}, + "authors": {"type": "array", "items": {"type": "object"}}, + "full_text": {"type": "string"}, + "date": {"type": "integer"}, + "title": {"type": "string"}, + "source": {"type": "string"}, + "doi": {"type": "string"}, + }, + }, +} diff --git a/ia/prompts/text.py b/ia/prompts/text.py new file mode 100644 index 0000000..03c5c46 --- /dev/null +++ b/ia/prompts/text.py @@ -0,0 +1,113 @@ +import json + +JSON_SCHEMAS = { + "titles": { + "type": "object", + "properties": { + "titles": { + "type": "array", + "items": { + "type": "object", + "properties": { + "text": {"type": "string"}, + "language": {"type": "string"}, + "kind": {"type": "string", "enum": ["main", "translated"]}, + }, + "required": ["text", "language", "kind"], + }, + }, + }, + "required": ["titles"], + }, + "authors_affs": { + "type": "object", + "properties": { + "authors": {"type": "array", "items": {"type": "object"}}, + "affiliations": {"type": "array", "items": {"type": "object"}}, + }, + }, + "abstracts": { + "type": "object", + "properties": {"abstracts": {"type": "array", "items": {"type": "object"}}}, + }, + "keywords": { + "type": "object", + "properties": {"keywords": {"type": "array", "items": {"type": "object"}}}, + }, + "dates": { + "type": "object", + "properties": {"dates": {"type": "array", "items": {"type": "object"}}}, + }, + "meta": {"type": "object", "properties": {}}, +} + + +def build_title_prompt(content_str, article_language): + return ( + f"Extract the article title and any translated titles from this academic paper. " + f"The title is the main heading at the top. Section headings such as Abstract, " + f"Introduction, Methodology, Results, Discussion, Conclusion, and References are NOT " + f"titles. There must be exactly one main title. " + f"The article language is {article_language}. " + f'Return a JSON object with a "titles" key containing objects with ' + f'"text", "language", and "kind" ("main" or "translated"). ' + f"Content:\n{content_str}" + ) + + +def build_auth_prompt(content_str): + return ( + "Extract all authors and their affiliations from this academic paper. " + 'Return a JSON object with "authors" and "affiliations" keys. ' + f"Content:\n{content_str}" + ) + + +def build_abstracts_prompt(content_str): + return ( + "Extract all abstracts from this academic paper and keep the original text. " + 'Return a JSON object with an "abstracts" key. ' + f"Content:\n{content_str}" + ) + + +def build_keywords_prompt(content_str): + return ( + "Extract all keyword groups from this academic paper. " + 'Return a JSON object with a "keywords" key. ' + f"Content:\n{content_str}" + ) + + +def build_dates_prompt(content_str): + return ( + "Extract all dates from this academic paper and normalize to YYYY-MM-DD when possible. " + 'Return a JSON object with a "dates" key. ' + f"Content:\n{content_str}" + ) + + +def build_meta_prompt(content_str): + return ( + "Extract DOI, journal ISSN and issue metadata from this academic paper. " + 'Return a JSON object with keys "doi", "journal" and "issue". ' + f"Content:\n{content_str}" + ) + + +def build_split_task_prompts(content, is_xml, article_language): + content_str = ( + content if is_xml else json.dumps(content, ensure_ascii=False, indent=2) + ) + return [ + ( + "titles", + build_title_prompt(content_str, article_language), + JSON_SCHEMAS["titles"], + ), + ("authors_affs", build_auth_prompt(content_str), JSON_SCHEMAS["authors_affs"]), + ("abstracts", build_abstracts_prompt(content_str), JSON_SCHEMAS["abstracts"]), + ("keywords", build_keywords_prompt(content_str), JSON_SCHEMAS["keywords"]), + ("dates", build_dates_prompt(content_str), JSON_SCHEMAS["dates"]), + ("meta", build_meta_prompt(content_str), JSON_SCHEMAS["meta"]), + ] diff --git a/ia/prompts/vision.py b/ia/prompts/vision.py new file mode 100644 index 0000000..66dc0ec --- /dev/null +++ b/ia/prompts/vision.py @@ -0,0 +1,62 @@ +VISION_TASKS = [ + ( + "titles", + "Analyze the pages of this article and extract the main title and any translated titles, " + 'with their languages. Return a JSON object with a "titles" key containing a list of objects ' + 'with "text", "language" (ISO 639-1: en, es, pt, fr), and "kind" ("main" or "translated"). ' + 'Example: {"titles":[{"text":"The Title","language":"en","kind":"main"},' + '{"text":"El Título","language":"es","kind":"translated"}]}. ' + "Do not include explanations, only the JSON.", + ), + ( + "authors", + "Analyze the pages of this article and extract all author names, their ORCIDs " + "(format 0000-0000-0000-0000), and affiliation symbols (*, **, etc). " + 'Return a JSON object with an "authors" key containing a list of objects with "given_names", ' + '"surname", "display", "orcid", and "affiliations" (list of IDs). ' + 'Use empty string "" for ORCID if not found. NEVER invent ORCIDs. ' + 'Also return "affiliations", a list of objects with "id" and "text". ' + 'Example: {"authors":[{"given_names":"John","surname":"Smith",' + '"display":"John Smith","orcid":"0000-0000-0000-0000","affiliations":["1"]}],' + '"affiliations":[{"id":"1","text":"University of Example"}]}. ' + "Do not include explanations, only the JSON.", + ), + ( + "abstracts", + "Analyze the pages of this article and extract all abstracts " + "(Abstract, Resumo, Resumen, Résumé, etc.), with their respective languages. " + 'Return a JSON object with an "abstracts" key containing a list of objects with ' + '"title" (exact label: "Abstract", "Resumo", etc.), ' + '"text" (only the abstract body text, do NOT include keywords), ' + 'and "language" (ISO 639-1). ' + 'Example: {"abstracts":[{"title":"Abstract","text":"The full abstract text...","language":"en"}]}. ' + "Do not include explanations, only the JSON.", + ), + ( + "keywords", + "Analyze the pages of this article and extract all keyword groups, " + 'with their respective languages. Return a JSON object with a "keywords" key containing a ' + 'list of objects with "title" (label: "Keywords", "Palavras-chave", etc.), ' + '"terms" (list of terms), and "language" (ISO 639-1). ' + 'Example: {"keywords":[{"title":"Keywords","terms":["term1","term2"],"language":"en"}]}. ' + "Do not include explanations, only the JSON.", + ), + ( + "dates", + "Analyze the pages of this article and extract the received, accepted, and publication dates. " + 'Return a JSON object with a "dates" key containing a list of objects with "type" ' + '("received", "accepted", "published", "ahp"), "date" (format YYYY-MM-DD), ' + 'and "raw" (original date text). ' + 'Example: {"dates":[{"type":"received","date":"2023-01-15","raw":"Received: 15.01.2023"}]}. ' + "Do not include explanations, only the JSON.", + ), + ( + "meta", + "Analyze the pages of this article and extract the DOI, journal name, ISSN, " + 'volume, number, and year. Return a JSON object with keys "doi", "journal" (with "title" and ' + '"issn"), and "issue" (with "volume", "number", "year"). ' + 'Example: {"doi":"10.1234/example.2023","journal":{"title":"Journal Name","issn":"1234-5678"},' + '"issue":{"volume":"15","number":"3","year":"2023"}}. ' + "Do not include explanations, only the JSON.", + ), +] diff --git a/ia/providers/__init__.py b/ia/providers/__init__.py new file mode 100644 index 0000000..45a56da --- /dev/null +++ b/ia/providers/__init__.py @@ -0,0 +1,8 @@ +JSON_INSTRUCTION = ( + "Output ONLY a JSON object, never text before or after. " + "Start with {, end with }. " + "CRITICAL: never invent names, ORCIDs, dates, or any data. " + 'Use empty string "" for missing fields. ' + "Use empty array [] for missing lists. " + "Only extract what is explicitly written in the content." +) diff --git a/ia/providers/gemini.py b/ia/providers/gemini.py new file mode 100644 index 0000000..1ad98b8 --- /dev/null +++ b/ia/providers/gemini.py @@ -0,0 +1,62 @@ +import logging +import time +import warnings +from importlib import import_module + +GEMINI_MODEL = "models/gemini-3.1-flash-lite-preview" +logger = logging.getLogger(__name__) + + +class GeminiProvider: + def __init__(self, model, response_format=None): + with warnings.catch_warnings(): + warnings.filterwarnings("ignore", category=FutureWarning) + genai = import_module("google.generativeai") + genai.configure(api_key=model.api_key) + self.genai = genai + self.model = model + self.response_format = response_format + + def chat(self, messages): + started = time.monotonic() + prompt_parts = [] + for msg in messages: + role = msg.get("role", "") + content = msg.get("content", "") + if role == "system": + prompt_parts.append(f"System instruction:\n{content}\n") + elif role == "user": + prompt_parts.append(f"User: {content}") + elif role == "assistant": + prompt_parts.append(f"Assistant: {content}") + prompt = "\n".join(prompt_parts) + + generation_config = {} + if self.response_format and self.response_format.get("type") == "json_object": + generation_config["response_mime_type"] = "application/json" + + model = self.genai.GenerativeModel(GEMINI_MODEL) + response_text = model.generate_content( + prompt, generation_config=generation_config + ).text + elapsed = time.monotonic() - started + logger.info("Gemini chat: %d chars in %.2fs", len(response_text or ""), elapsed) + time.sleep(15) + return {"choices": [{"message": {"content": response_text}}]} + + def prompt(self, user_input, response_format=None): + started = time.monotonic() + model = self.genai.GenerativeModel(GEMINI_MODEL) + generation_config = {"response_mime_type": "application/json"} + if response_format: + generation_config["response_schema"] = response_format + response_text = model.generate_content( + user_input, + generation_config=generation_config, + ).text + elapsed = time.monotonic() - started + logger.info( + "Gemini prompt: %d chars in %.2fs", len(response_text or ""), elapsed + ) + time.sleep(15) + return response_text diff --git a/ia/providers/local.py b/ia/providers/local.py new file mode 100644 index 0000000..c5a4f5e --- /dev/null +++ b/ia/providers/local.py @@ -0,0 +1,103 @@ +import logging +import os +import time + +from ia.exceptions import ( + LlamaDisabledError, + LlamaModelNotFoundError, + LlamaNotInstalledError, +) +from ia.providers import JSON_INSTRUCTION + +logger = logging.getLogger(__name__) + + +class LocalProvider: + _cached_llm = None + + def __init__( + self, + model, + response_format=None, + temperature=0.0, + top_p=0.1, + max_tokens=4000, + stop=None, + n_ctx=None, + nthreads=2, + ): + from django.conf import settings + + self.model = model + self.response_format = response_format + self.temperature = temperature + self.top_p = top_p + self.max_tokens = max_tokens + self.stop = stop + self.n_ctx = n_ctx or 32768 + llama_enabled = getattr(settings, "LLAMA_ENABLED", True) + llama_model_dir = getattr( + settings, "LLAMA_MODEL_DIR", settings.ROOT_DIR / "ia/download" + ) + + if not llama_enabled: + raise LlamaDisabledError("LLaMA is disabled.") + + if LocalProvider._cached_llm is None: + try: + from llama_cpp import Llama + except ImportError as exc: + raise LlamaNotInstalledError("llama-cpp-python not installed.") from exc + + model_path = os.path.join(str(llama_model_dir), self.model.name_file) + if not os.path.isfile(model_path): + raise LlamaModelNotFoundError(f"Model file not found at {model_path}.") + + logger.info("Loading local Llama: %s", model_path) + LocalProvider._cached_llm = Llama( + model_path=model_path, n_ctx=self.n_ctx, n_threads=nthreads + ) + logger.info("Local Llama loaded.") + + self.llm = LocalProvider._cached_llm + + def chat(self, messages): + started = time.monotonic() + logger.info("Local Llama chat. Preview: %r", messages[-1]["content"][:150]) + response = self.llm.create_chat_completion( + messages=messages, + response_format=self.response_format, + max_tokens=self.max_tokens, + temperature=self.temperature, + top_p=self.top_p, + ) + elapsed = time.monotonic() - started + try: + response_text = response["choices"][0]["message"]["content"] + except (KeyError, IndexError, TypeError): + response_text = "" + logger.info("Local Llama chat: %d chars in %.2fs", len(response_text), elapsed) + return response + + def prompt(self, user_input, response_format=None): + started = time.monotonic() + logger.info("Local Llama prompt. Preview: %r", user_input[:150]) + messages = [ + {"role": "system", "content": JSON_INSTRUCTION}, + {"role": "user", "content": user_input}, + ] + response = self.llm.create_chat_completion( + messages=messages, + max_tokens=self.max_tokens, + temperature=self.temperature, + stop=self.stop, + ) + elapsed = time.monotonic() - started + try: + response_text = response["choices"][0]["message"]["content"] + except (KeyError, IndexError, TypeError): + response_text = "" + logger.info( + "Local Llama prompt: %d chars in %.2fs", len(response_text), elapsed + ) + return response_text diff --git a/ia/providers/ollama.py b/ia/providers/ollama.py new file mode 100644 index 0000000..e9bc138 --- /dev/null +++ b/ia/providers/ollama.py @@ -0,0 +1,134 @@ +import json +import logging +import time + +import requests + +from ia.providers import JSON_INSTRUCTION + +logger = logging.getLogger(__name__) + + +class OllamaProvider: + def __init__( + self, + model, + response_format=None, + temperature=0.0, + top_p=0.1, + max_tokens=4000, + stop=None, + ): + self.model = model + self.response_format = response_format + self.temperature = temperature + self.top_p = top_p + self.max_tokens = max_tokens + self.stop = stop + + def _url(self, path="api/chat"): + return f"{self.model.url.rstrip('/')}/{path}" + + def chat(self, messages): + started = time.monotonic() + logger.info("Ollama chat. Preview: %r", messages[-1]["content"][:150]) + + options = {"temperature": self.temperature, "top_p": self.top_p} + if self.max_tokens: + options["num_predict"] = self.max_tokens + + payload = { + "model": self.model.model, + "messages": messages, + "options": options, + "stream": False, + } + if self.response_format and self.response_format.get("type") == "json_object": + payload["format"] = "json" + + try: + resp = requests.post(self._url(), json=payload, timeout=300) + resp.raise_for_status() + response_text = resp.json()["message"]["content"] + except Exception as exc: + logger.error("Ollama chat error: %s", exc) + response_text = "" + + elapsed = time.monotonic() - started + logger.info("Ollama chat: %d chars in %.2fs", len(response_text), elapsed) + return {"choices": [{"message": {"content": response_text}}]} + + def prompt(self, user_input, response_format=None): + started = time.monotonic() + logger.info("Ollama prompt. Preview: %r", user_input[:150]) + + options = {"temperature": self.temperature, "enable_thinking": False} + if self.max_tokens: + options["num_predict"] = self.max_tokens + if self.stop: + options["stop"] = self.stop + + payload = { + "model": self.model.model, + "messages": [ + {"role": "system", "content": JSON_INSTRUCTION}, + {"role": "user", "content": user_input}, + ], + "stream": False, + "options": options, + } + if response_format: + payload["format"] = response_format + elif self.response_format and self.response_format.get("type") == "json_object": + payload["format"] = "json" + + try: + resp = requests.post(self._url(), json=payload, timeout=300) + resp.raise_for_status() + response_text = resp.json().get("message", {}).get("content") or "" + elapsed = time.monotonic() - started + logger.info("Ollama prompt: %d chars in %.2fs", len(response_text), elapsed) + return response_text + except Exception as exc: + logger.error("Ollama prompt error: %s", exc) + return "" + + def chat_with_images(self, images, prompt, model=None): + started = time.monotonic() + model_name = model or self.model.model + logger.info( + "Ollama vision. Model=%s images=%d prompt=%r", + model_name, + len(images), + prompt[:150], + ) + + payload = { + "model": model_name, + "messages": [{"role": "user", "content": prompt, "images": images}], + "stream": True, + "options": {"temperature": 0.0, "num_ctx": 16384, "num_predict": 16384}, + } + try: + resp = requests.post(self._url(), json=payload, timeout=300) + resp.raise_for_status() + except Exception as exc: + logger.error("Ollama vision error: %s", exc) + return "" + + parts = [] + for line in resp.iter_lines(decode_unicode=True): + if not line: + continue + try: + chunk = json.loads(line) + content = chunk.get("message", {}).get("content", "") + if content: + parts.append(content) + except json.JSONDecodeError: + continue + + response_text = "".join(parts) + elapsed = time.monotonic() - started + logger.info("Ollama vision: %d chars in %.2fs", len(response_text), elapsed) + return response_text diff --git a/ia/references.py b/ia/references.py new file mode 100644 index 0000000..1215881 --- /dev/null +++ b/ia/references.py @@ -0,0 +1,38 @@ +import logging + +from ia.exceptions import ( + LlamaDisabledError, + LlamaModelNotFoundError, + LlamaNotInstalledError, +) +from ia.prompts.back import MESSAGES, RESPONSE_FORMAT +from ia.service import LLMService + +logger = logging.getLogger(__name__) + + +def mark_reference(reference_text): + try: + reference_marker = LLMService(MESSAGES, RESPONSE_FORMAT) + output = reference_marker.run(reference_text) + for item in output.get("choices", []): + yield item.get("message", {}).get("content", "") + + except (LlamaDisabledError, LlamaNotInstalledError, LlamaModelNotFoundError) as exc: + logger.error("Error marking reference: %s — ref=%s", exc, reference_text) + if isinstance(exc, LlamaModelNotFoundError): + yield f"Llama model file not found: {str(exc)}" + else: + yield f"Llama model is not available: {str(exc)}" + + except Exception as exc: + logger.exception("Unexpected error marking reference: ref=%s", reference_text) + yield f"An unexpected error occurred: {str(exc)}" + + +def mark_references(reference_block): + for ref_row in reference_block.split("\n"): + ref_row = ref_row.strip() + if ref_row: + choices = mark_reference(ref_row) + yield {"references": ref_row, "choices": list(choices)} diff --git a/ia/service.py b/ia/service.py new file mode 100644 index 0000000..99d4270 --- /dev/null +++ b/ia/service.py @@ -0,0 +1,68 @@ +import logging + +from ia.exceptions import LlamaModelNotFoundError +from ia.models import GeminiModel, HuggingFaceModel, OllamaModel +from ia.providers.gemini import GeminiProvider +from ia.providers.local import LocalProvider +from ia.providers.ollama import OllamaProvider + +logger = logging.getLogger(__name__) + + +class LLMService: + def __init__( + self, + messages=None, + response_format=None, + max_tokens=4000, + temperature=0.0, + top_p=0.1, + mode="chat", + nthreads=2, + stop=None, + n_ctx=None, + ): + self.messages = messages or [] + self.response_format = response_format + self.mode = mode + + provider_kwargs = { + "response_format": response_format, + "temperature": temperature, + "top_p": top_p, + "max_tokens": max_tokens, + "stop": stop, + } + + gemini = GeminiModel.objects.filter(is_active=True).first() + if gemini: + logger.info("LLMService: using Gemini") + self.provider = GeminiProvider(gemini, response_format=response_format) + return + + ollama = OllamaModel.objects.filter(is_active=True).first() + if ollama: + logger.info("LLMService: using Ollama at %s", ollama.url) + self.provider = OllamaProvider(ollama, **provider_kwargs) + return + + hf = HuggingFaceModel.objects.filter(is_active=True).first() + if hf: + logger.info("LLMService: using local Llama") + self.provider = LocalProvider( + hf, n_ctx=n_ctx, nthreads=nthreads, **provider_kwargs + ) + return + + raise LlamaModelNotFoundError("No IA model configured.") + + def run(self, user_input, response_format=None): + if self.mode == "chat": + messages = self.messages.copy() + messages.append({"role": "user", "content": user_input}) + return self.provider.chat(messages) + if self.mode == "prompt": + return self.provider.prompt( + user_input, response_format or self.response_format + ) + return "" diff --git a/ia/tasks.py b/ia/tasks.py new file mode 100644 index 0000000..696295c --- /dev/null +++ b/ia/tasks.py @@ -0,0 +1,59 @@ +import logging +import os + +from huggingface_hub import hf_hub_download, login + +from config import celery_app +from ia.models import DownloadStatus, HuggingFaceModel + +logger = logging.getLogger(__name__) + + +def _download_hf_model(hf_token, model_name, model_file): + from django.conf import settings + + if hf_token: + login(token=hf_token) + hf_hub_download( + repo_id=model_name, + filename=model_file, + local_dir=getattr( + settings, "LLAMA_MODEL_DIR", settings.ROOT_DIR / "ia/download" + ), + ) + + +@celery_app.task() +def download_model(instance_id=None): + logger.info("Download task started. instance_id=%s", instance_id) + try: + if instance_id is None: + instance = HuggingFaceModel.objects.first() + if not instance: + logger.info("No HuggingFace model found. Creating default...") + hf_token = os.getenv("HF_TOKEN", "") + instance = HuggingFaceModel.objects.create( + name_model="hugging-quants/Llama-3.2-3B-Instruct-Q4_K_M-GGUF", + name_file="llama-3.2-3b-instruct-q4_k_m.gguf", + hf_token=hf_token, + download_status=DownloadStatus.DOWNLOADING, + ) + else: + instance.download_status = DownloadStatus.DOWNLOADING + instance.save() + else: + instance = HuggingFaceModel.objects.get(pk=instance_id) + instance.download_status = DownloadStatus.DOWNLOADING + instance.save() + + logger.info("Downloading %s / %s...", instance.name_model, instance.name_file) + _download_hf_model(instance.hf_token, instance.name_model, instance.name_file) + instance.download_status = DownloadStatus.DOWNLOADED + instance.save() + logger.info("Download complete.") + except Exception as exc: + logger.error("Download failed: %s", exc, exc_info=True) + instance = locals().get("instance") + if instance: + instance.download_status = DownloadStatus.ERROR + instance.save() diff --git a/ia/templates/wagtailadmin/icons/ia-brain.svg b/ia/templates/wagtailadmin/icons/ia-brain.svg new file mode 100644 index 0000000..bacd445 --- /dev/null +++ b/ia/templates/wagtailadmin/icons/ia-brain.svg @@ -0,0 +1,11 @@ + + + + + + + + + + + diff --git a/ia/tests/__init__.py b/ia/tests/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/ia/tests/__init__.py @@ -0,0 +1 @@ + diff --git a/ia/tests/test_db_router.py b/ia/tests/test_db_router.py new file mode 100644 index 0000000..3f39049 --- /dev/null +++ b/ia/tests/test_db_router.py @@ -0,0 +1,59 @@ +from ia.db_router import IADatabaseRouter + + +class MetaStub: + def __init__(self, app_label): + self.app_label = app_label + + +class ModelStub: + def __init__(self, app_label): + self._meta = MetaStub(app_label) + + +def test_router_uses_default_when_ia_db_not_configured(settings): + settings.DATABASES.clear() + settings.DATABASES["default"] = { + "ENGINE": "django.db.backends.sqlite3", + "NAME": ":memory:", + } + router = IADatabaseRouter() + model = ModelStub("ia") + + assert router.db_for_read(model) is None + assert router.db_for_write(model) is None + assert router.allow_migrate("default", "ia") is True + assert router.allow_migrate("ia_db", "ia") is False + + +def test_router_uses_ia_db_when_configured(settings): + settings.DATABASES.clear() + settings.DATABASES["default"] = { + "ENGINE": "django.db.backends.sqlite3", + "NAME": ":memory:", + } + settings.DATABASES["ia_db"] = { + "ENGINE": "django.db.backends.sqlite3", + "NAME": ":memory:", + } + router = IADatabaseRouter() + model = ModelStub("ia") + other_model = ModelStub("xml_manager") + + assert router.db_for_read(model) == "ia_db" + assert router.db_for_write(model) == "ia_db" + assert router.db_for_read(other_model) is None + assert router.allow_migrate("ia_db", "ia") is True + assert router.allow_migrate("default", "ia") is False + assert router.allow_migrate("default", "xml_manager") is None + + +def test_router_allows_relations_with_ia_models(): + router = IADatabaseRouter() + ia_obj = ModelStub("ia") + other_obj = ModelStub("xml_manager") + no_ia_obj_a = ModelStub("users") + no_ia_obj_b = ModelStub("core") + + assert router.allow_relation(ia_obj, other_obj) is True + assert router.allow_relation(no_ia_obj_a, no_ia_obj_b) is None diff --git a/ia/tests/test_forms.py b/ia/tests/test_forms.py new file mode 100644 index 0000000..39683f8 --- /dev/null +++ b/ia/tests/test_forms.py @@ -0,0 +1,57 @@ +import pytest +from django.contrib.auth import get_user_model + +from ia.forms import IAAdminModelForm +from ia.models import GeminiModel + +pytestmark = pytest.mark.django_db + + +class GeminiAdminModelForm(IAAdminModelForm): + formsets = {} + + class Meta: + model = GeminiModel + fields = ("api_key", "is_active") + + +def test_save_all_sets_creator_on_create(): + user = get_user_model().objects.create_user( + username="creator-user", + email="creator@example.com", + password="secret", + ) + form = GeminiAdminModelForm(data={"api_key": "key-a", "is_active": True}) + + assert form.is_valid() + instance = form.save_all(user) + + instance.refresh_from_db() + assert instance.creator_id == user.id + assert instance.updated_by_id is None + + +def test_save_all_sets_updated_by_on_update(): + User = get_user_model() + creator = User.objects.create_user( + username="first-user", + email="first@example.com", + password="secret", + ) + updater = User.objects.create_user( + username="updater-user", + email="updater@example.com", + password="secret", + ) + instance = GeminiModel.objects.create(api_key="key-a", creator=creator) + form = GeminiAdminModelForm( + data={"api_key": "key-b", "is_active": False}, instance=instance + ) + + assert form.is_valid() + saved = form.save_all(updater) + + saved.refresh_from_db() + assert saved.creator_id == creator.id + assert saved.updated_by_id == updater.id + assert saved.api_key == "key-b" diff --git a/ia/tests/test_models.py b/ia/tests/test_models.py new file mode 100644 index 0000000..692b23a --- /dev/null +++ b/ia/tests/test_models.py @@ -0,0 +1,54 @@ +import pytest + +from ia.models import GeminiModel, HuggingFaceModel, OllamaModel + +pytestmark = pytest.mark.django_db + + +def test_huggingface_active_is_unique(): + first = HuggingFaceModel.objects.create( + name_model="repo/a", + name_file="a.gguf", + is_active=True, + ) + second = HuggingFaceModel.objects.create( + name_model="repo/b", + name_file="b.gguf", + is_active=True, + ) + + first.refresh_from_db() + second.refresh_from_db() + + assert first.is_active is False + assert second.is_active is True + + +def test_ollama_active_is_unique(): + first = OllamaModel.objects.create( + url="http://localhost:11434", + model="llama3.2", + is_active=True, + ) + second = OllamaModel.objects.create( + url="http://localhost:11434", + model="qwen2.5", + is_active=True, + ) + + first.refresh_from_db() + second.refresh_from_db() + + assert first.is_active is False + assert second.is_active is True + + +def test_gemini_active_is_unique(): + first = GeminiModel.objects.create(api_key="key-1", is_active=True) + second = GeminiModel.objects.create(api_key="key-2", is_active=True) + + first.refresh_from_db() + second.refresh_from_db() + + assert first.is_active is False + assert second.is_active is True diff --git a/ia/tests/test_references.py b/ia/tests/test_references.py new file mode 100644 index 0000000..2c20088 --- /dev/null +++ b/ia/tests/test_references.py @@ -0,0 +1,49 @@ +from ia.references import mark_reference, mark_references + + +class ServiceStub: + def __init__(self, *_args, **_kwargs): + pass + + def run(self, _reference_text): + return { + "choices": [ + {"message": {"content": '{"reftype":"journal"}'}}, + ] + } + + +class ServiceFailureStub: + def __init__(self, *_args, **_kwargs): + pass + + def run(self, _reference_text): + raise ValueError("boom") + + +def test_mark_reference_returns_choices(monkeypatch): + monkeypatch.setattr("ia.references.LLMService", ServiceStub) + + result = list(mark_reference("A reference")) + + assert result == ['{"reftype":"journal"}'] + + +def test_mark_reference_returns_error_message_on_unexpected_exception(monkeypatch): + monkeypatch.setattr("ia.references.LLMService", ServiceFailureStub) + + result = list(mark_reference("A reference")) + + assert len(result) == 1 + assert result[0].startswith("An unexpected error occurred: ") + + +def test_mark_references_processes_non_empty_lines(monkeypatch): + monkeypatch.setattr("ia.references.LLMService", ServiceStub) + + result = list(mark_references("Ref A\n\nRef B")) + + assert len(result) == 2 + assert result[0]["references"] == "Ref A" + assert result[1]["references"] == "Ref B" + assert result[0]["choices"] == ['{"reftype":"journal"}'] diff --git a/ia/tests/test_service.py b/ia/tests/test_service.py new file mode 100644 index 0000000..d801e82 --- /dev/null +++ b/ia/tests/test_service.py @@ -0,0 +1,86 @@ +import pytest + +from ia.exceptions import LlamaModelNotFoundError +from ia.models import GeminiModel, HuggingFaceModel, OllamaModel +from ia.service import LLMService + +pytestmark = pytest.mark.django_db + + +class DummyProvider: + def __init__(self, *args, **kwargs): + self.args = args + self.kwargs = kwargs + + def chat(self, messages): + return {"messages": messages, "provider": "chat"} + + def prompt(self, user_input, response_format): + return { + "user_input": user_input, + "response_format": response_format, + "provider": "prompt", + } + + +def test_service_prefers_gemini(monkeypatch): + GeminiModel.objects.create(api_key="gemini-key", is_active=True) + OllamaModel.objects.create( + url="http://localhost:11434", + model="llama3.2", + is_active=True, + ) + HuggingFaceModel.objects.create( + name_model="repo/model", + name_file="model.gguf", + is_active=True, + ) + + monkeypatch.setattr("ia.service.GeminiProvider", DummyProvider) + monkeypatch.setattr("ia.service.OllamaProvider", DummyProvider) + monkeypatch.setattr("ia.service.LocalProvider", DummyProvider) + + service = LLMService(messages=[{"role": "system", "content": "s"}], mode="chat") + response = service.run("hello") + + assert response["provider"] == "chat" + assert response["messages"][-1] == {"role": "user", "content": "hello"} + assert isinstance(service.provider, DummyProvider) + assert isinstance(service.provider.args[0], GeminiModel) + + +def test_service_uses_ollama_when_no_gemini(monkeypatch): + OllamaModel.objects.create( + url="http://localhost:11434", + model="llama3.2", + is_active=True, + ) + + monkeypatch.setattr("ia.service.OllamaProvider", DummyProvider) + + service = LLMService(mode="prompt", response_format={"type": "json_object"}) + response = service.run("hello", response_format={"type": "json_object"}) + + assert response["provider"] == "prompt" + assert response["user_input"] == "hello" + assert isinstance(service.provider.args[0], OllamaModel) + + +def test_service_uses_local_when_only_hf(monkeypatch): + HuggingFaceModel.objects.create( + name_model="repo/model", + name_file="model.gguf", + is_active=True, + ) + + monkeypatch.setattr("ia.service.LocalProvider", DummyProvider) + + service = LLMService(mode="prompt") + service.run("hello") + + assert isinstance(service.provider.args[0], HuggingFaceModel) + + +def test_service_raises_when_no_model(): + with pytest.raises(LlamaModelNotFoundError): + LLMService() diff --git a/ia/tests/test_urls.py b/ia/tests/test_urls.py new file mode 100644 index 0000000..5eeec78 --- /dev/null +++ b/ia/tests/test_urls.py @@ -0,0 +1,102 @@ +import pytest +from django.urls import reverse + +pytestmark = pytest.mark.django_db + + +class ResponseStub: + def __init__(self, payload): + self._payload = payload + + def raise_for_status(self): + return None + + def json(self): + return self._payload + + +def test_ollama_tags_returns_models(client, monkeypatch): + def fake_get(url, timeout): + assert url == "http://localhost:11434/api/tags" + assert timeout == 10 + return ResponseStub({"models": [{"name": "llama3.2"}, {"name": "qwen2.5"}]}) + + monkeypatch.setattr("ia.wagtail_hooks.req.get", fake_get) + + response = client.get( + reverse("ia:ollama_tags"), + {"url": "http://localhost:11434"}, + ) + + assert response.status_code == 200 + assert response.json() == {"models": ["llama3.2", "qwen2.5"]} + + +def test_ollama_tags_requires_url(client): + response = client.get(reverse("ia:ollama_tags")) + + assert response.status_code == 400 + assert response.json() == {"error": "URL is required"} + + +def test_ollama_model_info_returns_context(client, monkeypatch): + payload = { + "details": {"parameter_size": "8B"}, + "model_info": {"llama.context_length": 8192}, + } + + def fake_post(url, json, timeout): + assert url == "http://localhost:11434/api/show" + assert json == {"name": "llama3.2"} + assert timeout == 15 + return ResponseStub(payload) + + monkeypatch.setattr("ia.wagtail_hooks.req.post", fake_post) + + response = client.get( + reverse("ia:ollama_model_info"), + {"url": "http://localhost:11434", "model": "llama3.2"}, + ) + + assert response.status_code == 200 + assert response.json() == {"context_length": 8192, "parameter_size": "8B"} + + +def test_ollama_model_info_requires_params(client): + response = client.get(reverse("ia:ollama_model_info"), {"url": "http://localhost"}) + + assert response.status_code == 400 + assert response.json() == {"error": "URL and model are required"} + + +def test_ollama_tags_returns_502_on_upstream_error(client, monkeypatch): + def fake_get(url, timeout): + raise RuntimeError("upstream unavailable") + + monkeypatch.setattr("ia.wagtail_hooks.req.get", fake_get) + response = client.get( + reverse("ia:ollama_tags"), + {"url": "http://localhost:11434"}, + ) + + assert response.status_code == 502 + assert "upstream unavailable" in response.json()["error"] + + +def test_ollama_model_info_without_context_length(client, monkeypatch): + payload = { + "details": {"parameter_size": "7B"}, + "model_info": {"other_key": 2048}, + } + + def fake_post(url, json, timeout): + return ResponseStub(payload) + + monkeypatch.setattr("ia.wagtail_hooks.req.post", fake_post) + response = client.get( + reverse("ia:ollama_model_info"), + {"url": "http://localhost:11434", "model": "llama3.2"}, + ) + + assert response.status_code == 200 + assert response.json() == {"context_length": None, "parameter_size": "7B"} diff --git a/ia/urls.py b/ia/urls.py new file mode 100644 index 0000000..0571b4d --- /dev/null +++ b/ia/urls.py @@ -0,0 +1,10 @@ +from django.urls import path + +from ia.wagtail_hooks import ollama_model_info, ollama_tags + +app_name = "ia" + +urlpatterns = [ + path("ollama-tags/", ollama_tags, name="ollama_tags"), + path("ollama-model-info/", ollama_model_info, name="ollama_model_info"), +] diff --git a/ia/utils/__init__.py b/ia/utils/__init__.py new file mode 100644 index 0000000..8b13789 --- /dev/null +++ b/ia/utils/__init__.py @@ -0,0 +1 @@ + diff --git a/ia/utils/blocks.py b/ia/utils/blocks.py new file mode 100644 index 0000000..567454b --- /dev/null +++ b/ia/utils/blocks.py @@ -0,0 +1,63 @@ +from ia.utils.normalizers import stz_language + + +def plain_paragraph_text(value): + return str(value or "") + + +def block_parts(item): + if hasattr(item, "block_type") and hasattr(item, "value"): + return item.block_type, dict(item.value) + if isinstance(item, dict): + return item.get("type", ""), item.get("value") or {} + return "", {} + + +def block_text(value): + text = ( + value.get("paragraph") + or value.get("text_aff") + or value.get("original") + or value.get("title") + or "" + ) + return plain_paragraph_text(str(text)).strip() + + +def make_block(label, text): + return {"type": "paragraph", "value": {"label": label, "paragraph": text or ""}} + + +def make_lang_block(label, text, language): + return { + "type": "paragraph_with_language", + "value": { + "label": label, + "language": stz_language(language, "en"), + "paragraph": text or "", + }, + } + + +def source_blocks(front, body, limit=60): + blocks = [] + body_items = body[:25] if isinstance(body, list) else body + for section, items in (("front", front), ("body", body_items)): + for index, item in enumerate(items or []): + block_type, value = block_parts(item) + text = block_text(value) + if not text: + continue + max_text = 6000 if limit >= 60 else 2000 + blocks.append( + { + "section": section, + "index": index, + "type": block_type, + "label": value.get("label", ""), + "text": text[:max_text], + } + ) + if len(blocks) >= limit: + return blocks + return blocks diff --git a/ia/utils/json.py b/ia/utils/json.py new file mode 100644 index 0000000..1fc2db8 --- /dev/null +++ b/ia/utils/json.py @@ -0,0 +1,80 @@ +import json +import re + + +def fix_unicode_escapes(text): + result = [] + index = 0 + while index < len(text): + if text[index : index + 2] == "\\u": + after = text[index + 2 : index + 6] + if len(after) == 4 and all(c in "0123456789abcdefABCDEF" for c in after): + result.append(text[index : index + 6]) + index += 6 + continue + result.append("\\\\u") + index += 2 + continue + result.append(text[index]) + index += 1 + return "".join(result) + + +def extract_balanced_json(text): + start = text.find("{") + while start != -1: + depth = 0 + in_string = False + escape = False + for index in range(start, len(text)): + char = text[index] + if escape: + escape = False + continue + if char == "\\": + escape = True + continue + if char == '"' and not escape: + in_string = not in_string + continue + if in_string: + continue + if char == "{": + depth += 1 + elif char == "}": + depth -= 1 + if depth == 0: + candidate = text[start : index + 1] + try: + return json.loads(candidate) + except json.JSONDecodeError: + break + start = text.find("{", start + 1) + return None + + +def try_parse_json(text): + text = str(text or "").strip() + if not text: + raise ValueError("empty_response") + + text = re.sub(r"^```(?:json)?\s*\n?", "", text) + text = re.sub(r"\n?\s*```$", "", text) + + try: + return json.loads(text) + except json.JSONDecodeError: + pass + + sanitized = fix_unicode_escapes(text) + try: + return json.loads(sanitized) + except json.JSONDecodeError: + pass + + for source in (text, sanitized): + result = extract_balanced_json(source) + if result is not None: + return result + + raise ValueError("invalid_response") diff --git a/ia/utils/normalizers.py b/ia/utils/normalizers.py new file mode 100644 index 0000000..d688ad4 --- /dev/null +++ b/ia/utils/normalizers.py @@ -0,0 +1,51 @@ +import re + +DOI_RE = re.compile(r"\b10\.\d{4,}(?:\.\d+)*\/[^\s\"]+", re.I) +ORCID_RE = re.compile(r"\d{4}-\d{4}-\d{4}-\d{3}[X\d]") + + +def stz_text(value): + return str(value or "").strip() + + +def stz_norm(value): + return re.sub(r"\s+", " ", str(value or "").strip().lower()) + + +def stz_language(code, fallback="en"): + code = (code or "").strip()[:2].lower() + return code if len(code) == 2 and code.isalpha() else fallback + + +def stz_country_code(value): + code = (value or "").strip().upper()[:2] + return code if len(code) == 2 and code.isalpha() else "" + + +def stz_affiliation_id(value): + if isinstance(value, (list, tuple)): + return [str(item) for item in value if item] + if value: + return [str(value)] + return [] + + +def stz_first_number(value): + match = re.search(r"\d+", str(value or "")) + return match.group(0) if match else "" + + +def stz_date(value): + text = stz_text(value) + if not text: + return "" + match = re.match(r"(\d{4})(?:-(\d{2})(?:-(\d{2}))?)?", text) + if not match: + return "" + year, month, day = match.group(1), match.group(2), match.group(3) + if month: + month = str(min(max(int(month), 1), 12)).zfill(2) + if day: + day = str(min(max(int(day), 1), 31)).zfill(2) + parts = [year, month, day] if month else [year] + return "-".join(part for part in parts if part) diff --git a/ia/utils/text.py b/ia/utils/text.py new file mode 100644 index 0000000..d30523a --- /dev/null +++ b/ia/utils/text.py @@ -0,0 +1,41 @@ +import logging +import os +import zipfile + +from lxml import etree + +logger = logging.getLogger(__name__) + + +def extract_text_from_docx(docx_path, limit_chars=30000): + logger.info("Text extractor: using zipfile for %s", os.path.basename(docx_path)) + with zipfile.ZipFile(docx_path) as archive: + xml_bytes = archive.read("word/document.xml") + + root = etree.fromstring(xml_bytes) + nsmap = {"w": "http://schemas.openxmlformats.org/wordprocessingml/2006/main"} + + paragraphs = [] + for paragraph in root.xpath("//w:p | //w:tc//w:p", namespaces=nsmap): + text = "".join( + t.text or "" for t in paragraph.xpath(".//w:t", namespaces=nsmap) + ) + text = text.strip() + if text: + paragraphs.append(text) + + result = "\n".join(paragraphs)[:limit_chars] + logger.info("Text extractor: zipfile produced %d chars", len(result)) + return result + + +def extract_text_from_docx_via_docling(docx_path, limit_chars=30000): + from docling.document_converter import DocumentConverter + + logger.info("Text extractor: using docling for %s", os.path.basename(docx_path)) + converter = DocumentConverter() + result = converter.convert(docx_path) + text = result.document.export_to_text() + result_text = text[:limit_chars] + logger.info("Text extractor: docling produced %d chars", len(result_text)) + return result_text diff --git a/ia/utils/vision.py b/ia/utils/vision.py new file mode 100644 index 0000000..2cd20d9 --- /dev/null +++ b/ia/utils/vision.py @@ -0,0 +1,91 @@ +import base64 +import logging +import os +import subprocess +import tempfile + +logger = logging.getLogger(__name__) + + +def _convert_to_pdf(docx_path, pdf_path): + try: + subprocess.run( + [ + "soffice", + "--headless", + "--convert-to", + "pdf", + "--outdir", + os.path.dirname(pdf_path), + docx_path, + ], + capture_output=True, + timeout=60, + ) + generated = os.path.join( + os.path.dirname(pdf_path), + os.path.splitext(os.path.basename(docx_path))[0] + ".pdf", + ) + if os.path.isfile(generated) and generated != pdf_path: + os.rename(generated, pdf_path) + except Exception as exc: + logger.warning("Failed to convert DOCX to PDF: %s", exc) + + +def _pdf_page_count(pdf_path): + try: + result = subprocess.run( + ["pdfinfo", pdf_path], capture_output=True, text=True, timeout=15 + ) + for line in result.stdout.splitlines(): + if line.startswith("Pages:"): + return int(line.split(":")[1].strip()) + except Exception: + pass + return 0 + + +def _pdf_page_to_png(pdf_path, page_num, output_path): + try: + subprocess.run( + [ + "pdftoppm", + "-f", + str(page_num), + "-l", + str(page_num), + "-r", + "120", + "-png", + "-singlefile", + pdf_path, + os.path.splitext(output_path)[0], + ], + capture_output=True, + timeout=30, + ) + except Exception as exc: + logger.warning("Failed to convert PDF page %d to PNG: %s", page_num, exc) + + +def extract_page_images(docx_path, max_pages=3): + with tempfile.TemporaryDirectory() as tmpdir: + pdf_path = os.path.join(tmpdir, "document.pdf") + _convert_to_pdf(docx_path, pdf_path) + if not os.path.isfile(pdf_path): + return [] + + total_pages = _pdf_page_count(pdf_path) + num_pages = min(max_pages, total_pages) + if num_pages == 0: + return [] + + images = [] + for page_num in range(1, num_pages + 1): + png_path = os.path.join(tmpdir, f"page-{page_num}.png") + _pdf_page_to_png(pdf_path, page_num, png_path) + if os.path.isfile(png_path): + with open(png_path, "rb") as file: + b64 = base64.b64encode(file.read()).decode() + images.append(b64) + return images diff --git a/ia/wagtail_hooks.py b/ia/wagtail_hooks.py new file mode 100644 index 0000000..d7c2cef --- /dev/null +++ b/ia/wagtail_hooks.py @@ -0,0 +1,263 @@ +import logging + +import requests as req +from django.contrib import messages +from django.http import HttpResponseRedirect, JsonResponse +from django.utils.translation import gettext_lazy as _ +from wagtail import hooks +from wagtail.snippets.models import register_snippet +from wagtail.snippets.views.snippets import ( + CreateView, + EditView, + SnippetViewSet, + SnippetViewSetGroup, +) + +from config.menu import get_menu_order +from ia.models import DownloadStatus, GeminiModel, HuggingFaceModel, OllamaModel + +logger = logging.getLogger(__name__) + + +def ollama_tags(request): + url = request.GET.get("url", "").strip().rstrip("/") + if not url: + return JsonResponse({"error": "URL is required"}, status=400) + try: + logger.info("Fetching Ollama tags from %s/api/tags", url) + resp = req.get(f"{url}/api/tags", timeout=10) + resp.raise_for_status() + data = resp.json() + models = data.get("models", []) + tags = [item["name"] for item in models] + logger.info("Ollama tags: %d found", len(tags)) + return JsonResponse({"models": tags}) + except Exception as exc: + logger.error("Ollama tags error: %s", exc) + return JsonResponse({"error": str(exc)}, status=502) + + +def ollama_model_info(request): + url = request.GET.get("url", "").strip().rstrip("/") + model = request.GET.get("model", "").strip() + if not url or not model: + return JsonResponse({"error": "URL and model are required"}, status=400) + try: + logger.info("Fetching Ollama model info for %s from %s", model, url) + resp = req.post(f"{url}/api/show", json={"name": model}, timeout=15) + resp.raise_for_status() + data = resp.json() + info = { + "context_length": None, + "parameter_size": data.get("details", {}).get("parameter_size"), + } + model_info = data.get("model_info", {}) + for key in model_info: + if "context_length" in key or "num_ctx" in key: + info["context_length"] = model_info[key] + break + logger.info("Ollama model info: %s", info) + return JsonResponse(info) + except Exception as exc: + logger.error("Ollama model info error: %s", exc) + return JsonResponse({"error": str(exc)}, status=502) + + +class HFModelCreateView(CreateView): + def form_invalid(self, form): + self.produced_error_message = True + return super().form_invalid(form) + + def form_valid(self, form): + self.object = form.save_all(self.request.user) + if self.object.hf_token: + self.object.download_status = DownloadStatus.DOWNLOADING + self.object.save() + from ia.tasks import download_model + + download_model.delay(self.object.pk) + messages.success(self.request, _("Model created, download started.")) + else: + messages.success(self.request, _("Model created. Add a token to download.")) + return HttpResponseRedirect(self.get_success_url()) + + +class HFModelEditView(EditView): + def form_invalid(self, form): + self.produced_error_message = True + return super().form_invalid(form) + + def form_valid(self, form): + form.instance.updated_by = self.request.user + form.instance.save() + if ( + form.instance.hf_token + and form.instance.download_status != DownloadStatus.DOWNLOADING + ): + form.instance.download_status = DownloadStatus.DOWNLOADING + form.instance.save() + from ia.tasks import download_model + + download_model.delay(form.instance.pk) + messages.success(self.request, _("Download started.")) + else: + messages.success(self.request, _("Model updated.")) + return HttpResponseRedirect(self.get_success_url()) + + +class HuggingFaceViewSet(SnippetViewSet): + model = HuggingFaceModel + add_view_class = HFModelCreateView + edit_view_class = HFModelEditView + menu_label = _("HuggingFace") + menu_icon = "download" + list_display = ("__str__", "get_download_status_display", "is_active") + search_fields = ("name_model", "name_file") + + +class OllamaModelCreateView(CreateView): + def form_invalid(self, form): + self.produced_error_message = True + return super().form_invalid(form) + + def form_valid(self, form): + self.object = form.save_all(self.request.user) + messages.success(self.request, _("Ollama model created.")) + return HttpResponseRedirect(self.get_success_url()) + + +class OllamaModelEditView(EditView): + def form_invalid(self, form): + self.produced_error_message = True + return super().form_invalid(form) + + def form_valid(self, form): + form.instance.updated_by = self.request.user + form.instance.save() + messages.success(self.request, _("Ollama model updated.")) + return HttpResponseRedirect(self.get_success_url()) + + +class OllamaViewSet(SnippetViewSet): + model = OllamaModel + add_view_class = OllamaModelCreateView + edit_view_class = OllamaModelEditView + menu_label = _("Ollama") + menu_icon = "link-external" + list_display = ("__str__", "url", "is_active") + + +class GeminiCreateView(CreateView): + def form_invalid(self, form): + self.produced_error_message = True + return super().form_invalid(form) + + def form_valid(self, form): + self.object = form.save_all(self.request.user) + messages.success(self.request, _("Gemini model created.")) + return HttpResponseRedirect(self.get_success_url()) + + +class GeminiEditView(EditView): + def form_invalid(self, form): + self.produced_error_message = True + return super().form_invalid(form) + + def form_valid(self, form): + form.instance.updated_by = self.request.user + form.instance.save() + messages.success(self.request, _("Gemini model updated.")) + return HttpResponseRedirect(self.get_success_url()) + + +class GeminiViewSet(SnippetViewSet): + model = GeminiModel + add_view_class = GeminiCreateView + edit_view_class = GeminiEditView + menu_label = _("Gemini") + menu_icon = "key" + list_display = ("__str__", "is_active") + + +class IAModelGroup(SnippetViewSetGroup): + menu_name = "ia" + menu_label = _("IA Models") + menu_icon = "ia-brain" + menu_order = get_menu_order("ia") + add_to_admin_menu = True + add_to_settings_menu = True + items = (HuggingFaceViewSet, OllamaViewSet, GeminiViewSet) + + +register_snippet(IAModelGroup) + + +@hooks.register("insert_global_admin_js") +def ia_model_admin_js(): + return """""" + + +@hooks.register("register_icons") +def register_ia_icons(icons): + return icons + ["wagtailadmin/icons/ia-brain.svg"] diff --git "a/instru\303\247\303\265es.txt" "b/instru\303\247\303\265es.txt" new file mode 100644 index 0000000..42bbe6e --- /dev/null +++ "b/instru\303\247\303\265es.txt" @@ -0,0 +1,34 @@ +Preciso criar uma sequencia de issue no github como tarefa. + + +IA: + + Crie todas os nomes das aplicações no singular. + + + Todas as aplicações mencionadas aqui estão em ../scielo-tools-initial + + Preciso trazer a aplicação de ia para esse projeto. + + * preciso garanti que seja uma aplicação django auto suficiente. + * que tenha banco de dados próprio (migrations própria). + * que tenha suas próprias telas no wagtail. + * que tenha suas próprias APIs. + * que possa ser reutilizada em outro projeto. + + Faça somente o isso dessa atividade por agora + + + +refenreces: + + Dentro de manuscripts existe um fluxo/código que realizar a marcação das referencias. + + Preciso ter a marcação das referencias de forma isolada e reutilizada em outras aplicações + + Preciso ter uma API REST que possa utilizar dessa marcação. + + Suspeito fortemente que estamos utilizando diversos modelos de IA para realizar essa marcação. + + + Faça somente o isso dessa atividade por agora diff --git a/pytest.ini b/pytest.ini index b8c192d..d19bf9d 100644 --- a/pytest.ini +++ b/pytest.ini @@ -3,3 +3,6 @@ DJANGO_SETTINGS_MODULE = config.settings.test python_files = tests.py test_*.py *_tests.py pythonpath = . addopts = --reuse-db +filterwarnings = + ignore:pkg_resources is deprecated as an API:UserWarning:packtools.catalogs + ignore:Deprecated call to `pkg_resources.declare_namespace\('google'\)`:DeprecationWarning:pkg_resources diff --git a/requirements/base.txt b/requirements/base.txt index 6977f3c..e644f05 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -1,4 +1,4 @@ -setuptools>=68.2.2,<82 +setuptools>=68.2.2,<81 whitenoise==6.12.0 # https://github.com/evansd/whitenoise redis==7.4.0 # https://github.com/redis/redis-py celery==5.3.6 # pyup: < 6.0 # https://github.com/celery/celery @@ -24,6 +24,8 @@ git+https://git@github.com/scieloorg/packtools@4.12.6#egg=packtools tenacity==8.2.3 langdetect~=1.0.9 requests>=2.31.0 +google-generativeai +huggingface_hub # Kombu # ------------------------------------------------------------------------------ diff --git a/requirements/local.txt b/requirements/local.txt index fed4d03..df6a0b3 100644 --- a/requirements/local.txt +++ b/requirements/local.txt @@ -10,5 +10,4 @@ django-debug-toolbar # https://github.com/jazzband/django-debug-toolbar pytest==9.0.3 pytest-django==4.11.1 pytest-cov==7.1.0 -coverage==7.10.6 -django-coverage-plugin==3.1.0 \ No newline at end of file +coverage==7.10.6 \ No newline at end of file diff --git a/setup.cfg b/setup.cfg index e38b3c6..c56cd12 100644 --- a/setup.cfg +++ b/setup.cfg @@ -34,21 +34,24 @@ django_settings_module = config.settings.test ignore_errors = True [coverage:run] -include = - core/* - core_settings/* - users/* +source = + core + core_settings + users + xml_manager + ia + config omit = *migrations* *tests* */tests/* */templates/* -plugins = - django_coverage_plugin [coverage:report] -fail_under = 100 include = core/* core_settings/* users/* + xml_manager/* + ia/* + config/*