diff --git a/.gitignore b/.gitignore index 1fc8657..6c4c62f 100644 --- a/.gitignore +++ b/.gitignore @@ -19,7 +19,6 @@ __pycache__ /config.yml #generate during install step compose.db -compose.yml /learn2rag/pipeline/data/loaded_documents_kcenter.json /.env /.local.envrc diff --git a/learn2rag/compose/__init__.py b/learn2rag/compose/__init__.py index b735789..688c969 100644 --- a/learn2rag/compose/__init__.py +++ b/learn2rag/compose/__init__.py @@ -10,6 +10,7 @@ import urllib.request from typing import Any, Optional +import jinja2 import psutil import yaml @@ -109,9 +110,18 @@ class Project(): content: dict[str, Any] @staticmethod - def create(project_file: str | Path, name: str) -> 'Project | None': - with open(project_file) as f: - content = yaml.safe_load(f) + def create( + compose_file: str | Path, + name: str, + *, + template: bool = False, + template_context: dict[str, Any] = {}, + ) -> 'Project | None': + if template: + content = yaml.safe_load(jinja2.Template(Path(compose_file).read_text()).render(template_context)) + else: + with open(compose_file) as f: + content = yaml.safe_load(f) assert len(content['services']) > 0 cur = con.cursor() cur.execute('BEGIN EXCLUSIVE') diff --git a/learn2rag/pipeline/llm.py b/learn2rag/pipeline/llm.py index 18cc061..0060049 100644 --- a/learn2rag/pipeline/llm.py +++ b/learn2rag/pipeline/llm.py @@ -2,22 +2,33 @@ import os from pydantic import SecretStr from langchain_core.language_models.chat_models import BaseChatModel +from langchain_core.messages import AIMessage, BaseMessage, SystemMessage +from langchain_core.outputs import ChatGeneration, ChatResult from langchain_ollama import ChatOllama from langchain_openai import ChatOpenAI +from typing import Any, ClassVar logger = logging.getLogger(__name__) class LLMClient(): - # ID is used as a key to store in user data, should not be changed ID: str - # LABEL is a display label for user interface - LABEL: str + '''A key stored in user data, must not be changed''' + + LABEL: str | None + ''' + A display label for the interface. + If None, the option would be excluded from the interface. + ''' + chat_model: BaseChatModel llms = {} +'''A dict holding supported LLM client classes''' + + def llm_client(cls: type[LLMClient]) -> type[LLMClient]: llms[cls.ID] = cls; return cls @@ -25,6 +36,7 @@ def llm_client(cls: type[LLMClient]) -> type[LLMClient]: # First @llm_client would be the default in UI when adding an external model @llm_client class OpenAIClient(LLMClient): + '''A LLM client based on OpenAI API''' ID = 'ChatOpenAI' LABEL = 'OpenAI' @@ -39,6 +51,7 @@ def __init__(self, *, url: str, token: SecretStr, model: str, proxy: str | None) @llm_client class OllamaClient(LLMClient): + '''A LLM client based on Ollama API''' ID = 'ChatOllama' LABEL = 'Ollama' @@ -54,7 +67,44 @@ def __init__(self, *, url: str, token: str | None, model: str, proxy: str | None ) +class TestFakeChatModel(BaseChatModel): + ''' + A mock BaseChatModel implementation. + Responds with the full content of the system prompt. + ''' + hint: ClassVar[str] = 'This is an internal model used for testing only.' + + @property + def _llm_type(self) -> str: return 'test_fake_chat_model' + + def _generate( + self, + messages: list[BaseMessage], + stop: list[str] | None = None, + run_manager: Any = None, + **kwargs: Any + ) -> ChatResult: + assert isinstance(messages[0], SystemMessage) + content = f'{self.hint} {messages[0].content}' + return ChatResult( + generations=[ + ChatGeneration(message=AIMessage(content=content)), + ], + ) + + +@llm_client +class FakeClient(LLMClient): + '''A mock LLM client to use only in tests''' + ID = 'ChatFake' + LABEL = None + + def __init__(self, *, url: str, token: str | None, model: str, proxy: str | None) -> None: + self.chat_model = TestFakeChatModel() + + def chat_model_from_env() -> BaseChatModel: + '''Returns an instance of LLM client based on the environment variables''' default_llm = OpenAIClient llm_id = os.environ.get('LLM_API_TYPE', default_llm.ID) logger.debug('Using LLM: %s', llm_id) diff --git a/learn2rag/pipeline/qdrant.py b/learn2rag/pipeline/qdrant.py index 4939d4a..510507e 100644 --- a/learn2rag/pipeline/qdrant.py +++ b/learn2rag/pipeline/qdrant.py @@ -6,13 +6,16 @@ from .config import user_config +api_key = os.environ.get('QDRANT__SERVICE__API_KEY') +path = os.environ.get('QDRANT_PATH') or None +location = None if path else os.environ.get('QDRANT_LOCATION', 'http://localhost:6336') + class Qdrant: client = QdrantClient( - host="localhost", - port=int(os.environ.get('QDRANT__SERVICE__HTTP_PORT', 6336)), - api_key=os.environ.get('QDRANT__SERVICE__API_KEY'), - https=False, + location=location, + api_key=api_key, + path=path, ) def __init__(self, collection_name: str, opt_config: dict[str, Any]) -> None: diff --git a/learn2rag/tests/__init__.py b/learn2rag/tests/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/learn2rag/tests/data/rabbits.txt b/learn2rag/tests/data/rabbits.txt new file mode 100644 index 0000000..3d814e5 --- /dev/null +++ b/learn2rag/tests/data/rabbits.txt @@ -0,0 +1,7 @@ +Rabbits or bunnies are small mammals in the family Leporidae (which also includes the hares), which is in the order Lagomorpha (which also includes pikas). They are familiar throughout the world as a small herbivore, a prey animal, a domesticated form of livestock, and a pet, having a widespread effect on ecologies and cultures. The most widespread rabbit genera are Oryctolagus and Sylvilagus. The former, Oryctolagus, includes the European rabbit, Oryctolagus cuniculus, which is the ancestor of the hundreds of breeds of domestic rabbit and has been introduced on every continent except Antarctica. The latter, Sylvilagus, includes over 13 wild rabbit species, among them the cottontails and tapetis. Wild rabbits not included in Oryctolagus and Sylvilagus include several species of limited distribution, including the pygmy rabbit, volcano rabbit, and Sumatran striped rabbit. + +Rabbits are a paraphyletic grouping, and do not constitute a clade, as hares (belonging to the genus Lepus) are nested within the Leporidae clade and are not described as rabbits. Although once considered rodents, lagomorphs diverged earlier and have a number of traits rodents lack, including two extra incisors. Similarities between rabbits and rodents were once attributed to convergent evolution, but studies in molecular biology have found a common ancestor between lagomorphs and rodents and place them in the clade Glires. + +Rabbit physiology is suited to escaping predators and surviving in various habitats, living either alone or in groups in nests or burrows. As prey animals, rabbits are constantly aware of their surroundings, having a wide field of vision and ears with high surface area to detect potential predators. The ears of a rabbit are essential for thermoregulation and contain a high density of blood vessels. The bone structure of a rabbit's hind legs, which is longer than that of the fore legs, allows for quick hopping, which is beneficial for escaping predators and can provide powerful kicks if captured. Rabbits are typically nocturnal and often sleep with their eyes open. They reproduce quickly, having short pregnancies, large litters of four to twelve kits, and no particular mating season; however, the mortality rate of rabbit embryos is high, and there exist several widespread diseases that affect rabbits, such as rabbit hemorrhagic disease and myxomatosis. In some regions, especially Australia, rabbits have caused ecological problems and are regarded as a pest. + +Humans have used rabbits as livestock since at least the first century BC in ancient Rome, raising them for their meat, fur and wool. The various breeds of the European rabbit have been developed to suit each of these products; the practice of raising and breeding rabbits as livestock is known as cuniculture. Rabbits are seen in human culture globally, appearing as a symbol of fertility, cunning, and innocence in major religions, historical and contemporary art. diff --git a/learn2rag/tests/test_learn2rag.py b/learn2rag/tests/test_learn2rag.py new file mode 100644 index 0000000..095de51 --- /dev/null +++ b/learn2rag/tests/test_learn2rag.py @@ -0,0 +1,102 @@ +import shutil +from pathlib import Path +from unittest import TestCase +from typing import Any + +from ..compose import Project +from ..utils import is_windows, save_data_path, waitUntil + +from openai import APIConnectionError, OpenAI + +template_dir = Path(__file__).resolve().parent.parent / 'ui' / 'templates' / 'compose' / 'pipelines' +data_dir = Path(__file__).resolve().parent / 'data' + + +class Learn2RAGTestCase(TestCase): + openai_client: Any + project_name: str + rag_port: int + storage_path: Path + + def setUp(self) -> None: + self.project_name = 'test' + self.rag_port = 5002 + self.storage_path = Path(save_data_path('Learn2RAG', 'tests')) + self.storage_path.mkdir(parents=True, exist_ok=True) + self.openai_client = OpenAI( + api_key='', + base_url=f'http://localhost:{self.rag_port}', + max_retries=0, + ) + if project := Project.get(self.project_name): + if project.running: + project.stop() + project.remove() + + def tearDown(self) -> None: + if self.storage_path is not None: + shutil.rmtree(self.storage_path, ignore_errors=True) + if project := Project.get(self.project_name): + if project.running: + project.stop() + project.remove() + + def test_learn2rag(self) -> None: + template_context = { + 'is_windows': is_windows(), + 'learn2rag_path': Path('.').absolute(), + 'storage_path': self.storage_path, + 'ports': { + 'pipeline': self.rag_port, + }, + 'qdrant_api_key': '', + 'language_model': {'api': 'ChatFake'}, + 'pipeline': { + 'qdrant_path': self.storage_path / 'qdrant_persistence', + }, + 'import_config': { + 'loaders': [ + { + 'loader_id': 'local_test', + 'loader_type': 'DirectoryLoader', + 'recursive': 'True', + 'path': str(data_dir), + }, + ], + }, + } + + project = Project.create(template_dir / 'import.yml', self.project_name, template=True, template_context=template_context) + assert project is not None, 'project should not be None' + project.start() + assert project.running + + def check_import() -> None: + project = Project.get(self.project_name) + assert project is not None + assert not project.running + waitUntil(check_import, timeout=1 * 60 * 1000) + + project.remove() + + project = Project.create(template_dir / 'pipeline.yml', self.project_name, template=True, template_context=template_context) + assert project is not None, 'project should not be None' + project.start() + assert project.running + + def check_rag() -> None: + try: + completion = self.openai_client.chat.completions.create( + model='learn2rag', + messages=[ + {'role': 'user', 'content': f'What are rabbits?'}, + ], + ) + content = completion.choices[-1].message.content + assert 'for testing only' in content, 'contains test marker' + assert "Information:\\n" in content, 'contains the prompt' + assert not content.endswith("Information:\\n"), 'contains any document chunks in the prompt' + assert 'Lagomorpha' in content, 'specific text from a test file' + except APIConnectionError: + assert False + waitUntil(check_rag, timeout=1 * 60 * 1000) diff --git a/learn2rag/ui/__init__.py b/learn2rag/ui/__init__.py index bf830a2..b235a00 100644 --- a/learn2rag/ui/__init__.py +++ b/learn2rag/ui/__init__.py @@ -4,13 +4,10 @@ import logging import math import os -import platform -import xdg.BaseDirectory import secrets import shutil import signal import socket -import subprocess import threading import time from typing import Any @@ -20,7 +17,6 @@ from flask import Flask, flash, redirect as flask_redirect, render_template, request, make_response, url_for from flask_babel import Babel, gettext, ngettext, pgettext # type: ignore[import-untyped] import flask.logging -import jinja2 import ollama import uvicorn import yaml @@ -29,6 +25,12 @@ from learn2rag.compose import Project import learn2rag.data import learn2rag.pipeline.llm +from ..utils import ( + is_windows, + normalize_path, + open_web_browser, + save_data_path, +) from datetime import datetime # <-- ADD THIS @@ -37,10 +39,6 @@ logging.getLogger().setLevel(logging.DEBUG) -def expand_path(path: Path) -> Path: - return Path(path).expanduser().absolute() - - import werkzeug def redirect(url: str) -> 'werkzeug.wrappers.response.Response': if 'HX-Boosted' in request.headers: @@ -53,22 +51,19 @@ def redirect(url: str) -> 'werkzeug.wrappers.response.Response': def start_project(name: str, template_file: Path, storage_path: Path, render_context: dict[str, Any]={}) -> Project: logging.debug('UI starting project: %s', name) - storage_path = expand_path(storage_path) logging.debug('Storage path: %s', storage_path) - storage_path.mkdir(parents=True, exist_ok=True) - project_file = storage_path / 'compose.yml' - - template = jinja2.Template(template_file.read_text()) - project_file.write_text(template.render(render_context | { - 'is_windows': platform.system() == 'Windows', - 'learn2rag_path': Path('.').absolute(), - 'storage_path': storage_path, - })) project = None if project := Project.get(name): assert not project.running project.remove() - project = Project.create(project_file, name) + + storage_path.mkdir(parents=True, exist_ok=True) + + project = Project.create(template_file, name, template=True, template_context=render_context | { + 'is_windows': is_windows(), + 'learn2rag_path': Path('.').absolute(), + 'storage_path': storage_path, + }) assert project is not None, 'project should not be None' project.start() return project @@ -126,14 +121,9 @@ def merge(source: dict[str, Any], destination: dict[str, Any]) -> dict[str, Any] def create_app(config: dict[str, Any]={}) -> Flask: # create and configure the app - if platform.system() == 'Windows': - windows_app_data = os.getenv('LOCALAPPDATA') - assert windows_app_data is not None - default_instance_path = windows_app_data + '/Learn2RAG/instance' - else: - default_instance_path = xdg.BaseDirectory.save_data_path('Learn2RAG/instance') + default_instance_path = save_data_path('Learn2RAG', 'instance') - example_local_path = r'C:\Users\User\Documents' if platform.system() == 'Windows' else '/home/user/Documents' + example_local_path = r'C:\Users\User\Documents' if is_windows() else '/home/user/Documents' app = Flask( __name__, instance_path=config.get('flask', {}).get('instance_path', default_instance_path), @@ -220,7 +210,7 @@ def inject_current_year() -> dict[str, Any]: def remove_pipeline_storage_directory(storage_path: Path) -> bool: try: - storage_path = expand_path(storage_path) + storage_path = normalize_path(storage_path) shutil.rmtree(storage_path) flash(pgettext('flash', 'Directory removed: %(path)s', path=storage_path)) except FileNotFoundError: @@ -281,7 +271,7 @@ def model_create() -> 'str | werkzeug.wrappers.response.Response': if request.form.get('ollama') == 'pull': if model.find(':') == -1: model += ':latest' - start_project('ollama_download', components_template_path / 'ollama-download.yml', Path(), {'model': model}) + start_project('ollama_download', components_template_path / 'ollama-download.yml', Path(app.instance_path) / 'ollama_download', {'model': model}) return flask_redirect(url_for('model_pulling', model=model)) elif api == learn2rag.pipeline.llm.OpenAIClient.ID: url = request.form['url'] @@ -423,7 +413,7 @@ def start_pipeline(name: str, pipeline: dict[str, Any], template_name: str) -> N sources = learn2rag.data.get_entries(app.instance_path, 'sources', pipeline['sources']) for path_name, source in sources.items(): if 'path' in source: - source['path'] = str(expand_path(source['path'])) + source['path'] = str(normalize_path(source['path'])) # Fetch the language model configuration first let see if it works language_model = learn2rag.data.get_entry(app.instance_path, 'models', pipeline['language_model']) @@ -475,7 +465,7 @@ def start_pipeline(name: str, pipeline: dict[str, Any], template_name: str) -> N ports = find_free_ports(len(port_names), configured_ports=configured_ports, preferred_ports=app.config.get('PREFERRED_PORTS', range(9001, 9011))) render_context['ports'] = dict(zip(port_names, ports)) - storage_path = Path(pipeline['storage_path']) + storage_path = normalize_path(pipeline['storage_path']) try: project = start_project(name, template_file, storage_path, render_context) @@ -520,7 +510,7 @@ def pipeline_logs(name: str, file: str) -> 'str | werkzeug.wrappers.response.Res if pipeline is None: flash(pgettext('flash', 'The requested pipeline is not found'), 'error') elif file in ['debug.log', 'error.log']: - storage_path = expand_path(pipeline['storage_path']) + storage_path = normalize_path(pipeline['storage_path']) log_file = storage_path / 'logs' / file try: content = log_file.read_text() @@ -578,18 +568,6 @@ def shutdown() -> None: os.kill(os.getpid(), signal.SIGTERM) -def webbrowser_open(url: str) -> None: - try: - if platform.system() == 'Windows': - subprocess.Popen(['explorer', url]) - else: - subprocess.Popen(['xdg-open', url]) - except FileNotFoundError: - pass - except Exception as e: - print(e) - - def main(config: dict[str, Any]) -> None: app = create_app(config=config) @@ -616,7 +594,7 @@ def main(config: dict[str, Any]) -> None: protocol = 'https' if use_https else 'http' url = f"{protocol}://localhost:{port}" - webbrowser_open(url) + open_web_browser(url) logging.info('*' * 40) logging.info('Learn2RAG: ' + url) logging.info('*' * 40) diff --git a/learn2rag/ui/templates/compose/pipelines/continuous.yml b/learn2rag/ui/templates/compose/pipelines/continuous.yml index 8ec4bbb..bd4254f 100644 --- a/learn2rag/ui/templates/compose/pipelines/continuous.yml +++ b/learn2rag/ui/templates/compose/pipelines/continuous.yml @@ -102,6 +102,7 @@ files: content: '' services: + #!!! {% if ports.ui %} open-webui: working_dir: '{{storage_path}}' command: @@ -161,12 +162,15 @@ services: healthcheck: # TODO: We only support ['CMD', 'curl', '-f', ...] test: ['CMD', 'curl', '-f', '{{learn2rag_scheme}}://localhost:{{ports.ui}}/health'] + #!!! {% endif %} + #!!! {% if not pipeline.qdrant_location and not pipeline.qdrant_path %} qdrant: working_dir: '{{storage_path}}' command: - '{{learn2rag_path}}/services/qdrant/qdrant{% if is_windows %}.exe{% endif %}' - '--config-path' - '{{storage_path}}/qdrant_config.yml' + #!!! {% endif %} import: working_dir: '{{storage_path}}' command: @@ -203,7 +207,7 @@ services: environment: LEARN2RAG_PATH: '{{learn2rag_path}}' LEARN2RAG_PIPELINE_PORT: '{{ports.pipeline}}' - QDRANT__SERVICE__HTTP_PORT: '{{ports.qdrant_http}}' + QDRANT_LOCATION: '{{ pipeline.qdrant_location or "http://localhost:" ~ ports.qdrant_http }}' QDRANT__SERVICE__API_KEY: '{{qdrant_api_key}}' PIPELINE_USER_CONFIG: '{{storage_path}}/basic_user_config.json' IMPORTER_CONFIG: '{{storage_path}}/importer_config.json' diff --git a/learn2rag/ui/templates/compose/pipelines/import.yml b/learn2rag/ui/templates/compose/pipelines/import.yml index 29acb44..fe9302b 100644 --- a/learn2rag/ui/templates/compose/pipelines/import.yml +++ b/learn2rag/ui/templates/compose/pipelines/import.yml @@ -99,12 +99,14 @@ files: content: '' services: + #!!! {% if not pipeline.qdrant_location and not pipeline.qdrant_path %} qdrant: working_dir: '{{storage_path}}' command: - '{{learn2rag_path}}/services/qdrant/qdrant{% if is_windows %}.exe{% endif %}' - '--config-path' - '{{storage_path}}/qdrant_config.yml' + #!!! {% endif %} main: working_dir: '{{storage_path}}' command: @@ -118,7 +120,8 @@ services: environment: LEARN2RAG_PATH: '{{learn2rag_path}}' STORAGE_PATH: '{{storage_path}}' - QDRANT__SERVICE__HTTP_PORT: '{{ports.qdrant_http}}' + QDRANT_LOCATION: '{{ pipeline.qdrant_location or "http://localhost:" ~ ports.qdrant_http }}' + QDRANT_PATH: '{{ pipeline.qdrant_path }}' QDRANT__SERVICE__API_KEY: '{{qdrant_api_key}}' PIPELINE_USER_CONFIG: '{{storage_path}}/basic_user_config.json' IMPORTER_CONFIG: '{{storage_path}}/importer_config.json' diff --git a/learn2rag/ui/templates/compose/pipelines/pipeline.yml b/learn2rag/ui/templates/compose/pipelines/pipeline.yml index 6460ead..1c77961 100644 --- a/learn2rag/ui/templates/compose/pipelines/pipeline.yml +++ b/learn2rag/ui/templates/compose/pipelines/pipeline.yml @@ -14,7 +14,7 @@ files: service: api_key: '{{qdrant_api_key}}' grpc_port: null - http_port: '{{ports.qdrant_http}}' + http_port: {{ports.qdrant_http}} host: '127.0.0.1' telemetry_disabled: true - path: '{{storage_path}}/basic_user_config.json' @@ -100,6 +100,7 @@ files: content: '' services: + #!!! {% if ports.ui %} open-webui: working_dir: '{{storage_path}}' command: @@ -159,12 +160,15 @@ services: healthcheck: # TODO: We only support ['CMD', 'curl', '-f', ...] test: ['CMD', 'curl', '-f', '{{learn2rag_scheme}}://localhost:{{ports.ui}}/health'] + #!!! {% endif %} + #!!! {% if not pipeline.qdrant_location and not pipeline.qdrant_path %} qdrant: working_dir: '{{storage_path}}' command: - '{{learn2rag_path}}/services/qdrant/qdrant{% if is_windows %}.exe{% endif %}' - '--config-path' - '{{storage_path}}/qdrant_config.yml' + #!!! {% endif %} main: working_dir: '{{storage_path}}' command: @@ -175,7 +179,8 @@ services: environment: LEARN2RAG_PATH: '{{learn2rag_path}}' LEARN2RAG_PIPELINE_PORT: '{{ports.pipeline}}' - QDRANT__SERVICE__HTTP_PORT: '{{ports.qdrant_http}}' + QDRANT_LOCATION: '{{ pipeline.qdrant_location or "http://localhost:" ~ ports.qdrant_http }}' + QDRANT_PATH: '{{ pipeline.qdrant_path }}' QDRANT__SERVICE__API_KEY: '{{qdrant_api_key}}' PIPELINE_USER_CONFIG: '{{storage_path}}/basic_user_config.json' IMPORTER_CONFIG: '{{storage_path}}/importer_config.json' diff --git a/learn2rag/ui/templates/models_add.html b/learn2rag/ui/templates/models_add.html index f17d093..15fce6e 100644 --- a/learn2rag/ui/templates/models_add.html +++ b/learn2rag/ui/templates/models_add.html @@ -16,12 +16,14 @@
{{ model.api }}
+ {% endif %}
+ {% else %}
+ {{ model.api }}
⚠️
{% endif %}