Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@ __pycache__
/config.yml
#generate during install step
compose.db
compose.yml
/learn2rag/pipeline/data/loaded_documents_kcenter.json
/.env
/.local.envrc
Expand Down
16 changes: 13 additions & 3 deletions learn2rag/compose/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import urllib.request
from typing import Any, Optional

import jinja2
import psutil
import yaml

Expand Down Expand Up @@ -109,9 +110,18 @@ class Project():
content: dict[str, Any]

@staticmethod
def create(project_file: str | Path, name: str) -> 'Project | None':
with open(project_file) as f:
content = yaml.safe_load(f)
def create(
compose_file: str | Path,
name: str,
*,
template: bool = False,
template_context: dict[str, Any] = {},
) -> 'Project | None':
if template:
content = yaml.safe_load(jinja2.Template(Path(compose_file).read_text()).render(template_context))
else:
with open(compose_file) as f:
content = yaml.safe_load(f)
assert len(content['services']) > 0
cur = con.cursor()
cur.execute('BEGIN EXCLUSIVE')
Expand Down
56 changes: 53 additions & 3 deletions learn2rag/pipeline/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,41 @@
import os
from pydantic import SecretStr
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import AIMessage, BaseMessage, SystemMessage
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_ollama import ChatOllama
from langchain_openai import ChatOpenAI
from typing import Any, ClassVar


logger = logging.getLogger(__name__)


class LLMClient():
# ID is used as a key to store in user data, should not be changed
ID: str
# LABEL is a display label for user interface
LABEL: str
'''A key stored in user data, must not be changed'''

LABEL: str | None
'''
A display label for the interface.
If None, the option would be excluded from the interface.
'''

chat_model: BaseChatModel


llms = {}
'''A dict holding supported LLM client classes'''


def llm_client(cls: type[LLMClient]) -> type[LLMClient]:
llms[cls.ID] = cls; return cls


# First @llm_client would be the default in UI when adding an external model
@llm_client
class OpenAIClient(LLMClient):
'''A LLM client based on OpenAI API'''
ID = 'ChatOpenAI'
LABEL = 'OpenAI'

Expand All @@ -39,6 +51,7 @@ def __init__(self, *, url: str, token: SecretStr, model: str, proxy: str | None)

@llm_client
class OllamaClient(LLMClient):
'''A LLM client based on Ollama API'''
ID = 'ChatOllama'
LABEL = 'Ollama'

Expand All @@ -54,7 +67,44 @@ def __init__(self, *, url: str, token: str | None, model: str, proxy: str | None
)


class TestFakeChatModel(BaseChatModel):
'''
A mock BaseChatModel implementation.
Responds with the full content of the system prompt.
'''
hint: ClassVar[str] = 'This is an internal model used for testing only.'

@property
def _llm_type(self) -> str: return 'test_fake_chat_model'

def _generate(
self,
messages: list[BaseMessage],
stop: list[str] | None = None,
run_manager: Any = None,
**kwargs: Any
) -> ChatResult:
assert isinstance(messages[0], SystemMessage)
content = f'{self.hint} {messages[0].content}'
return ChatResult(
generations=[
ChatGeneration(message=AIMessage(content=content)),
],
)


@llm_client
class FakeClient(LLMClient):
'''A mock LLM client to use only in tests'''
ID = 'ChatFake'
LABEL = None

def __init__(self, *, url: str, token: str | None, model: str, proxy: str | None) -> None:
self.chat_model = TestFakeChatModel()


def chat_model_from_env() -> BaseChatModel:
'''Returns an instance of LLM client based on the environment variables'''
default_llm = OpenAIClient
llm_id = os.environ.get('LLM_API_TYPE', default_llm.ID)
logger.debug('Using LLM: %s', llm_id)
Expand Down
11 changes: 7 additions & 4 deletions learn2rag/pipeline/qdrant.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,13 +6,16 @@

from .config import user_config

api_key = os.environ.get('QDRANT__SERVICE__API_KEY')
path = os.environ.get('QDRANT_PATH') or None
location = None if path else os.environ.get('QDRANT_LOCATION', 'http://localhost:6336')


class Qdrant:
client = QdrantClient(
host="localhost",
port=int(os.environ.get('QDRANT__SERVICE__HTTP_PORT', 6336)),
api_key=os.environ.get('QDRANT__SERVICE__API_KEY'),
https=False,
location=location,
api_key=api_key,
path=path,
)

def __init__(self, collection_name: str, opt_config: dict[str, Any]) -> None:
Expand Down
Empty file added learn2rag/tests/__init__.py
Empty file.
7 changes: 7 additions & 0 deletions learn2rag/tests/data/rabbits.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
Rabbits or bunnies are small mammals in the family Leporidae (which also includes the hares), which is in the order Lagomorpha (which also includes pikas). They are familiar throughout the world as a small herbivore, a prey animal, a domesticated form of livestock, and a pet, having a widespread effect on ecologies and cultures. The most widespread rabbit genera are Oryctolagus and Sylvilagus. The former, Oryctolagus, includes the European rabbit, Oryctolagus cuniculus, which is the ancestor of the hundreds of breeds of domestic rabbit and has been introduced on every continent except Antarctica. The latter, Sylvilagus, includes over 13 wild rabbit species, among them the cottontails and tapetis. Wild rabbits not included in Oryctolagus and Sylvilagus include several species of limited distribution, including the pygmy rabbit, volcano rabbit, and Sumatran striped rabbit.

