Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 18 additions & 9 deletions src/praisonai-agents/praisonaiagents/mcp/mcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ def __init__(self, server_params, timeout=60):
super().__init__(daemon=True)
self.server_params = server_params
self.queue = queue.Queue()
self.result_queue = queue.Queue()
self.initialized = threading.Event()
self._init_error = None
self.tools = []
self.timeout = timeout
self._tool_timings = {}
Expand Down Expand Up @@ -66,12 +66,12 @@ async def _run_async(self):
if item is None: # Shutdown signal
break

tool_name, arguments = item
response_queue, tool_name, arguments = item
try:
result = await session.call_tool(tool_name, arguments)
self.result_queue.put((True, result))
response_queue.put((True, result))
except Exception as e:
self.result_queue.put((False, str(e)))
response_queue.put((False, str(e)))
except queue.Empty:
pass

Expand All @@ -80,8 +80,8 @@ async def _run_async(self):
except asyncio.CancelledError:
break
except Exception as e:
self._init_error = f"MCP initialization error: {str(e)}"
self.initialized.set() # Ensure we don't hang
self.result_queue.put((False, f"MCP initialization error: {str(e)}"))

def call_tool(self, tool_name, arguments):
"""Call an MCP tool and wait for the result."""
Expand All @@ -100,16 +100,25 @@ def call_tool(self, tool_name, arguments):
if telemetry:
telemetry.track_tool_usage(tool_name, success=False, execution_time=0)
return f"Error: MCP initialization timed out after {self.timeout} seconds"

if self._init_error:
if telemetry:
telemetry.track_tool_usage(tool_name, success=False, execution_time=0)
return f"Error: {self._init_error}"

# Start timing after initialization check
start_time = time.time()
is_success = False
response_queue = queue.Queue(maxsize=1)
try:
# Put request in queue
self.queue.put((tool_name, arguments))
# Put request in queue with caller-specific response channel
self.queue.put((response_queue, tool_name, arguments))

# Wait for result
success, result = self.result_queue.get()
# Wait for result with timeout
try:
success, result = response_queue.get(timeout=self.timeout)
except queue.Empty:
return f"Error: MCP tool call timed out after {self.timeout} seconds"
if not success:
return f"Error: {result}"

Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
"""Regression tests for MCPToolRunner concurrent call routing."""

import queue
import threading
import time
from unittest.mock import Mock, patch

import pytest


class TestMCPToolRunnerConcurrency:
def test_concurrent_calls_receive_matching_results(self):
from praisonaiagents.mcp.mcp import MCPToolRunner

with patch.object(MCPToolRunner, "start", lambda self: None):
runner = MCPToolRunner(server_params=Mock(), timeout=5)
runner.initialized.set()

results = {}
barrier = threading.Barrier(2)

def slow_worker():
while True:
item = runner.queue.get()
if item is None:
break
response_queue, tool_name, _arguments = item
if tool_name == "slow_tool":
time.sleep(0.05)
response_queue.put((True, "slow-result"))
else:
response_queue.put((True, "fast-result"))

worker = threading.Thread(target=slow_worker, daemon=True)
worker.start()

def call_tool(name):
barrier.wait()
results[name] = runner.call_tool(name, {})

threads = [
threading.Thread(target=call_tool, args=("slow_tool",)),
threading.Thread(target=call_tool, args=("fast_tool",)),
]
for thread in threads:
thread.start()
for thread in threads:
thread.join(timeout=5)

runner.queue.put(None)
worker.join(timeout=2)

assert results["slow_tool"] == "slow-result"
assert results["fast_tool"] == "fast-result"

def test_call_tool_times_out_when_worker_stalls(self):
from praisonaiagents.mcp.mcp import MCPToolRunner

with patch.object(MCPToolRunner, "start", lambda self: None):
runner = MCPToolRunner(server_params=Mock(), timeout=1)
runner.initialized.set()

result = runner.call_tool("stalled_tool", {})
assert "timed out" in result.lower()

def test_init_error_is_not_returned_to_unrelated_callers(self):
from praisonaiagents.mcp.mcp import MCPToolRunner

with patch.object(MCPToolRunner, "start", lambda self: None):
runner = MCPToolRunner(server_params=Mock(), timeout=5)
runner.initialized.set()
runner._init_error = "MCP initialization error: boom"

result = runner.call_tool("any_tool", {})
assert result == "Error: MCP initialization error: boom"
16 changes: 13 additions & 3 deletions src/praisonai/praisonai/bots/telegram.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,8 +277,11 @@ async def handle_voice(update: Update, context: ContextTypes.DEFAULT_TYPE):
async def handle_command(update: Update, context: ContextTypes.DEFAULT_TYPE):
if not update.message or not update.message.text:
return

message = self._convert_update_to_message(update)

message = await process_inbound_telegram_message(update, self)
if not message:
return

command = message.command

if command and command in self._command_handlers:
Expand All @@ -294,18 +297,25 @@ async def handle_command(update: Update, context: ContextTypes.DEFAULT_TYPE):
async def handle_status(update: Update, context: ContextTypes.DEFAULT_TYPE):
if not update.message:
return
if not await process_inbound_telegram_message(update, self):
return
await update.message.reply_text(self._format_status())

async def handle_new(update: Update, context: ContextTypes.DEFAULT_TYPE):
if not update.message:
return
user_id = str(update.message.from_user.id) if update.message.from_user else "unknown"
message = await process_inbound_telegram_message(update, self)
if not message:
return
user_id = message.sender.user_id if message.sender else "unknown"
self._session.reset(user_id)
await update.message.reply_text("Session reset. Starting fresh conversation.")

