Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions api/memoryalpha/ask.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@ def ask_endpoint_post(request: AskRequest):
Accepts POST requests with JSON payload for cleaner API usage.
"""
try:
answer = rag_instance.ask(
result = rag_instance.ask(
request.question,
max_tokens=request.max_tokens,
top_k=request.top_k,
top_p=request.top_p,
temperature=request.temperature
)
return JSONResponse(content={"response": answer})
return JSONResponse(content=result)
except Exception as e:
return JSONResponse(status_code=500, content={"error": str(e)})

Expand All @@ -48,13 +48,13 @@ def ask_endpoint(
Now uses advanced tool-enabled RAG by default for better results.
"""
try:
answer = rag_instance.ask(
result = rag_instance.ask(
question,
max_tokens=max_tokens,
top_k=top_k,
top_p=top_p,
temperature=temperature
)
return JSONResponse(content={"response": answer})
return JSONResponse(content=result)
except Exception as e:
return JSONResponse(status_code=500, content={"error": str(e)})
49 changes: 45 additions & 4 deletions api/memoryalpha/rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,15 +224,20 @@ def search_tool(self, query: str, top_k: int = 5) -> str:
return formatted_result

def ask(self, query: str, max_tokens: int = 2048, top_k: int = 10, top_p: float = 0.8, temperature: float = 0.3,
model: str = os.getenv("DEFAULT_MODEL")) -> str:
model: str = os.getenv("DEFAULT_MODEL")) -> Dict[str, Any]:
"""
Ask a question using the advanced Memory Alpha RAG system with tool use.
Returns a dictionary with answer and token usage information.
"""

if not model:
raise ValueError("model must be provided or set in DEFAULT_MODEL environment variable.")

logger.info(f"Starting tool-enabled RAG for query: {query}")

# Initialize token tracking
total_input_tokens = 0
total_output_tokens = 0

# Define the search tool
search_tool_definition = {
Expand Down Expand Up @@ -317,6 +322,20 @@ def ask(self, query: str, max_tokens: int = 2048, top_k: int = 10, top_p: float
logger.info(f"LLM response type: {type(response_message)}")
logger.debug(f"Response message content: {response_message.get('content', 'No content')[:200]}...")

# Estimate tokens based on content length
# Rough estimation: ~4 characters per token for English text
content = response_message.get('content', '')
estimated_output_tokens = len(content) // 4
total_output_tokens += estimated_output_tokens

# Estimate input tokens from current message content
input_text = ' '.join([msg.get('content', '') for msg in messages])
estimated_input_tokens = len(input_text) // 4
# Only add the increment from this iteration to avoid double counting
total_input_tokens = estimated_input_tokens

logger.info(f"Estimated tokens - Input: {estimated_input_tokens}, Output: {estimated_output_tokens}")

# Check if the model wants to use a tool
tool_calls = getattr(response_message, 'tool_calls', None) or response_message.get('tool_calls')
if tool_calls:
Expand Down Expand Up @@ -377,15 +396,37 @@ def ask(self, query: str, max_tokens: int = 2048, top_k: int = 10, top_p: float

self._update_history(query, final_response)
logger.info("Returning final answer")
return final_response

return {
"answer": final_response,
"token_usage": {
"input_tokens": total_input_tokens,
"output_tokens": total_output_tokens,
"total_tokens": total_input_tokens + total_output_tokens
}
}

except Exception as e:
logger.error(f"Chat failed: {e}")
return f"Error processing query: {str(e)}"
return {
"answer": f"Error processing query: {str(e)}",
"token_usage": {
"input_tokens": total_input_tokens,
"output_tokens": total_output_tokens,
"total_tokens": total_input_tokens + total_output_tokens
}
}

# Fallback if max iterations reached
logger.warning(f"Max iterations reached for query: {query}")
return "Query processing exceeded maximum iterations. Please try a simpler question."
return {
"answer": "Query processing exceeded maximum iterations. Please try a simpler question.",
"token_usage": {
"input_tokens": total_input_tokens,
"output_tokens": total_output_tokens,
"total_tokens": total_input_tokens + total_output_tokens
}
}

def _update_history(self, question: str, answer: str):
"""Update conversation history."""
Expand Down
22 changes: 21 additions & 1 deletion chat.sh
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,30 @@ ask_question() {
local response
response=$(curl -s \
"${BASE_URL}/memoryalpha/rag/ask?question=${encoded_question}&thinkingmode=${THINKING_MODE}&max_tokens=${MAX_TOKENS}&top_k=${TOP_K}&top_p=${TOP_P}&temperature=${TEMPERATURE}")

# Check if response is valid JSON
if ! echo "$response" | jq . >/dev/null 2>&1; then
printf "Error: Invalid response received.\n"
printf "Raw response: %s\n" "$response"
echo "----------------------------------------"
return
fi

local answer
answer=$(echo "$response" | jq -r '.response // empty')
answer=$(echo "$response" | jq -r '.answer // empty')
if [[ -n "$answer" ]]; then
printf "%s\n" "$answer"

# Display token usage if available
local input_tokens output_tokens total_tokens
input_tokens=$(echo "$response" | jq -r '.token_usage.input_tokens // empty')
output_tokens=$(echo "$response" | jq -r '.token_usage.output_tokens // empty')
total_tokens=$(echo "$response" | jq -r '.token_usage.total_tokens // empty')

if [[ -n "$input_tokens" && -n "$output_tokens" && -n "$total_tokens" ]]; then
echo
printf "📊 Token Usage: Input: %s | Output: %s | Total: %s\n" "$input_tokens" "$output_tokens" "$total_tokens"
fi
else
local error
error=$(echo "$response" | jq -r '.error // empty')
Expand Down