From 47295befde3a1aa4b8072f3a70646375169f9dbc Mon Sep 17 00:00:00 2001 From: Matteo Cargnelutti Date: Fri, 29 Sep 2023 13:10:18 -0400 Subject: [PATCH] Update chains.py Allows for replacing LangChain's default prompt via the `retriever.custom_prompt` property of the config object. --- chatdocs/chains.py | 12 ++++++++++++ 1 file changed, 12 insertions(+) diff --git a/chatdocs/chains.py b/chatdocs/chains.py index 9af8dd8..dda5747 100644 --- a/chatdocs/chains.py +++ b/chatdocs/chains.py @@ -1,6 +1,7 @@ from typing import Any, Callable, Dict, Optional from langchain.chains import RetrievalQA +from langchain.prompts import PromptTemplate from .llms import get_llm from .vectorstores import get_vectorstore @@ -14,8 +15,19 @@ def get_retrieval_qa( db = get_vectorstore(config) retriever = db.as_retriever(**config["retriever"]) llm = get_llm(config, callback=callback) + chain_type_kwargs = {} + + # Prepare and pass custom prompt if provided + if "retriever" in config and "custom_prompt" in config["retriever"]: + custom_prompt = config["retriever"]["custom_prompt"] + + chain_type_kwargs["prompt"] = PromptTemplate( + template=custom_prompt, input_variables=["context", "question"] + ) + return RetrievalQA.from_chain_type( llm=llm, retriever=retriever, return_source_documents=True, + chain_type_kwargs=chain_type_kwargs )