Rabbits are a paraphyletic grouping, and do not constitute a clade, as hares (belonging to the genus Lepus) are nested within the Leporidae clade and are not described as rabbits. Although once considered rodents, lagomorphs diverged earlier and have a number of traits rodents lack, including two extra incisors. Similarities between rabbits and rodents were once attributed to convergent evolution, but studies in molecular biology have found a common ancestor between lagomorphs and rodents and place them in the clade Glires.

Rabbit physiology is suited to escaping predators and surviving in various habitats, living either alone or in groups in nests or burrows. As prey animals, rabbits are constantly aware of their surroundings, having a wide field of vision and ears with high surface area to detect potential predators. The ears of a rabbit are essential for thermoregulation and contain a high density of blood vessels. The bone structure of a rabbit's hind legs, which is longer than that of the fore legs, allows for quick hopping, which is beneficial for escaping predators and can provide powerful kicks if captured. Rabbits are typically nocturnal and often sleep with their eyes open. They reproduce quickly, having short pregnancies, large litters of four to twelve kits, and no particular mating season; however, the mortality rate of rabbit embryos is high, and there exist several widespread diseases that affect rabbits, such as rabbit hemorrhagic disease and myxomatosis. In some regions, especially Australia, rabbits have caused ecological problems and are regarded as a pest.

Humans have used rabbits as livestock since at least the first century BC in ancient Rome, raising them for their meat, fur and wool. The various breeds of the European rabbit have been developed to suit each of these products; the practice of raising and breeding rabbits as livestock is known as cuniculture. Rabbits are seen in human culture globally, appearing as a symbol of fertility, cunning, and innocence in major religions, historical and contemporary art.
102 changes: 102 additions & 0 deletions learn2rag/tests/test_learn2rag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
import shutil
from pathlib import Path
from unittest import TestCase
from typing import Any

from ..compose import Project
from ..utils import is_windows, save_data_path, waitUntil

from openai import APIConnectionError, OpenAI

template_dir = Path(__file__).resolve().parent.parent / 'ui' / 'templates' / 'compose' / 'pipelines'
data_dir = Path(__file__).resolve().parent / 'data'


class Learn2RAGTestCase(TestCase):
openai_client: Any
project_name: str
rag_port: int
storage_path: Path

def setUp(self) -> None:
self.project_name = 'test'
self.rag_port = 5002
self.storage_path = Path(save_data_path('Learn2RAG', 'tests'))
self.storage_path.mkdir(parents=True, exist_ok=True)
self.openai_client = OpenAI(
api_key='',
base_url=f'http://localhost:{self.rag_port}',
max_retries=0,
)
if project := Project.get(self.project_name):
if project.running:
project.stop()
project.remove()

def tearDown(self) -> None:
if self.storage_path is not None:
shutil.rmtree(self.storage_path, ignore_errors=True)
if project := Project.get(self.project_name):
if project.running:
project.stop()
project.remove()

def test_learn2rag(self) -> None:
template_context = {
'is_windows': is_windows(),
'learn2rag_path': Path('.').absolute(),
'storage_path': self.storage_path,
'ports': {
'pipeline': self.rag_port,
},
'qdrant_api_key': '',
'language_model': {'api': 'ChatFake'},
'pipeline': {
'qdrant_path': self.storage_path / 'qdrant_persistence',
},
'import_config': {
'loaders': [
{
'loader_id': 'local_test',
'loader_type': 'DirectoryLoader',
'recursive': 'True',
'path': str(data_dir),
},
],
},
}

project = Project.create(template_dir / 'import.yml', self.project_name, template=True, template_context=template_context)
assert project is not None, 'project should not be None'
project.start()
assert project.running

def check_import() -> None:
project = Project.get(self.project_name)
assert project is not None
assert not project.running
waitUntil(check_import, timeout=1 * 60 * 1000)

project.remove()

project = Project.create(template_dir / 'pipeline.yml', self.project_name, template=True, template_context=template_context)
assert project is not None, 'project should not be None'
project.start()
assert project.running

def check_rag() -> None:
try:
completion = self.openai_client.chat.completions.create(
model='learn2rag',
messages=[
{'role': 'user', 'content': f'What are rabbits?'},
],
)
content = completion.choices[-1].message.content
assert 'for testing only' in content, 'contains test marker'
assert "Information:\\n" in content, 'contains the prompt'
assert not content.endswith("Information:\\n"), 'contains any document chunks in the prompt'
assert 'Lagomorpha' in content, 'specific text from a test file'
except APIConnectionError:
assert False
waitUntil(check_rag, timeout=1 * 60 * 1000)
Loading
Loading