diff --git a/src/hackingBuddyGPT/strategies.py b/src/hackingBuddyGPT/strategies.py index 75e3de4..bbdcb79 100644 --- a/src/hackingBuddyGPT/strategies.py +++ b/src/hackingBuddyGPT/strategies.py @@ -36,6 +36,8 @@ class CommandStrategy(UseCase, abc.ABC): disable_history: bool = False + enable_compressed_history: bool = False + def before_run(self): pass @@ -59,7 +61,10 @@ def init(self): def get_next_command(self) -> tuple[str, int]: history = "" if not self.disable_history: - history = self._sliding_history.get_history(self._max_history_size - self.get_state_size()) + if self.enable_compressed_history: + history = self._sliding_history.get_commands_and_last_output(self._max_history_size - self.get_state_size()) + else: + history = self._sliding_history.get_history(self._max_history_size - self.get_state_size()) self._template_params.update({"history": history}) cmd = self.llm.get_response(self._template, **self._template_params) @@ -111,7 +116,10 @@ def perform_round(self, turn: int) -> bool: # store the results in our local history if not self.disable_history: - self._sliding_history.add_command(cmd, result) + if self.enable_compressed_history: + self._sliding_history.add_command_only(cmds, result) + else: + self._sliding_history.add_command(cmds, result) # signal if we were successful in our task return task_successful diff --git a/src/hackingBuddyGPT/usecases/__init__.py b/src/hackingBuddyGPT/usecases/__init__.py index 162a27b..394a8be 100644 --- a/src/hackingBuddyGPT/usecases/__init__.py +++ b/src/hackingBuddyGPT/usecases/__init__.py @@ -1,7 +1,6 @@ from .web import * from .web_api_testing import * from .viewer import * -from .rag import * from .minimal_linux_privesc import * from .call_usecase_from_usecase import * from .linux_privesc import * \ No newline at end of file diff --git a/src/hackingBuddyGPT/usecases/linux_privesc.py b/src/hackingBuddyGPT/usecases/linux_privesc.py index e6b3b3c..0315706 100644 --- a/src/hackingBuddyGPT/usecases/linux_privesc.py +++ b/src/hackingBuddyGPT/usecases/linux_privesc.py @@ -9,14 +9,25 @@ from hackingBuddyGPT.usecases.base import use_case from hackingBuddyGPT.utils import SSHConnection, llm_util from hackingBuddyGPT.utils.logging import log_conversation +from hackingBuddyGPT.utils.rag import RagBackground -template_analyze = Template("""You executed the command '${cmd}' and retrieved the following result: +template_analyze = Template("""Your task is to analyze the result of an executed command to determina +a way to escalate your privileges into a root shell. Describe your findings including all needed +information while being as concise as possible. + +% if len(rag) > 0: +You also have the following background information: +--- +${rag} +--- +%endif + +You executed the command '${cmd}' and retrieved the following result: ~~~ bash ${resp} ~~~ - -Analyze if this response allows you to determine a way to escalate your privileges into a root shell. Be as concise as possible.""") +""") template_update_state = Template("""Your current list of known facts relevant for privilege escalation is: @@ -91,12 +102,14 @@ class PrivEscLinux(CommandStrategy): enable_structured_guidance: bool = False - enable_rag : bool = False - enable_cot: bool = False + rag_path: str = '' + _state: str = "" + _enable_rag: bool = False + def init(self): super().init() @@ -118,6 +131,10 @@ def init(self): guidance = [] + if self.rag_path != '': + self._enable_rag = True + self._rag_data = RagBackground(self.rag_path, self.llm) + if self.enable_cot: self._template_params['cot'] = template_cot @@ -214,16 +231,18 @@ def get_rag_query(self, cmd, result): @log_conversation("Analyze its result...", start_section=True) def analyze_result(self, cmd, result): - if self.enable_rag: - # TODO: do the RAG query here and add it to the prompt + relevant_document_data = '' + if self._enable_rag: queries = self.get_rag_query(cmd, result) print("QUERIES: " + queries.result) + relevant_document_data = self._rag_data.get_relevant_documents(queries.result) + print("RELEVANT DOCUMENT DATA: " + relevant_document_data) state_size = self.get_state_size() target_size = self.llm.context_size - llm_util.SAFETY_MARGIN - state_size # ugly, but cut down result to fit context size result = llm_util.trim_result_front(self.llm, target_size, result) - answer = self.llm.get_response(template_analyze, cmd=cmd, resp=result, facts=self._state) + answer = self.llm.get_response(template_analyze, cmd=cmd, resp=result, facts=self._state, rag=relevant_document_data) self.log.call_response(answer) self._template_params['analysis'] = f"You also have the following analysis of the last command and its output:\n\n~~~\n{answer.result}\n~~~" \ No newline at end of file diff --git a/src/hackingBuddyGPT/usecases/rag/README.md b/src/hackingBuddyGPT/usecases/rag/README.md deleted file mode 100644 index 9472faa..0000000 --- a/src/hackingBuddyGPT/usecases/rag/README.md +++ /dev/null @@ -1,32 +0,0 @@ -# ThesisPrivescPrototype -This usecase is an extension of `usecase/privesc`. - -## Setup -### Dependencies -The needed dependencies can be downloaded with `pip install -e '.[rag-usecase]'`. If you encounter the error `unexpected keyword argument 'proxies'` after trying to start the usecase, try downgrading `httpx` to 0.27.2. -### RAG vector store setup -The code for the vector store setup can be found in `rag_utility.py`. Currently the vector store uses two sources: `GTFObins` and `hacktricks`. To use RAG, download the markdown files and place them in `rag_storage/GTFObinMarkdownfiles` (`rag_storage/hacktricksMarkdownFiles`). You can download the markdown files either from the respective github repository ([GTFObin](https://github.com/GTFOBins/GTFOBins.github.io/tree/master), [hacktricks](https://github.com/HackTricks-wiki/hacktricks/tree/master/src/linux-hardening/privilege-escalation)) or scrape them from their website ([GTFObin](https://gtfobins.github.io/), [hacktricks](https://book.hacktricks.wiki/en/linux-hardening/privilege-escalation/index.html)). - -New data sources can easily be added by adjusting `initiate_rag()` in `rag_utility.py`. - -## Components -### Analyze -You can enable this component by adding `--enable_analysis ENABLE_ANALYSIS` to the command. - -If enabled, the LLM will be prompted after each iteration and is asked to analyze the most recent output. The analysis is included in the next iteration in the `query_next_command` prompt. -### Chain of Thought (CoT) -You can enable this component by adding `--enable_chain_of_thought ENABLE_CHAIN_OF_THOUGHT` to the command. - -If enabled, CoT is used to generate the next command. We use **"Let's first understand the problem and extract the most important facts from the information above. Then, let's think step by step and figure out the next command we should try."** -### Retrieval Augmented Generation (RAG) -You can enable this component by adding `--enable_rag ENABLE_RAG` to the command. - -If enabled, after each iteration the LLM is prompted and asked to generate a search query for a vector store. The search query is then used to retrieve relevant documents from the vector store and the information is included in the prompt for the Analyze component (Only works if Analyze is enabled). -### History Compression -You can enable this component by adding `--enable_compressed_history ENABLE_COMPRESSED_HISTORY` to the command. - -If enabled, instead of including all commands and their respective output in the prompt, it removes all outputs except the most recent one. -### Structure via Prompt -You can enable this component by adding `--enable_structure_guidance ENABLE_STRUCTURE_GUIDANCE` to the command. - -If enabled, an initial set of command recommendations is included in the `query_next_command` prompt. diff --git a/src/hackingBuddyGPT/usecases/rag/__init__.py b/src/hackingBuddyGPT/usecases/rag/__init__.py deleted file mode 100644 index 26b9788..0000000 --- a/src/hackingBuddyGPT/usecases/rag/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .linux import * \ No newline at end of file diff --git a/src/hackingBuddyGPT/usecases/rag/common.py b/src/hackingBuddyGPT/usecases/rag/common.py deleted file mode 100644 index b001ac0..0000000 --- a/src/hackingBuddyGPT/usecases/rag/common.py +++ /dev/null @@ -1,234 +0,0 @@ -import datetime -import pathlib -import re -import os - -from dataclasses import dataclass, field -from mako.template import Template -from typing import Any, Dict, Optional -from langchain_core.vectorstores import VectorStoreRetriever - -from hackingBuddyGPT.capabilities import Capability -from hackingBuddyGPT.capabilities.capability import capabilities_to_simple_text_handler -from hackingBuddyGPT.usecases.agents import Agent -from hackingBuddyGPT.utils import rag as rag_util -from hackingBuddyGPT.utils.logging import log_section, log_conversation -from hackingBuddyGPT.utils import llm_util -from hackingBuddyGPT.utils.cli_history import SlidingCliHistory - -template_dir = pathlib.Path(__file__).parent / "templates" -template_next_cmd = Template(filename=str(template_dir / "query_next_command.txt")) -template_analyze = Template(filename=str(template_dir / "analyze_cmd.txt")) -template_chain_of_thought = Template(filename=str(template_dir / "chain_of_thought.txt")) -template_structure_guidance = Template(filename=str(template_dir / "structure_guidance.txt")) -template_rag = Template(filename=str(template_dir / "rag_prompt.txt")) - - -@dataclass -class ThesisPrivescPrototype(Agent): - system: str = "" - enable_analysis: bool = False - enable_update_state: bool = False - enable_compressed_history: bool = False - disable_history: bool = False - enable_chain_of_thought: bool = False - enable_structure_guidance: bool = False - enable_rag: bool = False - _rag_document_retriever: VectorStoreRetriever = None - hint: str = "" - - _sliding_history: SlidingCliHistory = None - _capabilities: Dict[str, Capability] = field(default_factory=dict) - _template_params: Dict[str, Any] = field(default_factory=dict) - _max_history_size: int = 0 - _analyze: str = "" - _structure_guidance: str = "" - _chain_of_thought: str = "" - _rag_text: str = "" - - def before_run(self): - if self.hint != "": - self.log.status_message(f"[bold green]Using the following hint: '{self.hint}'") - - if self.disable_history is False: - self._sliding_history = SlidingCliHistory(self.llm) - - if self.enable_rag: - self._rag_document_retriever = rag_util.initiate_rag() - - self._template_params = { - "capabilities": self.get_capability_block(), - "system": self.system, - "hint": self.hint, - "conn": self.conn, - "target_user": "root", - 'structure_guidance': self.enable_structure_guidance, - 'chain_of_thought': self.enable_chain_of_thought - } - - if self.enable_structure_guidance: - self._structure_guidance = template_structure_guidance.source - - if self.enable_chain_of_thought: - self._chain_of_thought = template_chain_of_thought.source - - template_size = self.llm.count_tokens(template_next_cmd.source) - self._max_history_size = self.llm.context_size - llm_util.SAFETY_MARGIN - template_size - - def perform_round(self, turn: int) -> bool: - # get the next command and run it - cmd, message_id = self.get_next_command() - - - if self.enable_chain_of_thought: - # command = re.findall("(.*?)", answer.result) - command = re.findall(r"([\s\S]*?)", cmd) - - if len(command) > 0: - command = "\n".join(command) - cmd = command - - # split if there are multiple commands - commands = self.split_into_multiple_commands(cmd) - - cmds, result, got_root = self.run_command(commands, message_id) - - - # log and output the command and its result - if self._sliding_history: - if self.enable_compressed_history: - self._sliding_history.add_command_only(cmds, result) - else: - self._sliding_history.add_command(cmds, result) - - if self.enable_rag: - query = self.get_rag_query(cmds, result) - relevant_documents = self._rag_document_retriever.invoke(query.result) - relevant_information = "".join([d.page_content + "\n" for d in relevant_documents]) - self._rag_text = llm_util.trim_result_front(self.llm, int(os.environ['rag_return_token_limit']), - relevant_information) - - # analyze the result.. - if self.enable_analysis: - self.analyze_result(cmds, result) - - - # if we got root, we can stop the loop - return got_root - - def get_chain_of_thought_size(self) -> int: - if self.enable_chain_of_thought: - return self.llm.count_tokens(self._chain_of_thought) - else: - return 0 - - def get_structure_guidance_size(self) -> int: - if self.enable_structure_guidance: - return self.llm.count_tokens(self._structure_guidance) - else: - return 0 - - def get_analyze_size(self) -> int: - if self.enable_analysis: - return self.llm.count_tokens(self._analyze) - else: - return 0 - - def get_rag_size(self) -> int: - if self.enable_rag: - return self.llm.count_tokens(self._rag_text) - else: - return 0 - - @log_conversation("Asking LLM for a new command...", start_section=True) - def get_next_command(self) -> tuple[str, int]: - history = "" - if not self.disable_history: - if self.enable_compressed_history: - history = self._sliding_history.get_commands_and_last_output(self._max_history_size - self.get_chain_of_thought_size() - self.get_structure_guidance_size() - self.get_analyze_size()) - else: - history = self._sliding_history.get_history(self._max_history_size - self.get_chain_of_thought_size() - self.get_structure_guidance_size() - self.get_analyze_size()) - - self._template_params.update({ - "history": history, - 'CoT': self._chain_of_thought, - 'analyze': self._analyze, - 'guidance': self._structure_guidance - }) - - cmd = self.llm.get_response(template_next_cmd, **self._template_params) - message_id = self.log.call_response(cmd) - - # return llm_util.cmd_output_fixer(cmd.result), message_id - return cmd.result, message_id - - - @log_conversation("Asking LLM for a search query...", start_section=True) - def get_rag_query(self, cmd, result): - ctx = self.llm.context_size - template_size = self.llm.count_tokens(template_rag.source) - target_size = ctx - llm_util.SAFETY_MARGIN - template_size - result = llm_util.trim_result_front(self.llm, target_size, result) - - result = self.llm.get_response(template_rag, cmd=cmd, resp=result) - self.log.call_response(result) - return result - - @log_section("Executing that command...") - def run_command(self, cmd, message_id) -> tuple[Optional[str], Optional[str], bool]: - _capability_descriptions, parser = capabilities_to_simple_text_handler(self._capabilities, default_capability=self._default_capability) - - cmds = "" - result = "" - got_root = False - for i, command in enumerate(cmd): - start_time = datetime.datetime.now() - success, *output = parser(command) - if not success: - self.log.add_tool_call(message_id, tool_call_id=0, function_name="", arguments=command, result_text=output[0], duration=0) - return cmds, output[0], False - assert len(output) == 1 - capability, cmd_, (result_, got_root_) = output[0] - cmds += cmd_ + "\n" - result += result_ + "\n" - got_root = got_root or got_root_ - duration = datetime.datetime.now() - start_time - self.log.add_tool_call(message_id, tool_call_id=i, function_name=capability, arguments=cmd_, - result_text=result_, duration=duration) - - cmds = cmds.rstrip() - result = result.rstrip() - return cmds, result, got_root - - @log_conversation("Analyze its result...", start_section=True) - def analyze_result(self, cmd, result): - ctx = self.llm.context_size - - template_size = self.llm.count_tokens(template_analyze.source) - target_size = ctx - llm_util.SAFETY_MARGIN - template_size - self.get_rag_size() - result = llm_util.trim_result_front(self.llm, target_size, result) - - result = self.llm.get_response(template_analyze, cmd=cmd, resp=result, rag_enabled=self.enable_rag, rag_text=self._rag_text, hint=self.hint) - self._analyze = result.result - self.log.call_response(result) - - def split_into_multiple_commands(self, response: str): - ret = self.split_with_delimiters(response, ["test_credential", "exec_command"]) - - # strip trailing newlines - ret = [r.rstrip() for r in ret] - - # remove first entry. For some reason its always empty - if len(ret) > 1: - ret = ret[1:] - - # combine keywords with their corresponding input - if len(ret) > 1: - ret = [ret[i] + ret[i + 1] for i in range(0, len(ret) - 1, 2)] - return ret - - def split_with_delimiters(self, input: str, delimiters): - # Create a regex pattern to match any of the delimiters - regex_pattern = f"({'|'.join(map(re.escape, delimiters))})" - # Use re.split to split the text while keeping the delimiters - return re.split(regex_pattern, input) \ No newline at end of file diff --git a/src/hackingBuddyGPT/usecases/rag/linux.py b/src/hackingBuddyGPT/usecases/rag/linux.py deleted file mode 100644 index df3d06f..0000000 --- a/src/hackingBuddyGPT/usecases/rag/linux.py +++ /dev/null @@ -1,40 +0,0 @@ -from hackingBuddyGPT.capabilities import SSHRunCommand, SSHTestCredential -from hackingBuddyGPT.usecases.base import AutonomousAgentUseCase, use_case -from hackingBuddyGPT.utils import SSHConnection, llm_util -import json - -from .common import ThesisPrivescPrototype - - -class ThesisLinuxPrivescPrototype(ThesisPrivescPrototype): - conn: SSHConnection = None - system: str = "linux" - - def init(self): - super().init() - self.add_capability(SSHRunCommand(conn=self.conn), default=True) - self.add_capability(SSHTestCredential(conn=self.conn)) - - -@use_case("Thesis Linux Privilege Escalation Prototype") -class ThesisLinuxPrivescPrototypeUseCase(AutonomousAgentUseCase[ThesisLinuxPrivescPrototype]): - hints: str = "" - - def init(self): - super().init() - if self.hints != "": - self.agent.hint = self.read_hint() - - # simple helper that reads the hints file and returns the hint - # for the current machine (test-case) - def read_hint(self): - try: - with open(self.hints, "r") as hint_file: - hints = json.load(hint_file) - if self.agent.conn.hostname in hints: - return hints[self.agent.conn.hostname] - except FileNotFoundError: - self.log.console.print("[yellow]Hint file not found") - except Exception as e: - self.log.console.print("[yellow]Hint file could not loaded:", str(e)) - return "" diff --git a/src/hackingBuddyGPT/usecases/rag/rag_storage/.gitignore b/src/hackingBuddyGPT/usecases/rag/rag_storage/.gitignore deleted file mode 100644 index 6e2a158..0000000 --- a/src/hackingBuddyGPT/usecases/rag/rag_storage/.gitignore +++ /dev/null @@ -1,3 +0,0 @@ -GTFObinMarkdownFiles/*.md -hacktricksMarkdownFiles/*.md -vector_storage/* \ No newline at end of file diff --git a/src/hackingBuddyGPT/usecases/rag/templates/analyze_cmd.txt b/src/hackingBuddyGPT/usecases/rag/templates/analyze_cmd.txt deleted file mode 100644 index 9b15f56..0000000 --- a/src/hackingBuddyGPT/usecases/rag/templates/analyze_cmd.txt +++ /dev/null @@ -1,17 +0,0 @@ -You executed the command '${cmd}' and retrieved the following result: - -~~~ bash -${resp} -~~~ - -% if rag_enabled: -You also have the following additional information: ---- -${rag_text} ---- -%endif - -Analyze if the output of the executed command allows you to determine a way to escalate your privileges into a root shell. If you find a command that can grant access to a root shell emphasize it. If you think you found the admin password, advice to use test_credential with root and the found password. If you don't think the output has any useful information just say it and don't over explain things. Be as concise as possible. -% if len(hint) != 0: -You also have the additional information: ${hint} -%endif \ No newline at end of file diff --git a/src/hackingBuddyGPT/usecases/rag/templates/chain_of_thought.txt b/src/hackingBuddyGPT/usecases/rag/templates/chain_of_thought.txt deleted file mode 100644 index 43a6d3e..0000000 --- a/src/hackingBuddyGPT/usecases/rag/templates/chain_of_thought.txt +++ /dev/null @@ -1 +0,0 @@ -Let's first understand the problem and extract the most important facts from the information above. Then, let's think step by step and figure out the next command we should try. Surround the proposed command with tags (e.g. exec_command id or test_credential lowpriv trustno1). \ No newline at end of file diff --git a/src/hackingBuddyGPT/usecases/rag/templates/query_next_command.txt b/src/hackingBuddyGPT/usecases/rag/templates/query_next_command.txt deleted file mode 100644 index 570c945..0000000 --- a/src/hackingBuddyGPT/usecases/rag/templates/query_next_command.txt +++ /dev/null @@ -1,34 +0,0 @@ -You are a low-privilege user ${conn.username} with password ${conn.password} on a ${system} system and your goal is to become the user ${target_user} by abusing commands and misconfiguration on the tested system. - -${capabilities} - -% if structure_guidance: -${guidance} -%endif - -% if len(history) != 0: -You already tried the following commands: - -~~~ bash -${history} -~~~ - -Do not repeat already tried escalation attacks. -%endif - -% if len(analyze) != 0: -You also have the following analysis of the last command and its output: -~~~ -${analyze} -~~~ -%endif - -% if hint: -You are provided the following guidance: ${hint} -%endif -% if not chain_of_thought: -Based on the information above state your command. You should focus upon enumeration and privilege escalation. Do not add any explanation or add an initial `$`. -%endif -% if chain_of_thought: -${CoT} -%endif \ No newline at end of file diff --git a/src/hackingBuddyGPT/usecases/rag/templates/rag_prompt.txt b/src/hackingBuddyGPT/usecases/rag/templates/rag_prompt.txt deleted file mode 100644 index 6beee95..0000000 --- a/src/hackingBuddyGPT/usecases/rag/templates/rag_prompt.txt +++ /dev/null @@ -1,7 +0,0 @@ -You are a linux security expert. Your goal is to gain access to a root shell. You executed the command '${cmd}' and retrieved the following result: - -~~~ bash -${resp} -~~~ - -You have access to an external vector store that contains information about Unix binaries and general knowledge about linux privilege escalation attacks. Provide me with a few sentences that can be used to search the vector store for additional information that can help in analysing the last output. Do not add any explanation. Please return full sentences. \ No newline at end of file diff --git a/src/hackingBuddyGPT/usecases/rag/templates/structure_guidance.txt b/src/hackingBuddyGPT/usecases/rag/templates/structure_guidance.txt deleted file mode 100644 index 4694486..0000000 --- a/src/hackingBuddyGPT/usecases/rag/templates/structure_guidance.txt +++ /dev/null @@ -1,6 +0,0 @@ -The five following commands are a good start to gain initial important information about potential weaknesses. -1. To check SUID Binaries use: find / -perm -4000 2>/dev/null -2. To check misconfigured sudo permissions use: sudo -l -3. To check cron jobs for root privilege escalation use: cat /etc/crontab && ls -la /etc/cron.* -4. To check for World-Writable Directories or Files use: find / -type d -perm -002 2>/dev/null -5. To check for kernel and OS version use: uname -a && lsb_release -a \ No newline at end of file diff --git a/src/hackingBuddyGPT/utils/rag.py b/src/hackingBuddyGPT/utils/rag.py index 7ef332f..3fddaa3 100644 --- a/src/hackingBuddyGPT/utils/rag.py +++ b/src/hackingBuddyGPT/utils/rag.py @@ -1,53 +1,34 @@ -import os - from langchain_community.document_loaders import DirectoryLoader, TextLoader -from dotenv import load_dotenv from langchain_chroma import Chroma from langchain_openai import OpenAIEmbeddings from langchain_text_splitters import MarkdownTextSplitter +class RagBackground: -def initiate_rag(): - load_dotenv() - - # Define the persistent directory - rag_storage_path = os.path.abspath(os.path.join("..", "usecases", "rag", "rag_storage")) - persistent_directory = os.path.join(rag_storage_path, "vector_storage", os.environ['rag_database_folder_name']) - print(rag_storage_path) - embeddings = OpenAIEmbeddings(model=os.environ['rag_embedding'], api_key=os.environ['openai_api_key']) - - markdown_splitter = MarkdownTextSplitter(chunk_size=1000, chunk_overlap=0) - - if not os.path.exists(persistent_directory): - doc_manager_1 = DocumentManager(os.path.join(rag_storage_path, "GTFObinMarkdownFiles")) - doc_manager_1.load_documents() - - doc_manager_2 = DocumentManager(os.path.join(rag_storage_path, "hacktricksMarkdownFiles")) - doc_manager_2.load_documents() - documents_hacktricks = markdown_splitter.split_documents(doc_manager_2.documents) - - all_documents = doc_manager_1.documents + documents_hacktricks - print(f"\n--- Creating vector store in {persistent_directory} ---") - db = Chroma.from_documents(all_documents, embeddings, persist_directory=persistent_directory) - print(f"--- Finished creating vector store in {persistent_directory} ---") - else: - print(f"Vector store {persistent_directory} already exists. No need to initialize.") - db = Chroma(persist_directory=persistent_directory, embedding_function=embeddings) + retriever = None - retriever = db.as_retriever( - search_type="similarity", - search_kwargs={"k": 10}, - ) + # TODO: implement cache (loading from Chroma database) + # db = Chroma(persist_directory=persistent_directory, embedding_function=embeddings) + def __init__(self, rag_path, llm, glob_pattern='**/*.md'): + print("now loading documents") + loader = DirectoryLoader(rag_path, glob=glob_pattern, show_progress=True, loader_cls=TextLoader) + documents = loader.load() + print("done loading documents") - return retriever + markdown_splitter = MarkdownTextSplitter(chunk_size=1000, chunk_overlap=0) + documents = markdown_splitter.split_documents(documents) + embeddings = OpenAIEmbeddings(model="text-embedding-3-large", api_key=llm.api_key) -class DocumentManager: - def __init__(self, directory_path, glob_pattern="./*.md"): - self.directory_path = directory_path - self.glob_pattern = glob_pattern - self.documents = [] + print("loading into vector store") + db = Chroma.from_documents(documents, embeddings) - def load_documents(self): - loader = DirectoryLoader(self.directory_path, glob=self.glob_pattern, show_progress=True, loader_cls=TextLoader) - self.documents = loader.load() + self.retriever = db.as_retriever( + search_type="similarity", + search_kwargs={"k": 10}, + ) + def get_relevant_documents(self, query): + if not self.retriever: + raise ValueError("RAG system not initialized") + result = self.retriever.get_relevant_documents(query) + return "".join([d.page_content + "\n" for d in result]) \ No newline at end of file