From 2498c54015c94de94e942e7151f352b85300f530 Mon Sep 17 00:00:00 2001 From: aniongithub Date: Mon, 8 Sep 2025 19:23:18 +0000 Subject: [PATCH] Enhance RAG API to return token usage information and improve error handling in chat script --- api/memoryalpha/ask.py | 8 +++---- api/memoryalpha/rag.py | 49 ++++++++++++++++++++++++++++++++++++++---- chat.sh | 22 ++++++++++++++++++- 3 files changed, 70 insertions(+), 9 deletions(-) diff --git a/api/memoryalpha/ask.py b/api/memoryalpha/ask.py index 1f11e55..518927c 100644 --- a/api/memoryalpha/ask.py +++ b/api/memoryalpha/ask.py @@ -24,14 +24,14 @@ def ask_endpoint_post(request: AskRequest): Accepts POST requests with JSON payload for cleaner API usage. """ try: - answer = rag_instance.ask( + result = rag_instance.ask( request.question, max_tokens=request.max_tokens, top_k=request.top_k, top_p=request.top_p, temperature=request.temperature ) - return JSONResponse(content={"response": answer}) + return JSONResponse(content=result) except Exception as e: return JSONResponse(status_code=500, content={"error": str(e)}) @@ -48,13 +48,13 @@ def ask_endpoint( Now uses advanced tool-enabled RAG by default for better results. """ try: - answer = rag_instance.ask( + result = rag_instance.ask( question, max_tokens=max_tokens, top_k=top_k, top_p=top_p, temperature=temperature ) - return JSONResponse(content={"response": answer}) + return JSONResponse(content=result) except Exception as e: return JSONResponse(status_code=500, content={"error": str(e)}) diff --git a/api/memoryalpha/rag.py b/api/memoryalpha/rag.py index 8c51c69..7d52beb 100644 --- a/api/memoryalpha/rag.py +++ b/api/memoryalpha/rag.py @@ -224,15 +224,20 @@ def search_tool(self, query: str, top_k: int = 5) -> str: return formatted_result def ask(self, query: str, max_tokens: int = 2048, top_k: int = 10, top_p: float = 0.8, temperature: float = 0.3, - model: str = os.getenv("DEFAULT_MODEL")) -> str: + model: str = os.getenv("DEFAULT_MODEL")) -> Dict[str, Any]: """ Ask a question using the advanced Memory Alpha RAG system with tool use. + Returns a dictionary with answer and token usage information. """ if not model: raise ValueError("model must be provided or set in DEFAULT_MODEL environment variable.") logger.info(f"Starting tool-enabled RAG for query: {query}") + + # Initialize token tracking + total_input_tokens = 0 + total_output_tokens = 0 # Define the search tool search_tool_definition = { @@ -317,6 +322,20 @@ def ask(self, query: str, max_tokens: int = 2048, top_k: int = 10, top_p: float logger.info(f"LLM response type: {type(response_message)}") logger.debug(f"Response message content: {response_message.get('content', 'No content')[:200]}...") + # Estimate tokens based on content length + # Rough estimation: ~4 characters per token for English text + content = response_message.get('content', '') + estimated_output_tokens = len(content) // 4 + total_output_tokens += estimated_output_tokens + + # Estimate input tokens from current message content + input_text = ' '.join([msg.get('content', '') for msg in messages]) + estimated_input_tokens = len(input_text) // 4 + # Only add the increment from this iteration to avoid double counting + total_input_tokens = estimated_input_tokens + + logger.info(f"Estimated tokens - Input: {estimated_input_tokens}, Output: {estimated_output_tokens}") + # Check if the model wants to use a tool tool_calls = getattr(response_message, 'tool_calls', None) or response_message.get('tool_calls') if tool_calls: @@ -377,15 +396,37 @@ def ask(self, query: str, max_tokens: int = 2048, top_k: int = 10, top_p: float self._update_history(query, final_response) logger.info("Returning final answer") - return final_response + + return { + "answer": final_response, + "token_usage": { + "input_tokens": total_input_tokens, + "output_tokens": total_output_tokens, + "total_tokens": total_input_tokens + total_output_tokens + } + } except Exception as e: logger.error(f"Chat failed: {e}") - return f"Error processing query: {str(e)}" + return { + "answer": f"Error processing query: {str(e)}", + "token_usage": { + "input_tokens": total_input_tokens, + "output_tokens": total_output_tokens, + "total_tokens": total_input_tokens + total_output_tokens + } + } # Fallback if max iterations reached logger.warning(f"Max iterations reached for query: {query}") - return "Query processing exceeded maximum iterations. Please try a simpler question." + return { + "answer": "Query processing exceeded maximum iterations. Please try a simpler question.", + "token_usage": { + "input_tokens": total_input_tokens, + "output_tokens": total_output_tokens, + "total_tokens": total_input_tokens + total_output_tokens + } + } def _update_history(self, question: str, answer: str): """Update conversation history.""" diff --git a/chat.sh b/chat.sh index 0937b9d..ce76fd7 100755 --- a/chat.sh +++ b/chat.sh @@ -22,10 +22,30 @@ ask_question() { local response response=$(curl -s \ "${BASE_URL}/memoryalpha/rag/ask?question=${encoded_question}&thinkingmode=${THINKING_MODE}&max_tokens=${MAX_TOKENS}&top_k=${TOP_K}&top_p=${TOP_P}&temperature=${TEMPERATURE}") + + # Check if response is valid JSON + if ! echo "$response" | jq . >/dev/null 2>&1; then + printf "Error: Invalid response received.\n" + printf "Raw response: %s\n" "$response" + echo "----------------------------------------" + return + fi + local answer - answer=$(echo "$response" | jq -r '.response // empty') + answer=$(echo "$response" | jq -r '.answer // empty') if [[ -n "$answer" ]]; then printf "%s\n" "$answer" + + # Display token usage if available + local input_tokens output_tokens total_tokens + input_tokens=$(echo "$response" | jq -r '.token_usage.input_tokens // empty') + output_tokens=$(echo "$response" | jq -r '.token_usage.output_tokens // empty') + total_tokens=$(echo "$response" | jq -r '.token_usage.total_tokens // empty') + + if [[ -n "$input_tokens" && -n "$output_tokens" && -n "$total_tokens" ]]; then + echo + printf "📊 Token Usage: Input: %s | Output: %s | Total: %s\n" "$input_tokens" "$output_tokens" "$total_tokens" + fi else local error error=$(echo "$response" | jq -r '.error // empty')