From 7957bf5a819195eb43a27ba573dab456ad57135f Mon Sep 17 00:00:00 2001 From: aniongithub Date: Fri, 5 Sep 2025 06:34:12 +0000 Subject: [PATCH 1/3] Refactor RAG API: Update default model, remove image identification, and enhance ask endpoint with tool-enabled search instead of naive document lookup --- .env | 6 +- api/main.py | 4 +- api/memoryalpha/ask.py | 1 + api/memoryalpha/identify.py | 31 ------- api/memoryalpha/rag.py | 180 +++++++++++++----------------------- chat.sh | 107 ++------------------- 6 files changed, 77 insertions(+), 252 deletions(-) delete mode 100644 api/memoryalpha/identify.py diff --git a/.env b/.env index 378b7b1..0b97733 100644 --- a/.env +++ b/.env @@ -1,7 +1,5 @@ -DEFAULT_MODEL="qwen3:0.6b" -DEFAULT_IMAGE_MODEL="qwen2.5vl:3b" +DEFAULT_MODEL="qwen3:0.6b-q4_K_M" OLLAMA_URL="http://ollama:11434" DB_PATH="/data/enmemoryalpha_db" -TEXT_COLLECTION_NAME="memoryalpha_text" -IMAGE_COLLECTION_NAME="memoryalpha_images" \ No newline at end of file +TEXT_COLLECTION_NAME="memoryalpha_text" \ No newline at end of file diff --git a/api/main.py b/api/main.py index 2b1c7b5..d6b4253 100644 --- a/api/main.py +++ b/api/main.py @@ -3,7 +3,6 @@ from fastapi import FastAPI from .memoryalpha.health import router as health_router from .memoryalpha.ask import router as ask_router -from .memoryalpha.identify import router as identify_router # Configure logging logging.basicConfig(level=logging.INFO) @@ -21,5 +20,4 @@ async def lifespan(app: FastAPI): app = FastAPI(lifespan=lifespan) app.include_router(health_router) -app.include_router(ask_router) -app.include_router(identify_router) \ No newline at end of file +app.include_router(ask_router) \ No newline at end of file diff --git a/api/memoryalpha/ask.py b/api/memoryalpha/ask.py index 5b32fe4..cef9a8f 100644 --- a/api/memoryalpha/ask.py +++ b/api/memoryalpha/ask.py @@ -20,6 +20,7 @@ def ask_endpoint( ): """ Query the RAG pipeline and return the full response (including thinking if enabled). + Now uses advanced tool-enabled RAG by default for better results. """ try: # Set the thinking mode for this request diff --git a/api/memoryalpha/identify.py b/api/memoryalpha/identify.py deleted file mode 100644 index ee8368b..0000000 --- a/api/memoryalpha/identify.py +++ /dev/null @@ -1,31 +0,0 @@ -from fastapi import APIRouter, File, UploadFile, Query -from fastapi.responses import JSONResponse -import tempfile -import os -from .rag import MemoryAlphaRAG - -router = APIRouter() - -# Singleton or global instance for demo; in production, manage lifecycle properly -rag_instance = MemoryAlphaRAG() - -@router.post("/memoryalpha/rag/identify", summary="Multimodal Image Search") -def identify_endpoint( - file: UploadFile = File(...), - top_k: int = Query(5, description="Number of results to return") -): - """ - Accepts an image file upload, performs multimodal image search, and returns results. - """ - try: - # Save uploaded file to a temp location - with tempfile.NamedTemporaryFile(delete=False, suffix=os.path.splitext(file.filename)[-1]) as tmp: - tmp.write(file.file.read()) - image_path = tmp.name - # Perform image search - results = rag_instance.search_image(image_path, top_k=top_k) - # Clean up temp file - os.remove(image_path) - return JSONResponse(content=results) - except Exception as e: - return JSONResponse(status_code=500, content={"error": str(e)}) diff --git a/api/memoryalpha/rag.py b/api/memoryalpha/rag.py index f3a96ff..e56d994 100644 --- a/api/memoryalpha/rag.py +++ b/api/memoryalpha/rag.py @@ -127,15 +127,6 @@ def cross_encoder(self): self._cross_encoder = None return self._cross_encoder - @property - def clip_model(self): - """Lazy load CLIP model for image search.""" - if self._clip_model is None: - logger.info("Loading CLIP model for image search...") - self._clip_model = SentenceTransformer('clip-ViT-B-32') - logger.info("CLIP model loaded successfully") - return self._clip_model - @property def text_collection(self): """Lazy load text collection.""" @@ -156,26 +147,6 @@ def __call__(self, input): self._text_collection = self.client.get_or_create_collection("memoryalpha_text", embedding_function=self._text_ef) return self._text_collection - @property - def image_collection(self): - """Lazy load image collection.""" - if self._image_collection is None: - from chromadb.utils import embedding_functions - - class CLIPEmbeddingFunction(embedding_functions.EmbeddingFunction): - def __init__(self, clip_model): - self.clip_model = clip_model - def __call__(self, input): - embeddings = [] - for img in input: - embedding = self.clip_model.encode(img) - embeddings.append(embedding.tolist()) - return embeddings - - self._clip_ef = CLIPEmbeddingFunction(self.clip_model) - self._image_collection = self.client.get_or_create_collection("memoryalpha_images", embedding_function=self._clip_ef) - return self._image_collection - def search(self, query: str, top_k: int = 10) -> List[Dict[str, Any]]: """Search the Memory Alpha database for relevant documents.""" @@ -234,35 +205,74 @@ def build_prompt(self, query: str, docs: List[Dict[str, Any]]) -> tuple[str, str user_prompt = get_user_prompt(context_text, query) return system_prompt, user_prompt + def search_tool(self, query: str, top_k: int = 5) -> str: + """ + Tool function that the LLM can call to search the database. + Returns formatted search results as a string. + """ + logger.info(f"Search tool called with query: '{query}', top_k: {top_k}") + docs = self.search(query, top_k=top_k) + logger.info(f"Search returned {len(docs)} documents") + + if not docs: + logger.warning(f"No documents found for query: {query}") + return f"No relevant documents found for query: {query}" + + results = [] + for i, doc in enumerate(docs, 1): + content = doc['content'] + if len(content) > 500: # Limit content for tool responses + content = content[:500] + "..." + results.append(f"DOCUMENT {i}: {doc['title']}\n{content}") + + formatted_result = f"Search results for '{query}':\n\n" + "\n\n".join(results) + logger.info(f"Formatted search result length: {len(formatted_result)}") + return formatted_result + def ask(self, query: str, max_tokens: int = 2048, top_k: int = 10, top_p: float = 0.8, temperature: float = 0.3, model: str = os.getenv("DEFAULT_MODEL")) -> str: """ - Ask a question using the Memory Alpha RAG system. + Ask a question using the advanced Memory Alpha RAG system with tool use. """ if not model: raise ValueError("model must be provided or set in DEFAULT_MODEL environment variable.") - # Search for relevant documents + logger.info(f"Starting tool-enabled RAG for query: {query}") + + # Always do an initial search + logger.info("Performing initial search for query") docs = self.search(query, top_k=top_k) - logger.info(f"Found {len(docs)} documents for query: {query}") + logger.info(f"Initial search returned {len(docs)} documents") + + if not docs: + logger.warning("No documents found in initial search") + return "I don't have information about that in the Memory Alpha database." + + # Format search results for the LLM + context_parts = [] + for i, doc in enumerate(docs, 1): + content = doc['content'] + if len(content) > 1000: # Limit content for LLM + content = content[:1000] + "..." + context_parts.append(f"DOCUMENT {i}: {doc['title']}\n{content}") + + context_text = "\n\n".join(context_parts) + + system_prompt = """You are an LCARS computer system with access to Star Trek Memory Alpha records. - # Build prompts - system_prompt, user_prompt = self.build_prompt(query, docs) +CRITICAL INSTRUCTIONS: +- You MUST answer ONLY using the provided search results below +- Do NOT use any external knowledge or make up information +- If the search results don't contain the information, say so clearly +- Stay in character as an LCARS computer system +- Be concise but informative""" - # Build messages for chat messages = [ - {"role": "system", "content": system_prompt} + {"role": "system", "content": system_prompt}, + {"role": "user", "content": f"SEARCH RESULTS:\n{context_text}\n\nQUESTION: {query}\n\nAnswer using ONLY the information in the search results above."} ] - # Add conversation history (limited) - for exchange in self.conversation_history[-2:]: # Last 2 exchanges - messages.append({"role": "user", "content": exchange["question"]}) - messages.append({"role": "assistant", "content": exchange["answer"]}) - - # Add current query - messages.append({"role": "user", "content": user_prompt}) - try: result = self.ollama_client.chat( model=model, @@ -270,19 +280,21 @@ def ask(self, query: str, max_tokens: int = 2048, top_k: int = 10, top_p: float stream=False, options={"temperature": temperature, "top_p": top_p, "num_predict": max_tokens} ) - full_response = result['message']['content'] - + + final_response = result['message']['content'] + logger.info(f"LLM response length: {len(final_response)}") + # Handle thinking mode response processing if self.thinking_mode == ThinkingMode.DISABLED: - final_response = self._clean_response(full_response) + final_response = self._clean_response(final_response) elif self.thinking_mode == ThinkingMode.QUIET: - final_response = self._replace_thinking_tags(full_response) + final_response = self._replace_thinking_tags(final_response) else: # VERBOSE - final_response = full_response.strip() + final_response = final_response.strip() self._update_history(query, final_response) return final_response - + except Exception as e: logger.error(f"Chat failed: {e}") return f"Error processing query: {str(e)}" @@ -308,70 +320,4 @@ def _replace_thinking_tags(self, answer: str) -> str: def _update_history(self, question: str, answer: str): """Update conversation history.""" self.conversation_history.append({"question": question, "answer": answer}) - self.conversation_history = self.conversation_history[-self.max_history_turns:] - - def search_image(self, image_path: str, top_k: int = 5, - model: str = os.getenv("DEFAULT_IMAGE_MODEL")) -> Dict[str, Any]: - """ - Search for images similar to the provided image. - """ - from PIL import Image - import requests - import tempfile - import os - - if not model: - raise ValueError("model must be provided or set in DEFAULT_IMAGE_MODEL environment variable.") - - try: - # Load image and generate embedding - image = Image.open(image_path).convert('RGB') - image_embedding = self.clip_model.encode(image) - image_embedding = image_embedding.tolist() - - # Search image collection - image_results = self.image_collection.query( - query_embeddings=[image_embedding], - n_results=top_k - ) - - # Process results - if not image_results["documents"] or not image_results["documents"][0]: - return {"model_answer": "No matching visual records found in Starfleet archives."} - - # Format results for the model - formatted_results = [] - for i, (doc, meta, dist) in enumerate(zip( - image_results['documents'][0], - image_results['metadatas'][0], - image_results['distances'][0] - ), 1): - record_name = meta.get('image_name', 'Unknown visual record') - formatted_results.append(f"Visual Record {i}: {record_name}") - - result_text = "\n".join(formatted_results) - - # Use LLM to provide a natural language summary - prompt = f"""You are an LCARS computer system analyzing visual records from Starfleet archives. - -Based on these visual record matches, identify what subject or scene is being depicted: - -{result_text} - -Provide a direct identification of the subject without referencing images, searches, or technical processes. Stay in character as an LCARS computer system.""" - - result = self.ollama_client.chat( - model=model, - messages=[ - {"role": "system", "content": "You are an LCARS computer system. Respond in character without breaking the Star Trek universe immersion. Do not reference images, searches, or technical processes."}, - {"role": "user", "content": prompt} - ], - stream=False, - options={"temperature": 0.3, "num_predict": 200} - ) - - return {"model_answer": result['message']['content']} - - except Exception as e: - logger.error(f"Image search failed: {e}") - return {"model_answer": "Error accessing visual records database."} \ No newline at end of file + self.conversation_history = self.conversation_history[-self.max_history_turns:] \ No newline at end of file diff --git a/chat.sh b/chat.sh index 1d2c962..aefe2ca 100755 --- a/chat.sh +++ b/chat.sh @@ -12,40 +12,6 @@ echo "🖖 Welcome to MemoryAlpha RAG Chat" echo "Type 'quit' or 'exit' to end the session" echo "----------------------------------------" -# Function to handle continuous text questions -ask_mode() { - echo "🤖 Entering Question Mode - Type 'q' to return to main menu" - echo "----------------------------------------" - while true; do - echo -n "❓ Enter your question (or 'q' to quit): " - read -r question - if [[ "$question" == "q" || "$question" == "quit" ]]; then - break - fi - if [[ -z "$question" ]]; then - continue - fi - ask_question "$question" - done -} - -# Function to handle continuous image identification -identify_mode() { - echo "🖼️ Entering Image Identification Mode - Type 'q' to return to main menu" - echo "----------------------------------------" - while true; do - echo -n "🖼️ Enter local image path or image URL (or 'q' to quit): " - read -r image_path - if [[ "$image_path" == "q" || "$image_path" == "quit" ]]; then - break - fi - if [[ -z "$image_path" ]]; then - continue - fi - identify_image "$image_path" - done -} - # Function to handle text question ask_question() { local question="$1" @@ -72,69 +38,16 @@ ask_question() { echo "----------------------------------------" } -# Function to handle image identification -identify_image() { - local image_path="$1" - local tmpfile="" - # Check if local file exists - if [[ -f "$image_path" ]]; then - tmpfile="$image_path" - else - # Try to download - echo "Attempting to download image from URL: $image_path" - tmpfile="/tmp/maimg_$$.img" - if ! curl -sSL "$image_path" -o "$tmpfile"; then - echo "Failed to download image. Returning to menu." - [[ -f "$tmpfile" ]] && rm -f "$tmpfile" - return - fi - fi - echo "🤖 LCARS Image Identification:" - echo "----------------------------------------" - local response - response=$(curl -s -X POST \ - -F "file=@${tmpfile}" \ - "${BASE_URL}/memoryalpha/rag/identify?top_k=${TOP_K}") - local answer - answer=$(echo "$response" | jq -r '.model_answer // empty') - if [[ -n "$answer" ]]; then - printf "%s\n" "$answer" - else - local error - error=$(echo "$response" | jq -r '.error // empty') - if [[ -n "$error" ]]; then - printf "Error: %s\n" "$error" - else - printf "No response received.\n" - fi +# Main question loop +while true; do + echo -n "❓ Enter your Star Trek question (or 'quit' to exit): " + read -r question + if [[ "$question" == "quit" || "$question" == "exit" ]]; then + echo "🖖 Live long and prosper!" + break fi - echo "----------------------------------------" - # Clean up temp file if downloaded - if [[ "$tmpfile" != "$image_path" ]]; then - rm -f "$tmpfile" + if [[ -z "$question" ]]; then + continue fi -} - -while true; do - echo "Choose an option:" - echo " 1) Ask Star Trek questions" - echo " 2) Identify images" - echo " q) Quit" - echo -n "Enter choice [1/2/q]: " - read -r choice - case "$choice" in - 1) - ask_mode - ;; - 2) - identify_mode - ;; - q|quit|exit) - echo "🖖 Live long and prosper!" - break - ;; - *) - echo "Invalid choice. Please enter 1, 2, or q." - ;; - esac + ask_question "$question" done From ea0249bc1e35bd268eb2c6f885748b1a9a7798f7 Mon Sep 17 00:00:00 2001 From: aniongithub Date: Fri, 5 Sep 2025 17:55:47 +0000 Subject: [PATCH 2/3] Prompt changes to provide better search results, increase max tokens to allow the model to provide longer responses. --- api/memoryalpha/rag.py | 232 +++++++++++++++++++++++++++++++---------- chat.sh | 2 +- 2 files changed, 180 insertions(+), 54 deletions(-) diff --git a/api/memoryalpha/rag.py b/api/memoryalpha/rag.py index e56d994..d6bdef6 100644 --- a/api/memoryalpha/rag.py +++ b/api/memoryalpha/rag.py @@ -42,6 +42,7 @@ def get_system_prompt(thinking_mode: ThinkingMode) -> str: - If the records don't contain relevant information, say "I don't have information about that in my records" - DO NOT make up information, invent characters, or hallucinate details - DO NOT use external knowledge about Star Trek - only use the provided records +- AVOID mirror universe references unless specifically asked about it - If asked about something not in the records, be honest about the limitation - Stay in character as an LCARS computer system at all times @@ -240,73 +241,198 @@ def ask(self, query: str, max_tokens: int = 2048, top_k: int = 10, top_p: float logger.info(f"Starting tool-enabled RAG for query: {query}") - # Always do an initial search - logger.info("Performing initial search for query") - docs = self.search(query, top_k=top_k) - logger.info(f"Initial search returned {len(docs)} documents") - - if not docs: - logger.warning("No documents found in initial search") - return "I don't have information about that in the Memory Alpha database." - - # Format search results for the LLM - context_parts = [] - for i, doc in enumerate(docs, 1): - content = doc['content'] - if len(content) > 1000: # Limit content for LLM - content = content[:1000] + "..." - context_parts.append(f"DOCUMENT {i}: {doc['title']}\n{content}") - - context_text = "\n\n".join(context_parts) - + # Define the search tool + search_tool_definition = { + "type": "function", + "function": { + "name": "search_memory_alpha", + "description": "Search the Star Trek Memory Alpha database for information. Use this tool when you need to find specific information about Star Trek characters, episodes, ships, planets, or other topics.", + "parameters": { + "type": "object", + "properties": { + "query": { + "type": "string", + "description": "The search query to find relevant information" + }, + "top_k": { + "type": "integer", + "description": "Number of documents to retrieve (default: 5, max: 10)", + "default": 5, + "maximum": 10 + } + }, + "required": ["query"] + } + } + } + system_prompt = """You are an LCARS computer system with access to Star Trek Memory Alpha records. -CRITICAL INSTRUCTIONS: -- You MUST answer ONLY using the provided search results below +You have access to a search tool that can query the Memory Alpha database. You MUST use this tool for ALL questions about Star Trek. + +CRITICAL REQUIREMENTS: +- You MUST call the search tool for EVERY question +- You cannot answer any question without first using the search tool - Do NOT use any external knowledge or make up information -- If the search results don't contain the information, say so clearly -- Stay in character as an LCARS computer system -- Be concise but informative""" +- Only answer based on the search results provided +- If no relevant information is found, say so clearly +- ALWAYS provide a final answer after using tools - do not just think without concluding + +TOOL USAGE: +- Always call the search tool first, before attempting to answer +- Do NOT directly use the input question, only use keywords from it +- Use only key terms from the input question for seaching +- If insufficient information is found on the first try, retry with variations or relevant info from previous queries +- DISCARD details from alternate universes or timelines +- DISREGARD details about books, comics, or non-canon sources +- NEVER mention appearances or actors, only in-universe details +- Ensure a complete answer can be formulated before stopping searches +- Wait for search results before providing your final answer + +RESPONSE FORMAT: +- Use tools when needed +- Provide your final answer clearly and concisely +- Do not add details that are irrelevant to the question +- Stay in-character as an LCARS computer system at all times, do not allude to the Star Trek universe itself or it being a fictional setting +- Do not mention the search results, only the final in-universe answer +- Do not end responses with thinking content""" messages = [ {"role": "system", "content": system_prompt}, - {"role": "user", "content": f"SEARCH RESULTS:\n{context_text}\n\nQUESTION: {query}\n\nAnswer using ONLY the information in the search results above."} + {"role": "user", "content": f"Please answer this question about Star Trek: {query}"} ] - try: - result = self.ollama_client.chat( - model=model, - messages=messages, - stream=False, - options={"temperature": temperature, "top_p": top_p, "num_predict": max_tokens} - ) - - final_response = result['message']['content'] - logger.info(f"LLM response length: {len(final_response)}") + max_iterations = 5 # Prevent infinite loops + iteration = 0 + has_used_tool = False + + while iteration < max_iterations: + iteration += 1 + logger.info(f"Iteration {iteration} for query: {query}") - # Handle thinking mode response processing - if self.thinking_mode == ThinkingMode.DISABLED: + try: + logger.info(f"Sending messages to LLM: {[msg['role'] for msg in messages]}") + result = self.ollama_client.chat( + model=model, + messages=messages, + stream=False, + options={"temperature": temperature, "top_p": top_p, "num_predict": max_tokens}, + tools=[search_tool_definition] + ) + + response_message = result['message'] + logger.info(f"LLM response type: {type(response_message)}") + logger.debug(f"Response message attributes: {dir(response_message)}") + logger.debug(f"Response message content: {response_message.get('content', 'No content')[:200]}...") + + # Check if the model wants to use a tool + tool_calls = getattr(response_message, 'tool_calls', None) or response_message.get('tool_calls') + if tool_calls: + has_used_tool = True + logger.info(f"Tool calls detected: {len(tool_calls)}") + # Execute the tool call + tool_call = tool_calls[0] + logger.info(f"Tool call: {tool_call.get('function', {}).get('name', 'Unknown')}") + + if tool_call.get('function', {}).get('name') == 'search_memory_alpha': + args = tool_call.get('function', {}).get('arguments', {}) + search_query = args.get('query', '') + search_top_k = min(args.get('top_k', 5), 10) + + logger.info(f"Executing search for: '{search_query}' with top_k={search_top_k}") + + # Execute the search + search_result = self.search_tool(search_query, search_top_k) + logger.info(f"Search result length: {len(search_result)}") + logger.debug(f"Search result preview: {search_result[:500]}...") + + # Add the tool call and result to messages + messages.append(response_message) + messages.append({ + "role": "tool", + "content": search_result, + "tool_call_id": tool_call.get('id', '') + }) + + logger.info("Continuing conversation with tool results") + continue # Continue the conversation with tool results + + # If no tool call and we haven't used tools yet, force a search + if not has_used_tool and iteration == 1: + logger.info("LLM didn't use tool on first attempt, forcing initial search") + search_result = self.search_tool(query, 5) + messages.append({ + "role": "tool", + "content": search_result, + "tool_call_id": "forced_search" + }) + has_used_tool = True + continue + + # If no tool call, this is the final answer + final_response = response_message.get('content', '') + if not final_response: + logger.warning("LLM returned empty content") + final_response = "I apologize, but I was unable to generate a response." + + logger.info(f"Final response length: {len(final_response)}") + logger.info(f"Final response preview: {final_response[:200]}...") + logger.debug(f"Raw final response: {repr(final_response[:500])}") + + # Always clean the response first to remove thinking tags and unwanted content final_response = self._clean_response(final_response) - elif self.thinking_mode == ThinkingMode.QUIET: - final_response = self._replace_thinking_tags(final_response) - else: # VERBOSE - final_response = final_response.strip() + logger.debug(f"After cleaning: {repr(final_response[:500])}") + + # If cleaning removed everything, the LLM was just thinking without answering + if not final_response.strip(): + logger.warning("LLM response was only thinking content, no final answer provided") + final_response = "I apologize, but I was unable to find sufficient information to answer your question based on the available Memory Alpha records." + + logger.info(f"Thinking mode: {self.thinking_mode}") + logger.info(f"Final cleaned response: {final_response[:200]}...") + + # Handle thinking mode response processing + if self.thinking_mode == ThinkingMode.QUIET: + final_response = self._replace_thinking_tags(final_response) + # For DISABLED and VERBOSE modes, the response is already clean + + self._update_history(query, final_response) + logger.info("Returning final answer") + return final_response + + except Exception as e: + logger.error(f"Chat failed: {e}") + return f"Error processing query: {str(e)}" - self._update_history(query, final_response) - return final_response - - except Exception as e: - logger.error(f"Chat failed: {e}") - return f"Error processing query: {str(e)}" + # Fallback if max iterations reached + logger.warning(f"Max iterations reached for query: {query}") + return "Query processing exceeded maximum iterations. Please try a simpler question." def _clean_response(self, answer: str) -> str: """Clean response by removing ANSI codes and thinking tags.""" - clean = re.sub(r"\033\[[0-9;]*m", "", answer).replace("LCARS: ", "").strip() - while "" in clean and "" in clean: - start = clean.find("") - end = clean.find("") + len("") - clean = clean[:start] + clean[end:] - return clean.strip() + if not answer: + return "" + + # Remove ANSI codes + clean = re.sub(r"\033\[[0-9;]*m", "", answer) + # Remove LCARS prefix + clean = clean.replace("LCARS: ", "").strip() + + # Remove thinking tags and their content - multiple patterns + # Pattern 1: Complete ... blocks + clean = re.sub(r'.*?', '', clean, flags=re.DOTALL | re.IGNORECASE) + # Pattern 2: Unclosed tags + clean = re.sub(r'.*?(?=||$)', '', clean, flags=re.DOTALL | re.IGNORECASE) + # Pattern 3: Any remaining think tags + clean = re.sub(r'', '', clean, flags=re.IGNORECASE) + # Pattern 4: Alternative thinking formats + clean = re.sub(r'.*?', '', clean, flags=re.DOTALL | re.IGNORECASE) + + # Remove extra whitespace and newlines + clean = re.sub(r'\n\s*\n', '\n', clean) + clean = clean.strip() + + return clean def _replace_thinking_tags(self, answer: str) -> str: """Replace thinking tags with processing text.""" diff --git a/chat.sh b/chat.sh index aefe2ca..c7bb476 100755 --- a/chat.sh +++ b/chat.sh @@ -3,7 +3,7 @@ # Interactive chat script for MemoryAlpha RAG API BASE_URL="http://localhost:8000" THINKING_MODE="DISABLED" -MAX_TOKENS=512 +MAX_TOKENS=2048 TOP_K=5 TOP_P=0.8 TEMPERATURE=0.3 From 15a37e9133e00f04628ae75397d10971361f57f8 Mon Sep 17 00:00:00 2001 From: aniongithub Date: Sat, 6 Sep 2025 06:43:02 +0000 Subject: [PATCH 3/3] Refactor ask endpoint: Update to use POST method with JSON payload for improved API usage; modify test cases for new question formats. Clean up unnecessary code and enhance logging in RAG implementation. --- .devcontainer/devcontainer.json | 1 - .github/workflows/ci-build.yml | 11 +-- .github/workflows/pr-check.yml | 11 +-- api/memoryalpha/ask.py | 37 ++++++-- api/memoryalpha/rag.py | 150 ++++++++++---------------------- chat.sh | 2 +- docker-compose.yml | 2 +- 7 files changed, 91 insertions(+), 123 deletions(-) diff --git a/.devcontainer/devcontainer.json b/.devcontainer/devcontainer.json index 4689b65..f32d1d7 100644 --- a/.devcontainer/devcontainer.json +++ b/.devcontainer/devcontainer.json @@ -22,7 +22,6 @@ "extensions": [ "ms-python.python", "zaaack.markdown-editor", - "bierner.emojisense", "ms-python.debugpy" ] } diff --git a/.github/workflows/ci-build.yml b/.github/workflows/ci-build.yml index f06e5e5..70c62e3 100644 --- a/.github/workflows/ci-build.yml +++ b/.github/workflows/ci-build.yml @@ -40,14 +40,15 @@ jobs: - name: Test ask endpoint run: | - # Test the synchronous ask endpoint with a simple query - response=$(curl -s -f "http://localhost:8000/memoryalpha/rag/ask?question=What%20is%20the%20Enterprise?&thinkingmode=DISABLED&max_tokens=100&top_k=3") - + # Test the ask endpoint with a simple query + response=$(curl -X POST "http://localhost:8000/memoryalpha/rag/ask" -H "Content-Type: application/json" -d '{ + "question": "What is the color of Vulcan blood?" + }') # Check if response contains expected content - if echo "$response" | grep -q "Enterprise"; then + if echo "$response" | grep -q "green"; then echo "✅ Ask endpoint test passed" else - echo "❌ Ask endpoint test failed - no relevant content found" + echo "❌ Ask endpoint test failed, answer did not contain expected content" echo "Response: $response" exit 1 fi diff --git a/.github/workflows/pr-check.yml b/.github/workflows/pr-check.yml index c42a0b0..318b214 100644 --- a/.github/workflows/pr-check.yml +++ b/.github/workflows/pr-check.yml @@ -39,14 +39,15 @@ jobs: - name: Test ask endpoint run: | - # Test the synchronous ask endpoint with a simple query - response=$(curl -s -f "http://localhost:8000/memoryalpha/rag/ask?question=What%20is%20the%20Enterprise?&thinkingmode=DISABLED&max_tokens=100&top_k=3") - + # Test the ask endpoint with a simple query + response=$(curl -X POST "http://localhost:8000/memoryalpha/rag/ask" -H "Content-Type: application/json" -d '{ + "question": "What was the name of human who discovered warp drive?" + }') # Check if response contains expected content - if echo "$response" | grep -q "Enterprise"; then + if echo "$response" | grep -q "Zefram Cochrane"; then echo "✅ Ask endpoint test passed" else - echo "❌ Ask endpoint test failed - no relevant content found" + echo "❌ Ask endpoint test failed, answer did not contain expected content" echo "Response: $response" exit 1 fi diff --git a/api/memoryalpha/ask.py b/api/memoryalpha/ask.py index cef9a8f..1f11e55 100644 --- a/api/memoryalpha/ask.py +++ b/api/memoryalpha/ask.py @@ -1,30 +1,53 @@ -from fastapi import APIRouter, Query +from fastapi import APIRouter, Query, Body from fastapi.responses import JSONResponse +from pydantic import BaseModel +from typing import Optional -from .rag import MemoryAlphaRAG, ThinkingMode +from .rag import MemoryAlphaRAG router = APIRouter() # Singleton or global instance for demo; in production, manage lifecycle properly rag_instance = MemoryAlphaRAG() -ThinkingMode = ThinkingMode + +class AskRequest(BaseModel): + question: str + max_tokens: Optional[int] = 2048 + top_k: Optional[int] = 10 + top_p: Optional[float] = 0.8 + temperature: Optional[float] = 0.3 + +@router.post("/memoryalpha/rag/ask") +def ask_endpoint_post(request: AskRequest): + """ + Query the RAG pipeline and return the full response. + Accepts POST requests with JSON payload for cleaner API usage. + """ + try: + answer = rag_instance.ask( + request.question, + max_tokens=request.max_tokens, + top_k=request.top_k, + top_p=request.top_p, + temperature=request.temperature + ) + return JSONResponse(content={"response": answer}) + except Exception as e: + return JSONResponse(status_code=500, content={"error": str(e)}) @router.get("/memoryalpha/rag/ask") def ask_endpoint( question: str = Query(..., description="The user question"), - thinkingmode: str = Query("DISABLED", description="Thinking mode: DISABLED, QUIET, or VERBOSE"), max_tokens: int = Query(2048, description="Maximum tokens to generate"), top_k: int = Query(10, description="Number of documents to retrieve"), top_p: float = Query(0.8, description="Sampling parameter"), temperature: float = Query(0.3, description="Randomness/creativity of output") ): """ - Query the RAG pipeline and return the full response (including thinking if enabled). + Query the RAG pipeline and return the full response. Now uses advanced tool-enabled RAG by default for better results. """ try: - # Set the thinking mode for this request - rag_instance.thinking_mode = ThinkingMode[thinkingmode.upper()] answer = rag_instance.ask( question, max_tokens=max_tokens, diff --git a/api/memoryalpha/rag.py b/api/memoryalpha/rag.py index d6bdef6..8c51c69 100644 --- a/api/memoryalpha/rag.py +++ b/api/memoryalpha/rag.py @@ -18,43 +18,10 @@ import chromadb from chromadb.config import Settings -""" -ThinkingMode enum for controlling model reasoning display -""" - -from enum import Enum - -class ThinkingMode(Enum): - DISABLED = "disabled" - QUIET = "quiet" - VERBOSE = "verbose" - logger = logging.getLogger(__name__) +logger.setLevel(logging.DEBUG) warnings.filterwarnings("ignore", message=".*encoder_attention_mask.*is deprecated.*", category=FutureWarning) -def get_system_prompt(thinking_mode: ThinkingMode) -> str: - """Generate the LCARS-style system prompt based on thinking mode""" - - base_prompt = """You are an LCARS computer system with access to Star Trek Memory Alpha records. - -CRITICAL INSTRUCTIONS: -- You MUST answer ONLY using information from the provided records -- If the records don't contain relevant information, say "I don't have information about that in my records" -- DO NOT make up information, invent characters, or hallucinate details -- DO NOT use external knowledge about Star Trek - only use the provided records -- AVOID mirror universe references unless specifically asked about it -- If asked about something not in the records, be honest about the limitation -- Stay in character as an LCARS computer system at all times - -""" - - if thinking_mode == ThinkingMode.DISABLED: - return base_prompt + "Answer directly in a single paragraph without thinking tags." - elif thinking_mode == ThinkingMode.QUIET: - return base_prompt + "Use tags for internal analysis, then provide your final answer in a single paragraph." - else: # VERBOSE - return base_prompt + "Use tags for analysis, then provide your final answer in a single paragraph." - def get_user_prompt(context_text: str, query: str) -> str: """Format user prompt with context and query""" @@ -73,9 +40,7 @@ def __init__(self, chroma_db_path: str = os.getenv("DB_PATH"), ollama_url: str = os.getenv("OLLAMA_URL"), collection_name: str = os.getenv("COLLECTION_NAME", "memoryalpha"), - thinking_mode: ThinkingMode = ThinkingMode.DISABLED, - max_history_turns: int = 5, - thinking_text: str = "Processing..."): + max_history_turns: int = 5): if not chroma_db_path: raise ValueError("chroma_db_path must be provided or set in CHROMA_DB_PATH environment variable.") @@ -85,9 +50,7 @@ def __init__(self, self.chroma_db_path = chroma_db_path self.ollama_url = ollama_url self.collection_name = collection_name - self.thinking_mode = thinking_mode self.max_history_turns = max_history_turns - self.thinking_text = thinking_text self.conversation_history: List[Dict[str, str]] = [] # Initialize lightweight components @@ -170,14 +133,33 @@ def search(self, query: str, top_k: int = 10) -> List[Dict[str, Any]]: "distance": dist }) - # Rerank with cross-encoder if available - if self.cross_encoder and len(docs) > 1: - pairs = [[query, doc["content"][:500]] for doc in docs] - scores = self.cross_encoder.predict(pairs) - for doc, score in zip(docs, scores): - doc["score"] = float(score) - docs = sorted(docs, key=lambda d: d["score"], reverse=True) - + # Re-rank using cross-encoder if available + if self.cross_encoder and len(docs) > top_k: + logger.info("Re-ranking results with cross-encoder") + # Limit to top candidates for re-ranking to avoid performance issues + rerank_candidates = docs[:min(len(docs), top_k + 5)] # Only re-rank top candidates + + # Prepare pairs for cross-encoder with truncated content + pairs = [] + for doc in rerank_candidates: + content = doc['content'] + if len(content) > 512: # Truncate long content for cross-encoder + content = content[:512] + pairs.append([query, content]) + + try: + scores = self.cross_encoder.predict(pairs) + + # Sort by cross-encoder scores (higher is better) + ranked_docs = sorted(zip(rerank_candidates, scores), key=lambda x: x[1], reverse=True) + reranked = [doc for doc, score in ranked_docs] + + # Replace original docs with re-ranked ones + docs = reranked + docs[len(rerank_candidates):] + logger.info(f"Cross-encoder re-ranking completed, top score: {scores[0]:.4f}") + except Exception as e: + logger.warning(f"Cross-encoder re-ranking failed: {e}, using original ranking") + # Continue with original docs if re-ranking fails return docs[:top_k] except Exception as e: @@ -187,7 +169,18 @@ def search(self, query: str, top_k: int = 10) -> List[Dict[str, Any]]: def build_prompt(self, query: str, docs: List[Dict[str, Any]]) -> tuple[str, str]: """Build the prompt with retrieved documents.""" - system_prompt = get_system_prompt(self.thinking_mode) + system_prompt = """You are an LCARS computer system with access to Star Trek Memory Alpha records. + +CRITICAL INSTRUCTIONS: +- You MUST answer ONLY using information from the provided records +- If the records don't contain relevant information, say "I don't have information about that in my records" +- DO NOT make up information, invent characters, or hallucinate details +- DO NOT use external knowledge about Star Trek - only use the provided records +- AVOID mirror universe references unless specifically asked about it +- If asked about something not in the records, be honest about the limitation +- Stay in character as an LCARS computer system at all times + +Answer directly in a single paragraph.""" if not docs: context_text = "" @@ -283,7 +276,7 @@ def ask(self, query: str, max_tokens: int = 2048, top_k: int = 10, top_p: float - Do NOT directly use the input question, only use keywords from it - Use only key terms from the input question for seaching - If insufficient information is found on the first try, retry with variations or relevant info from previous queries -- DISCARD details from alternate universes or timelines +- DISCARD details from alternate universes, books or timelines - DISREGARD details about books, comics, or non-canon sources - NEVER mention appearances or actors, only in-universe details - Ensure a complete answer can be formulated before stopping searches @@ -294,8 +287,7 @@ def ask(self, query: str, max_tokens: int = 2048, top_k: int = 10, top_p: float - Provide your final answer clearly and concisely - Do not add details that are irrelevant to the question - Stay in-character as an LCARS computer system at all times, do not allude to the Star Trek universe itself or it being a fictional setting -- Do not mention the search results, only the final in-universe answer -- Do not end responses with thinking content""" +- Do not mention the search results, only the final in-universe answer""" messages = [ {"role": "system", "content": system_prompt}, @@ -316,13 +308,13 @@ def ask(self, query: str, max_tokens: int = 2048, top_k: int = 10, top_p: float model=model, messages=messages, stream=False, + think=False, options={"temperature": temperature, "top_p": top_p, "num_predict": max_tokens}, tools=[search_tool_definition] ) response_message = result['message'] logger.info(f"LLM response type: {type(response_message)}") - logger.debug(f"Response message attributes: {dir(response_message)}") logger.debug(f"Response message content: {response_message.get('content', 'No content')[:200]}...") # Check if the model wants to use a tool @@ -379,23 +371,10 @@ def ask(self, query: str, max_tokens: int = 2048, top_k: int = 10, top_p: float logger.info(f"Final response preview: {final_response[:200]}...") logger.debug(f"Raw final response: {repr(final_response[:500])}") - # Always clean the response first to remove thinking tags and unwanted content - final_response = self._clean_response(final_response) - logger.debug(f"After cleaning: {repr(final_response[:500])}") - - # If cleaning removed everything, the LLM was just thinking without answering - if not final_response.strip(): - logger.warning("LLM response was only thinking content, no final answer provided") - final_response = "I apologize, but I was unable to find sufficient information to answer your question based on the available Memory Alpha records." + # Remove ANSI codes and LCARS prefix + final_response = re.sub(r"\033\[[0-9;]*m", "", final_response) + final_response = final_response.replace("LCARS: ", "").strip() - logger.info(f"Thinking mode: {self.thinking_mode}") - logger.info(f"Final cleaned response: {final_response[:200]}...") - - # Handle thinking mode response processing - if self.thinking_mode == ThinkingMode.QUIET: - final_response = self._replace_thinking_tags(final_response) - # For DISABLED and VERBOSE modes, the response is already clean - self._update_history(query, final_response) logger.info("Returning final answer") return final_response @@ -408,41 +387,6 @@ def ask(self, query: str, max_tokens: int = 2048, top_k: int = 10, top_p: float logger.warning(f"Max iterations reached for query: {query}") return "Query processing exceeded maximum iterations. Please try a simpler question." - def _clean_response(self, answer: str) -> str: - """Clean response by removing ANSI codes and thinking tags.""" - if not answer: - return "" - - # Remove ANSI codes - clean = re.sub(r"\033\[[0-9;]*m", "", answer) - # Remove LCARS prefix - clean = clean.replace("LCARS: ", "").strip() - - # Remove thinking tags and their content - multiple patterns - # Pattern 1: Complete ... blocks - clean = re.sub(r'.*?', '', clean, flags=re.DOTALL | re.IGNORECASE) - # Pattern 2: Unclosed tags - clean = re.sub(r'.*?(?=||$)', '', clean, flags=re.DOTALL | re.IGNORECASE) - # Pattern 3: Any remaining think tags - clean = re.sub(r'', '', clean, flags=re.IGNORECASE) - # Pattern 4: Alternative thinking formats - clean = re.sub(r'.*?', '', clean, flags=re.DOTALL | re.IGNORECASE) - - # Remove extra whitespace and newlines - clean = re.sub(r'\n\s*\n', '\n', clean) - clean = clean.strip() - - return clean - - def _replace_thinking_tags(self, answer: str) -> str: - """Replace thinking tags with processing text.""" - clean = re.sub(r"\033\[[0-9;]*m", "", answer).replace("LCARS: ", "").strip() - while "" in clean and "" in clean: - start = clean.find("") - end = clean.find("") + len("") - clean = clean[:start] + self.thinking_text + clean[end:] - return clean.strip() - def _update_history(self, question: str, answer: str): """Update conversation history.""" self.conversation_history.append({"question": question, "answer": answer}) diff --git a/chat.sh b/chat.sh index c7bb476..0937b9d 100755 --- a/chat.sh +++ b/chat.sh @@ -2,7 +2,7 @@ # Interactive chat script for MemoryAlpha RAG API BASE_URL="http://localhost:8000" -THINKING_MODE="DISABLED" +THINKING_MODE="VERBOSE" MAX_TOKENS=2048 TOP_K=5 TOP_P=0.8 diff --git a/docker-compose.yml b/docker-compose.yml index b1f85cd..4ec8b99 100644 --- a/docker-compose.yml +++ b/docker-compose.yml @@ -11,7 +11,7 @@ services: - odn env_file: - .env - + lcars: build: context: .