async def handle_help(update: Update, context: ContextTypes.DEFAULT_TYPE):
if not update.message:
return
if not await process_inbound_telegram_message(update, self):
return
await update.message.reply_text(self._format_help())

self._application.add_handler(CommandHandler("status", handle_status))
Expand Down
12 changes: 11 additions & 1 deletion src/praisonai/praisonai/gateway/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -1959,18 +1959,28 @@ async def handle_voice(update: Update, context: Any):
async def handle_status(update: Update, context: Any):
if not update.message:
return
from praisonai.bots.telegram import process_inbound_telegram_message
if not await process_inbound_telegram_message(update, bot):
return
await update.message.reply_text(bot._format_status())

async def handle_new(update: Update, context: Any):
if not update.message:
return
user_id = str(update.message.from_user.id) if update.message.from_user else "unknown"
from praisonai.bots.telegram import process_inbound_telegram_message
message = await process_inbound_telegram_message(update, bot)
if not message:
return
user_id = message.sender.user_id if message.sender else "unknown"
bot._session.reset(user_id)
await update.message.reply_text("Session reset. Starting fresh conversation.")

async def handle_help(update: Update, context: Any):
if not update.message:
return
from praisonai.bots.telegram import process_inbound_telegram_message
if not await process_inbound_telegram_message(update, bot):
return
await update.message.reply_text(bot._format_help())

Comment on lines 1961 to 1985
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Repeated per-function imports of process_inbound_telegram_message

Each of the three command handlers (handle_status, handle_new, handle_help) independently issues from praisonai.bots.telegram import process_inbound_telegram_message at call-time. Python module caching makes this functionally correct but it is redundant and inconsistent with the rest of the file's import style. A single import at the outer function scope (or at module level alongside the other praisonai.bots imports) would be cleaner.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

# Register handlers
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -65,8 +65,12 @@ def create_test_bot(allowed_users=None, allowed_channels=None, group_policy="men
is_bot=True,
)

# Mock the fire_message_received method
# Mock the fire_message_received method and other required attributes
bot.fire_message_received = MagicMock()
bot._started_at = 1234567890.0
bot._agent = MagicMock()
bot._command_handlers = {}
bot._session = MagicMock()

return bot

Expand Down Expand Up @@ -239,6 +243,85 @@ def test_security_pipeline_exists():
assert callable(process_inbound_telegram_message), "Security pipeline function should be callable"


@pytest.mark.asyncio
@patch.object(UnknownUserHandler, 'handle')
async def test_command_handlers_respect_user_allowlist(mock_unknown_handler):
"""Built-in commands must pass the same security pipeline as text messages."""
mock_unknown_handler.return_value = False

bot = create_test_bot(allowed_users=["42"])

# Mock the reply_text method to track if command handlers were called
reply_mock = AsyncMock()

# Test that disallowed users are blocked by command handlers
for command in ("help", "status", "new"):
update = create_mock_telegram_update(
user_id="99",
text=f"/{command}",
chat_type="private",
)
update.message.reply_text = reply_mock
reply_mock.reset_mock()

# Get the registered handler for this command from the bot's handlers
# We need to simulate how the telegram bot framework would call the handler
if command == "help":
from praisonai.bots.telegram import TelegramBot
# Create a handler like the bot does
async def test_handle_help(update, context):
if not update.message:
return
if not await process_inbound_telegram_message(update, bot):
return
await update.message.reply_text(bot._format_help())
await test_handle_help(update, None)
elif command == "status":
async def test_handle_status(update, context):
if not update.message:
return
if not await process_inbound_telegram_message(update, bot):
return
await update.message.reply_text(bot._format_status())
await test_handle_status(update, None)
elif command == "new":
async def test_handle_new(update, context):
if not update.message:
return
message = await process_inbound_telegram_message(update, bot)
if not message:
return
user_id = message.sender.user_id if message.sender else "unknown"
bot._session.reset(user_id)
await update.message.reply_text("Session reset. Starting fresh conversation.")
# Mock session reset
bot._session = MagicMock()
bot._session.reset = MagicMock()
await test_handle_new(update, None)

# Assert the command handler did not reply (because security blocked it)
reply_mock.assert_not_called(), f"/{command} from disallowed user should not reply"

# Test that allowed users can use commands
allowed_update = create_mock_telegram_update(
user_id="42",
text="/help",
chat_type="private",
)
allowed_update.message.reply_text = reply_mock
reply_mock.reset_mock()

async def test_handle_help_allowed(update, context):
if not update.message:
return
if not await process_inbound_telegram_message(update, bot):
return
await update.message.reply_text(bot._format_help())

await test_handle_help_allowed(allowed_update, None)
reply_mock.assert_called_once(), "Commands from allowed users should reply"

Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Test assertions silently discard custom error messages

reply_mock.assert_not_called(), f"/{command} from disallowed user should not reply" evaluates as a two-element tuple expression. assert_not_called() raises AssertionError with no message if the mock was called — the f-string is never passed to it. The same pattern appears on line 327 (assert_called_once()). When the assertion fails in CI, the output will show a raw AssertionError with no context about which command or user was involved.


@pytest.mark.asyncio
async def test_shared_pipeline_consistency():
"""Test that the shared pipeline provides consistent results."""
Expand Down
Loading